1"""
2Model module for bi-encoder models.
3
4This module defines the model class used to implement bi-encoder models.
5"""
6
7from abc import ABC, abstractmethod
8from collections.abc import Iterable, Sequence
9from dataclasses import dataclass
10from string import punctuation
11from typing import Literal, Self, overload
12
13import torch
14from transformers import BatchEncoding
15
16from ..base import LightningIRModel, LightningIROutput
17from ..modeling_utils.batching import _batch_elementwise_scoring
18from .bi_encoder_config import BiEncoderConfig, MultiVectorBiEncoderConfig, SingleVectorBiEncoderConfig
19
20
[docs]
21@dataclass
22class BiEncoderEmbedding:
23 """Dataclass containing embeddings and the encoding for bi-encoder models."""
24
25 embeddings: torch.Tensor
26 """Embedding tensor generated by a bi-encoder model of shape [batch_size x seq_len x hidden_size]. The sequence
27 length varies depending on the pooling strategy and the hidden size varies depending on the projection settings."""
28 scoring_mask: torch.Tensor | None = None
29 """Mask tensor designating which vectors should be ignored during scoring."""
30 encoding: BatchEncoding | None = None
31 """Tokenizer encodings used to generate the embeddings."""
32 ids: list[str] | None = None
33 """list of ids for the embeddings, e.g., query or document ids."""
34
35 @overload
36 def to(self, device: torch.device, /) -> Self: ...
37
38 @overload
39 def to(self, other: Self, /) -> Self: ...
40
[docs]
41 def to(self, device) -> Self:
42 """Moves the embeddings to the specified device.
43
44 Args:
45 device (torch.device | BiEncoderEmbedding): Device to move the embeddings to or another instance to move to
46 the same device.
47 Returns:
48 Self: The instance with embeddings moved to the specified device.
49 """
50 if isinstance(device, BiEncoderEmbedding):
51 device = device.device
52 self.embeddings = self.embeddings.to(device)
53 self.scoring_mask = self.scoring_mask.to(device) if self.scoring_mask is not None else None
54 self.encoding = self.encoding.to(device) if self.encoding is not None else None
55 return self
56
57 @property
58 def device(self) -> torch.device:
59 """Returns the device of the embeddings.
60
61 Returns:
62 torch.device: Device of the embeddings.
63 Raises:
64 ValueError: If the embeddings and scoring_mask are not on the same device.
65 """
66 return self.embeddings.device
67
[docs]
68 def items(self) -> Iterable[tuple[str, torch.Tensor]]:
69 """Iterates over the embeddings attributes and their values like `dict.items()`.
70
71 Yields:
72 tuple[str, torch.Tensor]: tuple of attribute name and its value.
73 """
74 for field in self.__dataclass_fields__:
75 yield field, getattr(self, field)
76
77
[docs]
78@dataclass
79class BiEncoderOutput(LightningIROutput):
80 """Dataclass containing the output of a bi-encoder model."""
81
82 query_embeddings: BiEncoderEmbedding | None = None
83 """Query embeddings generated by the model."""
84 doc_embeddings: BiEncoderEmbedding | None = None
85 """Document embeddings generated by the model."""
86 similarity: torch.Tensor | None = None
87 """Similarity scores between query and document embeddings."""
88
89
[docs]
90class BiEncoderModel(LightningIRModel, ABC):
91 """A bi-encoder model that encodes queries and documents separately and computes a relevance score between them.
92 See :class:`.BiEncoderConfig` for configuration options."""
93
94 config_class: type[BiEncoderConfig] = BiEncoderConfig
95 """Configuration class for the bi-encoder model."""
96
[docs]
97 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None:
98 """Initializes a bi-encoder model given a :class:`.BiEncoderConfig`.
99
100 Args:
101 config (BiEncoderConfig): Configuration for the bi-encoder model.
102 Raises:
103 ValueError: If the similarity function is not supported.
104 """
105 super().__init__(config, *args, **kwargs)
106 self.config: BiEncoderConfig
107
108 if self.config.similarity_function == "cosine":
109 self.similarity_function = self._cosine_similarity
110 elif self.config.similarity_function == "l2":
111 self.similarity_function = self._l2_similarity
112 elif self.config.similarity_function == "dot":
113 self.similarity_function = self._dot_similarity
114 else:
115 raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
116
[docs]
117 def forward(
118 self,
119 query_encoding: BatchEncoding | None,
120 doc_encoding: BatchEncoding | None,
121 num_docs: Sequence[int] | int | None = None,
122 ) -> BiEncoderOutput:
123 """Embeds queries and/or documents and computes relevance scores between them if both are provided.
124
125 Args:
126 query_encoding (BatchEncoding | None): Tokenizer encodings for the queries. Defaults to None.
127 doc_encoding (BatchEncoding | None): Tokenizer encodings for the documents. Defaults to None.
128 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
129 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
130 number of documents, i.e., the sequence contains one value per query specifying the number of documents
131 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
132 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
133 Returns:
134 BiEncoderOutput: Output of the model containing query and document embeddings and relevance scores.
135 """
136 query_embeddings = None
137 if query_encoding is not None:
138 query_embeddings = self.encode_query(query_encoding)
139 doc_embeddings = None
140 if doc_encoding is not None:
141 doc_embeddings = self.encode_doc(doc_encoding)
142 output = BiEncoderOutput(query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
143 if doc_embeddings is not None and query_embeddings is not None:
144 output = self.score(output, num_docs)
145 return output
146
[docs]
147 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
148 """Encodes tokenized queries.
149
150 Args:
151 encoding (BatchEncoding): Tokenizer encodings for the queries.
152 Returns:
153 BiEncoderEmbedding: Query embeddings and scoring mask.
154 """
155 return self.encode(encoding=encoding, input_type="query")
156
[docs]
157 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
158 """Encodes tokenized documents.
159
160 Args:
161 encoding (BatchEncoding): Tokenizer encodings for the documents.
162 Returns:
163 BiEncoderEmbedding: Document embeddings and scoring mask.
164 """
165 return self.encode(encoding=encoding, input_type="doc")
166
167 def _parse_num_docs(
168 self, query_shape: int, doc_shape: int, num_docs: int | Sequence[int] | None, device: torch.device | None = None
169 ) -> torch.Tensor:
170 """Helper function to parse the number of documents per query."""
171 if isinstance(num_docs, int):
172 num_docs = [num_docs] * query_shape
173 if isinstance(num_docs, list):
174 if sum(num_docs) != doc_shape or len(num_docs) != query_shape:
175 raise ValueError("Num docs does not match doc embeddings")
176 if num_docs is None:
177 if doc_shape % query_shape != 0:
178 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided")
179 num_docs = [doc_shape // query_shape] * query_shape
180 return torch.tensor(num_docs, device=device)
181
[docs]
182 def compute_similarity(
183 self,
184 query_embeddings: BiEncoderEmbedding,
185 doc_embeddings: BiEncoderEmbedding,
186 num_docs: Sequence[int] | int | None = None,
187 ) -> torch.Tensor:
188 """Computes the similarity score between all query and document embedding vector pairs.
189
190 Args:
191 query_embeddings (BiEncoderEmbedding): Embeddings of the queries.
192 doc_embeddings (BiEncoderEmbedding): Embeddings of the documents.
193 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
194 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the
195 number of documents, i.e., the sequence contains one value per query specifying the number of documents
196 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
197 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
198 Returns:
199 torch.Tensor: Similarity scores between all query and document embedding vector pairs.
200 """
201 num_docs_t = self._parse_num_docs(
202 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device
203 )
204 query_emb = query_embeddings.embeddings.repeat_interleave(num_docs_t, dim=0).unsqueeze(2)
205 doc_emb = doc_embeddings.embeddings.unsqueeze(1)
206 similarity = self.similarity_function(query_emb, doc_emb)
207 return similarity
208
209 @staticmethod
210 @_batch_elementwise_scoring
211 @torch.autocast(device_type="cuda", enabled=False)
212 def _cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
213 return torch.nn.functional.cosine_similarity(x, y, dim=-1)
214
215 @staticmethod
216 @_batch_elementwise_scoring
217 @torch.autocast(device_type="cuda", enabled=False)
218 def _l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
219 return -1 * torch.cdist(x, y).squeeze(-2)
220
221 @staticmethod
222 @_batch_elementwise_scoring
223 @torch.autocast(device_type="cuda", enabled=False)
224 def _dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
225 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2)
226
[docs]
227 @abstractmethod
228 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
229 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
230
231 Args:
232 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
233 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
234 Returns:
235 BiEncoderEmbedding: Embeddings and scoring mask for the text sequence.
236 Raises:
237 NotImplementedError: If the method is not implemented.
238 """
239 raise NotImplementedError
240
[docs]
241 @abstractmethod
242 def score(self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None) -> BiEncoderOutput:
243 """Compute relevance scores between queries and documents.
244
245 Args:
246 output (BiEncoderOutput): Output containing embeddings and scoring mask.
247 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
248 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
249 number of documents, i.e., the sequence contains one value per query specifying the number of documents
250 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
251 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
252 Returns:
253 torch.Tensor: Relevance scores.
254 Raises:
255 NotImplementedError: If the method is not implemented.
256 """
257 raise NotImplementedError
258
259
[docs]
260class SingleVectorBiEncoderModel(BiEncoderModel):
261 """A bi-encoder model that encodes queries and documents separately, pools the contextualized embeddings into a
262 single vector, and computes a relevance score based on the similarities between the two vectors. See
263 :class:`.SingleVectorBiEncoderConfig` for configuration options."""
264
265 config_class: type[SingleVectorBiEncoderConfig] = SingleVectorBiEncoderConfig
266 """Configuration class for the single-vector bi-encoder model."""
267
[docs]
268 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
269 """Initializes a single-vector bi-encoder model given a :class:`.SingleVectorBiEncoderConfig`.
270
271 Args:
272 config (SingleVectorBiEncoderConfig): Configuration for the single-vector bi-encoder model.
273 """
274 super().__init__(config, *args, **kwargs)
275
[docs]
276 def score(
277 self,
278 output: BiEncoderOutput,
279 num_docs: Sequence[int] | int | None = None,
280 ) -> BiEncoderOutput:
281 """Compute relevance scores between queries and documents.
282
283 Args:
284 output (BiEncoderOutput): Output containing embeddings and scoring mask.
285 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
286 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
287 number of documents, i.e., the sequence contains one value per query specifying the number of documents
288 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
289 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
290 Returns:
291 BiEncoderOutput: Output containing relevance scores.
292 Raises:
293 ValueError: If query or document embeddings are not provided in the output.
294 """
295 if output.query_embeddings is None or output.doc_embeddings is None:
296 raise ValueError("Expected query and document embeddings in BiEncoderOutput")
297 similarity = self.compute_similarity(output.query_embeddings, output.doc_embeddings, num_docs)
298 output.scores = similarity.view(-1)
299 output.similarity = similarity
300 return output
301
302
[docs]
303class MultiVectorBiEncoderModel(BiEncoderModel):
304 config_class: type[MultiVectorBiEncoderConfig] = MultiVectorBiEncoderConfig
305 """Configuration class for the single-vector bi-encoder model."""
306
307 supports_retrieval_models = []
308
[docs]
309 def __init__(self, config: MultiVectorBiEncoderConfig, *args, **kwargs) -> None:
310 """Initializes a multi-vector bi-encoder model given a :class:`.MultiVectorBiEncoderConfig`.
311
312 Args:
313 config (MultiVectorBiEncoderConfig): Configuration for the multi-vector bi-encoder model
314 Raises:
315 ValueError: If mask scoring tokens are specified in the configuration but the tokenizer is not available
316 ValueError: If the specified mask scoring tokens are not in the tokenizer vocab
317 """
318 super().__init__(config, *args, **kwargs)
319
320 self.query_mask_scoring_input_ids: torch.Tensor | None = None
321 self.doc_mask_scoring_input_ids: torch.Tensor | None = None
322
323 # Adds the mask scoring input ids to the model if they are specified in the configuration.
324 for sequence in ("query", "doc"):
325 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens")
326 if mask_scoring_tokens is None:
327 continue
328 if mask_scoring_tokens == "punctuation":
329 mask_scoring_tokens = list(punctuation)
330 try:
331 from transformers import AutoTokenizer
332
333 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
334 except OSError as err:
335 raise ValueError(
336 "Can't use token scoring masking if the checkpoint does not have a tokenizer."
337 ) from err
338 mask_scoring_input_ids = []
339 for token in mask_scoring_tokens:
340 if token not in tokenizer.vocab:
341 raise ValueError(f"Token {token} not in tokenizer vocab")
342 mask_scoring_input_ids.append(tokenizer.vocab[token])
343 setattr(
344 self,
345 f"{sequence}_mask_scoring_input_ids",
346 torch.tensor(mask_scoring_input_ids, dtype=torch.long),
347 )
348
349 def _expand_mask(self, shape: torch.Size, mask: torch.Tensor, dim: int) -> torch.Tensor:
350 """Helper function to expand the mask to the shape of the similarity scores."""
351 if mask.ndim == len(shape):
352 return mask
353 if mask.ndim > len(shape):
354 raise ValueError("Mask has too many dimensions")
355 fill_values = len(shape) - mask.ndim + 1
356 new_shape = [*mask.shape[:-1]] + [1] * fill_values
357 new_shape[dim] = mask.shape[-1]
358 return mask.view(*new_shape)
359
360 def _aggregate(
361 self,
362 scores: torch.Tensor,
363 mask: torch.Tensor,
364 aggregation_function: Literal["max", "sum", "mean"],
365 dim: int,
366 ) -> torch.Tensor:
367 """Helper function to aggregate similarity scores over query and document embeddings."""
368 mask = self._expand_mask(scores.shape, mask, dim)
369 if aggregation_function == "max":
370 scores = scores.masked_fill(~mask, float("-inf"))
371 return scores.amax(dim, keepdim=True)
372 scores = scores.masked_fill(~mask, 0)
373 if aggregation_function == "sum":
374 return scores.sum(dim, keepdim=True)
375 num_non_masked = mask.sum(dim, keepdim=True)
376 if aggregation_function == "mean":
377 return scores.masked_fill(num_non_masked == 0, 0).sum(dim, keepdim=True) / num_non_masked.clamp(min=1)
378 raise ValueError(f"Unknown aggregation {aggregation_function}")
379
[docs]
380 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
381 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
382 out vectors during scoring.
383
384 Args:
385 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
386 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
387 Returns:
388 torch.Tensor: Scoring mask.
389 """
390 input_ids = encoding["input_ids"]
391 attention_mask = encoding["attention_mask"]
392 scoring_mask = attention_mask
393 if scoring_mask is None:
394 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool)
395 scoring_mask = scoring_mask.bool()
396 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids")
397 if mask_scoring_input_ids is not None:
398 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1)
399 scoring_mask = scoring_mask & ~ignore_mask
400 return scoring_mask
401
[docs]
402 def aggregate_similarity(
403 self,
404 similarity: torch.Tensor,
405 query_scoring_mask: torch.Tensor,
406 doc_scoring_mask: torch.Tensor,
407 num_docs: int | Sequence[int] | None = None,
408 ) -> torch.Tensor:
409 """Aggregates the matrix of query-document similarities into a single score based on the configured aggregation
410 strategy.
411
412 Args:
413 similarity (torch.Tensor): Query--document similarity matrix.
414 query_scoring_mask (torch.Tensor): Which query vectors should be masked out during scoring.
415 doc_scoring_mask (torch.Tensor): Which document vectors should be masked out during scoring.
416 Returns:
417 torch.Tensor: Aggregated similarity scores.
418 """
419 scores = similarity
420 if scores.shape[-1] > 1:
421 scores = self._aggregate(scores, doc_scoring_mask, self.config.doc_aggregation_function, -1)
422 if scores.shape[-2] > 1:
423 num_docs_t = self._parse_num_docs(
424 query_scoring_mask.shape[0], doc_scoring_mask.shape[0], num_docs, similarity.device
425 )
426 repeated_query_scoring_mask = query_scoring_mask.repeat_interleave(num_docs_t, dim=0)
427 scores = self._aggregate(scores, repeated_query_scoring_mask, self.config.query_aggregation_function, -2)
428 return scores.view(scores.shape[0])
429
[docs]
430 def score(self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None) -> BiEncoderOutput:
431 """Compute relevance scores between queries and documents.
432
433 Args:
434 output (BiEncoderOutput): Output containing embeddings and scoring mask.
435 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
436 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
437 number of documents, i.e., the sequence contains one value per query specifying the number of documents
438 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
439 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
440 Returns:
441 BiEncoderOutput: Output containing relevance scores.
442 Raises:
443 ValueError: If query or document embeddings are not provided in the output.
444 ValueError: If scoring masks are not provided for the embeddings.
445 """
446 query_embeddings = output.query_embeddings
447 doc_embeddings = output.doc_embeddings
448 if query_embeddings is None or doc_embeddings is None:
449 raise ValueError("Embeddings expected for scoring multi-vector embeddings")
450 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None:
451 raise ValueError("Scoring masks expected for scoring multi-vector embeddings")
452
453 similarity = self.compute_similarity(query_embeddings, doc_embeddings, num_docs)
454 scores = self.aggregate_similarity(
455 similarity, query_embeddings.scoring_mask, doc_embeddings.scoring_mask, num_docs
456 )
457 output.scores = scores
458 output.similarity = similarity
459 return output