Source code for lightning_ir.cross_encoder.cross_encoder_module

 1"""
 2Module module for cross-encoder models.
 3
 4This module defines the Lightning IR module class used to implement cross-encoder models.
 5"""
 6
 7from typing import Any, List, Mapping, Sequence, Tuple, Type
 8
 9import torch
10from transformers import PreTrainedModel
11
12from ..base.module import LightningIRModule
13from ..data import RankBatch, SearchBatch, TrainBatch
14from ..loss.base import LossFunction, ScoringLossFunction
15from .cross_encoder_config import CrossEncoderConfig
16from .cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
17from .cross_encoder_tokenizer import CrossEncoderTokenizer
18
19
[docs] 20class CrossEncoderModule(LightningIRModule):
[docs] 21 def __init__( 22 self, 23 model_name_or_path: str | None = None, 24 config: CrossEncoderConfig | None = None, 25 model: CrossEncoderModel | None = None, 26 BackboneModel: Type[PreTrainedModel] | None = None, 27 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 28 evaluation_metrics: Sequence[str] | None = None, 29 model_kwargs: Mapping[str, Any] | None = None, 30 ): 31 """:class:`.LightningIRModule` for cross-encoder models. It contains a :class:`.CrossEncoderModel` and a 32 :class:`.CrossEncoderTokenizer` and implements the training, validation, and testing steps for the model. 33 34 .. _ir-measures: https://ir-measur.es/en/latest/index.html 35 36 Args: 37 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model. 38 Defaults to None. 39 config (CrossEncoderConfig | None): CrossEncoderConfig to apply when loading from backbone model. 40 Defaults to None. 41 model (CrossEncoderModel | None): Already instantiated CrossEncoderModel. Defaults to None. 42 BackboneModel (Type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone 43 instead of the default AutoModel. Defaults to None. 44 loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None): 45 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function. 46 Defaults to None. 47 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings to apply 48 during validation or testing. Defaults to None. 49 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained` when 50 loading a model. Defaults to None. 51 """ 52 super().__init__( 53 model_name_or_path=model_name_or_path, 54 config=config, 55 model=model, 56 BackboneModel=BackboneModel, 57 loss_functions=loss_functions, 58 evaluation_metrics=evaluation_metrics, 59 model_kwargs=model_kwargs, 60 ) 61 self.model: CrossEncoderModel 62 self.config: CrossEncoderConfig 63 self.tokenizer: CrossEncoderTokenizer
64
[docs] 65 def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput: 66 """Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the 67 backbone model as well as the relevance scores. 68 69 Args: 70 batch (RankBatch | TrainBatch | SearchBatch): Batch of data to run the forward pass on. 71 Returns: 72 CrossEncoderOutput: Output of the model. 73 Raises: 74 ValueError: If the batch is a SearchBatch. 75 """ 76 if isinstance(batch, SearchBatch): 77 raise ValueError("Searching is not available for cross-encoders") 78 queries = batch.queries 79 docs = [d for docs in batch.docs for d in docs] 80 num_docs = [len(docs) for docs in batch.docs] 81 encoding = self.prepare_input(queries, docs, num_docs) 82 output = self.model.forward(encoding["encoding"]) 83 return output
84 85 def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List[torch.Tensor]: 86 """Computes the losses for a training batch.""" 87 if self.loss_functions is None: 88 raise ValueError("loss_functions must be set in the module") 89 90 output.scores = output.scores.view(len(batch.query_ids), -1) 91 batch.targets = batch.targets.view(*output.scores.shape, -1) 92 93 losses = [] 94 for loss_function, _ in self.loss_functions: 95 if not isinstance(loss_function, ScoringLossFunction): 96 raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function") 97 losses.append(loss_function.compute_loss(output, batch)) 98 return losses