Source code for lightning_ir.base.module

  1"""LightningModule for Lightning IR.
  2
  3This module contains the main module class deriving from a LightningModule_.
  4
  5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
  6"""
  7
  8from pathlib import Path
  9from typing import Any, Dict, List, Mapping, Sequence, Tuple, Type
 10
 11import pandas as pd
 12import torch
 13from lightning import LightningModule
 14from lightning.pytorch.trainer.states import RunningStage
 15from transformers import BatchEncoding
 16
 17from ..data import IRDataset, RankBatch, RunDataset, SearchBatch, TrainBatch
 18from ..loss import InBatchLossFunction, LossFunction
 19from .config import LightningIRConfig
 20from .model import LightningIRModel, LightningIROutput
 21from .tokenizer import LightningIRTokenizer
 22from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
 23
 24
[docs] 25class LightningIRModule(LightningModule): 26 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a 27 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the 28 model. Derived classes must implement the forward method for the model. 29 30 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html 31 """ 32
[docs] 33 def __init__( 34 self, 35 model_name_or_path: str | None = None, 36 config: LightningIRConfig | None = None, 37 model: LightningIRModel | None = None, 38 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 39 evaluation_metrics: Sequence[str] | None = None, 40 model_kwargs: Mapping[str, Any] | None = None, 41 ): 42 """Initializes the LightningIRModule. 43 44 .. _ir-measures: https://ir-measur.es/en/latest/index.html 45 46 Args: 47 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model. 48 Defaults to None. 49 config (LightningIRConfig | None): LightningIRConfig to apply when loading from backbone model. 50 Defaults to None. 51 model (LightningIRModel | None): Already instantiated Lightning IR model. Defaults to None. 52 loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None): 53 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function 54 Defaults to None. 55 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings 56 to apply during validation or testing. Defaults to None. 57 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained` when 58 loading a model. Defaults to None. 59 Raises: 60 ValueError: If both model and model_name_or_path are provided. 61 ValueError: If neither model nor model_name_or_path are provided. 62 """ 63 super().__init__() 64 model_kwargs = model_kwargs if model_kwargs is not None else {} 65 self.save_hyperparameters() 66 if model is not None and model_name_or_path is not None: 67 raise ValueError("Only one of model or model_name_or_path must be provided.") 68 if model is None: 69 if model_name_or_path is None: 70 raise ValueError("Either model or model_name_or_path must be provided.") 71 model = LightningIRModel.from_pretrained(model_name_or_path, config=config, **model_kwargs) 72 73 self.model: LightningIRModel = model 74 self.config = self.model.config 75 self.loss_functions: List[Tuple[LossFunction, float]] | None = None 76 if loss_functions is not None: 77 self.loss_functions = [] 78 for loss_function in loss_functions: 79 if isinstance(loss_function, LossFunction): 80 self.loss_functions.append((loss_function, 1.0)) 81 else: 82 self.loss_functions.append(loss_function) 83 self.evaluation_metrics = evaluation_metrics 84 self._optimizer: torch.optim.Optimizer | None = None 85 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config) 86 self._additional_log_metrics: Dict[str, float] = {}
87
[docs] 88 def on_train_start(self) -> None: 89 """Called at the beginning of training after sanity check.""" 90 super().on_train_start() 91 # NOTE huggingface models are in eval mode by default 92 self.model = self.model.train()
93
[docs] 94 def on_validation_start(self) -> None: 95 """Called at the beginning of validation.""" 96 # NOTE monkey patch result printing of the trainer 97 try: 98 trainer = self.trainer 99 except RuntimeError: 100 trainer = None 101 if trainer is None: 102 return 103 104 trainer._evaluation_loop._print_results = lambda *args, **kwargs: None
105
[docs] 106 def on_test_start(self) -> None: 107 """Called at the beginning of testing.""" 108 self.on_validation_start()
109
[docs] 110 def configure_optimizers(self) -> torch.optim.Optimizer: 111 """Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR 112 programmatically, the optimizer must be set using :meth:`set_optimizer`. 113 114 Returns: 115 torch.optim.Optimizer: The optimizer set for the model. 116 Raises: 117 ValueError: If optimizer is not set. Call `set_optimizer`. 118 """ 119 if self._optimizer is None: 120 raise ValueError("Optimizer is not set. Call `set_optimizer`.") 121 return self._optimizer
122
[docs] 123 def set_optimizer( 124 self, optimizer: Type[torch.optim.Optimizer], **optimizer_kwargs: Dict[str, Any] 125 ) -> "LightningIRModule": 126 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI. 127 128 Args: 129 optimizer (Type[torch.optim.Optimizer]): Torch optimizer class. 130 optimizer_kwargs (Dict[str, Any]): Arguments to initialize the optimizer. 131 Returns: 132 LightningIRModule: Self with the optimizer set. 133 """ 134 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs) 135 return self
136
[docs] 137 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput: 138 """Computes relevance scores for queries and documents. 139 140 Args: 141 queries (Sequence[str] | str): Queries to score. 142 docs (Sequence[Sequence[str]] | Sequence[str]): Documents to score. 143 Returns: 144 LightningIROutput: Model output containing the scores. 145 """ 146 if isinstance(queries, str): 147 queries = (queries,) 148 if isinstance(docs[0], str): 149 docs = (docs,) 150 batch = RankBatch(queries, docs, None, None) 151 with torch.no_grad(): 152 return self.forward(batch)
153
[docs] 154 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput: 155 """Handles the forward pass of the model. 156 157 Args: 158 batch (TrainBatch | RankBatch | SearchBatch): Batch of training or ranking data. 159 Returns: 160 LightningIROutput: Model output. 161 Raises: 162 NotImplementedError: Must be implemented by derived class. 163 """ 164 raise NotImplementedError
165
[docs] 166 def prepare_input( 167 self, queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None 168 ) -> Dict[str, BatchEncoding]: 169 """Tokenizes queries and documents and returns the tokenized BatchEncoding_. 170 171 .. _BatchEncoding: https://huggingface.co/transformers/main_classes/tokenizer#transformers.BatchEncoding 172 173 Args: 174 queries (Sequence[str] | None): Queries to tokenize. 175 docs (Sequence[str] | None): Documents to tokenize. 176 num_docs (Sequence[int] | int | None): Number of documents per query, if None num_docs is inferred by 177 `len(docs) // len(queries)`. Defaults to None. 178 Returns: 179 Dict[str, BatchEncoding]: Tokenized queries and documents, format depends on the tokenizer. 180 """ 181 encodings = self.tokenizer.tokenize( 182 queries, docs, return_tensors="pt", padding=True, truncation=True, num_docs=num_docs 183 ) 184 for key in encodings: 185 encodings[key] = encodings[key].to(self.device) 186 return encodings
187 188 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]: 189 """Computes the losses for a training batch.""" 190 raise NotImplementedError 191
[docs] 192 def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor: 193 """Handles the training step for the model. 194 195 Args: 196 batch (TrainBatch): Batch of training data. 197 batch_idx (int): Index of the batch. 198 Returns: 199 torch.Tensor: Sum of the losses weighted by the loss weights. 200 Raises: 201 ValueError: If no loss functions are set. 202 """ 203 if self.loss_functions is None: 204 raise ValueError("Loss functions are not set") 205 output = self.forward(batch) 206 losses = self._compute_losses(batch, output) 207 total_loss = torch.tensor(0) 208 assert len(losses) == len(self.loss_functions) 209 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses): 210 self.log(loss_function.__class__.__name__, loss) 211 total_loss = total_loss + loss * loss_weight 212 self.log("loss", total_loss, prog_bar=True) 213 return total_loss
214
[docs] 215 def validation_step( 216 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0 217 ) -> LightningIROutput: 218 """Handles the validation step for the model. 219 220 Args: 221 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 222 batch_idx (int): Index of the batch. 223 dataloader_idx (int, optional): Index of the dataloader. Defaults to 0. 224 Returns: 225 LightningIROutput: Model output. 226 """ 227 output = self.forward(batch) 228 229 if self.evaluation_metrics is None: 230 return output 231 232 dataset = self.get_dataset(dataloader_idx) 233 dataset_id = str(dataloader_idx) if dataset is None else self.get_dataset_id(dataset) 234 metrics = self.validate(output, batch) 235 for key, value in metrics.items(): 236 key = f"{dataset_id}/{key}" 237 self.log(key, value, batch_size=len(batch.queries)) 238 return output
239
[docs] 240 def test_step( 241 self, 242 batch: TrainBatch | RankBatch, 243 batch_idx: int, 244 dataloader_idx: int = 0, 245 ) -> LightningIROutput: 246 """Handles the testing step for the model. Passes the batch to the validation step. 247 248 Args: 249 batch (TrainBatch | RankBatch): Batch of testing data. 250 batch_idx (int): Index of the batch. 251 dataloader_idx (int, optional): Index of the dataloader. Defaults to 0. 252 Returns: 253 LightningIROutput: Model output. 254 """ 255 return self.validation_step(batch, batch_idx, dataloader_idx)
256
[docs] 257 def get_dataset(self, dataloader_idx: int) -> IRDataset | None: 258 """Gets the dataset instance from the dataloader index. Returns None if no dataset is found. 259 260 Args: 261 dataloader_idx (int): Index of the dataloader. 262 Returns: 263 IRDataset | None: Inference dataset or None if no dataset is found. 264 """ 265 try: 266 trainer = self.trainer 267 except RuntimeError: 268 trainer = None 269 if trainer is None: 270 return None 271 STAGE_TO_DATALOADER = { 272 RunningStage.VALIDATING: "val_dataloaders", 273 RunningStage.TESTING: "test_dataloaders", 274 RunningStage.PREDICTING: "predict_dataloaders", 275 RunningStage.SANITY_CHECKING: "val_dataloaders", 276 } 277 if trainer.state.stage is None: 278 return None 279 dataloaders = getattr(trainer, STAGE_TO_DATALOADER[trainer.state.stage], None) 280 if dataloaders is None: 281 return None 282 if isinstance(dataloaders, torch.utils.data.DataLoader): 283 dataloaders = [dataloaders] 284 return dataloaders[dataloader_idx].dataset
285
[docs] 286 def get_dataset_id(self, dataset: IRDataset) -> str: 287 """Gets the dataset id from the dataloader index for logging. 288 289 .. _ir-datasets: https://ir-datasets.com/ 290 291 Args: 292 dataset (IRDataset): Dataset instance. 293 Returns: 294 str: Path to run file, ir-datasets_ dataset id, or dataloader index. 295 """ 296 if isinstance(dataset, RunDataset) and dataset.run_path is not None: 297 dataset_id = dataset.run_path.name 298 else: 299 dataset_id = dataset.dataset_id 300 return dataset_id
301
[docs] 302 def validate( 303 self, 304 output: LightningIROutput, 305 batch: TrainBatch | RankBatch | SearchBatch, 306 ) -> Dict[str, float]: 307 """Validates the model output with the evaluation metrics and loss functions. 308 309 Args: 310 output (LightningIROutput): Model output. 311 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 312 Returns: 313 Dict[str, float]: Dictionary of evaluation metrics. 314 """ 315 metrics: Dict[str, float] = {} 316 if self.evaluation_metrics is None or output.scores is None: 317 return metrics 318 metrics.update(self.validate_metrics(output, batch)) 319 metrics.update(self.validate_loss(output, batch)) 320 return metrics
321
[docs] 322 def validate_metrics( 323 self, 324 output: LightningIROutput, 325 batch: TrainBatch | RankBatch | SearchBatch, 326 ) -> Dict[str, float]: 327 """Validates the model output with the evaluation metrics. 328 329 Args: 330 output (LightningIROutput): Model output. 331 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 332 Returns: 333 Dict[str, float]: Dictionary of evaluation metrics. 334 Raises: 335 ValueError: If query_ids or doc_ids are not set in the batch. 336 """ 337 metrics: Dict[str, float] = {} 338 qrels = batch.qrels 339 if self.evaluation_metrics is None or qrels is None: 340 return metrics 341 query_ids = batch.query_ids 342 doc_ids = batch.doc_ids 343 if query_ids is None: 344 raise ValueError("query_ids must be set") 345 if doc_ids is None: 346 raise ValueError("doc_ids must be set") 347 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"] 348 ir_measures_qrels = create_qrels_from_dicts(qrels) 349 if evaluation_metrics and qrels is not None and output.scores is not None: 350 run = create_run_from_scores(query_ids, doc_ids, output.scores) 351 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics)) 352 return metrics
353
[docs] 354 def validate_loss( 355 self, 356 output: LightningIROutput, 357 batch: TrainBatch | RankBatch | SearchBatch, 358 ) -> Dict[str, float]: 359 """Validates the model output with the loss functions. 360 361 Args: 362 output (LightningIROutput): Model output. 363 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 364 Returns: 365 Dict[str, float]: Dictionary of evaluation metrics. 366 """ 367 metrics: Dict[str, float] = {} 368 query_ids = batch.query_ids 369 if query_ids is None: 370 raise ValueError("query_ids must be set") 371 if ( 372 self.evaluation_metrics is None 373 or "loss" not in self.evaluation_metrics 374 or getattr(batch, "targets", None) is None 375 or self.loss_functions is None 376 or output.scores is None 377 ): 378 return metrics 379 output.scores = output.scores.view(len(query_ids), -1) 380 for loss_function, _ in self.loss_functions: 381 # NOTE skip in-batch losses because they can use a lot of memory 382 if isinstance(loss_function, InBatchLossFunction): 383 continue 384 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch) 385 return metrics
386
[docs] 387 def on_validation_end(self) -> None: 388 """Prints the validation results for each dataloader.""" 389 trainer = self.trainer 390 if not (trainer.is_global_zero and trainer._evaluation_loop.verbose): 391 return 392 results = trainer.callback_metrics 393 394 data = [] 395 for key, value in {**results, **self._additional_log_metrics}.items(): 396 if "dataloader_idx" in key: 397 key = "/".join(key.split("/")[:-1]) 398 *dataset_parts, metric = key.split("/") 399 if metric.startswith("validation-"): 400 metric = metric[len("validation-") :] 401 dataset = "/".join(dataset_parts) 402 if isinstance(value, torch.Tensor): 403 value = value.item() 404 data.append({"dataset": dataset, "metric": metric, "value": value}) 405 if not data: 406 return 407 df = pd.DataFrame(data) 408 df = df.pivot(index="dataset", columns="metric", values="value") 409 df.columns.name = None 410 411 # bring into correct order when skipping inference datasets 412 datamodule = getattr(self.trainer, "datamodule", None) 413 if datamodule is not None and hasattr(datamodule, "inference_datasets"): 414 inference_datasets = datamodule.inference_datasets 415 if len(inference_datasets) != df.shape[0]: 416 raise ValueError( 417 "Number of inference datasets does not match number of dataloaders. " 418 "Check if the dataloaders are correctly configured." 419 ) 420 dataset_ids = [self.get_dataset_id(dataset) for dataset in inference_datasets] 421 df = df.reindex(dataset_ids) 422 423 trainer.print(df)
424
[docs] 425 def on_test_end(self) -> None: 426 """Prints the accumulated metrics for each dataloader.""" 427 self.on_validation_end()
428
[docs] 429 def save_pretrained(self, save_path: str | Path) -> None: 430 """Saves the model and tokenizer to the save path. 431 432 Args: 433 save_path (str | Path): Path to save the model and tokenizer. 434 """ 435 self.model.save_pretrained(save_path) 436 self.tokenizer.save_pretrained(save_path)
437
[docs] 438 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 439 """Saves the model and tokenizer to the trainer's log directory.""" 440 if self.trainer is not None and self.trainer.log_dir is not None: 441 if self.trainer.global_rank != 0: 442 return 443 _step = self.trainer.global_step 444 self.config.save_step = _step 445 log_dir = Path(self.trainer.log_dir) 446 save_path = log_dir / "huggingface_checkpoint" 447 self.save_pretrained(save_path)