1"""LightningModule for Lightning IR.
2
3This module contains the main module class deriving from a LightningModule_.
4
5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
6"""
7
8from collections.abc import Mapping, Sequence
9from pathlib import Path
10from typing import Any
11
12import pandas as pd
13import torch
14from lightning import LightningModule
15from lightning.pytorch.trainer.states import RunningStage
16from transformers import BatchEncoding, PreTrainedModel
17
18from ..data import IRDataset, RankBatch, RunDataset, SearchBatch, TrainBatch
19from ..loss import InBatchLossFunction, LossFunction
20from .config import LightningIRConfig
21from .model import LightningIRModel, LightningIROutput
22from .tokenizer import LightningIRTokenizer
23from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
24
25
[docs]
26class LightningIRModule(LightningModule):
27 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a
28 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the
29 model. Derived classes must implement the forward method for the model.
30
31 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
32 """
33
[docs]
34 def __init__(
35 self,
36 model_name_or_path: str | None = None,
37 config: LightningIRConfig | None = None,
38 model: LightningIRModel | None = None,
39 BackboneModel: type[PreTrainedModel] | None = None,
40 loss_functions: Sequence[LossFunction | tuple[LossFunction, float]] | None = None,
41 evaluation_metrics: Sequence[str] | None = None,
42 model_kwargs: Mapping[str, Any] | None = None,
43 ):
44 """Initializes the LightningIRModule.
45
46 .. _ir-measures: https://ir-measur.es/en/latest/index.html
47
48 Args:
49 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model.
50 Defaults to None.
51 config (LightningIRConfig | None): LightningIRConfig to apply when loading from backbone model.
52 Defaults to None.
53 model (LightningIRModel | None): Already instantiated Lightning IR model. Defaults to None.
54 BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone
55 instead of the default AutoModel. Defaults to None.
56 loss_functions (Sequence[LossFunction | tuple[LossFunction, float]] | None):
57 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function
58 Defaults to None.
59 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings
60 to apply during validation or testing. Defaults to None.
61 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained` when
62 loading a model. Defaults to None.
63 Raises:
64 ValueError: If both model and model_name_or_path are provided.
65 ValueError: If neither model nor model_name_or_path are provided.
66 """
67 super().__init__()
68 model_kwargs = model_kwargs if model_kwargs is not None else {}
69 self.save_hyperparameters()
70 if model is not None and model_name_or_path is not None:
71 raise ValueError("Only one of model or model_name_or_path must be provided.")
72 if model is None:
73 if model_name_or_path is None:
74 raise ValueError("Either model or model_name_or_path must be provided.")
75 model = LightningIRModel.from_pretrained(
76 model_name_or_path, config=config, BackboneModel=BackboneModel, **model_kwargs
77 )
78
79 self.model: LightningIRModel = model
80 self.config = self.model.config
81 self.loss_functions: list[tuple[LossFunction, float]] | None = None
82 if loss_functions is not None:
83 self.loss_functions = []
84 for loss_function in loss_functions:
85 if isinstance(loss_function, LossFunction):
86 self.loss_functions.append((loss_function, 1.0))
87 else:
88 self.loss_functions.append(loss_function)
89 self.evaluation_metrics = evaluation_metrics
90 self._optimizer: torch.optim.Optimizer | None = None
91 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config)
92 self._additional_log_metrics: dict[str, float] = {}
93
[docs]
94 def on_train_start(self) -> None:
95 """Called at the beginning of training after sanity check."""
96 super().on_train_start()
97 # NOTE huggingface models are in eval mode by default
98 self.model = self.model.train()
99
[docs]
100 def on_validation_start(self) -> None:
101 """Called at the beginning of validation."""
102 # NOTE monkey patch result printing of the trainer
103 try:
104 trainer = self.trainer
105 except RuntimeError:
106 trainer = None
107 if trainer is None:
108 return
109
110 trainer._evaluation_loop._print_results = lambda *args, **kwargs: None
111
[docs]
112 def on_test_start(self) -> None:
113 """Called at the beginning of testing."""
114 self.on_validation_start()
115
128
[docs]
129 def set_optimizer(
130 self, optimizer: type[torch.optim.Optimizer], **optimizer_kwargs: dict[str, Any]
131 ) -> "LightningIRModule":
132 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
133
134 Args:
135 optimizer (type[torch.optim.Optimizer]): Torch optimizer class.
136 optimizer_kwargs (dict[str, Any]): Arguments to initialize the optimizer.
137 Returns:
138 LightningIRModule: Self with the optimizer set.
139 """
140 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs)
141 return self
142
[docs]
143 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput:
144 """Computes relevance scores for queries and documents.
145
146 Args:
147 queries (Sequence[str] | str): Queries to score.
148 docs (Sequence[Sequence[str]] | Sequence[str]): Documents to score.
149 Returns:
150 LightningIROutput: Model output containing the scores.
151 """
152 if isinstance(queries, str):
153 queries = (queries,)
154 if isinstance(docs[0], str):
155 docs = (docs,)
156 batch = RankBatch(queries, docs, None, None)
157 with torch.no_grad():
158 return self.forward(batch)
159
[docs]
160 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput:
161 """Handles the forward pass of the model.
162
163 Args:
164 batch (TrainBatch | RankBatch | SearchBatch): Batch of training or ranking data.
165 Returns:
166 LightningIROutput: Model output.
167 Raises:
168 NotImplementedError: Must be implemented by derived class.
169 """
170 raise NotImplementedError
171
193
194 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> list[torch.Tensor]:
195 """Computes the losses for a training batch."""
196 raise NotImplementedError
197
[docs]
198 def training_step(self, batch: TrainBatch, batch_idx: int) -> dict[str, Any]:
199 """Handles the training step for the model.
200
201 Args:
202 batch (TrainBatch): Batch of training data.
203 batch_idx (int): Index of the batch.
204 Returns:
205 torch.Tensor: Sum of the losses weighted by the loss weights.
206 Raises:
207 ValueError: If no loss functions are set.
208 """
209 if self.loss_functions is None:
210 raise ValueError("Loss functions are not set")
211 output = self.forward(batch)
212 losses = self._compute_losses(batch, output)
213 total_loss = torch.tensor(0)
214 assert len(losses) == len(self.loss_functions)
215 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses):
216 self.log(loss_function.__class__.__name__, loss)
217 total_loss = total_loss + loss * loss_weight
218 self.log("loss", total_loss, prog_bar=True)
219 return {"loss": total_loss, "output": output}
220
[docs]
221 def validation_step(
222 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0
223 ) -> LightningIROutput:
224 """Handles the validation step for the model.
225
226 Args:
227 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data.
228 batch_idx (int): Index of the batch.
229 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
230 Returns:
231 LightningIROutput: Model output.
232 """
233 output = self.forward(batch)
234
235 if self.evaluation_metrics is None:
236 return output
237
238 dataset = self.get_dataset(dataloader_idx)
239 dataset_id = str(dataloader_idx) if dataset is None else self.get_dataset_id(dataset)
240 metrics = self.validate(output, batch)
241 for key, value in metrics.items():
242 key = f"{dataset_id}/{key}"
243 self.log(key, value, batch_size=len(batch.queries))
244 return output
245
[docs]
246 def test_step(
247 self,
248 batch: TrainBatch | RankBatch,
249 batch_idx: int,
250 dataloader_idx: int = 0,
251 ) -> LightningIROutput:
252 """Handles the testing step for the model. Passes the batch to the validation step.
253
254 Args:
255 batch (TrainBatch | RankBatch): Batch of testing data.
256 batch_idx (int): Index of the batch.
257 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
258 Returns:
259 LightningIROutput: Model output.
260 """
261 return self.validation_step(batch, batch_idx, dataloader_idx)
262
[docs]
263 def get_dataset(self, dataloader_idx: int) -> IRDataset | None:
264 """Gets the dataset instance from the dataloader index. Returns None if no dataset is found.
265
266 Args:
267 dataloader_idx (int): Index of the dataloader.
268 Returns:
269 IRDataset | None: Inference dataset or None if no dataset is found.
270 """
271 try:
272 trainer = self.trainer
273 except RuntimeError:
274 trainer = None
275 if trainer is None:
276 return None
277 STAGE_TO_DATALOADER = {
278 RunningStage.VALIDATING: "val_dataloaders",
279 RunningStage.TESTING: "test_dataloaders",
280 RunningStage.PREDICTING: "predict_dataloaders",
281 RunningStage.SANITY_CHECKING: "val_dataloaders",
282 }
283 if trainer.state.stage is None:
284 return None
285 dataloaders = getattr(trainer, STAGE_TO_DATALOADER[trainer.state.stage], None)
286 if dataloaders is None:
287 return None
288 if isinstance(dataloaders, torch.utils.data.DataLoader):
289 dataloaders = [dataloaders]
290 return dataloaders[dataloader_idx].dataset
291
[docs]
292 def get_dataset_id(self, dataset: IRDataset) -> str:
293 """Gets the dataset id from the dataloader index for logging.
294
295 .. _ir-datasets: https://ir-datasets.com/
296
297 Args:
298 dataset (IRDataset): Dataset instance.
299 Returns:
300 str: Path to run file, ir-datasets_ dataset id, or dataloader index.
301 """
302 if isinstance(dataset, RunDataset) and dataset.run_path is not None:
303 dataset_id = dataset.run_path.name
304 else:
305 dataset_id = dataset.dataset_id
306 return dataset_id
307
[docs]
308 def validate(
309 self,
310 output: LightningIROutput,
311 batch: TrainBatch | RankBatch | SearchBatch,
312 ) -> dict[str, float]:
313 """Validates the model output with the evaluation metrics and loss functions.
314
315 Args:
316 output (LightningIROutput): Model output.
317 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data.
318 Returns:
319 dict[str, float]: Dictionary of evaluation metrics.
320 """
321 metrics: dict[str, float] = {}
322 if self.evaluation_metrics is None or output.scores is None:
323 return metrics
324 metrics.update(self.validate_metrics(output, batch))
325 metrics.update(self.validate_loss(output, batch))
326 return metrics
327
[docs]
328 def validate_metrics(
329 self,
330 output: LightningIROutput,
331 batch: TrainBatch | RankBatch | SearchBatch,
332 ) -> dict[str, float]:
333 """Validates the model output with the evaluation metrics.
334
335 Args:
336 output (LightningIROutput): Model output.
337 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data.
338 Returns:
339 dict[str, float]: Dictionary of evaluation metrics.
340 Raises:
341 ValueError: If query_ids or doc_ids are not set in the batch.
342 """
343 metrics: dict[str, float] = {}
344 qrels = batch.qrels
345 if self.evaluation_metrics is None or qrels is None:
346 return metrics
347 query_ids = batch.query_ids
348 doc_ids = batch.doc_ids
349 if query_ids is None:
350 raise ValueError("query_ids must be set")
351 if doc_ids is None:
352 raise ValueError("doc_ids must be set")
353 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"]
354 ir_measures_qrels = create_qrels_from_dicts(qrels)
355 if evaluation_metrics and qrels is not None and output.scores is not None:
356 run = create_run_from_scores(query_ids, doc_ids, output.scores)
357 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics))
358 return metrics
359
[docs]
360 def validate_loss(
361 self,
362 output: LightningIROutput,
363 batch: TrainBatch | RankBatch | SearchBatch,
364 ) -> dict[str, float]:
365 """Validates the model output with the loss functions.
366
367 Args:
368 output (LightningIROutput): Model output.
369 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data.
370 Returns:
371 dict[str, float]: Dictionary of evaluation metrics.
372 """
373 metrics: dict[str, float] = {}
374 query_ids = batch.query_ids
375 if query_ids is None:
376 raise ValueError("query_ids must be set")
377 if (
378 self.evaluation_metrics is None
379 or "loss" not in self.evaluation_metrics
380 or getattr(batch, "targets", None) is None
381 or self.loss_functions is None
382 or output.scores is None
383 ):
384 return metrics
385 output.scores = output.scores.view(len(query_ids), -1)
386 for loss_function, _ in self.loss_functions:
387 # NOTE skip in-batch losses because they can use a lot of memory
388 if isinstance(loss_function, InBatchLossFunction):
389 continue
390 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch)
391 return metrics
392
[docs]
393 def on_validation_end(self) -> None:
394 """Prints the validation results for each dataloader."""
395 trainer = self.trainer
396 if not (trainer.is_global_zero and trainer._evaluation_loop.verbose):
397 return
398 results = trainer.callback_metrics
399
400 data = []
401 for key, value in {**results, **self._additional_log_metrics}.items():
402 if "dataloader_idx" in key:
403 key = "/".join(key.split("/")[:-1])
404 *dataset_parts, metric = key.split("/")
405 if metric.startswith("validation-"):
406 metric = metric[len("validation-") :]
407 dataset = "/".join(dataset_parts)
408 if isinstance(value, torch.Tensor):
409 value = value.item()
410 data.append({"dataset": dataset, "metric": metric, "value": value})
411 if not data:
412 return
413 df = pd.DataFrame(data)
414 df = df.pivot(index="dataset", columns="metric", values="value")
415 df.columns.name = None
416
417 # bring into correct order when skipping inference datasets
418 datamodule = getattr(self.trainer, "datamodule", None)
419 if datamodule is not None and hasattr(datamodule, "inference_datasets"):
420 inference_datasets = datamodule.inference_datasets
421 if len(inference_datasets) != df.shape[0]:
422 raise ValueError(
423 "Number of inference datasets does not match number of dataloaders. "
424 "Check if the dataloaders are correctly configured."
425 )
426 dataset_ids = [self.get_dataset_id(dataset) for dataset in inference_datasets]
427 df = df.reindex(dataset_ids)
428
429 trainer.print(df)
430
[docs]
431 def on_test_end(self) -> None:
432 """Prints the accumulated metrics for each dataloader."""
433 self.on_validation_end()
434
[docs]
435 def save_pretrained(self, save_path: str | Path) -> None:
436 """Saves the model and tokenizer to the save path.
437
438 Args:
439 save_path (str | Path): Path to save the model and tokenizer.
440 """
441 self.model.save_pretrained(save_path)
442 self.tokenizer.save_pretrained(save_path)
443
[docs]
444 def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
445 """Saves the model and tokenizer to the trainer's log directory."""
446 if self.trainer is not None and self.trainer.log_dir is not None:
447 if self.trainer.global_rank != 0:
448 return
449 _step = self.trainer.global_step
450 self.config.save_step = _step
451 log_dir = Path(self.trainer.log_dir)
452 save_path = log_dir / "huggingface_checkpoint"
453 self.save_pretrained(save_path)