1"""LightningModule for Lightning IR.
2
3This module contains the main module class deriving from a LightningModule_.
4
5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
6"""
7
8from pathlib import Path
9from typing import Any, Dict, List, Mapping, Sequence, Tuple, Type
10
11import pandas as pd
12import torch
13from lightning import LightningModule
14from lightning.pytorch.trainer.states import RunningStage
15from transformers import BatchEncoding
16
17from ..data import IRDataset, RankBatch, RunDataset, SearchBatch, TrainBatch
18from ..loss.loss import InBatchLossFunction, LossFunction
19from .config import LightningIRConfig
20from .model import LightningIRModel, LightningIROutput
21from .tokenizer import LightningIRTokenizer
22from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
23
24
[docs]
25class LightningIRModule(LightningModule):
26 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a
27 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the
28 model. Derived classes must implement the forward method for the model.
29
30 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
31 """
32
[docs]
33 def __init__(
34 self,
35 model_name_or_path: str | None = None,
36 config: LightningIRConfig | None = None,
37 model: LightningIRModel | None = None,
38 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
39 evaluation_metrics: Sequence[str] | None = None,
40 model_kwargs: Mapping[str, Any] | None = None,
41 ):
42 """Initializes the LightningIRModule.
43
44 .. _ir-measures: https://ir-measur.es/en/latest/index.html
45
46 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
47 :type model_name_or_path: str | None, optional
48 :param config: LightningIRConfig to apply when loading from backbone model, defaults to None
49 :type config: LightningIRConfig | None, optional
50 :param model: Already instantiated Lightning IR model, defaults to None
51 :type model: LightningIRModel | None, optional
52 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
53 loss function, defaults to None
54 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional
55 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or
56 testing, defaults to None
57 :type evaluation_metrics: Sequence[str] | None, optional
58 :param model_kwargs: Additional keyword arguments to pass to `from_pretrained` when loading a model,
59 defaults to None
60 :type model_kwargs: Mapping[str, Any] | None, optional
61 :raises ValueError: If both model and model_name_or_path are provided
62 :raises ValueError: If neither model nor model_name_or_path are provided
63 """
64 super().__init__()
65 model_kwargs = model_kwargs if model_kwargs is not None else {}
66 self.save_hyperparameters()
67 if model is not None and model_name_or_path is not None:
68 raise ValueError("Only one of model or model_name_or_path must be provided.")
69 if model is None:
70 if model_name_or_path is None:
71 raise ValueError("Either model or model_name_or_path must be provided.")
72 model = LightningIRModel.from_pretrained(model_name_or_path, config=config, **model_kwargs)
73
74 self.model: LightningIRModel = model
75 self.config = self.model.config
76 self.loss_functions: List[Tuple[LossFunction, float]] | None = None
77 if loss_functions is not None:
78 self.loss_functions = []
79 for loss_function in loss_functions:
80 if isinstance(loss_function, LossFunction):
81 self.loss_functions.append((loss_function, 1.0))
82 else:
83 self.loss_functions.append(loss_function)
84 self.evaluation_metrics = evaluation_metrics
85 self._optimizer: torch.optim.Optimizer | None = None
86 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config)
87 self._additional_log_metrics: Dict[str, float] = {}
88
[docs]
89 def on_train_start(self) -> None:
90 """Called at the beginning of training after sanity check."""
91 super().on_train_start()
92 # NOTE huggingface models are in eval mode by default
93 self.model = self.model.train()
94
[docs]
95 def on_validation_start(self) -> None:
96 """Called at the beginning of validation."""
97 # NOTE monkey patch result printing of the trainer
98 try:
99 trainer = self.trainer
100 except RuntimeError:
101 trainer = None
102 if trainer is None:
103 return
104
105 trainer._evaluation_loop._print_results = lambda *args, **kwargs: None
106
[docs]
107 def on_test_start(self) -> None:
108 """Called at the beginning of testing."""
109 self.on_validation_start()
110
122
[docs]
123 def set_optimizer(
124 self, optimizer: Type[torch.optim.Optimizer], **optimizer_kwargs: Dict[str, Any]
125 ) -> "LightningIRModule":
126 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
127
128 :param optimizer: Torch optimizer class
129 :type optimizer: Type[torch.optim.Optimizer]
130 :param optimizer_kwargs: Arguments to initialize the optimizer
131 :type optimizer_kwargs: Dict[str, Any]
132 :return: self
133 :rtype: LightningIRModule
134 """
135 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs)
136 return self
137
[docs]
138 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput:
139 """Computes relevance scores for queries and documents.
140
141 :param queries: Queries to score
142 :type queries: Sequence[str]
143 :param docs: Documents to score
144 :type docs: Sequence[Sequence[str]]
145 :return: Model output
146 :rtype: LightningIROutput
147 """
148 if isinstance(queries, str):
149 queries = (queries,)
150 if isinstance(docs[0], str):
151 docs = (docs,)
152 batch = RankBatch(queries, docs, None, None)
153 with torch.no_grad():
154 return self.forward(batch)
155
[docs]
156 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput:
157 """Handles the forward pass of the model.
158
159 :param batch: Batch of training or ranking data
160 :type batch: TrainBatch | RankBatch
161 :raises NotImplementedError: Must be implemented by derived class
162 :return: Model output
163 :rtype: LightningIROutput
164 """
165 raise NotImplementedError
166
190
191 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]:
192 """Computes the losses for a training batch."""
193 raise NotImplementedError
194
[docs]
195 def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
196 """Handles the training step for the model.
197
198 :param batch: Batch of training data
199 :type batch: TrainBatch
200 :param batch_idx: Index of the batch
201 :type batch_idx: int
202 :raises ValueError: If no loss functions are set
203 :return: Sum of the losses weighted by the loss weights
204 :rtype: torch.Tensor
205 """
206 if self.loss_functions is None:
207 raise ValueError("Loss functions are not set")
208 output = self.forward(batch)
209 losses = self._compute_losses(batch, output)
210 total_loss = torch.tensor(0)
211 assert len(losses) == len(self.loss_functions)
212 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses):
213 self.log(loss_function.__class__.__name__, loss)
214 total_loss = total_loss + loss * loss_weight
215 self.log("loss", total_loss, prog_bar=True)
216 return total_loss
217
[docs]
218 def validation_step(
219 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0
220 ) -> LightningIROutput:
221 """Handles the validation step for the model.
222
223 :param batch: Batch of validation or testing data
224 :type batch: TrainBatch | RankBatch | SearchBatch
225 :param batch_idx: Index of the batch
226 :type batch_idx: int
227 :param dataloader_idx: Index of the dataloader, defaults to 0
228 :type dataloader_idx: int, optional
229 :return: Model output
230 :rtype: LightningIROutput
231 """
232 output = self.forward(batch)
233
234 if self.evaluation_metrics is None:
235 return output
236
237 dataset = self.get_dataset(dataloader_idx)
238 dataset_id = str(dataloader_idx) if dataset is None else self.get_dataset_id(dataset)
239 metrics = self.validate(output, batch)
240 for key, value in metrics.items():
241 key = f"{dataset_id}/{key}"
242 self.log(key, value, batch_size=len(batch.queries))
243 return output
244
[docs]
245 def test_step(
246 self,
247 batch: TrainBatch | RankBatch,
248 batch_idx: int,
249 dataloader_idx: int = 0,
250 ) -> LightningIROutput:
251 """Handles the testing step for the model. Passes the batch to the validation step.
252
253 :param batch: Batch of testing data
254 :type batch: TrainBatch | RankBatch
255 :param batch_idx: Index of the batch
256 :type batch_idx: int
257 :param dataloader_idx: Index of the dataloader, defaults to 0
258 :type dataloader_idx: int, optional
259 :return: Model output
260 :rtype: LightningIROutput
261 """
262 return self.validation_step(batch, batch_idx, dataloader_idx)
263
[docs]
264 def get_dataset(self, dataloader_idx: int) -> IRDataset | None:
265 """Gets the dataset instance from the dataloader index. Returns None if no dataset is found.
266
267 :param dataloader_idx: Index of the dataloader
268 :type dataloader_idx: int
269 :return: Inference dataset
270 :rtype: IRDataset | None
271 """
272 try:
273 trainer = self.trainer
274 except RuntimeError:
275 trainer = None
276 if trainer is None:
277 return None
278 STAGE_TO_DATALOADER = {
279 RunningStage.VALIDATING: "val_dataloaders",
280 RunningStage.TESTING: "test_dataloaders",
281 RunningStage.PREDICTING: "predict_dataloaders",
282 RunningStage.SANITY_CHECKING: "val_dataloaders",
283 }
284 if trainer.state.stage is None:
285 return None
286 dataloaders = getattr(trainer, STAGE_TO_DATALOADER[trainer.state.stage], None)
287 if dataloaders is None:
288 return None
289 if isinstance(dataloaders, torch.utils.data.DataLoader):
290 dataloaders = [dataloaders]
291 return dataloaders[dataloader_idx].dataset
292
[docs]
293 def get_dataset_id(self, dataset: IRDataset) -> str:
294 """Gets the dataset id from the dataloader index for logging.
295
296 .. _ir-datasets: https://ir-datasets.com/
297
298 :param dataloader_idx: Index of the dataloader
299 :type dataloader_idx: int
300 :return: path to run file, ir-datasets_ dataset id, or dataloader index
301 :rtype: str
302 """
303 if isinstance(dataset, RunDataset) and dataset.run_path is not None:
304 dataset_id = dataset.run_path.name
305 else:
306 dataset_id = dataset.dataset_id
307 return dataset_id
308
[docs]
309 def validate(
310 self,
311 output: LightningIROutput,
312 batch: TrainBatch | RankBatch | SearchBatch,
313 ) -> Dict[str, float]:
314 """Validates the model output with the evaluation metrics and loss functions.
315
316 :param output: Model output
317 :type output: LightningIROutput
318 :param batch: Batch of validation or testing data
319 :type batch: TrainBatch | RankBatch | SearchBatch
320 :return: Dictionary of evaluation metrics
321 :rtype: Dict[str, float]
322 """
323 metrics: Dict[str, float] = {}
324 if self.evaluation_metrics is None or output.scores is None:
325 return metrics
326 metrics.update(self.validate_metrics(output, batch))
327 metrics.update(self.validate_loss(output, batch))
328 return metrics
329
[docs]
330 def validate_metrics(
331 self,
332 output: LightningIROutput,
333 batch: TrainBatch | RankBatch | SearchBatch,
334 ) -> Dict[str, float]:
335 """Validates the model output with the evaluation metrics.
336
337 :param output: Model output
338 :type output: LightningIROutput
339 :param batch: Batch of validation or testing data
340 :type batch: TrainBatch | RankBatch | SearchBatch
341 :return: Evaluation metrics
342 :rtype: Dict[str, float]
343 """
344 metrics: Dict[str, float] = {}
345 qrels = batch.qrels
346 if self.evaluation_metrics is None or qrels is None:
347 return metrics
348 query_ids = batch.query_ids
349 doc_ids = batch.doc_ids
350 if query_ids is None:
351 raise ValueError("query_ids must be set")
352 if doc_ids is None:
353 raise ValueError("doc_ids must be set")
354 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"]
355 ir_measures_qrels = create_qrels_from_dicts(qrels)
356 if evaluation_metrics and qrels is not None and output.scores is not None:
357 run = create_run_from_scores(query_ids, doc_ids, output.scores)
358 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics))
359 return metrics
360
[docs]
361 def validate_loss(
362 self,
363 output: LightningIROutput,
364 batch: TrainBatch | RankBatch | SearchBatch,
365 ) -> Dict[str, float]:
366 """Validates the model output with the loss functions.
367
368 :param output: Model output
369 :type output: LightningIROutput
370 :param batch: Batch of validation or testing data
371 :type batch: TrainBatch | RankBatch | SearchBatch
372 :return: Evaluation metrics
373 :rtype: Dict[str, float]
374 """
375 metrics: Dict[str, float] = {}
376 query_ids = batch.query_ids
377 if query_ids is None:
378 raise ValueError("query_ids must be set")
379 if (
380 self.evaluation_metrics is None
381 or "loss" not in self.evaluation_metrics
382 or getattr(batch, "targets", None) is None
383 or self.loss_functions is None
384 or output.scores is None
385 ):
386 return metrics
387 output.scores = output.scores.view(len(query_ids), -1)
388 for loss_function, _ in self.loss_functions:
389 # NOTE skip in-batch losses because they can use a lot of memory
390 if isinstance(loss_function, InBatchLossFunction):
391 continue
392 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch)
393 return metrics
394
[docs]
395 def on_validation_end(self) -> None:
396 """Prints the validation results for each dataloader."""
397 trainer = self.trainer
398 if not (trainer.is_global_zero and trainer._evaluation_loop.verbose):
399 return
400 results = trainer.callback_metrics
401
402 data = []
403 for key, value in {**results, **self._additional_log_metrics}.items():
404 if "dataloader_idx" in key:
405 key = "/".join(key.split("/")[:-1])
406 *dataset_parts, metric = key.split("/")
407 if metric.startswith("validation-"):
408 metric = metric[len("validation-") :]
409 dataset = "/".join(dataset_parts)
410 if isinstance(value, torch.Tensor):
411 value = value.item()
412 data.append({"dataset": dataset, "metric": metric, "value": value})
413 if not data:
414 return
415 df = pd.DataFrame(data)
416 df = df.pivot(index="dataset", columns="metric", values="value")
417 df.columns.name = None
418
419 # bring into correct order when skipping inference datasets
420 datamodule = getattr(self.trainer, "datamodule", None)
421 if datamodule is not None and hasattr(datamodule, "inference_datasets"):
422 inference_datasets = datamodule.inference_datasets
423 if len(inference_datasets) != df.shape[0]:
424 raise ValueError(
425 "Number of inference datasets does not match number of dataloaders. "
426 "Check if the dataloaders are correctly configured."
427 )
428 dataset_ids = [self.get_dataset_id(dataset) for dataset in inference_datasets]
429 df = df.reindex(dataset_ids)
430
431 trainer.print(df)
432
[docs]
433 def on_test_end(self) -> None:
434 """Prints the accumulated metrics for each dataloader."""
435 self.on_validation_end()
436
[docs]
437 def save_pretrained(self, save_path: str | Path) -> None:
438 """Saves the model and tokenizer to the save path.
439
440 :param save_path: Path to save the model and tokenizer
441 :type save_path: str | Path
442 """
443 self.model.save_pretrained(save_path)
444 self.tokenizer.save_pretrained(save_path)
445
[docs]
446 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
447 """Saves the model and tokenizer to the trainer's log directory."""
448 if self.trainer is not None and self.trainer.log_dir is not None:
449 if self.trainer.global_rank != 0:
450 return
451 _step = self.trainer.global_step
452 self.config.save_step = _step
453 log_dir = Path(self.trainer.log_dir)
454 save_path = log_dir / "huggingface_checkpoint"
455 self.save_pretrained(save_path)