Source code for lightning_ir.base.module

  1"""LightningModule for Lightning IR.
  2
  3This module contains the main module class deriving from a LightningModule_.
  4
  5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
  6"""
  7
  8from collections.abc import Mapping, Sequence
  9from pathlib import Path
 10from typing import Any
 11
 12import pandas as pd
 13import torch
 14from lightning import LightningModule
 15from lightning.pytorch.trainer.states import RunningStage
 16from transformers import BatchEncoding, PreTrainedModel
 17
 18from ..data import IRDataset, RankBatch, RunDataset, SearchBatch, TrainBatch
 19from ..loss import InBatchLossFunction, LossFunction
 20from .config import LightningIRConfig
 21from .model import LightningIRModel, LightningIROutput
 22from .tokenizer import LightningIRTokenizer
 23from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
 24
 25
[docs] 26class LightningIRModule(LightningModule): 27 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a 28 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the 29 model. Derived classes must implement the forward method for the model. 30 31 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html 32 """ 33
[docs] 34 def __init__( 35 self, 36 model_name_or_path: str | None = None, 37 config: LightningIRConfig | None = None, 38 model: LightningIRModel | None = None, 39 BackboneModel: type[PreTrainedModel] | None = None, 40 loss_functions: Sequence[LossFunction | tuple[LossFunction, float]] | None = None, 41 evaluation_metrics: Sequence[str] | None = None, 42 model_kwargs: Mapping[str, Any] | None = None, 43 ): 44 """Initializes the LightningIRModule. 45 46 .. _ir-measures: https://ir-measur.es/en/latest/index.html 47 48 Args: 49 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model. 50 Defaults to None. 51 config (LightningIRConfig | None): LightningIRConfig to apply when loading from backbone model. 52 Defaults to None. 53 model (LightningIRModel | None): Already instantiated Lightning IR model. Defaults to None. 54 BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone 55 instead of the default AutoModel. Defaults to None. 56 loss_functions (Sequence[LossFunction | tuple[LossFunction, float]] | None): 57 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function 58 Defaults to None. 59 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings 60 to apply during validation or testing. Defaults to None. 61 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained` when 62 loading a model. Defaults to None. 63 Raises: 64 ValueError: If both model and model_name_or_path are provided. 65 ValueError: If neither model nor model_name_or_path are provided. 66 """ 67 super().__init__() 68 model_kwargs = model_kwargs if model_kwargs is not None else {} 69 self.save_hyperparameters() 70 if model is not None and model_name_or_path is not None: 71 raise ValueError("Only one of model or model_name_or_path must be provided.") 72 if model is None: 73 if model_name_or_path is None: 74 raise ValueError("Either model or model_name_or_path must be provided.") 75 model = LightningIRModel.from_pretrained( 76 model_name_or_path, config=config, BackboneModel=BackboneModel, **model_kwargs 77 ) 78 79 self.model: LightningIRModel = model 80 self.config = self.model.config 81 self.loss_functions: list[tuple[LossFunction, float]] | None = None 82 if loss_functions is not None: 83 self.loss_functions = [] 84 for loss_function in loss_functions: 85 if isinstance(loss_function, LossFunction): 86 self.loss_functions.append((loss_function, 1.0)) 87 else: 88 self.loss_functions.append(loss_function) 89 self.evaluation_metrics = evaluation_metrics 90 self._optimizer: torch.optim.Optimizer | None = None 91 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config) 92 self._additional_log_metrics: dict[str, float] = {}
93
[docs] 94 def on_train_start(self) -> None: 95 """Called at the beginning of training after sanity check.""" 96 super().on_train_start() 97 # NOTE huggingface models are in eval mode by default 98 self.model = self.model.train()
99
[docs] 100 def on_validation_start(self) -> None: 101 """Called at the beginning of validation.""" 102 # NOTE monkey patch result printing of the trainer 103 try: 104 trainer = self.trainer 105 except RuntimeError: 106 trainer = None 107 if trainer is None: 108 return 109 110 trainer._evaluation_loop._print_results = lambda *args, **kwargs: None
111
[docs] 112 def on_test_start(self) -> None: 113 """Called at the beginning of testing.""" 114 self.on_validation_start()
115
[docs] 116 def configure_optimizers(self) -> torch.optim.Optimizer: 117 """Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR 118 programmatically, the optimizer must be set using :meth:`set_optimizer`. 119 120 Returns: 121 torch.optim.Optimizer: The optimizer set for the model. 122 Raises: 123 ValueError: If optimizer is not set. Call `set_optimizer`. 124 """ 125 if self._optimizer is None: 126 raise ValueError("Optimizer is not set. Call `set_optimizer`.") 127 return self._optimizer
128
[docs] 129 def set_optimizer( 130 self, optimizer: type[torch.optim.Optimizer], **optimizer_kwargs: dict[str, Any] 131 ) -> "LightningIRModule": 132 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI. 133 134 Args: 135 optimizer (type[torch.optim.Optimizer]): Torch optimizer class. 136 optimizer_kwargs (dict[str, Any]): Arguments to initialize the optimizer. 137 Returns: 138 LightningIRModule: Self with the optimizer set. 139 """ 140 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs) 141 return self
142
[docs] 143 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput: 144 """Computes relevance scores for queries and documents. 145 146 Args: 147 queries (Sequence[str] | str): Queries to score. 148 docs (Sequence[Sequence[str]] | Sequence[str]): Documents to score. 149 Returns: 150 LightningIROutput: Model output containing the scores. 151 """ 152 if isinstance(queries, str): 153 queries = (queries,) 154 if isinstance(docs[0], str): 155 docs = (docs,) 156 batch = RankBatch(queries, docs, None, None) 157 with torch.no_grad(): 158 return self.forward(batch)
159
[docs] 160 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput: 161 """Handles the forward pass of the model. 162 163 Args: 164 batch (TrainBatch | RankBatch | SearchBatch): Batch of training or ranking data. 165 Returns: 166 LightningIROutput: Model output. 167 Raises: 168 NotImplementedError: Must be implemented by derived class. 169 """ 170 raise NotImplementedError
171
[docs] 172 def prepare_input( 173 self, queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None 174 ) -> dict[str, BatchEncoding]: 175 """Tokenizes queries and documents and returns the tokenized BatchEncoding_. 176 177 .. _BatchEncoding: https://huggingface.co/transformers/main_classes/tokenizer#transformers.BatchEncoding 178 179 Args: 180 queries (Sequence[str] | None): Queries to tokenize. 181 docs (Sequence[str] | None): Documents to tokenize. 182 num_docs (Sequence[int] | int | None): Number of documents per query, if None num_docs is inferred by 183 `len(docs) // len(queries)`. Defaults to None. 184 Returns: 185 dict[str, BatchEncoding]: Tokenized queries and documents, format depends on the tokenizer. 186 """ 187 encodings = self.tokenizer.tokenize( 188 queries, docs, return_tensors="pt", padding=True, truncation=True, num_docs=num_docs 189 ) 190 for key in encodings: 191 encodings[key] = encodings[key].to(self.device) 192 return encodings
193 194 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> list[torch.Tensor]: 195 """Computes the losses for a training batch.""" 196 raise NotImplementedError 197
[docs] 198 def training_step(self, batch: TrainBatch, batch_idx: int) -> dict[str, Any]: 199 """Handles the training step for the model. 200 201 Args: 202 batch (TrainBatch): Batch of training data. 203 batch_idx (int): Index of the batch. 204 Returns: 205 torch.Tensor: Sum of the losses weighted by the loss weights. 206 Raises: 207 ValueError: If no loss functions are set. 208 """ 209 if self.loss_functions is None: 210 raise ValueError("Loss functions are not set") 211 output = self.forward(batch) 212 losses = self._compute_losses(batch, output) 213 total_loss = torch.tensor(0) 214 assert len(losses) == len(self.loss_functions) 215 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses): 216 self.log(loss_function.__class__.__name__, loss) 217 total_loss = total_loss + loss * loss_weight 218 self.log("loss", total_loss, prog_bar=True) 219 return {"loss": total_loss, "output": output}
220
[docs] 221 def validation_step( 222 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0 223 ) -> LightningIROutput: 224 """Handles the validation step for the model. 225 226 Args: 227 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 228 batch_idx (int): Index of the batch. 229 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 230 Returns: 231 LightningIROutput: Model output. 232 """ 233 output = self.forward(batch) 234 235 if self.evaluation_metrics is None: 236 return output 237 238 dataset = self.get_dataset(dataloader_idx) 239 dataset_id = str(dataloader_idx) if dataset is None else self.get_dataset_id(dataset) 240 metrics = self.validate(output, batch) 241 for key, value in metrics.items(): 242 key = f"{dataset_id}/{key}" 243 self.log(key, value, batch_size=len(batch.queries)) 244 return output
245
[docs] 246 def test_step( 247 self, 248 batch: TrainBatch | RankBatch, 249 batch_idx: int, 250 dataloader_idx: int = 0, 251 ) -> LightningIROutput: 252 """Handles the testing step for the model. Passes the batch to the validation step. 253 254 Args: 255 batch (TrainBatch | RankBatch): Batch of testing data. 256 batch_idx (int): Index of the batch. 257 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 258 Returns: 259 LightningIROutput: Model output. 260 """ 261 return self.validation_step(batch, batch_idx, dataloader_idx)
262
[docs] 263 def get_dataset(self, dataloader_idx: int) -> IRDataset | None: 264 """Gets the dataset instance from the dataloader index. Returns None if no dataset is found. 265 266 Args: 267 dataloader_idx (int): Index of the dataloader. 268 Returns: 269 IRDataset | None: Inference dataset or None if no dataset is found. 270 """ 271 try: 272 trainer = self.trainer 273 except RuntimeError: 274 trainer = None 275 if trainer is None: 276 return None 277 STAGE_TO_DATALOADER = { 278 RunningStage.VALIDATING: "val_dataloaders", 279 RunningStage.TESTING: "test_dataloaders", 280 RunningStage.PREDICTING: "predict_dataloaders", 281 RunningStage.SANITY_CHECKING: "val_dataloaders", 282 } 283 if trainer.state.stage is None: 284 return None 285 dataloaders = getattr(trainer, STAGE_TO_DATALOADER[trainer.state.stage], None) 286 if dataloaders is None: 287 return None 288 if isinstance(dataloaders, torch.utils.data.DataLoader): 289 dataloaders = [dataloaders] 290 return dataloaders[dataloader_idx].dataset
291
[docs] 292 def get_dataset_id(self, dataset: IRDataset) -> str: 293 """Gets the dataset id from the dataloader index for logging. 294 295 .. _ir-datasets: https://ir-datasets.com/ 296 297 Args: 298 dataset (IRDataset): Dataset instance. 299 Returns: 300 str: Path to run file, ir-datasets_ dataset id, or dataloader index. 301 """ 302 if isinstance(dataset, RunDataset) and dataset.run_path is not None: 303 dataset_id = dataset.run_path.name 304 else: 305 dataset_id = dataset.dataset_id 306 return dataset_id
307
[docs] 308 def validate( 309 self, 310 output: LightningIROutput, 311 batch: TrainBatch | RankBatch | SearchBatch, 312 ) -> dict[str, float]: 313 """Validates the model output with the evaluation metrics and loss functions. 314 315 Args: 316 output (LightningIROutput): Model output. 317 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 318 Returns: 319 dict[str, float]: Dictionary of evaluation metrics. 320 """ 321 metrics: dict[str, float] = {} 322 if self.evaluation_metrics is None or output.scores is None: 323 return metrics 324 metrics.update(self.validate_metrics(output, batch)) 325 metrics.update(self.validate_loss(output, batch)) 326 return metrics
327
[docs] 328 def validate_metrics( 329 self, 330 output: LightningIROutput, 331 batch: TrainBatch | RankBatch | SearchBatch, 332 ) -> dict[str, float]: 333 """Validates the model output with the evaluation metrics. 334 335 Args: 336 output (LightningIROutput): Model output. 337 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 338 Returns: 339 dict[str, float]: Dictionary of evaluation metrics. 340 Raises: 341 ValueError: If query_ids or doc_ids are not set in the batch. 342 """ 343 metrics: dict[str, float] = {} 344 qrels = batch.qrels 345 if self.evaluation_metrics is None or qrels is None: 346 return metrics 347 query_ids = batch.query_ids 348 doc_ids = batch.doc_ids 349 if query_ids is None: 350 raise ValueError("query_ids must be set") 351 if doc_ids is None: 352 raise ValueError("doc_ids must be set") 353 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"] 354 ir_measures_qrels = create_qrels_from_dicts(qrels) 355 if evaluation_metrics and qrels is not None and output.scores is not None: 356 run = create_run_from_scores(query_ids, doc_ids, output.scores) 357 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics)) 358 return metrics
359
[docs] 360 def validate_loss( 361 self, 362 output: LightningIROutput, 363 batch: TrainBatch | RankBatch | SearchBatch, 364 ) -> dict[str, float]: 365 """Validates the model output with the loss functions. 366 367 Args: 368 output (LightningIROutput): Model output. 369 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 370 Returns: 371 dict[str, float]: Dictionary of evaluation metrics. 372 """ 373 metrics: dict[str, float] = {} 374 query_ids = batch.query_ids 375 if query_ids is None: 376 raise ValueError("query_ids must be set") 377 if ( 378 self.evaluation_metrics is None 379 or "loss" not in self.evaluation_metrics 380 or getattr(batch, "targets", None) is None 381 or self.loss_functions is None 382 or output.scores is None 383 ): 384 return metrics 385 output.scores = output.scores.view(len(query_ids), -1) 386 for loss_function, _ in self.loss_functions: 387 # NOTE skip in-batch losses because they can use a lot of memory 388 if isinstance(loss_function, InBatchLossFunction): 389 continue 390 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch) 391 return metrics
392
[docs] 393 def on_validation_end(self) -> None: 394 """Prints the validation results for each dataloader.""" 395 trainer = self.trainer 396 if not (trainer.is_global_zero and trainer._evaluation_loop.verbose): 397 return 398 results = trainer.callback_metrics 399 400 data = [] 401 for key, value in {**results, **self._additional_log_metrics}.items(): 402 if "dataloader_idx" in key: 403 key = "/".join(key.split("/")[:-1]) 404 *dataset_parts, metric = key.split("/") 405 if metric.startswith("validation-"): 406 metric = metric[len("validation-") :] 407 dataset = "/".join(dataset_parts) 408 if isinstance(value, torch.Tensor): 409 value = value.item() 410 data.append({"dataset": dataset, "metric": metric, "value": value}) 411 if not data: 412 return 413 df = pd.DataFrame(data) 414 df = df.pivot(index="dataset", columns="metric", values="value") 415 df.columns.name = None 416 417 # bring into correct order when skipping inference datasets 418 datamodule = getattr(self.trainer, "datamodule", None) 419 if datamodule is not None and hasattr(datamodule, "inference_datasets"): 420 inference_datasets = datamodule.inference_datasets 421 if len(inference_datasets) != df.shape[0]: 422 raise ValueError( 423 "Number of inference datasets does not match number of dataloaders. " 424 "Check if the dataloaders are correctly configured." 425 ) 426 dataset_ids = [self.get_dataset_id(dataset) for dataset in inference_datasets] 427 df = df.reindex(dataset_ids) 428 429 trainer.print(df)
430
[docs] 431 def on_test_end(self) -> None: 432 """Prints the accumulated metrics for each dataloader.""" 433 self.on_validation_end()
434
[docs] 435 def save_pretrained(self, save_path: str | Path) -> None: 436 """Saves the model and tokenizer to the save path. 437 438 Args: 439 save_path (str | Path): Path to save the model and tokenizer. 440 """ 441 self.model.save_pretrained(save_path) 442 self.tokenizer.save_pretrained(save_path)
443
[docs] 444 def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: 445 """Saves the model and tokenizer to the trainer's log directory.""" 446 if self.trainer is not None and self.trainer.log_dir is not None: 447 if self.trainer.global_rank != 0: 448 return 449 _step = self.trainer.global_step 450 self.config.save_step = _step 451 log_dir = Path(self.trainer.log_dir) 452 save_path = log_dir / "huggingface_checkpoint" 453 self.save_pretrained(save_path)