Source code for lightning_ir.base.module

  1"""LightningModule for Lightning IR.
  2
  3This module contains the main module class deriving from a LightningModule_.
  4
  5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
  6"""
  7
  8from pathlib import Path
  9from typing import Any, Dict, List, Mapping, Sequence, Tuple, Type
 10
 11import pandas as pd
 12import torch
 13from lightning import LightningModule
 14from lightning.pytorch.trainer.states import RunningStage
 15from transformers import BatchEncoding
 16
 17from ..data import IRDataset, RankBatch, RunDataset, SearchBatch, TrainBatch
 18from ..loss.loss import InBatchLossFunction, LossFunction
 19from .config import LightningIRConfig
 20from .model import LightningIRModel, LightningIROutput
 21from .tokenizer import LightningIRTokenizer
 22from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
 23
 24
[docs] 25class LightningIRModule(LightningModule): 26 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a 27 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the 28 model. Derived classes must implement the forward method for the model. 29 30 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html 31 """ 32
[docs] 33 def __init__( 34 self, 35 model_name_or_path: str | None = None, 36 config: LightningIRConfig | None = None, 37 model: LightningIRModel | None = None, 38 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 39 evaluation_metrics: Sequence[str] | None = None, 40 model_kwargs: Mapping[str, Any] | None = None, 41 ): 42 """Initializes the LightningIRModule. 43 44 .. _ir-measures: https://ir-measur.es/en/latest/index.html 45 46 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None 47 :type model_name_or_path: str | None, optional 48 :param config: LightningIRConfig to apply when loading from backbone model, defaults to None 49 :type config: LightningIRConfig | None, optional 50 :param model: Already instantiated Lightning IR model, defaults to None 51 :type model: LightningIRModel | None, optional 52 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per 53 loss function, defaults to None 54 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional 55 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or 56 testing, defaults to None 57 :type evaluation_metrics: Sequence[str] | None, optional 58 :param model_kwargs: Additional keyword arguments to pass to `from_pretrained` when loading a model, 59 defaults to None 60 :type model_kwargs: Mapping[str, Any] | None, optional 61 :raises ValueError: If both model and model_name_or_path are provided 62 :raises ValueError: If neither model nor model_name_or_path are provided 63 """ 64 super().__init__() 65 model_kwargs = model_kwargs if model_kwargs is not None else {} 66 self.save_hyperparameters() 67 if model is not None and model_name_or_path is not None: 68 raise ValueError("Only one of model or model_name_or_path must be provided.") 69 if model is None: 70 if model_name_or_path is None: 71 raise ValueError("Either model or model_name_or_path must be provided.") 72 model = LightningIRModel.from_pretrained(model_name_or_path, config=config, **model_kwargs) 73 74 self.model: LightningIRModel = model 75 self.config = self.model.config 76 self.loss_functions: List[Tuple[LossFunction, float]] | None = None 77 if loss_functions is not None: 78 self.loss_functions = [] 79 for loss_function in loss_functions: 80 if isinstance(loss_function, LossFunction): 81 self.loss_functions.append((loss_function, 1.0)) 82 else: 83 self.loss_functions.append(loss_function) 84 self.evaluation_metrics = evaluation_metrics 85 self._optimizer: torch.optim.Optimizer | None = None 86 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config) 87 self._additional_log_metrics: Dict[str, float] = {}
88
[docs] 89 def on_train_start(self) -> None: 90 """Called at the beginning of training after sanity check.""" 91 super().on_train_start() 92 # NOTE huggingface models are in eval mode by default 93 self.model = self.model.train()
94
[docs] 95 def on_validation_start(self) -> None: 96 """Called at the beginning of validation.""" 97 # NOTE monkey patch result printing of the trainer 98 try: 99 trainer = self.trainer 100 except RuntimeError: 101 trainer = None 102 if trainer is None: 103 return 104 105 trainer._evaluation_loop._print_results = lambda *args, **kwargs: None
106
[docs] 107 def on_test_start(self) -> None: 108 """Called at the beginning of testing.""" 109 self.on_validation_start()
110
[docs] 111 def configure_optimizers(self) -> torch.optim.Optimizer: 112 """Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR 113 programmatically, the optimizer must be set using :meth:`set_optimizer`. 114 115 :raises ValueError: If optimizer is not set 116 :return: Optimizer 117 :rtype: torch.optim.Optimizer 118 """ 119 if self._optimizer is None: 120 raise ValueError("Optimizer is not set. Call `set_optimizer`.") 121 return self._optimizer
122
[docs] 123 def set_optimizer( 124 self, optimizer: Type[torch.optim.Optimizer], **optimizer_kwargs: Dict[str, Any] 125 ) -> "LightningIRModule": 126 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI. 127 128 :param optimizer: Torch optimizer class 129 :type optimizer: Type[torch.optim.Optimizer] 130 :param optimizer_kwargs: Arguments to initialize the optimizer 131 :type optimizer_kwargs: Dict[str, Any] 132 :return: self 133 :rtype: LightningIRModule 134 """ 135 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs) 136 return self
137
[docs] 138 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput: 139 """Computes relevance scores for queries and documents. 140 141 :param queries: Queries to score 142 :type queries: Sequence[str] 143 :param docs: Documents to score 144 :type docs: Sequence[Sequence[str]] 145 :return: Model output 146 :rtype: LightningIROutput 147 """ 148 if isinstance(queries, str): 149 queries = (queries,) 150 if isinstance(docs[0], str): 151 docs = (docs,) 152 batch = RankBatch(queries, docs, None, None) 153 with torch.no_grad(): 154 return self.forward(batch)
155
[docs] 156 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput: 157 """Handles the forward pass of the model. 158 159 :param batch: Batch of training or ranking data 160 :type batch: TrainBatch | RankBatch 161 :raises NotImplementedError: Must be implemented by derived class 162 :return: Model output 163 :rtype: LightningIROutput 164 """ 165 raise NotImplementedError
166
[docs] 167 def prepare_input( 168 self, queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None 169 ) -> Dict[str, BatchEncoding]: 170 """Tokenizes queries and documents and returns the tokenized BatchEncoding_. 171 172 .. _BatchEncoding: https://huggingface.co/transformers/main_classes/tokenizer#transformers.BatchEncoding 173 174 :param queries: Queries to tokenize 175 :type queries: Sequence[str] | None 176 :param docs: Documents to tokenize 177 :type docs: Sequence[str] | None 178 :param num_docs: Number of documents per query, if None num_docs is inferred by `len(docs) // len(queries)`, 179 defaults to None 180 :type num_docs: Sequence[int] | int | None 181 :return: Tokenized queries and documents, format depends on the tokenizer 182 :rtype: Dict[str, BatchEncoding] 183 """ 184 encodings = self.tokenizer.tokenize( 185 queries, docs, return_tensors="pt", padding=True, truncation=True, num_docs=num_docs 186 ) 187 for key in encodings: 188 encodings[key] = encodings[key].to(self.device) 189 return encodings
190 191 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]: 192 """Computes the losses for a training batch.""" 193 raise NotImplementedError 194
[docs] 195 def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor: 196 """Handles the training step for the model. 197 198 :param batch: Batch of training data 199 :type batch: TrainBatch 200 :param batch_idx: Index of the batch 201 :type batch_idx: int 202 :raises ValueError: If no loss functions are set 203 :return: Sum of the losses weighted by the loss weights 204 :rtype: torch.Tensor 205 """ 206 if self.loss_functions is None: 207 raise ValueError("Loss functions are not set") 208 output = self.forward(batch) 209 losses = self._compute_losses(batch, output) 210 total_loss = torch.tensor(0) 211 assert len(losses) == len(self.loss_functions) 212 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses): 213 self.log(loss_function.__class__.__name__, loss) 214 total_loss = total_loss + loss * loss_weight 215 self.log("loss", total_loss, prog_bar=True) 216 return total_loss
217
[docs] 218 def validation_step( 219 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0 220 ) -> LightningIROutput: 221 """Handles the validation step for the model. 222 223 :param batch: Batch of validation or testing data 224 :type batch: TrainBatch | RankBatch | SearchBatch 225 :param batch_idx: Index of the batch 226 :type batch_idx: int 227 :param dataloader_idx: Index of the dataloader, defaults to 0 228 :type dataloader_idx: int, optional 229 :return: Model output 230 :rtype: LightningIROutput 231 """ 232 output = self.forward(batch) 233 234 if self.evaluation_metrics is None: 235 return output 236 237 dataset = self.get_dataset(dataloader_idx) 238 dataset_id = str(dataloader_idx) if dataset is None else self.get_dataset_id(dataset) 239 metrics = self.validate(output, batch) 240 for key, value in metrics.items(): 241 key = f"{dataset_id}/{key}" 242 self.log(key, value, batch_size=len(batch.queries)) 243 return output
244
[docs] 245 def test_step( 246 self, 247 batch: TrainBatch | RankBatch, 248 batch_idx: int, 249 dataloader_idx: int = 0, 250 ) -> LightningIROutput: 251 """Handles the testing step for the model. Passes the batch to the validation step. 252 253 :param batch: Batch of testing data 254 :type batch: TrainBatch | RankBatch 255 :param batch_idx: Index of the batch 256 :type batch_idx: int 257 :param dataloader_idx: Index of the dataloader, defaults to 0 258 :type dataloader_idx: int, optional 259 :return: Model output 260 :rtype: LightningIROutput 261 """ 262 return self.validation_step(batch, batch_idx, dataloader_idx)
263
[docs] 264 def get_dataset(self, dataloader_idx: int) -> IRDataset | None: 265 """Gets the dataset instance from the dataloader index. Returns None if no dataset is found. 266 267 :param dataloader_idx: Index of the dataloader 268 :type dataloader_idx: int 269 :return: Inference dataset 270 :rtype: IRDataset | None 271 """ 272 try: 273 trainer = self.trainer 274 except RuntimeError: 275 trainer = None 276 if trainer is None: 277 return None 278 STAGE_TO_DATALOADER = { 279 RunningStage.VALIDATING: "val_dataloaders", 280 RunningStage.TESTING: "test_dataloaders", 281 RunningStage.PREDICTING: "predict_dataloaders", 282 RunningStage.SANITY_CHECKING: "val_dataloaders", 283 } 284 if trainer.state.stage is None: 285 return None 286 dataloaders = getattr(trainer, STAGE_TO_DATALOADER[trainer.state.stage], None) 287 if dataloaders is None: 288 return None 289 if isinstance(dataloaders, torch.utils.data.DataLoader): 290 dataloaders = [dataloaders] 291 return dataloaders[dataloader_idx].dataset
292
[docs] 293 def get_dataset_id(self, dataset: IRDataset) -> str: 294 """Gets the dataset id from the dataloader index for logging. 295 296 .. _ir-datasets: https://ir-datasets.com/ 297 298 :param dataloader_idx: Index of the dataloader 299 :type dataloader_idx: int 300 :return: path to run file, ir-datasets_ dataset id, or dataloader index 301 :rtype: str 302 """ 303 if isinstance(dataset, RunDataset) and dataset.run_path is not None: 304 dataset_id = dataset.run_path.name 305 else: 306 dataset_id = dataset.dataset_id 307 return dataset_id
308
[docs] 309 def validate( 310 self, 311 output: LightningIROutput, 312 batch: TrainBatch | RankBatch | SearchBatch, 313 ) -> Dict[str, float]: 314 """Validates the model output with the evaluation metrics and loss functions. 315 316 :param output: Model output 317 :type output: LightningIROutput 318 :param batch: Batch of validation or testing data 319 :type batch: TrainBatch | RankBatch | SearchBatch 320 :return: Dictionary of evaluation metrics 321 :rtype: Dict[str, float] 322 """ 323 metrics: Dict[str, float] = {} 324 if self.evaluation_metrics is None or output.scores is None: 325 return metrics 326 metrics.update(self.validate_metrics(output, batch)) 327 metrics.update(self.validate_loss(output, batch)) 328 return metrics
329
[docs] 330 def validate_metrics( 331 self, 332 output: LightningIROutput, 333 batch: TrainBatch | RankBatch | SearchBatch, 334 ) -> Dict[str, float]: 335 """Validates the model output with the evaluation metrics. 336 337 :param output: Model output 338 :type output: LightningIROutput 339 :param batch: Batch of validation or testing data 340 :type batch: TrainBatch | RankBatch | SearchBatch 341 :return: Evaluation metrics 342 :rtype: Dict[str, float] 343 """ 344 metrics: Dict[str, float] = {} 345 qrels = batch.qrels 346 if self.evaluation_metrics is None or qrels is None: 347 return metrics 348 query_ids = batch.query_ids 349 doc_ids = batch.doc_ids 350 if query_ids is None: 351 raise ValueError("query_ids must be set") 352 if doc_ids is None: 353 raise ValueError("doc_ids must be set") 354 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"] 355 ir_measures_qrels = create_qrels_from_dicts(qrels) 356 if evaluation_metrics and qrels is not None and output.scores is not None: 357 run = create_run_from_scores(query_ids, doc_ids, output.scores) 358 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics)) 359 return metrics
360
[docs] 361 def validate_loss( 362 self, 363 output: LightningIROutput, 364 batch: TrainBatch | RankBatch | SearchBatch, 365 ) -> Dict[str, float]: 366 """Validates the model output with the loss functions. 367 368 :param output: Model output 369 :type output: LightningIROutput 370 :param batch: Batch of validation or testing data 371 :type batch: TrainBatch | RankBatch | SearchBatch 372 :return: Evaluation metrics 373 :rtype: Dict[str, float] 374 """ 375 metrics: Dict[str, float] = {} 376 query_ids = batch.query_ids 377 if query_ids is None: 378 raise ValueError("query_ids must be set") 379 if ( 380 self.evaluation_metrics is None 381 or "loss" not in self.evaluation_metrics 382 or getattr(batch, "targets", None) is None 383 or self.loss_functions is None 384 or output.scores is None 385 ): 386 return metrics 387 output.scores = output.scores.view(len(query_ids), -1) 388 for loss_function, _ in self.loss_functions: 389 # NOTE skip in-batch losses because they can use a lot of memory 390 if isinstance(loss_function, InBatchLossFunction): 391 continue 392 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch) 393 return metrics
394
[docs] 395 def on_validation_end(self) -> None: 396 """Prints the validation results for each dataloader.""" 397 trainer = self.trainer 398 if not (trainer.is_global_zero and trainer._evaluation_loop.verbose): 399 return 400 results = trainer.callback_metrics 401 402 data = [] 403 for key, value in {**results, **self._additional_log_metrics}.items(): 404 if "dataloader_idx" in key: 405 key = "/".join(key.split("/")[:-1]) 406 *dataset_parts, metric = key.split("/") 407 if metric.startswith("validation-"): 408 metric = metric[len("validation-") :] 409 dataset = "/".join(dataset_parts) 410 if isinstance(value, torch.Tensor): 411 value = value.item() 412 data.append({"dataset": dataset, "metric": metric, "value": value}) 413 if not data: 414 return 415 df = pd.DataFrame(data) 416 df = df.pivot(index="dataset", columns="metric", values="value") 417 df.columns.name = None 418 419 # bring into correct order when skipping inference datasets 420 datamodule = getattr(self.trainer, "datamodule", None) 421 if datamodule is not None and hasattr(datamodule, "inference_datasets"): 422 inference_datasets = datamodule.inference_datasets 423 if len(inference_datasets) != df.shape[0]: 424 raise ValueError( 425 "Number of inference datasets does not match number of dataloaders. " 426 "Check if the dataloaders are correctly configured." 427 ) 428 dataset_ids = [self.get_dataset_id(dataset) for dataset in inference_datasets] 429 df = df.reindex(dataset_ids) 430 431 trainer.print(df)
432
[docs] 433 def on_test_end(self) -> None: 434 """Prints the accumulated metrics for each dataloader.""" 435 self.on_validation_end()
436
[docs] 437 def save_pretrained(self, save_path: str | Path) -> None: 438 """Saves the model and tokenizer to the save path. 439 440 :param save_path: Path to save the model and tokenizer 441 :type save_path: str | Path 442 """ 443 self.model.save_pretrained(save_path) 444 self.tokenizer.save_pretrained(save_path)
445
[docs] 446 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 447 """Saves the model and tokenizer to the trainer's log directory.""" 448 if self.trainer is not None and self.trainer.log_dir is not None: 449 if self.trainer.global_rank != 0: 450 return 451 _step = self.trainer.global_step 452 self.config.save_step = _step 453 log_dir = Path(self.trainer.log_dir) 454 save_path = log_dir / "huggingface_checkpoint" 455 self.save_pretrained(save_path)