Source code for lightning_ir.cross_encoder.cross_encoder_module

 1"""
 2Module module for cross-encoder models.
 3
 4This module defines the Lightning IR module class used to implement cross-encoder models.
 5"""
 6
 7from typing import Any, List, Mapping, Sequence, Tuple
 8
 9import torch
10
11from ..base.module import LightningIRModule
12from ..data import RankBatch, SearchBatch, TrainBatch
13from ..loss.loss import LossFunction, ScoringLossFunction
14from .cross_encoder_config import CrossEncoderConfig
15from .cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
16from .cross_encoder_tokenizer import CrossEncoderTokenizer
17
18
[docs] 19class CrossEncoderModule(LightningIRModule):
[docs] 20 def __init__( 21 self, 22 model_name_or_path: str | None = None, 23 config: CrossEncoderConfig | None = None, 24 model: CrossEncoderModel | None = None, 25 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 26 evaluation_metrics: Sequence[str] | None = None, 27 model_kwargs: Mapping[str, Any] | None = None, 28 ): 29 """:class:`.LightningIRModule` for cross-encoder models. It contains a :class:`.CrossEncoderModel` and a 30 :class:`.CrossEncoderTokenizer` and implements the training, validation, and testing steps for the model. 31 32 .. _ir-measures: https://ir-measur.es/en/latest/index.html 33 34 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None 35 :type model_name_or_path: str | None, optional 36 :param config: CrossEncoderConfig to apply when loading from backbone model, defaults to None 37 :type config: CrossEncoderConfig | None, optional 38 :param model: Already instantiated CrossEncoderModel, defaults to None 39 :type model: CrossEncoderModel | None, optional 40 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per 41 loss function, defaults to None 42 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional 43 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or 44 testing, defaults to None 45 :type evaluation_metrics: Sequence[str] | None, optional 46 :param model_kwargs: Additional keyword arguments to pass to `from_pretrained` when loading a model, 47 defaults to None 48 :type model_kwargs: Mapping[str, Any] | None, optional 49 """ 50 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics, model_kwargs) 51 self.model: CrossEncoderModel 52 self.config: CrossEncoderConfig 53 self.tokenizer: CrossEncoderTokenizer
54
[docs] 55 def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput: 56 """Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the 57 backbone model as well as the relevance scores. 58 59 :param batch: Batch of data to run the forward pass on 60 :type batch: RankBatch | TrainBatch | SearchBatch 61 :raises ValueError: If the batch is a SearchBatch 62 :return: Output of the model 63 :rtype: CrossEncoderOutput 64 """ 65 if isinstance(batch, SearchBatch): 66 raise ValueError("Searching is not available for cross-encoders") 67 queries = batch.queries 68 docs = [d for docs in batch.docs for d in docs] 69 num_docs = [len(docs) for docs in batch.docs] 70 encoding = self.prepare_input(queries, docs, num_docs) 71 output = self.model.forward(encoding["encoding"]) 72 return output
73 74 def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List[torch.Tensor]: 75 """Computes the losses for a training batch.""" 76 if self.loss_functions is None: 77 raise ValueError("loss_functions must be set in the module") 78 79 output.scores = output.scores.view(len(batch.query_ids), -1) 80 batch.targets = batch.targets.view(*output.scores.shape, -1) 81 82 losses = [] 83 for loss_function, _ in self.loss_functions: 84 if not isinstance(loss_function, ScoringLossFunction): 85 raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function") 86 losses.append(loss_function.compute_loss(output, batch)) 87 return losses