1"""
2Module module for bi-encoder models.
3
4This module defines the Lightning IR module class used to implement bi-encoder models.
5"""
6
7from __future__ import annotations
8
9from collections.abc import Mapping, Sequence
10from pathlib import Path
11from typing import TYPE_CHECKING, Any
12
13import torch
14from transformers import BatchEncoding, PreTrainedModel
15
16from ..base import LightningIRModule, LightningIROutput
17from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
18from ..loss.base import EmbeddingLossFunction, LossFunction, ScoringLossFunction
19from ..loss.in_batch import InBatchLossFunction
20from .bi_encoder_config import BiEncoderConfig
21from .bi_encoder_model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
22from .bi_encoder_tokenizer import BiEncoderTokenizer
23
24if TYPE_CHECKING:
25 from ..retrieve import SearchConfig, Searcher
26
27
[docs]
28class BiEncoderModule(LightningIRModule):
[docs]
29 def __init__(
30 self,
31 model_name_or_path: str | None = None,
32 config: BiEncoderConfig | None = None,
33 model: BiEncoderModel | None = None,
34 BackboneModel: type[PreTrainedModel] | None = None,
35 loss_functions: Sequence[LossFunction | tuple[LossFunction, float]] | None = None,
36 evaluation_metrics: Sequence[str] | None = None,
37 index_dir: Path | None = None,
38 search_config: SearchConfig | None = None,
39 model_kwargs: Mapping[str, Any] | None = None,
40 ):
41 """:class:`.LightningIRModule` for bi-encoder models. It contains a :class:`.BiEncoderModel` and a
42 :class:`.BiEncoderTokenizer` and implements the training, validation, and testing steps for the model.
43
44 .. _ir-measures: https://ir-measur.es/en/latest/index.html
45
46 Args:
47 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model.
48 Defaults to None.
49 config (BiEncoderConfig | None): BiEncoderConfig to apply when loading from backbone model.
50 Defaults to None.
51 model (BiEncoderModel | None): Already instantiated BiEncoderModel. Defaults to None.
52 BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone
53 instead of the default AutoModel. Defaults to None.
54 loss_functions (Sequence[LossFunction | tuple[LossFunction, float]] | None):
55 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function
56 Defaults to None.
57 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings
58 to apply during validation or testing. Defaults to None.
59 index_dir (Path | None): Path to an index used for retrieval. Defaults to None.
60 search_config (SearchConfig | None): Configuration to use during retrieval. Defaults to None.
61 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained`
62 when loading a model. Defaults to None.
63 """
64 super().__init__(
65 model_name_or_path=model_name_or_path,
66 config=config,
67 model=model,
68 BackboneModel=BackboneModel,
69 loss_functions=loss_functions,
70 evaluation_metrics=evaluation_metrics,
71 model_kwargs=model_kwargs,
72 )
73 self.model: BiEncoderModel
74 self.config: BiEncoderConfig
75 self.tokenizer: BiEncoderTokenizer
76 if len(self.tokenizer) > self.config.vocab_size:
77 self.model.resize_token_embeddings(len(self.tokenizer), 8)
78 self._searcher = None
79 self.search_config = search_config
80 self.index_dir = index_dir
81
82 @property
83 def searcher(self) -> Searcher | None:
84 """Searcher used for retrieval if `index_dir` and `search_config` are set.
85
86 Returns:
87 Searcher: Searcher class.
88 """
89 return self._searcher
90
91 @searcher.setter
92 def searcher(self, searcher: Searcher):
93 self._searcher = searcher
94
95 def _init_searcher(self) -> None:
96 if self.search_config is not None and self.index_dir is not None:
97 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self)
98
[docs]
99 def on_test_start(self) -> None:
100 """Called at the beginning of testing. Initializes the searcher if `index_dir` and `search_config` are set."""
101 self._init_searcher()
102 return super().on_test_start()
103
[docs]
104 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput:
105 """Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If
106 the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is the
107 similarity between the two embeddings. If the batch is an :class:`.IndexBatch`, only document embeddings
108 are comuputed. If the batch is a :class:`.SearchBatch`, only query embeddings are computed and
109 the model will additionally retrieve documents if :attr:`.searcher` is set.
110
111 Args:
112 batch (RankBatch | IndexBatch | SearchBatch): Input batch containing queries and/or documents.
113 Returns:
114 BiEncoderOutput: Output of the model.
115 Raises:
116 ValueError: If the input batch contains neither queries nor documents.
117 """
118 queries = getattr(batch, "queries", None)
119 docs = getattr(batch, "docs", None)
120 num_docs = None
121 if isinstance(batch, RankBatch):
122 num_docs = None if docs is None else [len(d) for d in docs]
123 docs = [d for nested in docs for d in nested] if docs is not None else None
124 encodings = self.prepare_input(queries, docs, num_docs)
125
126 if not encodings:
127 raise ValueError("No encodings were generated.")
128 output = self.model.forward(
129 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs
130 )
131 doc_ids = getattr(batch, "doc_ids", None)
132 if doc_ids is not None and output.doc_embeddings is not None:
133 output.doc_embeddings.ids = doc_ids
134 query_ids = getattr(batch, "query_ids", None)
135 if query_ids is not None and output.query_embeddings is not None:
136 output.query_embeddings.ids = query_ids
137 if isinstance(batch, SearchBatch) and self.searcher is not None:
138 scores, doc_ids = self.searcher.search(output)
139 output.scores = scores
140 if output.doc_embeddings is not None:
141 output.doc_embeddings.ids = [doc_id for _doc_ids in doc_ids for doc_id in _doc_ids]
142 batch.doc_ids = doc_ids
143 return output
144
[docs]
145 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput:
146 """Computes relevance scores for queries and documents.
147
148 Args:
149 queries (Sequence[str] | str): Queries to score.
150 docs (Sequence[Sequence[str]] | Sequence[str]): Documents to score.
151 Returns:
152 BiEncoderOutput: Output of the model.
153 """
154 return super().score(queries, docs)
155
156 def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> list[torch.Tensor]:
157 """Computes the losses for a training batch."""
158 if self.loss_functions is None:
159 raise ValueError("Loss function is not set")
160
161 if (
162 batch.targets is None
163 or output.query_embeddings is None
164 or output.doc_embeddings is None
165 or output.scores is None
166 ):
167 raise ValueError(
168 "targets, scores, query_embeddings, and doc_embeddings must be set in the output and batch"
169 )
170
171 num_queries = len(batch.queries)
172 output.scores = output.scores.view(num_queries, -1)
173 batch.targets = batch.targets.view(*output.scores.shape, -1)
174 losses = []
175 for loss_function, _ in self.loss_functions:
176 if isinstance(loss_function, InBatchLossFunction):
177 pos_idcs, neg_idcs = loss_function.get_ib_idcs(output, batch)
178 ib_doc_embeddings = self._get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries)
179 ib_scores = self.model.score(
180 BiEncoderOutput(query_embeddings=output.query_embeddings, doc_embeddings=ib_doc_embeddings)
181 ).scores
182 if ib_scores is None:
183 raise ValueError("In-batch scores cannot be None")
184 ib_scores = ib_scores.view(num_queries, -1)
185 losses.append(loss_function.compute_loss(LightningIROutput(ib_scores)))
186 elif isinstance(loss_function, EmbeddingLossFunction):
187 losses.append(loss_function.compute_loss(output))
188 elif isinstance(loss_function, ScoringLossFunction):
189 losses.append(loss_function.compute_loss(output, batch))
190 else:
191 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}")
192 if self.config.sparsification_strategy is not None:
193 query_num_nonzero = (
194 torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0]
195 )
196 doc_num_nonzero = (
197 torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0]
198 )
199 self.log("query_num_nonzero", query_num_nonzero)
200 self.log("doc_num_nonzero", doc_num_nonzero)
201 return losses
202
203 def _get_ib_doc_embeddings(
204 self,
205 embeddings: BiEncoderEmbedding,
206 pos_idcs: torch.Tensor,
207 neg_idcs: torch.Tensor,
208 num_queries: int,
209 ) -> BiEncoderEmbedding:
210 """Gets the in-batch document embeddings for a training batch."""
211 _, num_embs, emb_dim = embeddings.embeddings.shape
212 ib_embeddings = torch.cat(
213 [
214 embeddings.embeddings[pos_idcs].view(num_queries, -1, num_embs, emb_dim),
215 embeddings.embeddings[neg_idcs].view(num_queries, -1, num_embs, emb_dim),
216 ],
217 dim=1,
218 ).view(-1, num_embs, emb_dim)
219 if embeddings.scoring_mask is None:
220 ib_scoring_mask = None
221 else:
222 ib_scoring_mask = torch.cat(
223 [
224 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, num_embs),
225 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, num_embs),
226 ],
227 dim=1,
228 ).view(-1, num_embs)
229 if embeddings.encoding is None:
230 ib_encoding = None
231 else:
232 ib_encoding = {}
233 for key, value in embeddings.encoding.items():
234 seq_len = value.shape[-1]
235 ib_encoding[key] = torch.cat(
236 [value[pos_idcs].view(num_queries, -1, seq_len), value[neg_idcs].view(num_queries, -1, seq_len)],
237 dim=1,
238 ).view(-1, seq_len)
239 ib_encoding = BatchEncoding(ib_encoding)
240 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask, ib_encoding)
241
[docs]
242 def validation_step(
243 self,
244 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch,
245 batch_idx: int,
246 dataloader_idx: int = 0,
247 ) -> BiEncoderOutput:
248 """Handles the validation step for the model.
249
250 Args:
251 batch (TrainBatch | IndexBatch | SearchBatch | RankBatch): Batch of validation or testing data.
252 batch_idx (int): Index of the batch.
253 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
254 Returns:
255 BiEncoderOutput: Output of the model.
256 """
257 if isinstance(batch, IndexBatch):
258 return self.forward(batch)
259 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)):
260 return super().validation_step(batch, batch_idx, dataloader_idx)
261 raise ValueError(f"Unknown batch type {type(batch)}")