1"""
2Model module for bi-encoder models.
3
4This module defines the model class used to implement bi-encoder models.
5"""
6
7from abc import ABC, abstractmethod
8from dataclasses import dataclass
9from string import punctuation
10from typing import Iterable, List, Literal, Self, Sequence, Tuple, Type, overload
11
12import torch
13from transformers import BatchEncoding
14
15from ..base import LightningIRModel, LightningIROutput
16from ..modeling_utils.batching import _batch_elementwise_scoring
17from .bi_encoder_config import BiEncoderConfig, MultiVectorBiEncoderConfig, SingleVectorBiEncoderConfig
18
19
[docs]
20@dataclass
21class BiEncoderEmbedding:
22 """Dataclass containing embeddings and the encoding for bi-encoder models."""
23
24 embeddings: torch.Tensor
25 """Embedding tensor generated by a bi-encoder model of shape [batch_size x seq_len x hidden_size]. The sequence
26 length varies depending on the pooling strategy and the hidden size varies depending on the projection settings."""
27 scoring_mask: torch.Tensor | None = None
28 """Mask tensor designating which vectors should be ignored during scoring."""
29 encoding: BatchEncoding | None = None
30 """Tokenizer encodings used to generate the embeddings."""
31 ids: List[str] | None = None
32 """List of ids for the embeddings, e.g., query or document ids."""
33
34 @overload
35 def to(self, device: torch.device, /) -> Self: ...
36
37 @overload
38 def to(self, other: Self, /) -> Self: ...
39
[docs]
40 def to(self, device) -> Self:
41 """Moves the embeddings to the specified device.
42
43 Args:
44 device (torch.device | BiEncoderEmbedding): Device to move the embeddings to or another instance to move to
45 the same device.
46 Returns:
47 Self: The instance with embeddings moved to the specified device.
48 """
49 if isinstance(device, BiEncoderEmbedding):
50 device = device.device
51 self.embeddings = self.embeddings.to(device)
52 self.scoring_mask = self.scoring_mask.to(device) if self.scoring_mask is not None else None
53 self.encoding = self.encoding.to(device) if self.encoding is not None else None
54 return self
55
56 @property
57 def device(self) -> torch.device:
58 """Returns the device of the embeddings.
59
60 Returns:
61 torch.device: Device of the embeddings.
62 Raises:
63 ValueError: If the embeddings and scoring_mask are not on the same device.
64 """
65 return self.embeddings.device
66
[docs]
67 def items(self) -> Iterable[Tuple[str, torch.Tensor]]:
68 """Iterates over the embeddings attributes and their values like `dict.items()`.
69
70 Yields:
71 Tuple[str, torch.Tensor]: Tuple of attribute name and its value.
72 """
73 for field in self.__dataclass_fields__:
74 yield field, getattr(self, field)
75
76
[docs]
77@dataclass
78class BiEncoderOutput(LightningIROutput):
79 """Dataclass containing the output of a bi-encoder model."""
80
81 query_embeddings: BiEncoderEmbedding | None = None
82 """Query embeddings generated by the model."""
83 doc_embeddings: BiEncoderEmbedding | None = None
84 """Document embeddings generated by the model."""
85 similarity: torch.Tensor | None = None
86 """Similarity scores between query and document embeddings."""
87
88
[docs]
89class BiEncoderModel(LightningIRModel, ABC):
90 """A bi-encoder model that encodes queries and documents separately and computes a relevance score between them.
91 See :class:`.BiEncoderConfig` for configuration options."""
92
93 config_class: Type[BiEncoderConfig] = BiEncoderConfig
94 """Configuration class for the bi-encoder model."""
95
[docs]
96 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None:
97 """Initializes a bi-encoder model given a :class:`.BiEncoderConfig`.
98
99 Args:
100 config (BiEncoderConfig): Configuration for the bi-encoder model.
101 Raises:
102 ValueError: If the similarity function is not supported.
103 """
104 super().__init__(config, *args, **kwargs)
105 self.config: BiEncoderConfig
106
107 if self.config.similarity_function == "cosine":
108 self.similarity_function = self._cosine_similarity
109 elif self.config.similarity_function == "l2":
110 self.similarity_function = self._l2_similarity
111 elif self.config.similarity_function == "dot":
112 self.similarity_function = self._dot_similarity
113 else:
114 raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
115
[docs]
116 def forward(
117 self,
118 query_encoding: BatchEncoding | None,
119 doc_encoding: BatchEncoding | None,
120 num_docs: Sequence[int] | int | None = None,
121 ) -> BiEncoderOutput:
122 """Embeds queries and/or documents and computes relevance scores between them if both are provided.
123
124 Args:
125 query_encoding (BatchEncoding | None): Tokenizer encodings for the queries. Defaults to None.
126 doc_encoding (BatchEncoding | None): Tokenizer encodings for the documents. Defaults to None.
127 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
128 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
129 number of documents, i.e., the sequence contains one value per query specifying the number of documents
130 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
131 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
132 Returns:
133 BiEncoderOutput: Output of the model containing query and document embeddings and relevance scores.
134 """
135 query_embeddings = None
136 if query_encoding is not None:
137 query_embeddings = self.encode_query(query_encoding)
138 doc_embeddings = None
139 if doc_encoding is not None:
140 doc_embeddings = self.encode_doc(doc_encoding)
141 output = BiEncoderOutput(query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
142 if doc_embeddings is not None and query_embeddings is not None:
143 output = self.score(output, num_docs)
144 return output
145
[docs]
146 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
147 """Encodes tokenized queries.
148
149 Args:
150 encoding (BatchEncoding): Tokenizer encodings for the queries.
151 Returns:
152 BiEncoderEmbedding: Query embeddings and scoring mask.
153 """
154 return self.encode(encoding=encoding, input_type="query")
155
[docs]
156 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
157 """Encodes tokenized documents.
158
159 Args:
160 encoding (BatchEncoding): Tokenizer encodings for the documents.
161 Returns:
162 BiEncoderEmbedding: Document embeddings and scoring mask.
163 """
164 return self.encode(encoding=encoding, input_type="doc")
165
166 def _parse_num_docs(
167 self, query_shape: int, doc_shape: int, num_docs: int | Sequence[int] | None, device: torch.device | None = None
168 ) -> torch.Tensor:
169 """Helper function to parse the number of documents per query."""
170 if isinstance(num_docs, int):
171 num_docs = [num_docs] * query_shape
172 if isinstance(num_docs, list):
173 if sum(num_docs) != doc_shape or len(num_docs) != query_shape:
174 raise ValueError("Num docs does not match doc embeddings")
175 if num_docs is None:
176 if doc_shape % query_shape != 0:
177 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided")
178 num_docs = [doc_shape // query_shape] * query_shape
179 return torch.tensor(num_docs, device=device)
180
[docs]
181 def compute_similarity(
182 self,
183 query_embeddings: BiEncoderEmbedding,
184 doc_embeddings: BiEncoderEmbedding,
185 num_docs: Sequence[int] | int | None = None,
186 ) -> torch.Tensor:
187 """Computes the similarity score between all query and document embedding vector pairs.
188
189 Args:
190 query_embeddings (BiEncoderEmbedding): Embeddings of the queries.
191 doc_embeddings (BiEncoderEmbedding): Embeddings of the documents.
192 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
193 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the
194 number of documents, i.e., the sequence contains one value per query specifying the number of documents
195 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
196 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
197 Returns:
198 torch.Tensor: Similarity scores between all query and document embedding vector pairs.
199 """
200 num_docs_t = self._parse_num_docs(
201 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device
202 )
203 query_emb = query_embeddings.embeddings.repeat_interleave(num_docs_t, dim=0).unsqueeze(2)
204 doc_emb = doc_embeddings.embeddings.unsqueeze(1)
205 similarity = self.similarity_function(query_emb, doc_emb)
206 return similarity
207
208 @staticmethod
209 @_batch_elementwise_scoring
210 @torch.autocast(device_type="cuda", enabled=False)
211 def _cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
212 return torch.nn.functional.cosine_similarity(x, y, dim=-1)
213
214 @staticmethod
215 @_batch_elementwise_scoring
216 @torch.autocast(device_type="cuda", enabled=False)
217 def _l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
218 return -1 * torch.cdist(x, y).squeeze(-2)
219
220 @staticmethod
221 @_batch_elementwise_scoring
222 @torch.autocast(device_type="cuda", enabled=False)
223 def _dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
224 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2)
225
[docs]
226 @abstractmethod
227 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
228 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
229
230 Args:
231 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
232 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
233 Returns:
234 BiEncoderEmbedding: Embeddings and scoring mask for the text sequence.
235 Raises:
236 NotImplementedError: If the method is not implemented.
237 """
238 raise NotImplementedError
239
[docs]
240 @abstractmethod
241 def score(self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None) -> BiEncoderOutput:
242 """Compute relevance scores between queries and documents.
243
244 Args:
245 output (BiEncoderOutput): Output containing embeddings and scoring mask.
246 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
247 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
248 number of documents, i.e., the sequence contains one value per query specifying the number of documents
249 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
250 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
251 Returns:
252 torch.Tensor: Relevance scores.
253 Raises:
254 NotImplementedError: If the method is not implemented.
255 """
256 raise NotImplementedError
257
258
[docs]
259class SingleVectorBiEncoderModel(BiEncoderModel):
260 """A bi-encoder model that encodes queries and documents separately, pools the contextualized embeddings into a
261 single vector, and computes a relevance score based on the similarities between the two vectors. See
262 :class:`.SingleVectorBiEncoderConfig` for configuration options."""
263
264 config_class: Type[SingleVectorBiEncoderConfig] = SingleVectorBiEncoderConfig
265 """Configuration class for the single-vector bi-encoder model."""
266
[docs]
267 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
268 """Initializes a single-vector bi-encoder model given a :class:`.SingleVectorBiEncoderConfig`.
269
270 Args:
271 config (SingleVectorBiEncoderConfig): Configuration for the single-vector bi-encoder model.
272 """
273 super().__init__(config, *args, **kwargs)
274
[docs]
275 def score(
276 self,
277 output: BiEncoderOutput,
278 num_docs: Sequence[int] | int | None = None,
279 ) -> BiEncoderOutput:
280 """Compute relevance scores between queries and documents.
281
282 Args:
283 output (BiEncoderOutput): Output containing embeddings and scoring mask.
284 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
285 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
286 number of documents, i.e., the sequence contains one value per query specifying the number of documents
287 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
288 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
289 Returns:
290 BiEncoderOutput: Output containing relevance scores.
291 Raises:
292 ValueError: If query or document embeddings are not provided in the output.
293 """
294 if output.query_embeddings is None or output.doc_embeddings is None:
295 raise ValueError("Expected query and document embeddings in BiEncoderOutput")
296 similarity = self.compute_similarity(output.query_embeddings, output.doc_embeddings, num_docs)
297 output.scores = similarity.view(-1)
298 output.similarity = similarity
299 return output
300
301
[docs]
302class MultiVectorBiEncoderModel(BiEncoderModel):
303
304 config_class: Type[MultiVectorBiEncoderConfig] = MultiVectorBiEncoderConfig
305 """Configuration class for the single-vector bi-encoder model."""
306
307 supports_retrieval_models = []
308
[docs]
309 def __init__(self, config: MultiVectorBiEncoderConfig, *args, **kwargs) -> None:
310 """Initializes a multi-vector bi-encoder model given a :class:`.MultiVectorBiEncoderConfig`.
311
312 Args:
313 config (MultiVectorBiEncoderConfig): Configuration for the multi-vector bi-encoder model
314 Raises:
315 ValueError: If mask scoring tokens are specified in the configuration but the tokenizer is not available
316 ValueError: If the specified mask scoring tokens are not in the tokenizer vocab
317 """
318 super().__init__(config, *args, **kwargs)
319
320 self.query_mask_scoring_input_ids: torch.Tensor | None = None
321 self.doc_mask_scoring_input_ids: torch.Tensor | None = None
322
323 # Adds the mask scoring input ids to the model if they are specified in the configuration.
324 for sequence in ("query", "doc"):
325 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens")
326 if mask_scoring_tokens is None:
327 continue
328 if mask_scoring_tokens == "punctuation":
329 mask_scoring_tokens = list(punctuation)
330 try:
331 from transformers import AutoTokenizer
332
333 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
334 except OSError:
335 raise ValueError("Can't use token scoring masking if the checkpoint does not have a tokenizer.")
336 mask_scoring_input_ids = []
337 for token in mask_scoring_tokens:
338 if token not in tokenizer.vocab:
339 raise ValueError(f"Token {token} not in tokenizer vocab")
340 mask_scoring_input_ids.append(tokenizer.vocab[token])
341 setattr(
342 self,
343 f"{sequence}_mask_scoring_input_ids",
344 torch.tensor(mask_scoring_input_ids, dtype=torch.long),
345 )
346
347 def _expand_mask(self, shape: torch.Size, mask: torch.Tensor, dim: int) -> torch.Tensor:
348 """Helper function to expand the mask to the shape of the similarity scores."""
349 if mask.ndim == len(shape):
350 return mask
351 if mask.ndim > len(shape):
352 raise ValueError("Mask has too many dimensions")
353 fill_values = len(shape) - mask.ndim + 1
354 new_shape = [*mask.shape[:-1]] + [1] * fill_values
355 new_shape[dim] = mask.shape[-1]
356 return mask.view(*new_shape)
357
358 def _aggregate(
359 self,
360 scores: torch.Tensor,
361 mask: torch.Tensor,
362 aggregation_function: Literal["max", "sum", "mean"],
363 dim: int,
364 ) -> torch.Tensor:
365 """Helper function to aggregate similarity scores over query and document embeddings."""
366 mask = self._expand_mask(scores.shape, mask, dim)
367 if aggregation_function == "max":
368 scores = scores.masked_fill(~mask, float("-inf"))
369 return scores.amax(dim, keepdim=True)
370 scores = scores.masked_fill(~mask, 0)
371 if aggregation_function == "sum":
372 return scores.sum(dim, keepdim=True)
373 num_non_masked = mask.sum(dim, keepdim=True)
374 if aggregation_function == "mean":
375 return scores.masked_fill(num_non_masked == 0, 0).sum(dim, keepdim=True) / num_non_masked.clamp(min=1)
376 raise ValueError(f"Unknown aggregation {aggregation_function}")
377
[docs]
378 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
379 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
380 out vectors during scoring.
381
382 Args:
383 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
384 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
385 Returns:
386 torch.Tensor: Scoring mask.
387 """
388 input_ids = encoding["input_ids"]
389 attention_mask = encoding["attention_mask"]
390 scoring_mask = attention_mask
391 if scoring_mask is None:
392 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool)
393 scoring_mask = scoring_mask.bool()
394 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids")
395 if mask_scoring_input_ids is not None:
396 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1)
397 scoring_mask = scoring_mask & ~ignore_mask
398 return scoring_mask
399
[docs]
400 def aggregate_similarity(
401 self,
402 similarity: torch.Tensor,
403 query_scoring_mask: torch.Tensor,
404 doc_scoring_mask: torch.Tensor,
405 num_docs: int | Sequence[int] | None = None,
406 ) -> torch.Tensor:
407 """Aggregates the matrix of query-document similarities into a single score based on the configured aggregation
408 strategy.
409
410 Args:
411 similarity (torch.Tensor): Query--document similarity matrix.
412 query_scoring_mask (torch.Tensor): Which query vectors should be masked out during scoring.
413 doc_scoring_mask (torch.Tensor): Which document vectors should be masked out during scoring.
414 Returns:
415 torch.Tensor: Aggregated similarity scores.
416 """
417 scores = similarity
418 if scores.shape[-1] > 1:
419 scores = self._aggregate(scores, doc_scoring_mask, self.config.doc_aggregation_function, -1)
420 if scores.shape[-2] > 1:
421 num_docs_t = self._parse_num_docs(
422 query_scoring_mask.shape[0], doc_scoring_mask.shape[0], num_docs, similarity.device
423 )
424 repeated_query_scoring_mask = query_scoring_mask.repeat_interleave(num_docs_t, dim=0)
425 scores = self._aggregate(scores, repeated_query_scoring_mask, self.config.query_aggregation_function, -2)
426 return scores.view(scores.shape[0])
427
[docs]
428 def score(self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None) -> BiEncoderOutput:
429 """Compute relevance scores between queries and documents.
430
431 Args:
432 output (BiEncoderOutput): Output containing embeddings and scoring mask.
433 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
434 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
435 number of documents, i.e., the sequence contains one value per query specifying the number of documents
436 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
437 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
438 Returns:
439 BiEncoderOutput: Output containing relevance scores.
440 Raises:
441 ValueError: If query or document embeddings are not provided in the output.
442 ValueError: If scoring masks are not provided for the embeddings.
443 """
444 query_embeddings = output.query_embeddings
445 doc_embeddings = output.doc_embeddings
446 if query_embeddings is None or doc_embeddings is None:
447 raise ValueError("Embeddings expected for scoring multi-vector embeddings")
448 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None:
449 raise ValueError("Scoring masks expected for scoring multi-vector embeddings")
450
451 similarity = self.compute_similarity(query_embeddings, doc_embeddings, num_docs)
452 scores = self.aggregate_similarity(
453 similarity, query_embeddings.scoring_mask, doc_embeddings.scoring_mask, num_docs
454 )
455 output.scores = scores
456 output.similarity = similarity
457 return output