1"""
2Module module for bi-encoder models.
3
4This module defines the Lightning IR module class used to implement bi-encoder models.
5"""
6
7from __future__ import annotations
8
9from pathlib import Path
10from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple
11
12import torch
13from transformers import BatchEncoding
14
15from ..base import LightningIRModule, LightningIROutput
16from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
17from ..loss.loss import EmbeddingLossFunction, InBatchLossFunction, LossFunction, ScoringLossFunction
18from .bi_encoder_config import BiEncoderConfig
19from .bi_encoder_model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
20from .bi_encoder_tokenizer import BiEncoderTokenizer
21
22if TYPE_CHECKING:
23 from ..retrieve import SearchConfig, Searcher
24
25
[docs]
26class BiEncoderModule(LightningIRModule):
[docs]
27 def __init__(
28 self,
29 model_name_or_path: str | None = None,
30 config: BiEncoderConfig | None = None,
31 model: BiEncoderModel | None = None,
32 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
33 evaluation_metrics: Sequence[str] | None = None,
34 index_dir: Path | None = None,
35 search_config: SearchConfig | None = None,
36 model_kwargs: Mapping[str, Any] | None = None,
37 ):
38 """:class:`.LightningIRModule` for bi-encoder models. It contains a :class:`.BiEncoderModel` and a
39 :class:`.BiEncoderTokenizer` and implements the training, validation, and testing steps for the model.
40
41 .. _ir-measures: https://ir-measur.es/en/latest/index.html
42
43 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
44 :type model_name_or_path: str | None, optional
45 :param config: BiEncoderConfig to apply when loading from backbone model, defaults to None
46 :type config: BiEncoderConfig | None, optional
47 :param model: Already instantiated BiEncoderModel, defaults to None
48 :type model: BiEncoderModel | None, optional
49 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
50 loss function, defaults to None
51 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional
52 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or
53 testing, defaults to None
54 :type evaluation_metrics: Sequence[str] | None, optional
55 :param index_dir: Path to an index used for retrieval, defaults to None
56 :type index_dir: Path | None, optional
57 :param search_config: Configuration to use during retrieval, defaults to None
58 :type search_config: SearchConfig | None, optional
59 :param model_kwargs: Additional keyword arguments to pass to `from_pretrained` when loading a model,
60 defaults to None
61 :type model_kwargs: Mapping[str, Any] | None, optional
62 """
63 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics, model_kwargs)
64 self.model: BiEncoderModel
65 self.config: BiEncoderConfig
66 self.tokenizer: BiEncoderTokenizer
67 if self.config.add_marker_tokens and len(self.tokenizer) > self.config.vocab_size:
68 self.model.resize_token_embeddings(len(self.tokenizer), 8)
69 self._searcher = None
70 self.search_config = search_config
71 self.index_dir = index_dir
72
73 @property
74 def searcher(self) -> Searcher | None:
75 """Searcher used for retrieval if `index_dir` and `search_config` are set.
76
77 :return: Searcher class
78 :rtype: Searcher | None
79 """
80 return self._searcher
81
82 @searcher.setter
83 def searcher(self, searcher: Searcher):
84 self._searcher = searcher
85
86 def _init_searcher(self) -> None:
87 if self.search_config is not None and self.index_dir is not None:
88 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self)
89
[docs]
90 def on_test_start(self) -> None:
91 """Called at the beginning of testing. Initializes the searcher if `index_dir` and `search_config` are set."""
92 self._init_searcher()
93 return super().on_test_start()
94
[docs]
95 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput:
96 """Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If
97 the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is the
98 similarity between the two embeddings. If the batch is an :class:`.IndexBatch`, only document embeddings
99 are comuputed. If the batch is a :class:`.SearchBatch`, only query embeddings are computed and
100 the model will additionally retrieve documents if :attr:`.searcher` is set.
101
102 :param batch: Input batch containg
103 :type batch: RankBatch | IndexBatch | SearchBatch
104 :raises ValueError: If the input batch contains neither queries nor documents
105 :return: Output of the model
106 :rtype: BiEncoderOutput
107 """
108 queries = getattr(batch, "queries", None)
109 docs = getattr(batch, "docs", None)
110 num_docs = None
111 if isinstance(batch, RankBatch):
112 num_docs = None if docs is None else [len(d) for d in docs]
113 docs = [d for nested in docs for d in nested] if docs is not None else None
114 encodings = self.prepare_input(queries, docs, num_docs)
115
116 if not encodings:
117 raise ValueError("No encodings were generated.")
118 output = self.model.forward(
119 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs
120 )
121 doc_ids = getattr(batch, "doc_ids", None)
122 if doc_ids is not None and output.doc_embeddings is not None:
123 output.doc_embeddings.ids = doc_ids
124 query_ids = getattr(batch, "query_ids", None)
125 if query_ids is not None and output.query_embeddings is not None:
126 output.query_embeddings.ids = query_ids
127 if isinstance(batch, SearchBatch) and self.searcher is not None:
128 scores, doc_ids = self.searcher.search(output)
129 output.scores = scores
130 if output.doc_embeddings is not None:
131 output.doc_embeddings.ids = [doc_id for _doc_ids in doc_ids for doc_id in _doc_ids]
132 batch.doc_ids = doc_ids
133 return output
134
[docs]
135 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput:
136 """Computes relevance scores for queries and documents.
137
138 :param queries: Queries to score
139 :type queries: Sequence[str]
140 :param docs: Documents to score
141 :type docs: Sequence[Sequence[str]]
142 :return: Model output
143 :rtype: BiEncoderOutput
144 """
145 return super().score(queries, docs)
146
147 def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[torch.Tensor]:
148 """Computes the losses for a training batch."""
149 if self.loss_functions is None:
150 raise ValueError("Loss function is not set")
151
152 if (
153 batch.targets is None
154 or output.query_embeddings is None
155 or output.doc_embeddings is None
156 or output.scores is None
157 ):
158 raise ValueError(
159 "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch"
160 )
161
162 num_queries = len(batch.queries)
163 output.scores = output.scores.view(num_queries, -1)
164 batch.targets = batch.targets.view(*output.scores.shape, -1)
165 losses = []
166 for loss_function, _ in self.loss_functions:
167 if isinstance(loss_function, InBatchLossFunction):
168 pos_idcs, neg_idcs = loss_function.get_ib_idcs(output, batch)
169 ib_doc_embeddings = self._get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries)
170 ib_scores = self.model.score(output.query_embeddings, ib_doc_embeddings)
171 ib_scores = ib_scores.view(num_queries, -1)
172 losses.append(loss_function.compute_loss(LightningIROutput(ib_scores)))
173 elif isinstance(loss_function, EmbeddingLossFunction):
174 losses.append(loss_function.compute_loss(output))
175 elif isinstance(loss_function, ScoringLossFunction):
176 losses.append(loss_function.compute_loss(output, batch))
177 else:
178 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}")
179 if self.config.sparsification is not None:
180 query_num_nonzero = (
181 torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0]
182 )
183 doc_num_nonzero = (
184 torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0]
185 )
186 self.log("query_num_nonzero", query_num_nonzero)
187 self.log("doc_num_nonzero", doc_num_nonzero)
188 return losses
189
190 def _get_ib_doc_embeddings(
191 self,
192 embeddings: BiEncoderEmbedding,
193 pos_idcs: torch.Tensor,
194 neg_idcs: torch.Tensor,
195 num_queries: int,
196 ) -> BiEncoderEmbedding:
197 """Gets the in-batch document embeddings for a training batch."""
198 _, num_embs, emb_dim = embeddings.embeddings.shape
199 ib_embeddings = torch.cat(
200 [
201 embeddings.embeddings[pos_idcs].view(num_queries, -1, num_embs, emb_dim),
202 embeddings.embeddings[neg_idcs].view(num_queries, -1, num_embs, emb_dim),
203 ],
204 dim=1,
205 ).view(-1, num_embs, emb_dim)
206 if embeddings.scoring_mask is None:
207 ib_scoring_mask = None
208 else:
209 ib_scoring_mask = torch.cat(
210 [
211 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, num_embs),
212 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, num_embs),
213 ],
214 dim=1,
215 ).view(-1, num_embs)
216 if embeddings.encoding is None:
217 ib_encoding = None
218 else:
219 ib_encoding = {}
220 for key, value in embeddings.encoding.items():
221 seq_len = value.shape[-1]
222 ib_encoding[key] = torch.cat(
223 [value[pos_idcs].view(num_queries, -1, seq_len), value[neg_idcs].view(num_queries, -1, seq_len)],
224 dim=1,
225 ).view(-1, seq_len)
226 ib_encoding = BatchEncoding(ib_encoding)
227 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask, ib_encoding)
228
[docs]
229 def validation_step(
230 self,
231 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch,
232 batch_idx: int,
233 dataloader_idx: int = 0,
234 ) -> BiEncoderOutput:
235 """Handles the validation step for the model.
236
237 :param batch: Batch of validation or testing data
238 :type batch: TrainBatch | IndexBatch | SearchBatch | RankBatch
239 :param batch_idx: Index of the batch
240 :type batch_idx: int
241 :param dataloader_idx: Index of the dataloader, defaults to 0
242 :type dataloader_idx: int, optional
243 :return: Model output
244 :rtype: BiEncoderOutput
245 """
246 if isinstance(batch, IndexBatch):
247 return self.forward(batch)
248 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)):
249 return super().validation_step(batch, batch_idx, dataloader_idx)
250 raise ValueError(f"Unknown batch type {type(batch)}")