Source code for lightning_ir.base.module

  1"""LightningModule for Lightning IR.
  2
  3This module contains the main module class deriving from a LightningModule_.
  4
  5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
  6"""
  7
  8from pathlib import Path
  9from typing import Any, Dict, List, Mapping, Sequence, Tuple, Type
 10
 11import pandas as pd
 12import torch
 13from lightning import LightningModule
 14from lightning.pytorch.trainer.states import RunningStage
 15from transformers import BatchEncoding, PreTrainedModel
 16
 17from ..data import IRDataset, RankBatch, RunDataset, SearchBatch, TrainBatch
 18from ..loss import InBatchLossFunction, LossFunction
 19from .config import LightningIRConfig
 20from .model import LightningIRModel, LightningIROutput
 21from .tokenizer import LightningIRTokenizer
 22from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
 23
 24
[docs] 25class LightningIRModule(LightningModule): 26 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a 27 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the 28 model. Derived classes must implement the forward method for the model. 29 30 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html 31 """ 32
[docs] 33 def __init__( 34 self, 35 model_name_or_path: str | None = None, 36 config: LightningIRConfig | None = None, 37 model: LightningIRModel | None = None, 38 BackboneModel: Type[PreTrainedModel] | None = None, 39 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 40 evaluation_metrics: Sequence[str] | None = None, 41 model_kwargs: Mapping[str, Any] | None = None, 42 ): 43 """Initializes the LightningIRModule. 44 45 .. _ir-measures: https://ir-measur.es/en/latest/index.html 46 47 Args: 48 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model. 49 Defaults to None. 50 config (LightningIRConfig | None): LightningIRConfig to apply when loading from backbone model. 51 Defaults to None. 52 model (LightningIRModel | None): Already instantiated Lightning IR model. Defaults to None. 53 BackboneModel (Type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone 54 instead of the default AutoModel. Defaults to None. 55 loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None): 56 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function 57 Defaults to None. 58 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings 59 to apply during validation or testing. Defaults to None. 60 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained` when 61 loading a model. Defaults to None. 62 Raises: 63 ValueError: If both model and model_name_or_path are provided. 64 ValueError: If neither model nor model_name_or_path are provided. 65 """ 66 super().__init__() 67 model_kwargs = model_kwargs if model_kwargs is not None else {} 68 self.save_hyperparameters() 69 if model is not None and model_name_or_path is not None: 70 raise ValueError("Only one of model or model_name_or_path must be provided.") 71 if model is None: 72 if model_name_or_path is None: 73 raise ValueError("Either model or model_name_or_path must be provided.") 74 model = LightningIRModel.from_pretrained( 75 model_name_or_path, config=config, BackboneModel=BackboneModel, **model_kwargs 76 ) 77 78 self.model: LightningIRModel = model 79 self.config = self.model.config 80 self.loss_functions: List[Tuple[LossFunction, float]] | None = None 81 if loss_functions is not None: 82 self.loss_functions = [] 83 for loss_function in loss_functions: 84 if isinstance(loss_function, LossFunction): 85 self.loss_functions.append((loss_function, 1.0)) 86 else: 87 self.loss_functions.append(loss_function) 88 self.evaluation_metrics = evaluation_metrics 89 self._optimizer: torch.optim.Optimizer | None = None 90 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config) 91 self._additional_log_metrics: Dict[str, float] = {}
92
[docs] 93 def on_train_start(self) -> None: 94 """Called at the beginning of training after sanity check.""" 95 super().on_train_start() 96 # NOTE huggingface models are in eval mode by default 97 self.model = self.model.train()
98
[docs] 99 def on_validation_start(self) -> None: 100 """Called at the beginning of validation.""" 101 # NOTE monkey patch result printing of the trainer 102 try: 103 trainer = self.trainer 104 except RuntimeError: 105 trainer = None 106 if trainer is None: 107 return 108 109 trainer._evaluation_loop._print_results = lambda *args, **kwargs: None
110
[docs] 111 def on_test_start(self) -> None: 112 """Called at the beginning of testing.""" 113 self.on_validation_start()
114
[docs] 115 def configure_optimizers(self) -> torch.optim.Optimizer: 116 """Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR 117 programmatically, the optimizer must be set using :meth:`set_optimizer`. 118 119 Returns: 120 torch.optim.Optimizer: The optimizer set for the model. 121 Raises: 122 ValueError: If optimizer is not set. Call `set_optimizer`. 123 """ 124 if self._optimizer is None: 125 raise ValueError("Optimizer is not set. Call `set_optimizer`.") 126 return self._optimizer
127
[docs] 128 def set_optimizer( 129 self, optimizer: Type[torch.optim.Optimizer], **optimizer_kwargs: Dict[str, Any] 130 ) -> "LightningIRModule": 131 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI. 132 133 Args: 134 optimizer (Type[torch.optim.Optimizer]): Torch optimizer class. 135 optimizer_kwargs (Dict[str, Any]): Arguments to initialize the optimizer. 136 Returns: 137 LightningIRModule: Self with the optimizer set. 138 """ 139 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs) 140 return self
141
[docs] 142 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput: 143 """Computes relevance scores for queries and documents. 144 145 Args: 146 queries (Sequence[str] | str): Queries to score. 147 docs (Sequence[Sequence[str]] | Sequence[str]): Documents to score. 148 Returns: 149 LightningIROutput: Model output containing the scores. 150 """ 151 if isinstance(queries, str): 152 queries = (queries,) 153 if isinstance(docs[0], str): 154 docs = (docs,) 155 batch = RankBatch(queries, docs, None, None) 156 with torch.no_grad(): 157 return self.forward(batch)
158
[docs] 159 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput: 160 """Handles the forward pass of the model. 161 162 Args: 163 batch (TrainBatch | RankBatch | SearchBatch): Batch of training or ranking data. 164 Returns: 165 LightningIROutput: Model output. 166 Raises: 167 NotImplementedError: Must be implemented by derived class. 168 """ 169 raise NotImplementedError
170
[docs] 171 def prepare_input( 172 self, queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None 173 ) -> Dict[str, BatchEncoding]: 174 """Tokenizes queries and documents and returns the tokenized BatchEncoding_. 175 176 .. _BatchEncoding: https://huggingface.co/transformers/main_classes/tokenizer#transformers.BatchEncoding 177 178 Args: 179 queries (Sequence[str] | None): Queries to tokenize. 180 docs (Sequence[str] | None): Documents to tokenize. 181 num_docs (Sequence[int] | int | None): Number of documents per query, if None num_docs is inferred by 182 `len(docs) // len(queries)`. Defaults to None. 183 Returns: 184 Dict[str, BatchEncoding]: Tokenized queries and documents, format depends on the tokenizer. 185 """ 186 encodings = self.tokenizer.tokenize( 187 queries, docs, return_tensors="pt", padding=True, truncation=True, num_docs=num_docs 188 ) 189 for key in encodings: 190 encodings[key] = encodings[key].to(self.device) 191 return encodings
192 193 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]: 194 """Computes the losses for a training batch.""" 195 raise NotImplementedError 196
[docs] 197 def training_step(self, batch: TrainBatch, batch_idx: int) -> Dict[str, Any]: 198 """Handles the training step for the model. 199 200 Args: 201 batch (TrainBatch): Batch of training data. 202 batch_idx (int): Index of the batch. 203 Returns: 204 torch.Tensor: Sum of the losses weighted by the loss weights. 205 Raises: 206 ValueError: If no loss functions are set. 207 """ 208 if self.loss_functions is None: 209 raise ValueError("Loss functions are not set") 210 output = self.forward(batch) 211 losses = self._compute_losses(batch, output) 212 total_loss = torch.tensor(0) 213 assert len(losses) == len(self.loss_functions) 214 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses): 215 self.log(loss_function.__class__.__name__, loss) 216 total_loss = total_loss + loss * loss_weight 217 self.log("loss", total_loss, prog_bar=True) 218 return {"loss": total_loss, "output": output}
219
[docs] 220 def validation_step( 221 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0 222 ) -> LightningIROutput: 223 """Handles the validation step for the model. 224 225 Args: 226 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 227 batch_idx (int): Index of the batch. 228 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 229 Returns: 230 LightningIROutput: Model output. 231 """ 232 output = self.forward(batch) 233 234 if self.evaluation_metrics is None: 235 return output 236 237 dataset = self.get_dataset(dataloader_idx) 238 dataset_id = str(dataloader_idx) if dataset is None else self.get_dataset_id(dataset) 239 metrics = self.validate(output, batch) 240 for key, value in metrics.items(): 241 key = f"{dataset_id}/{key}" 242 self.log(key, value, batch_size=len(batch.queries)) 243 return output
244
[docs] 245 def test_step( 246 self, 247 batch: TrainBatch | RankBatch, 248 batch_idx: int, 249 dataloader_idx: int = 0, 250 ) -> LightningIROutput: 251 """Handles the testing step for the model. Passes the batch to the validation step. 252 253 Args: 254 batch (TrainBatch | RankBatch): Batch of testing data. 255 batch_idx (int): Index of the batch. 256 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 257 Returns: 258 LightningIROutput: Model output. 259 """ 260 return self.validation_step(batch, batch_idx, dataloader_idx)
261
[docs] 262 def get_dataset(self, dataloader_idx: int) -> IRDataset | None: 263 """Gets the dataset instance from the dataloader index. Returns None if no dataset is found. 264 265 Args: 266 dataloader_idx (int): Index of the dataloader. 267 Returns: 268 IRDataset | None: Inference dataset or None if no dataset is found. 269 """ 270 try: 271 trainer = self.trainer 272 except RuntimeError: 273 trainer = None 274 if trainer is None: 275 return None 276 STAGE_TO_DATALOADER = { 277 RunningStage.VALIDATING: "val_dataloaders", 278 RunningStage.TESTING: "test_dataloaders", 279 RunningStage.PREDICTING: "predict_dataloaders", 280 RunningStage.SANITY_CHECKING: "val_dataloaders", 281 } 282 if trainer.state.stage is None: 283 return None 284 dataloaders = getattr(trainer, STAGE_TO_DATALOADER[trainer.state.stage], None) 285 if dataloaders is None: 286 return None 287 if isinstance(dataloaders, torch.utils.data.DataLoader): 288 dataloaders = [dataloaders] 289 return dataloaders[dataloader_idx].dataset
290
[docs] 291 def get_dataset_id(self, dataset: IRDataset) -> str: 292 """Gets the dataset id from the dataloader index for logging. 293 294 .. _ir-datasets: https://ir-datasets.com/ 295 296 Args: 297 dataset (IRDataset): Dataset instance. 298 Returns: 299 str: Path to run file, ir-datasets_ dataset id, or dataloader index. 300 """ 301 if isinstance(dataset, RunDataset) and dataset.run_path is not None: 302 dataset_id = dataset.run_path.name 303 else: 304 dataset_id = dataset.dataset_id 305 return dataset_id
306
[docs] 307 def validate( 308 self, 309 output: LightningIROutput, 310 batch: TrainBatch | RankBatch | SearchBatch, 311 ) -> Dict[str, float]: 312 """Validates the model output with the evaluation metrics and loss functions. 313 314 Args: 315 output (LightningIROutput): Model output. 316 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 317 Returns: 318 Dict[str, float]: Dictionary of evaluation metrics. 319 """ 320 metrics: Dict[str, float] = {} 321 if self.evaluation_metrics is None or output.scores is None: 322 return metrics 323 metrics.update(self.validate_metrics(output, batch)) 324 metrics.update(self.validate_loss(output, batch)) 325 return metrics
326
[docs] 327 def validate_metrics( 328 self, 329 output: LightningIROutput, 330 batch: TrainBatch | RankBatch | SearchBatch, 331 ) -> Dict[str, float]: 332 """Validates the model output with the evaluation metrics. 333 334 Args: 335 output (LightningIROutput): Model output. 336 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 337 Returns: 338 Dict[str, float]: Dictionary of evaluation metrics. 339 Raises: 340 ValueError: If query_ids or doc_ids are not set in the batch. 341 """ 342 metrics: Dict[str, float] = {} 343 qrels = batch.qrels 344 if self.evaluation_metrics is None or qrels is None: 345 return metrics 346 query_ids = batch.query_ids 347 doc_ids = batch.doc_ids 348 if query_ids is None: 349 raise ValueError("query_ids must be set") 350 if doc_ids is None: 351 raise ValueError("doc_ids must be set") 352 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"] 353 ir_measures_qrels = create_qrels_from_dicts(qrels) 354 if evaluation_metrics and qrels is not None and output.scores is not None: 355 run = create_run_from_scores(query_ids, doc_ids, output.scores) 356 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics)) 357 return metrics
358
[docs] 359 def validate_loss( 360 self, 361 output: LightningIROutput, 362 batch: TrainBatch | RankBatch | SearchBatch, 363 ) -> Dict[str, float]: 364 """Validates the model output with the loss functions. 365 366 Args: 367 output (LightningIROutput): Model output. 368 batch (TrainBatch | RankBatch | SearchBatch): Batch of validation or testing data. 369 Returns: 370 Dict[str, float]: Dictionary of evaluation metrics. 371 """ 372 metrics: Dict[str, float] = {} 373 query_ids = batch.query_ids 374 if query_ids is None: 375 raise ValueError("query_ids must be set") 376 if ( 377 self.evaluation_metrics is None 378 or "loss" not in self.evaluation_metrics 379 or getattr(batch, "targets", None) is None 380 or self.loss_functions is None 381 or output.scores is None 382 ): 383 return metrics 384 output.scores = output.scores.view(len(query_ids), -1) 385 for loss_function, _ in self.loss_functions: 386 # NOTE skip in-batch losses because they can use a lot of memory 387 if isinstance(loss_function, InBatchLossFunction): 388 continue 389 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch) 390 return metrics
391
[docs] 392 def on_validation_end(self) -> None: 393 """Prints the validation results for each dataloader.""" 394 trainer = self.trainer 395 if not (trainer.is_global_zero and trainer._evaluation_loop.verbose): 396 return 397 results = trainer.callback_metrics 398 399 data = [] 400 for key, value in {**results, **self._additional_log_metrics}.items(): 401 if "dataloader_idx" in key: 402 key = "/".join(key.split("/")[:-1]) 403 *dataset_parts, metric = key.split("/") 404 if metric.startswith("validation-"): 405 metric = metric[len("validation-") :] 406 dataset = "/".join(dataset_parts) 407 if isinstance(value, torch.Tensor): 408 value = value.item() 409 data.append({"dataset": dataset, "metric": metric, "value": value}) 410 if not data: 411 return 412 df = pd.DataFrame(data) 413 df = df.pivot(index="dataset", columns="metric", values="value") 414 df.columns.name = None 415 416 # bring into correct order when skipping inference datasets 417 datamodule = getattr(self.trainer, "datamodule", None) 418 if datamodule is not None and hasattr(datamodule, "inference_datasets"): 419 inference_datasets = datamodule.inference_datasets 420 if len(inference_datasets) != df.shape[0]: 421 raise ValueError( 422 "Number of inference datasets does not match number of dataloaders. " 423 "Check if the dataloaders are correctly configured." 424 ) 425 dataset_ids = [self.get_dataset_id(dataset) for dataset in inference_datasets] 426 df = df.reindex(dataset_ids) 427 428 trainer.print(df)
429
[docs] 430 def on_test_end(self) -> None: 431 """Prints the accumulated metrics for each dataloader.""" 432 self.on_validation_end()
433
[docs] 434 def save_pretrained(self, save_path: str | Path) -> None: 435 """Saves the model and tokenizer to the save path. 436 437 Args: 438 save_path (str | Path): Path to save the model and tokenizer. 439 """ 440 self.model.save_pretrained(save_path) 441 self.tokenizer.save_pretrained(save_path)
442
[docs] 443 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 444 """Saves the model and tokenizer to the trainer's log directory.""" 445 if self.trainer is not None and self.trainer.log_dir is not None: 446 if self.trainer.global_rank != 0: 447 return 448 _step = self.trainer.global_step 449 self.config.save_step = _step 450 log_dir = Path(self.trainer.log_dir) 451 save_path = log_dir / "huggingface_checkpoint" 452 self.save_pretrained(save_path)