Source code for lightning_ir.cross_encoder.cross_encoder_module
1"""
2Module module for cross-encoder models.
3
4This module defines the Lightning IR module class used to implement cross-encoder models.
5"""
6
7from collections.abc import Mapping, Sequence
8from typing import Any
9
10import torch
11from transformers import PreTrainedModel
12
13from ..base.module import LightningIRModule
14from ..data import RankBatch, SearchBatch, TrainBatch
15from ..loss.base import LossFunction, ScoringLossFunction
16from .cross_encoder_config import CrossEncoderConfig
17from .cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
18from .cross_encoder_tokenizer import CrossEncoderTokenizer
19
20
[docs]
21class CrossEncoderModule(LightningIRModule):
[docs]
22 def __init__(
23 self,
24 model_name_or_path: str | None = None,
25 config: CrossEncoderConfig | None = None,
26 model: CrossEncoderModel | None = None,
27 BackboneModel: type[PreTrainedModel] | None = None,
28 loss_functions: Sequence[LossFunction | tuple[LossFunction, float]] | None = None,
29 evaluation_metrics: Sequence[str] | None = None,
30 model_kwargs: Mapping[str, Any] | None = None,
31 ):
32 """:class:`.LightningIRModule` for cross-encoder models. It contains a :class:`.CrossEncoderModel` and a
33 :class:`.CrossEncoderTokenizer` and implements the training, validation, and testing steps for the model.
34
35 .. _ir-measures: https://ir-measur.es/en/latest/index.html
36
37 Args:
38 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model.
39 Defaults to None.
40 config (CrossEncoderConfig | None): CrossEncoderConfig to apply when loading from backbone model.
41 Defaults to None.
42 model (CrossEncoderModel | None): Already instantiated CrossEncoderModel. Defaults to None.
43 BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone
44 instead of the default AutoModel. Defaults to None.
45 loss_functions (Sequence[LossFunction | tuple[LossFunction, float]] | None):
46 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function.
47 Defaults to None.
48 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings to apply
49 during validation or testing. Defaults to None.
50 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained` when
51 loading a model. Defaults to None.
52 """
53 super().__init__(
54 model_name_or_path=model_name_or_path,
55 config=config,
56 model=model,
57 BackboneModel=BackboneModel,
58 loss_functions=loss_functions,
59 evaluation_metrics=evaluation_metrics,
60 model_kwargs=model_kwargs,
61 )
62 self.model: CrossEncoderModel
63 self.config: CrossEncoderConfig
64 self.tokenizer: CrossEncoderTokenizer
65
[docs]
66 def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput:
67 """Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the
68 backbone model as well as the relevance scores.
69
70 Args:
71 batch (RankBatch | TrainBatch | SearchBatch): Batch of data to run the forward pass on.
72 Returns:
73 CrossEncoderOutput: Output of the model.
74 Raises:
75 ValueError: If the batch is a SearchBatch.
76 """
77 if isinstance(batch, SearchBatch):
78 raise ValueError("Searching is not available for cross-encoders")
79 queries = batch.queries
80 docs = [d for docs in batch.docs for d in docs]
81 num_docs = [len(docs) for docs in batch.docs]
82 encoding = self.prepare_input(queries, docs, num_docs)
83 output = self.model.forward(encoding["encoding"])
84 return output
85
86 def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> list[torch.Tensor]:
87 """Computes the losses for a training batch."""
88 if self.loss_functions is None:
89 raise ValueError("loss_functions must be set in the module")
90
91 output.scores = output.scores.view(len(batch.query_ids), -1)
92 batch.targets = batch.targets.view(*output.scores.shape, -1)
93
94 losses = []
95 for loss_function, _ in self.loss_functions:
96 if not isinstance(loss_function, ScoringLossFunction):
97 raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function")
98 losses.append(loss_function.compute_loss(output, batch))
99 return losses