1"""
2Module module for bi-encoder models.
3
4This module defines the Lightning IR module class used to implement bi-encoder models.
5"""
6
7from __future__ import annotations
8
9from pathlib import Path
10from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple
11
12import torch
13from transformers import BatchEncoding
14
15from ..base import LightningIRModule, LightningIROutput
16from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
17from ..loss.base import EmbeddingLossFunction, LossFunction, ScoringLossFunction
18from ..loss.in_batch import InBatchLossFunction
19from .bi_encoder_config import BiEncoderConfig
20from .bi_encoder_model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
21from .bi_encoder_tokenizer import BiEncoderTokenizer
22
23if TYPE_CHECKING:
24 from ..retrieve import SearchConfig, Searcher
25
26
[docs]
27class BiEncoderModule(LightningIRModule):
[docs]
28 def __init__(
29 self,
30 model_name_or_path: str | None = None,
31 config: BiEncoderConfig | None = None,
32 model: BiEncoderModel | None = None,
33 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
34 evaluation_metrics: Sequence[str] | None = None,
35 index_dir: Path | None = None,
36 search_config: SearchConfig | None = None,
37 model_kwargs: Mapping[str, Any] | None = None,
38 ):
39 """:class:`.LightningIRModule` for bi-encoder models. It contains a :class:`.BiEncoderModel` and a
40 :class:`.BiEncoderTokenizer` and implements the training, validation, and testing steps for the model.
41
42 .. _ir-measures: https://ir-measur.es/en/latest/index.html
43
44 Args:
45 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model.
46 Defaults to None.
47 config (BiEncoderConfig | None): BiEncoderConfig to apply when loading from backbone model.
48 Defaults to None.
49 model (BiEncoderModel | None): Already instantiated BiEncoderModel. Defaults to None.
50 loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None):
51 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function
52 Defaults to None.
53 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings
54 to apply during validation or testing. Defaults to None.
55 index_dir (Path | None): Path to an index used for retrieval. Defaults to None.
56 search_config (SearchConfig | None): Configuration to use during retrieval. Defaults to None.
57 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained`
58 when loading a model. Defaults to None.
59 """
60 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics, model_kwargs)
61 self.model: BiEncoderModel
62 self.config: BiEncoderConfig
63 self.tokenizer: BiEncoderTokenizer
64 if len(self.tokenizer) > self.config.vocab_size:
65 self.model.resize_token_embeddings(len(self.tokenizer), 8)
66 self._searcher = None
67 self.search_config = search_config
68 self.index_dir = index_dir
69
70 @property
71 def searcher(self) -> Searcher | None:
72 """Searcher used for retrieval if `index_dir` and `search_config` are set.
73
74 Returns:
75 Searcher: Searcher class.
76 """
77 return self._searcher
78
79 @searcher.setter
80 def searcher(self, searcher: Searcher):
81 self._searcher = searcher
82
83 def _init_searcher(self) -> None:
84 if self.search_config is not None and self.index_dir is not None:
85 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self)
86
[docs]
87 def on_test_start(self) -> None:
88 """Called at the beginning of testing. Initializes the searcher if `index_dir` and `search_config` are set."""
89 self._init_searcher()
90 return super().on_test_start()
91
[docs]
92 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput:
93 """Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If
94 the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is the
95 similarity between the two embeddings. If the batch is an :class:`.IndexBatch`, only document embeddings
96 are comuputed. If the batch is a :class:`.SearchBatch`, only query embeddings are computed and
97 the model will additionally retrieve documents if :attr:`.searcher` is set.
98
99 Args:
100 batch (RankBatch | IndexBatch | SearchBatch): Input batch containing queries and/or documents.
101 Returns:
102 BiEncoderOutput: Output of the model.
103 Raises:
104 ValueError: If the input batch contains neither queries nor documents.
105 """
106 queries = getattr(batch, "queries", None)
107 docs = getattr(batch, "docs", None)
108 num_docs = None
109 if isinstance(batch, RankBatch):
110 num_docs = None if docs is None else [len(d) for d in docs]
111 docs = [d for nested in docs for d in nested] if docs is not None else None
112 encodings = self.prepare_input(queries, docs, num_docs)
113
114 if not encodings:
115 raise ValueError("No encodings were generated.")
116 output = self.model.forward(
117 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs
118 )
119 doc_ids = getattr(batch, "doc_ids", None)
120 if doc_ids is not None and output.doc_embeddings is not None:
121 output.doc_embeddings.ids = doc_ids
122 query_ids = getattr(batch, "query_ids", None)
123 if query_ids is not None and output.query_embeddings is not None:
124 output.query_embeddings.ids = query_ids
125 if isinstance(batch, SearchBatch) and self.searcher is not None:
126 scores, doc_ids = self.searcher.search(output)
127 output.scores = scores
128 if output.doc_embeddings is not None:
129 output.doc_embeddings.ids = [doc_id for _doc_ids in doc_ids for doc_id in _doc_ids]
130 batch.doc_ids = doc_ids
131 return output
132
[docs]
133 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput:
134 """Computes relevance scores for queries and documents.
135
136 Args:
137 queries (Sequence[str] | str): Queries to score.
138 docs (Sequence[Sequence[str]] | Sequence[str]): Documents to score.
139 Returns:
140 BiEncoderOutput: Output of the model.
141 """
142 return super().score(queries, docs)
143
144 def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[torch.Tensor]:
145 """Computes the losses for a training batch."""
146 if self.loss_functions is None:
147 raise ValueError("Loss function is not set")
148
149 if (
150 batch.targets is None
151 or output.query_embeddings is None
152 or output.doc_embeddings is None
153 or output.scores is None
154 ):
155 raise ValueError(
156 "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch"
157 )
158
159 num_queries = len(batch.queries)
160 output.scores = output.scores.view(num_queries, -1)
161 batch.targets = batch.targets.view(*output.scores.shape, -1)
162 losses = []
163 for loss_function, _ in self.loss_functions:
164 if isinstance(loss_function, InBatchLossFunction):
165 pos_idcs, neg_idcs = loss_function.get_ib_idcs(output, batch)
166 ib_doc_embeddings = self._get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries)
167 ib_scores = self.model.score(
168 BiEncoderOutput(query_embeddings=output.query_embeddings, doc_embeddings=ib_doc_embeddings)
169 ).scores
170 if ib_scores is None:
171 raise ValueError("In-batch scores cannot be None")
172 ib_scores = ib_scores.view(num_queries, -1)
173 losses.append(loss_function.compute_loss(LightningIROutput(ib_scores)))
174 elif isinstance(loss_function, EmbeddingLossFunction):
175 losses.append(loss_function.compute_loss(output))
176 elif isinstance(loss_function, ScoringLossFunction):
177 losses.append(loss_function.compute_loss(output, batch))
178 else:
179 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}")
180 if self.config.sparsification is not None:
181 query_num_nonzero = (
182 torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0]
183 )
184 doc_num_nonzero = (
185 torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0]
186 )
187 self.log("query_num_nonzero", query_num_nonzero)
188 self.log("doc_num_nonzero", doc_num_nonzero)
189 return losses
190
191 def _get_ib_doc_embeddings(
192 self,
193 embeddings: BiEncoderEmbedding,
194 pos_idcs: torch.Tensor,
195 neg_idcs: torch.Tensor,
196 num_queries: int,
197 ) -> BiEncoderEmbedding:
198 """Gets the in-batch document embeddings for a training batch."""
199 _, num_embs, emb_dim = embeddings.embeddings.shape
200 ib_embeddings = torch.cat(
201 [
202 embeddings.embeddings[pos_idcs].view(num_queries, -1, num_embs, emb_dim),
203 embeddings.embeddings[neg_idcs].view(num_queries, -1, num_embs, emb_dim),
204 ],
205 dim=1,
206 ).view(-1, num_embs, emb_dim)
207 if embeddings.scoring_mask is None:
208 ib_scoring_mask = None
209 else:
210 ib_scoring_mask = torch.cat(
211 [
212 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, num_embs),
213 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, num_embs),
214 ],
215 dim=1,
216 ).view(-1, num_embs)
217 if embeddings.encoding is None:
218 ib_encoding = None
219 else:
220 ib_encoding = {}
221 for key, value in embeddings.encoding.items():
222 seq_len = value.shape[-1]
223 ib_encoding[key] = torch.cat(
224 [value[pos_idcs].view(num_queries, -1, seq_len), value[neg_idcs].view(num_queries, -1, seq_len)],
225 dim=1,
226 ).view(-1, seq_len)
227 ib_encoding = BatchEncoding(ib_encoding)
228 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask, ib_encoding)
229
[docs]
230 def validation_step(
231 self,
232 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch,
233 batch_idx: int,
234 dataloader_idx: int = 0,
235 ) -> BiEncoderOutput:
236 """Handles the validation step for the model.
237
238 Args:
239 batch (TrainBatch | IndexBatch | SearchBatch | RankBatch): Batch of validation or testing data.
240 batch_idx (int): Index of the batch.
241 dataloader_idx (int, optional): Index of the dataloader. Defaults to 0.
242 Returns:
243 BiEncoderOutput: Output of the model.
244 """
245 if isinstance(batch, IndexBatch):
246 return self.forward(batch)
247 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)):
248 return super().validation_step(batch, batch_idx, dataloader_idx)
249 raise ValueError(f"Unknown batch type {type(batch)}")