1"""LightningModule for Lightning IR.
2
3This module contains the main module class deriving from a LightningModule_.
4
5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
6"""
7
8from pathlib import Path
9from typing import Any, Dict, List, Mapping, Sequence, Tuple, Type
10
11import pandas as pd
12import torch
13from lightning import LightningModule
14from lightning.pytorch.trainer.states import RunningStage
15from transformers import BatchEncoding, PreTrainedModel
16
17from ..data import IRDataset, RankBatch, RunDataset, SearchBatch, TrainBatch
18from ..loss import InBatchLossFunction, LossFunction
19from .config import LightningIRConfig
20from .model import LightningIRModel, LightningIROutput
21from .tokenizer import LightningIRTokenizer
22from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
23
24
[docs]
25class LightningIRModule(LightningModule):
26 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a
27 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the
28 model. Derived classes must implement the forward method for the model.
29
30 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
31 """
32
[docs]
33 def __init__(
34 self,
35 model_name_or_path: str | None = None,
36 config: LightningIRConfig | None = None,
37 model: LightningIRModel | None = None,
38 BackboneModel: Type[PreTrainedModel] | None = None,
39 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
40 evaluation_metrics: Sequence[str] | None = None,
41 model_kwargs: Mapping[str, Any] | None = None,
42 ):
43 """Initializes the LightningIRModule.
44
45 .. _ir-measures: https://ir-measur.es/en/latest/index.html
46
47 Args:
48 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model.
49 Defaults to None.
50 config (LightningIRConfig | None): LightningIRConfig to apply when loading from backbone model.
51 Defaults to None.
52 model (LightningIRModel | None): Already instantiated Lightning IR model. Defaults to None.
53 BackboneModel (Type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone
54 instead of the default AutoModel. Defaults to None.
55 loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None):
56 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function
57 Defaults to None.
58 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings
59 to apply during validation or testing. Defaults to None.
60 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained` when
61 loading a model. Defaults to None.
62 Raises:
63 ValueError: If both model and model_name_or_path are provided.
64 ValueError: If neither model nor model_name_or_path are provided.
65 """
66 super().__init__()
67 model_kwargs = model_kwargs if model_kwargs is not None else {}
68 self.save_hyperparameters()
69 if model is not None and model_name_or_path is not None:
70 raise ValueError("Only one of model or model_name_or_path must be provided.")
71 if model is None:
72 if model_name_or_path is None:
73 raise ValueError("Either model or model_name_or_path must be provided.")
74 model = LightningIRModel.from_pretrained(
75 model_name_or_path, config=config, BackboneModel=BackboneModel, **model_kwargs
76 )
77
78 self.model: LightningIRModel = model
79 self.config = self.model.config
80 self.loss_functions: List[Tuple[LossFunction, float]] | None = None
81 if loss_functions is not None:
82 self.loss_functions = []
83 for loss_function in loss_functions:
84 if isinstance(loss_function, LossFunction):
85 self.loss_functions.append((loss_function, 1.0))
86 else:
87 self.loss_functions.append(loss_function)
88 self.evaluation_metrics = evaluation_metrics
89 self._optimizer: torch.optim.Optimizer | None = None
90 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config)
91 self._additional_log_metrics: Dict[str, float] = {}
92
[docs]
93 def on_train_start(self) -> None:
94 """Called at the beginning of training after sanity check."""
95 super().on_train_start()
96 # NOTE huggingface models are in eval mode by default
97 self.model = self.model.train()
98
[docs]
99 def on_validation_start(self) -> None:
100 """Called at the beginning of validation."""
101 # NOTE monkey patch result printing of the trainer
102 try:
103 trainer = self.trainer
104 except RuntimeError:
105 trainer = None
106 if trainer is None:
107 return
108
109 trainer._evaluation_loop._print_results = lambda *args, **kwargs: None
110
[docs]
111 def on_test_start(self) -> None:
112 """Called at the beginning of testing."""
113 self.on_validation_start()
114
127
[docs]
128 def set_optimizer(
129 self, optimizer: Type[torch.optim.Optimizer], **optimizer_kwargs: Dict[str, Any]
130 ) -> "LightningIRModule":
131 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
132
133 Args:
134 optimizer (Type[torch.optim.Optimizer]): Torch optimizer class.
135 optimizer_kwargs (Dict[str, Any]): Arguments to initialize the optimizer.
136 Returns:
137 LightningIRModule: Self with the optimizer set.
138 """
139 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs)
140 return self
141
[docs]
142 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput:
143 """Computes relevance scores for queries and documents.
144
145 Args:
146 queries (Sequence[str] | str): Queries to score.
147 docs (Sequence[Sequence[str]] | Sequence[str]): Documents to score.
148 Returns:
149 LightningIROutput: Model output containing the scores.
150 """
151 if isinstance(queries, str):
152 queries = (queries,)
153 if isinstance(docs[0], str):
154 docs = (docs,)
155 batch = RankBatch(queries, docs, None, None)
156 with torch.no_grad():
157 return self.forward(batch)
158
[docs]
159 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput:
160 """Handles the forward pass of the model.
161
162 Args:
163 batch (TrainBatch | RankBatch | SearchBatch): Batch of training or ranking data.
164 Returns:
165 LightningIROutput: Model output.
166 Raises:
167 NotImplementedError: Must be implemented by derived class.
168 """
169 raise NotImplementedError
170
192
193 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]:
194 """Computes the losses for a training batch."""
195 raise NotImplementedError
196
[docs]
197 def training_step(self, batch: TrainBatch, batch_idx: int) -> Dict[str, Any]:
198 """Handles the training step for the model.
199
200 Args:
201 batch (TrainBatch): Batch of training data.
202 batch_idx (int): Index of the batch.
203 Returns:
204 torch.Tensor: Sum of the losses weighted by the loss weights.
205 Raises:
206 ValueError: If no loss functions are set.
207 """
208 if self.loss_functions is None:
209 raise ValueError("Loss functions are not set")
210 output = self.forward(batch)
211 losses = self._compute_losses(batch, output)
212 total_loss = torch.tensor(0)
213 assert len(losses) == len(self.loss_functions)
214 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses):
215 self.log(loss_function.__class__.__name__, loss)
216 total_loss = total_loss + loss * loss_weight
217 self.log("loss", total_loss, prog_bar=True)
218 return {"loss": total_loss, "output": output}
219
[docs]
220 def validation_step(
221 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0
222 ) -> LightningIROutput:
223 """Handles the validation step for the model.
224
225 Args:
226 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data.
227 batch_idx (int): Index of the batch.
228 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
229 Returns:
230 LightningIROutput: Model output.
231 """
232 output = self.forward(batch)
233
234 if self.evaluation_metrics is None:
235 return output
236
237 dataset = self.get_dataset(dataloader_idx)
238 dataset_id = str(dataloader_idx) if dataset is None else self.get_dataset_id(dataset)
239 metrics = self.validate(output, batch)
240 for key, value in metrics.items():
241 key = f"{dataset_id}/{key}"
242 self.log(key, value, batch_size=len(batch.queries))
243 return output
244
[docs]
245 def test_step(
246 self,
247 batch: TrainBatch | RankBatch,
248 batch_idx: int,
249 dataloader_idx: int = 0,
250 ) -> LightningIROutput:
251 """Handles the testing step for the model. Passes the batch to the validation step.
252
253 Args:
254 batch (TrainBatch | RankBatch): Batch of testing data.
255 batch_idx (int): Index of the batch.
256 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
257 Returns:
258 LightningIROutput: Model output.
259 """
260 return self.validation_step(batch, batch_idx, dataloader_idx)
261
[docs]
262 def get_dataset(self, dataloader_idx: int) -> IRDataset | None:
263 """Gets the dataset instance from the dataloader index. Returns None if no dataset is found.
264
265 Args:
266 dataloader_idx (int): Index of the dataloader.
267 Returns:
268 IRDataset | None: Inference dataset or None if no dataset is found.
269 """
270 try:
271 trainer = self.trainer
272 except RuntimeError:
273 trainer = None
274 if trainer is None:
275 return None
276 STAGE_TO_DATALOADER = {
277 RunningStage.VALIDATING: "val_dataloaders",
278 RunningStage.TESTING: "test_dataloaders",
279 RunningStage.PREDICTING: "predict_dataloaders",
280 RunningStage.SANITY_CHECKING: "val_dataloaders",
281 }
282 if trainer.state.stage is None:
283 return None
284 dataloaders = getattr(trainer, STAGE_TO_DATALOADER[trainer.state.stage], None)
285 if dataloaders is None:
286 return None
287 if isinstance(dataloaders, torch.utils.data.DataLoader):
288 dataloaders = [dataloaders]
289 return dataloaders[dataloader_idx].dataset
290
[docs]
291 def get_dataset_id(self, dataset: IRDataset) -> str:
292 """Gets the dataset id from the dataloader index for logging.
293
294 .. _ir-datasets: https://ir-datasets.com/
295
296 Args:
297 dataset (IRDataset): Dataset instance.
298 Returns:
299 str: Path to run file, ir-datasets_ dataset id, or dataloader index.
300 """
301 if isinstance(dataset, RunDataset) and dataset.run_path is not None:
302 dataset_id = dataset.run_path.name
303 else:
304 dataset_id = dataset.dataset_id
305 return dataset_id
306
[docs]
307 def validate(
308 self,
309 output: LightningIROutput,
310 batch: TrainBatch | RankBatch | SearchBatch,
311 ) -> Dict[str, float]:
312 """Validates the model output with the evaluation metrics and loss functions.
313
314 Args:
315 output (LightningIROutput): Model output.
316 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data.
317 Returns:
318 Dict[str, float]: Dictionary of evaluation metrics.
319 """
320 metrics: Dict[str, float] = {}
321 if self.evaluation_metrics is None or output.scores is None:
322 return metrics
323 metrics.update(self.validate_metrics(output, batch))
324 metrics.update(self.validate_loss(output, batch))
325 return metrics
326
[docs]
327 def validate_metrics(
328 self,
329 output: LightningIROutput,
330 batch: TrainBatch | RankBatch | SearchBatch,
331 ) -> Dict[str, float]:
332 """Validates the model output with the evaluation metrics.
333
334 Args:
335 output (LightningIROutput): Model output.
336 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data.
337 Returns:
338 Dict[str, float]: Dictionary of evaluation metrics.
339 Raises:
340 ValueError: If query_ids or doc_ids are not set in the batch.
341 """
342 metrics: Dict[str, float] = {}
343 qrels = batch.qrels
344 if self.evaluation_metrics is None or qrels is None:
345 return metrics
346 query_ids = batch.query_ids
347 doc_ids = batch.doc_ids
348 if query_ids is None:
349 raise ValueError("query_ids must be set")
350 if doc_ids is None:
351 raise ValueError("doc_ids must be set")
352 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"]
353 ir_measures_qrels = create_qrels_from_dicts(qrels)
354 if evaluation_metrics and qrels is not None and output.scores is not None:
355 run = create_run_from_scores(query_ids, doc_ids, output.scores)
356 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics))
357 return metrics
358
[docs]
359 def validate_loss(
360 self,
361 output: LightningIROutput,
362 batch: TrainBatch | RankBatch | SearchBatch,
363 ) -> Dict[str, float]:
364 """Validates the model output with the loss functions.
365
366 Args:
367 output (LightningIROutput): Model output.
368 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data.
369 Returns:
370 Dict[str, float]: Dictionary of evaluation metrics.
371 """
372 metrics: Dict[str, float] = {}
373 query_ids = batch.query_ids
374 if query_ids is None:
375 raise ValueError("query_ids must be set")
376 if (
377 self.evaluation_metrics is None
378 or "loss" not in self.evaluation_metrics
379 or getattr(batch, "targets", None) is None
380 or self.loss_functions is None
381 or output.scores is None
382 ):
383 return metrics
384 output.scores = output.scores.view(len(query_ids), -1)
385 for loss_function, _ in self.loss_functions:
386 # NOTE skip in-batch losses because they can use a lot of memory
387 if isinstance(loss_function, InBatchLossFunction):
388 continue
389 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch)
390 return metrics
391
[docs]
392 def on_validation_end(self) -> None:
393 """Prints the validation results for each dataloader."""
394 trainer = self.trainer
395 if not (trainer.is_global_zero and trainer._evaluation_loop.verbose):
396 return
397 results = trainer.callback_metrics
398
399 data = []
400 for key, value in {**results, **self._additional_log_metrics}.items():
401 if "dataloader_idx" in key:
402 key = "/".join(key.split("/")[:-1])
403 *dataset_parts, metric = key.split("/")
404 if metric.startswith("validation-"):
405 metric = metric[len("validation-") :]
406 dataset = "/".join(dataset_parts)
407 if isinstance(value, torch.Tensor):
408 value = value.item()
409 data.append({"dataset": dataset, "metric": metric, "value": value})
410 if not data:
411 return
412 df = pd.DataFrame(data)
413 df = df.pivot(index="dataset", columns="metric", values="value")
414 df.columns.name = None
415
416 # bring into correct order when skipping inference datasets
417 datamodule = getattr(self.trainer, "datamodule", None)
418 if datamodule is not None and hasattr(datamodule, "inference_datasets"):
419 inference_datasets = datamodule.inference_datasets
420 if len(inference_datasets) != df.shape[0]:
421 raise ValueError(
422 "Number of inference datasets does not match number of dataloaders. "
423 "Check if the dataloaders are correctly configured."
424 )
425 dataset_ids = [self.get_dataset_id(dataset) for dataset in inference_datasets]
426 df = df.reindex(dataset_ids)
427
428 trainer.print(df)
429
[docs]
430 def on_test_end(self) -> None:
431 """Prints the accumulated metrics for each dataloader."""
432 self.on_validation_end()
433
[docs]
434 def save_pretrained(self, save_path: str | Path) -> None:
435 """Saves the model and tokenizer to the save path.
436
437 Args:
438 save_path (str | Path): Path to save the model and tokenizer.
439 """
440 self.model.save_pretrained(save_path)
441 self.tokenizer.save_pretrained(save_path)
442
[docs]
443 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
444 """Saves the model and tokenizer to the trainer's log directory."""
445 if self.trainer is not None and self.trainer.log_dir is not None:
446 if self.trainer.global_rank != 0:
447 return
448 _step = self.trainer.global_step
449 self.config.save_step = _step
450 log_dir = Path(self.trainer.log_dir)
451 save_path = log_dir / "huggingface_checkpoint"
452 self.save_pretrained(save_path)