1"""
2Model module for bi-encoder models.
3
4This module defines the model class used to implement bi-encoder models.
5"""
6
7from abc import ABC, abstractmethod
8from dataclasses import dataclass
9from string import punctuation
10from typing import Iterable, List, Literal, Self, Sequence, Tuple, Type, overload
11
12import torch
13from transformers import BatchEncoding
14
15from ..base import LightningIRModel, LightningIROutput
16from ..modeling_utils.batching import _batch_elementwise_scoring
17from .bi_encoder_config import BiEncoderConfig, MultiVectorBiEncoderConfig, SingleVectorBiEncoderConfig
18
19
[docs]
20@dataclass
21class BiEncoderEmbedding:
22 """Dataclass containing embeddings and the encoding for single-vector bi-encoder models."""
23
24 embeddings: torch.Tensor
25 """Embedding tensor generated by a bi-encoder model of shape [batch_size x seq_len x hidden_size]. The sequence
26 length varies depending on the pooling strategy and the hidden size varies depending on the projection settings."""
27 scoring_mask: torch.Tensor | None
28 """Mask tensor designating which vectors should be ignored during scoring."""
29 encoding: BatchEncoding | None
30 """Tokenizer encodings used to generate the embeddings."""
31 ids: List[str] | None = None
32 """List of ids for the embeddings, e.g., query or document ids."""
33
34 @overload
35 def to(self, device: torch.device, /) -> Self: ...
36
37 @overload
38 def to(self, other: Self, /) -> Self: ...
39
[docs]
40 def to(self, device) -> Self:
41 """Moves the embeddings to the specified device.
42
43 :param device: Device to move the embeddings to or another instance to move to the same device
44 :type device: torch.device | BiEncoderEmbedding
45 :return: Self
46 :rtype: BiEncoderEmbedding
47 """
48 if isinstance(device, BiEncoderEmbedding):
49 device = device.device
50 self.embeddings = self.embeddings.to(device)
51 self.scoring_mask = self.scoring_mask.to(device) if self.scoring_mask is not None else None
52 self.encoding = self.encoding.to(device) if self.encoding is not None else None
53 return self
54
55 @property
56 def device(self) -> torch.device:
57 """Returns the device of the embeddings.
58
59 :raises ValueError: If the embeddings and scoring_mask are not on the same device
60 :return: The device of the embeddings
61 :rtype: torch.device
62 """
63 return self.embeddings.device
64
[docs]
65 def items(self) -> Iterable[Tuple[str, torch.Tensor]]:
66 """Iterates over the embeddings attributes and their values like `dict.items()`.
67
68 :yield: Tuple of attribute name and its value
69 :rtype: Iterator[Iterable[Tuple[str, torch.Tensor]]]
70 """
71 for field in self.__dataclass_fields__:
72 yield field, getattr(self, field)
73
74
[docs]
75@dataclass
76class BiEncoderOutput(LightningIROutput):
77 """Dataclass containing the output of a bi-encoder model."""
78
79 query_embeddings: BiEncoderEmbedding | None = None
80 """Query embeddings generated by the model."""
81 doc_embeddings: BiEncoderEmbedding | None = None
82 """Document embeddings generated by the model."""
83
84
[docs]
85class BiEncoderModel(LightningIRModel, ABC):
86 """A bi-encoder model that encodes queries and documents separately and computes a relevance score between them.
87 See :class:`.BiEncoderConfig` for configuration options."""
88
89 config_class: Type[BiEncoderConfig] = BiEncoderConfig
90 """Configuration class for the bi-encoder model."""
91
[docs]
92 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None:
93 """Initializes a bi-encoder model given a :class:`.BiEncoderConfig`.
94
95 :param config: Configuration for the bi-encoder model
96 :type config: BiEncoderConfig
97 """
98 super().__init__(config, *args, **kwargs)
99 self.config: BiEncoderConfig
100
101 if self.config.similarity_function == "cosine":
102 self.similarity_function = self._cosine_similarity
103 elif self.config.similarity_function == "l2":
104 self.similarity_function = self._l2_similarity
105 elif self.config.similarity_function == "dot":
106 self.similarity_function = self._dot_similarity
107 else:
108 raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
109
[docs]
110 def forward(
111 self,
112 query_encoding: BatchEncoding | None,
113 doc_encoding: BatchEncoding | None,
114 num_docs: Sequence[int] | int | None = None,
115 ) -> BiEncoderOutput:
116 """Embeds queries and/or documents and computes relevance scores between them if both are provided.
117
118 :param query_encoding: Tokenizer encodings for the queries
119 :type query_encoding: BatchEncoding | None
120 :param doc_encoding: Tokenizer encodings for the documents
121 :type doc_encoding: BatchEncoding | None
122 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
123 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
124 sequence contains one value per query specifying the number of documents for that query. If an integer,
125 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
126 the number of documents by the number of queries, defaults to None
127 :type num_docs: Sequence[int] | int | None, optional
128 :return: Output of the model
129 :rtype: BiEncoderOutput
130 """
131 query_embeddings = None
132 if query_encoding is not None:
133 query_embeddings = self.encode_query(query_encoding)
134 doc_embeddings = None
135 if doc_encoding is not None:
136 doc_embeddings = self.encode_doc(doc_encoding)
137 scores = None
138 if doc_embeddings is not None and query_embeddings is not None:
139 scores = self.score(query_embeddings, doc_embeddings, num_docs)
140 return BiEncoderOutput(scores=scores, query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
141
[docs]
142 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
143 """Encodes tokenized queries.
144
145 :param encoding: Tokenizer encodings for the queries
146 :type encoding: BatchEncoding
147 :return: Query embeddings and scoring mask
148 :rtype: BiEncoderEmbedding
149 """
150 return self.encode(encoding=encoding, input_type="query")
151
[docs]
152 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
153 """Encodes tokenized documents.
154
155 :param encoding: Tokenizer encodings for the documents
156 :type encoding: BatchEncoding
157 :return: Query embeddings and scoring mask
158 :rtype: BiEncoderEmbedding
159 """
160 return self.encode(encoding=encoding, input_type="doc")
161
162 def _parse_num_docs(
163 self, query_shape: int, doc_shape: int, num_docs: int | Sequence[int] | None, device: torch.device | None = None
164 ) -> torch.Tensor:
165 """Helper function to parse the number of documents per query."""
166 if isinstance(num_docs, int):
167 num_docs = [num_docs] * query_shape
168 if isinstance(num_docs, list):
169 if sum(num_docs) != doc_shape or len(num_docs) != query_shape:
170 raise ValueError("Num docs does not match doc embeddings")
171 if num_docs is None:
172 if doc_shape % query_shape != 0:
173 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided")
174 num_docs = [doc_shape // query_shape] * query_shape
175 return torch.tensor(num_docs, device=device)
176
[docs]
177 def compute_similarity(
178 self,
179 query_embeddings: BiEncoderEmbedding,
180 doc_embeddings: BiEncoderEmbedding,
181 num_docs: Sequence[int] | int | None = None,
182 ) -> torch.Tensor:
183 """Computes the similarity score between all query and document embedding vector pairs.
184
185 :param query_embeddings: Embeddings of the queries
186 :type query_embeddings: BiEncoderEmbedding
187 :param doc_embeddings: Embeddings of the documents
188 :type doc_embeddings: BiEncoderEmbedding
189 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
190 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
191 sequence contains one value per query specifying the number of documents for that query. If an integer,
192 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
193 the number of documents by the number of queries, defaults to None
194 :type num_docs: Sequence[int] | int | None, optional
195 :return: Similarity scores between all query and document embedding vector pairs
196 :rtype: torch.Tensor
197 """
198 num_docs_t = self._parse_num_docs(
199 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device
200 )
201 query_emb = query_embeddings.embeddings.repeat_interleave(num_docs_t, dim=0).unsqueeze(2)
202 doc_emb = doc_embeddings.embeddings.unsqueeze(1)
203 similarity = self.similarity_function(query_emb, doc_emb)
204 return similarity
205
206 @staticmethod
207 @_batch_elementwise_scoring
208 @torch.autocast(device_type="cuda", enabled=False)
209 def _cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
210 return torch.nn.functional.cosine_similarity(x, y, dim=-1)
211
212 @staticmethod
213 @_batch_elementwise_scoring
214 @torch.autocast(device_type="cuda", enabled=False)
215 def _l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
216 return -1 * torch.cdist(x, y).squeeze(-2)
217
218 @staticmethod
219 @_batch_elementwise_scoring
220 @torch.autocast(device_type="cuda", enabled=False)
221 def _dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
222 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2)
223
[docs]
224 @abstractmethod
225 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
226 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
227
228 :param encoding: Tokenizer encodings for the text sequence
229 :type encoding: BatchEncoding
230 :param input_type: Type of input, either "query" or "doc"
231 :type input_type: Literal["query", "doc"]
232 :return: Embeddings and scoring mask
233 :rtype: BiEncoderEmbedding
234 """
235 pass
236
[docs]
237 @abstractmethod
238 def score(
239 self,
240 query_embeddings: BiEncoderEmbedding,
241 doc_embeddings: BiEncoderEmbedding,
242 num_docs: Sequence[int] | int | None = None,
243 ) -> torch.Tensor:
244 """Compute relevance scores between queries and documents.
245
246 :param query_embeddings: Embeddings and scoring mask for the queries
247 :type query_embeddings: BiEncoderEmbedding
248 :param doc_embeddings: Embeddings and scoring mask for the documents
249 :type doc_embeddings: BiEncoderEmbedding
250 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
251 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
252 sequence contains one value per query specifying the number of documents for that query. If an integer,
253 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
254 the number of documents by the number of queries, defaults to None
255 :type num_docs: Sequence[int] | int | None, optional
256 :return: Relevance scores
257 :rtype: torch.Tensor
258 """
259 pass
260
261
[docs]
262class SingleVectorBiEncoderModel(BiEncoderModel):
263 """A bi-encoder model that encodes queries and documents separately, pools the contextualized embeddings into a
264 single vector, and computes a relevance score based on the similarities between the two vectors. See
265 :class:`.SingleVectorBiEncoderConfig` for configuration options."""
266
267 config_class: Type[SingleVectorBiEncoderConfig] = SingleVectorBiEncoderConfig
268 """Configuration class for the single-vector bi-encoder model."""
269
[docs]
270 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
271 """Initializes a single-vector bi-encoder model given a :class:`.SingleVectorBiEncoderConfig`.
272
273 :param config: Configuration for the single-vector bi-encoder model
274 :type config: SingleVectorBiEncoderConfig
275 """
276 super().__init__(config, *args, **kwargs)
277
[docs]
278 def score(
279 self,
280 query_embeddings: BiEncoderEmbedding,
281 doc_embeddings: BiEncoderEmbedding,
282 num_docs: Sequence[int] | int | None = None,
283 ) -> torch.Tensor:
284 """Compute relevance scores between queries and documents.
285
286 :param query_embeddings: Embeddings and scoring mask for the queries
287 :type query_embeddings: BiEncoderEmbedding
288 :param doc_embeddings: Embeddings and scoring mask for the documents
289 :type doc_embeddings: BiEncoderEmbedding
290 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
291 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
292 sequence contains one value per query specifying the number of documents for that query. If an integer,
293 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
294 the number of documents by the number of queries, defaults to None
295 :type num_docs: Sequence[int] | int | None, optional
296 :return: Relevance scores
297 :rtype: torch.Tensor
298 """
299 return self.compute_similarity(query_embeddings, doc_embeddings, num_docs).view(-1)
300
301
[docs]
302class MultiVectorBiEncoderModel(BiEncoderModel):
303
304 config_class: Type[MultiVectorBiEncoderConfig] = MultiVectorBiEncoderConfig
305 """Configuration class for the single-vector bi-encoder model."""
306
307 supports_retrieval_models = []
308
[docs]
309 def __init__(self, config: MultiVectorBiEncoderConfig, *args, **kwargs) -> None:
310 """Initializes a multi-vector bi-encoder model given a :class:`.MultiVectorBiEncoderConfig`.
311
312 :param config: Configuration for the multi-vector bi-encoder model
313 :type config: MultiVectorBiEncoderConfig
314 :raises ValueError: If mask scoring tokens are specified in the configuration but the tokenizer is not available
315 :raises ValueError: If the specified mask scoring tokens are not in the tokenizer vocab
316 """
317 super().__init__(config, *args, **kwargs)
318
319 self.query_mask_scoring_input_ids: torch.Tensor | None = None
320 self.doc_mask_scoring_input_ids: torch.Tensor | None = None
321
322 # Adds the mask scoring input ids to the model if they are specified in the configuration.
323 for sequence in ("query", "doc"):
324 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens")
325 if mask_scoring_tokens is None:
326 continue
327 if mask_scoring_tokens == "punctuation":
328 mask_scoring_tokens = list(punctuation)
329 try:
330 from transformers import AutoTokenizer
331
332 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
333 except OSError:
334 raise ValueError("Can't use token scoring masking if the checkpoint does not have a tokenizer.")
335 mask_scoring_input_ids = []
336 for token in mask_scoring_tokens:
337 if token not in tokenizer.vocab:
338 raise ValueError(f"Token {token} not in tokenizer vocab")
339 mask_scoring_input_ids.append(tokenizer.vocab[token])
340 setattr(
341 self,
342 f"{sequence}_mask_scoring_input_ids",
343 torch.tensor(mask_scoring_input_ids, dtype=torch.long),
344 )
345
346 def _expand_mask(self, shape: torch.Size, mask: torch.Tensor, dim: int) -> torch.Tensor:
347 """Helper function to expand the mask to the shape of the similarity scores."""
348 if mask.ndim == len(shape):
349 return mask
350 if mask.ndim > len(shape):
351 raise ValueError("Mask has too many dimensions")
352 fill_values = len(shape) - mask.ndim + 1
353 new_shape = [*mask.shape[:-1]] + [1] * fill_values
354 new_shape[dim] = mask.shape[-1]
355 return mask.view(*new_shape)
356
357 def _aggregate(
358 self,
359 scores: torch.Tensor,
360 mask: torch.Tensor,
361 query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"],
362 dim: int,
363 ) -> torch.Tensor:
364 """Helper function to aggregate similarity scores over query and document embeddings."""
365 mask = self._expand_mask(scores.shape, mask, dim)
366 if query_aggregation_function == "max":
367 scores = scores.masked_fill(~mask, float("-inf"))
368 return scores.amax(dim, keepdim=True)
369 if query_aggregation_function == "sum":
370 scores = scores.masked_fill(~mask, 0)
371 return scores.sum(dim, keepdim=True)
372 num_non_masked = mask.sum(dim, keepdim=True)
373 if query_aggregation_function == "mean":
374 return torch.where(num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked)
375 if query_aggregation_function == "harmonic_mean":
376 return torch.where(
377 num_non_masked == 0,
378 0,
379 num_non_masked / (1 / scores).sum(dim, keepdim=True),
380 )
381 raise ValueError(f"Unknown aggregation {query_aggregation_function}")
382
[docs]
383 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
384 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
385 out vectors during scoring.
386
387 :param encoding: Tokenizer encodings for the text sequence
388 :type encoding: BatchEncoding
389 :param input_type: Type of input, either "query" or "doc"
390 :type input_type: Literal["query", "doc"]
391 :return: Scoring mask
392 :rtype: torch.Tensor
393 """
394 input_ids = encoding["input_ids"]
395 attention_mask = encoding["attention_mask"]
396 scoring_mask = attention_mask
397 if scoring_mask is None:
398 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool)
399 scoring_mask = scoring_mask.bool()
400 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids")
401 if mask_scoring_input_ids is not None:
402 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1)
403 scoring_mask = scoring_mask & ~ignore_mask
404 return scoring_mask
405
[docs]
406 def aggregate_similarity(
407 self,
408 similarity: torch.Tensor,
409 query_scoring_mask: torch.Tensor,
410 doc_scoring_mask: torch.Tensor,
411 num_docs: int | Sequence[int] | None = None,
412 ) -> torch.Tensor:
413 """Aggregates the matrix of query-document similarities into a single score based on the configured aggregation
414 strategy.
415
416 :param similarity: Query-document similarity matrix
417 :type similarity: torch.Tensor
418 :param query_scoring_mask: Which query vectors should be masked out during scoring
419 :type query_scoring_mask: torch.Tensor
420 :param doc_scoring_mask: Which doucment vectors should be masked out during scoring
421 :type doc_scoring_mask: torch.Tensor
422 :return: Aggregated similarity scores
423 :rtype: torch.Tensor
424 """
425 num_docs_t = self._parse_num_docs(
426 query_scoring_mask.shape[0], doc_scoring_mask.shape[0], num_docs, similarity.device
427 )
428 scores = self._aggregate(similarity, doc_scoring_mask, self.config.doc_aggregation_function, -1)
429 repeated_query_scoring_mask = query_scoring_mask.repeat_interleave(num_docs_t, dim=0)
430 scores = self._aggregate(scores, repeated_query_scoring_mask, self.config.query_aggregation_function, -2)
431 return scores.view(scores.shape[0])
432
[docs]
433 def score(
434 self,
435 query_embeddings: BiEncoderEmbedding,
436 doc_embeddings: BiEncoderEmbedding,
437 num_docs: Sequence[int] | int | None = None,
438 ) -> torch.Tensor:
439 """Compute relevance scores between queries and documents.
440
441 :param query_embeddings: Embeddings and scoring mask for the queries
442 :type query_embeddings: BiEncoderEmbedding
443 :param doc_embeddings: Embeddings and scoring mask for the documents
444 :type doc_embeddings: BiEncoderEmbedding
445 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
446 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
447 sequence contains one value per query specifying the number of documents for that query. If an integer,
448 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
449 the number of documents by the number of queries, defaults to None
450 :type num_docs: Sequence[int] | int | None, optional
451 :return: Relevance scores
452 :rtype: torch.Tensor
453 """
454 similarity = self.compute_similarity(query_embeddings, doc_embeddings, num_docs)
455 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None:
456 raise ValueError("Scoring masks expected for scoring multi-vector embeddings")
457 return self.aggregate_similarity(similarity, query_embeddings.scoring_mask, doc_embeddings.scoring_mask)