Source code for lightning_ir.bi_encoder.bi_encoder_model

  1"""
  2Model module for bi-encoder models.
  3
  4This module defines the model class used to implement bi-encoder models.
  5"""
  6
  7from abc import ABC, abstractmethod
  8from dataclasses import dataclass
  9from string import punctuation
 10from typing import Iterable, List, Literal, Self, Sequence, Tuple, Type, overload
 11
 12import torch
 13from transformers import BatchEncoding
 14
 15from ..base import LightningIRModel, LightningIROutput
 16from ..modeling_utils.batching import _batch_elementwise_scoring
 17from .bi_encoder_config import BiEncoderConfig, MultiVectorBiEncoderConfig, SingleVectorBiEncoderConfig
 18
 19
[docs] 20@dataclass 21class BiEncoderEmbedding: 22 """Dataclass containing embeddings and the encoding for single-vector bi-encoder models.""" 23 24 embeddings: torch.Tensor 25 """Embedding tensor generated by a bi-encoder model of shape [batch_size x seq_len x hidden_size]. The sequence 26 length varies depending on the pooling strategy and the hidden size varies depending on the projection settings.""" 27 scoring_mask: torch.Tensor | None 28 """Mask tensor designating which vectors should be ignored during scoring.""" 29 encoding: BatchEncoding | None 30 """Tokenizer encodings used to generate the embeddings.""" 31 ids: List[str] | None = None 32 """List of ids for the embeddings, e.g., query or document ids.""" 33 34 @overload 35 def to(self, device: torch.device, /) -> Self: ... 36 37 @overload 38 def to(self, other: Self, /) -> Self: ... 39
[docs] 40 def to(self, device) -> Self: 41 """Moves the embeddings to the specified device. 42 43 :param device: Device to move the embeddings to or another instance to move to the same device 44 :type device: torch.device | BiEncoderEmbedding 45 :return: Self 46 :rtype: BiEncoderEmbedding 47 """ 48 if isinstance(device, BiEncoderEmbedding): 49 device = device.device 50 self.embeddings = self.embeddings.to(device) 51 self.scoring_mask = self.scoring_mask.to(device) if self.scoring_mask is not None else None 52 self.encoding = self.encoding.to(device) if self.encoding is not None else None 53 return self
54 55 @property 56 def device(self) -> torch.device: 57 """Returns the device of the embeddings. 58 59 :raises ValueError: If the embeddings and scoring_mask are not on the same device 60 :return: The device of the embeddings 61 :rtype: torch.device 62 """ 63 return self.embeddings.device 64
[docs] 65 def items(self) -> Iterable[Tuple[str, torch.Tensor]]: 66 """Iterates over the embeddings attributes and their values like `dict.items()`. 67 68 :yield: Tuple of attribute name and its value 69 :rtype: Iterator[Iterable[Tuple[str, torch.Tensor]]] 70 """ 71 for field in self.__dataclass_fields__: 72 yield field, getattr(self, field)
73 74
[docs] 75@dataclass 76class BiEncoderOutput(LightningIROutput): 77 """Dataclass containing the output of a bi-encoder model.""" 78 79 query_embeddings: BiEncoderEmbedding | None = None 80 """Query embeddings generated by the model.""" 81 doc_embeddings: BiEncoderEmbedding | None = None 82 """Document embeddings generated by the model."""
83 84
[docs] 85class BiEncoderModel(LightningIRModel, ABC): 86 """A bi-encoder model that encodes queries and documents separately and computes a relevance score between them. 87 See :class:`.BiEncoderConfig` for configuration options.""" 88 89 config_class: Type[BiEncoderConfig] = BiEncoderConfig 90 """Configuration class for the bi-encoder model.""" 91
[docs] 92 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None: 93 """Initializes a bi-encoder model given a :class:`.BiEncoderConfig`. 94 95 :param config: Configuration for the bi-encoder model 96 :type config: BiEncoderConfig 97 """ 98 super().__init__(config, *args, **kwargs) 99 self.config: BiEncoderConfig 100 101 if self.config.similarity_function == "cosine": 102 self.similarity_function = self._cosine_similarity 103 elif self.config.similarity_function == "l2": 104 self.similarity_function = self._l2_similarity 105 elif self.config.similarity_function == "dot": 106 self.similarity_function = self._dot_similarity 107 else: 108 raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
109
[docs] 110 def forward( 111 self, 112 query_encoding: BatchEncoding | None, 113 doc_encoding: BatchEncoding | None, 114 num_docs: Sequence[int] | int | None = None, 115 ) -> BiEncoderOutput: 116 """Embeds queries and/or documents and computes relevance scores between them if both are provided. 117 118 :param query_encoding: Tokenizer encodings for the queries 119 :type query_encoding: BatchEncoding | None 120 :param doc_encoding: Tokenizer encodings for the documents 121 :type doc_encoding: BatchEncoding | None 122 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 123 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 124 sequence contains one value per query specifying the number of documents for that query. If an integer, 125 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 126 the number of documents by the number of queries, defaults to None 127 :type num_docs: Sequence[int] | int | None, optional 128 :return: Output of the model 129 :rtype: BiEncoderOutput 130 """ 131 query_embeddings = None 132 if query_encoding is not None: 133 query_embeddings = self.encode_query(query_encoding) 134 doc_embeddings = None 135 if doc_encoding is not None: 136 doc_embeddings = self.encode_doc(doc_encoding) 137 scores = None 138 if doc_embeddings is not None and query_embeddings is not None: 139 scores = self.score(query_embeddings, doc_embeddings, num_docs) 140 return BiEncoderOutput(scores=scores, query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
141
[docs] 142 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 143 """Encodes tokenized queries. 144 145 :param encoding: Tokenizer encodings for the queries 146 :type encoding: BatchEncoding 147 :return: Query embeddings and scoring mask 148 :rtype: BiEncoderEmbedding 149 """ 150 return self.encode(encoding=encoding, input_type="query")
151
[docs] 152 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 153 """Encodes tokenized documents. 154 155 :param encoding: Tokenizer encodings for the documents 156 :type encoding: BatchEncoding 157 :return: Query embeddings and scoring mask 158 :rtype: BiEncoderEmbedding 159 """ 160 return self.encode(encoding=encoding, input_type="doc")
161 162 def _parse_num_docs( 163 self, query_shape: int, doc_shape: int, num_docs: int | Sequence[int] | None, device: torch.device | None = None 164 ) -> torch.Tensor: 165 """Helper function to parse the number of documents per query.""" 166 if isinstance(num_docs, int): 167 num_docs = [num_docs] * query_shape 168 if isinstance(num_docs, list): 169 if sum(num_docs) != doc_shape or len(num_docs) != query_shape: 170 raise ValueError("Num docs does not match doc embeddings") 171 if num_docs is None: 172 if doc_shape % query_shape != 0: 173 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided") 174 num_docs = [doc_shape // query_shape] * query_shape 175 return torch.tensor(num_docs, device=device) 176
[docs] 177 def compute_similarity( 178 self, 179 query_embeddings: BiEncoderEmbedding, 180 doc_embeddings: BiEncoderEmbedding, 181 num_docs: Sequence[int] | int | None = None, 182 ) -> torch.Tensor: 183 """Computes the similarity score between all query and document embedding vector pairs. 184 185 :param query_embeddings: Embeddings of the queries 186 :type query_embeddings: BiEncoderEmbedding 187 :param doc_embeddings: Embeddings of the documents 188 :type doc_embeddings: BiEncoderEmbedding 189 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 190 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 191 sequence contains one value per query specifying the number of documents for that query. If an integer, 192 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 193 the number of documents by the number of queries, defaults to None 194 :type num_docs: Sequence[int] | int | None, optional 195 :return: Similarity scores between all query and document embedding vector pairs 196 :rtype: torch.Tensor 197 """ 198 num_docs_t = self._parse_num_docs( 199 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device 200 ) 201 query_emb = query_embeddings.embeddings.repeat_interleave(num_docs_t, dim=0).unsqueeze(2) 202 doc_emb = doc_embeddings.embeddings.unsqueeze(1) 203 similarity = self.similarity_function(query_emb, doc_emb) 204 return similarity
205 206 @staticmethod 207 @_batch_elementwise_scoring 208 @torch.autocast(device_type="cuda", enabled=False) 209 def _cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 210 return torch.nn.functional.cosine_similarity(x, y, dim=-1) 211 212 @staticmethod 213 @_batch_elementwise_scoring 214 @torch.autocast(device_type="cuda", enabled=False) 215 def _l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 216 return -1 * torch.cdist(x, y).squeeze(-2) 217 218 @staticmethod 219 @_batch_elementwise_scoring 220 @torch.autocast(device_type="cuda", enabled=False) 221 def _dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 222 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2) 223
[docs] 224 @abstractmethod 225 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 226 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 227 228 :param encoding: Tokenizer encodings for the text sequence 229 :type encoding: BatchEncoding 230 :param input_type: Type of input, either "query" or "doc" 231 :type input_type: Literal["query", "doc"] 232 :return: Embeddings and scoring mask 233 :rtype: BiEncoderEmbedding 234 """ 235 pass
236
[docs] 237 @abstractmethod 238 def score( 239 self, 240 query_embeddings: BiEncoderEmbedding, 241 doc_embeddings: BiEncoderEmbedding, 242 num_docs: Sequence[int] | int | None = None, 243 ) -> torch.Tensor: 244 """Compute relevance scores between queries and documents. 245 246 :param query_embeddings: Embeddings and scoring mask for the queries 247 :type query_embeddings: BiEncoderEmbedding 248 :param doc_embeddings: Embeddings and scoring mask for the documents 249 :type doc_embeddings: BiEncoderEmbedding 250 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 251 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 252 sequence contains one value per query specifying the number of documents for that query. If an integer, 253 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 254 the number of documents by the number of queries, defaults to None 255 :type num_docs: Sequence[int] | int | None, optional 256 :return: Relevance scores 257 :rtype: torch.Tensor 258 """ 259 pass
260 261
[docs] 262class SingleVectorBiEncoderModel(BiEncoderModel): 263 """A bi-encoder model that encodes queries and documents separately, pools the contextualized embeddings into a 264 single vector, and computes a relevance score based on the similarities between the two vectors. See 265 :class:`.SingleVectorBiEncoderConfig` for configuration options.""" 266 267 config_class: Type[SingleVectorBiEncoderConfig] = SingleVectorBiEncoderConfig 268 """Configuration class for the single-vector bi-encoder model.""" 269
[docs] 270 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 271 """Initializes a single-vector bi-encoder model given a :class:`.SingleVectorBiEncoderConfig`. 272 273 :param config: Configuration for the single-vector bi-encoder model 274 :type config: SingleVectorBiEncoderConfig 275 """ 276 super().__init__(config, *args, **kwargs)
277
[docs] 278 def score( 279 self, 280 query_embeddings: BiEncoderEmbedding, 281 doc_embeddings: BiEncoderEmbedding, 282 num_docs: Sequence[int] | int | None = None, 283 ) -> torch.Tensor: 284 """Compute relevance scores between queries and documents. 285 286 :param query_embeddings: Embeddings and scoring mask for the queries 287 :type query_embeddings: BiEncoderEmbedding 288 :param doc_embeddings: Embeddings and scoring mask for the documents 289 :type doc_embeddings: BiEncoderEmbedding 290 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 291 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 292 sequence contains one value per query specifying the number of documents for that query. If an integer, 293 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 294 the number of documents by the number of queries, defaults to None 295 :type num_docs: Sequence[int] | int | None, optional 296 :return: Relevance scores 297 :rtype: torch.Tensor 298 """ 299 return self.compute_similarity(query_embeddings, doc_embeddings, num_docs).view(-1)
300 301
[docs] 302class MultiVectorBiEncoderModel(BiEncoderModel): 303 304 config_class: Type[MultiVectorBiEncoderConfig] = MultiVectorBiEncoderConfig 305 """Configuration class for the single-vector bi-encoder model.""" 306 307 supports_retrieval_models = [] 308
[docs] 309 def __init__(self, config: MultiVectorBiEncoderConfig, *args, **kwargs) -> None: 310 """Initializes a multi-vector bi-encoder model given a :class:`.MultiVectorBiEncoderConfig`. 311 312 :param config: Configuration for the multi-vector bi-encoder model 313 :type config: MultiVectorBiEncoderConfig 314 :raises ValueError: If mask scoring tokens are specified in the configuration but the tokenizer is not available 315 :raises ValueError: If the specified mask scoring tokens are not in the tokenizer vocab 316 """ 317 super().__init__(config, *args, **kwargs) 318 319 self.query_mask_scoring_input_ids: torch.Tensor | None = None 320 self.doc_mask_scoring_input_ids: torch.Tensor | None = None 321 322 # Adds the mask scoring input ids to the model if they are specified in the configuration. 323 for sequence in ("query", "doc"): 324 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens") 325 if mask_scoring_tokens is None: 326 continue 327 if mask_scoring_tokens == "punctuation": 328 mask_scoring_tokens = list(punctuation) 329 try: 330 from transformers import AutoTokenizer 331 332 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path) 333 except OSError: 334 raise ValueError("Can't use token scoring masking if the checkpoint does not have a tokenizer.") 335 mask_scoring_input_ids = [] 336 for token in mask_scoring_tokens: 337 if token not in tokenizer.vocab: 338 raise ValueError(f"Token {token} not in tokenizer vocab") 339 mask_scoring_input_ids.append(tokenizer.vocab[token]) 340 setattr( 341 self, 342 f"{sequence}_mask_scoring_input_ids", 343 torch.tensor(mask_scoring_input_ids, dtype=torch.long), 344 )
345 346 def _expand_mask(self, shape: torch.Size, mask: torch.Tensor, dim: int) -> torch.Tensor: 347 """Helper function to expand the mask to the shape of the similarity scores.""" 348 if mask.ndim == len(shape): 349 return mask 350 if mask.ndim > len(shape): 351 raise ValueError("Mask has too many dimensions") 352 fill_values = len(shape) - mask.ndim + 1 353 new_shape = [*mask.shape[:-1]] + [1] * fill_values 354 new_shape[dim] = mask.shape[-1] 355 return mask.view(*new_shape) 356 357 def _aggregate( 358 self, 359 scores: torch.Tensor, 360 mask: torch.Tensor, 361 query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"], 362 dim: int, 363 ) -> torch.Tensor: 364 """Helper function to aggregate similarity scores over query and document embeddings.""" 365 mask = self._expand_mask(scores.shape, mask, dim) 366 if query_aggregation_function == "max": 367 scores = scores.masked_fill(~mask, float("-inf")) 368 return scores.amax(dim, keepdim=True) 369 if query_aggregation_function == "sum": 370 scores = scores.masked_fill(~mask, 0) 371 return scores.sum(dim, keepdim=True) 372 num_non_masked = mask.sum(dim, keepdim=True) 373 if query_aggregation_function == "mean": 374 return torch.where(num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked) 375 if query_aggregation_function == "harmonic_mean": 376 return torch.where( 377 num_non_masked == 0, 378 0, 379 num_non_masked / (1 / scores).sum(dim, keepdim=True), 380 ) 381 raise ValueError(f"Unknown aggregation {query_aggregation_function}") 382
[docs] 383 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 384 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 385 out vectors during scoring. 386 387 :param encoding: Tokenizer encodings for the text sequence 388 :type encoding: BatchEncoding 389 :param input_type: Type of input, either "query" or "doc" 390 :type input_type: Literal["query", "doc"] 391 :return: Scoring mask 392 :rtype: torch.Tensor 393 """ 394 input_ids = encoding["input_ids"] 395 attention_mask = encoding["attention_mask"] 396 scoring_mask = attention_mask 397 if scoring_mask is None: 398 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool) 399 scoring_mask = scoring_mask.bool() 400 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids") 401 if mask_scoring_input_ids is not None: 402 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1) 403 scoring_mask = scoring_mask & ~ignore_mask 404 return scoring_mask
405
[docs] 406 def aggregate_similarity( 407 self, 408 similarity: torch.Tensor, 409 query_scoring_mask: torch.Tensor, 410 doc_scoring_mask: torch.Tensor, 411 num_docs: int | Sequence[int] | None = None, 412 ) -> torch.Tensor: 413 """Aggregates the matrix of query-document similarities into a single score based on the configured aggregation 414 strategy. 415 416 :param similarity: Query-document similarity matrix 417 :type similarity: torch.Tensor 418 :param query_scoring_mask: Which query vectors should be masked out during scoring 419 :type query_scoring_mask: torch.Tensor 420 :param doc_scoring_mask: Which doucment vectors should be masked out during scoring 421 :type doc_scoring_mask: torch.Tensor 422 :return: Aggregated similarity scores 423 :rtype: torch.Tensor 424 """ 425 num_docs_t = self._parse_num_docs( 426 query_scoring_mask.shape[0], doc_scoring_mask.shape[0], num_docs, similarity.device 427 ) 428 scores = self._aggregate(similarity, doc_scoring_mask, self.config.doc_aggregation_function, -1) 429 repeated_query_scoring_mask = query_scoring_mask.repeat_interleave(num_docs_t, dim=0) 430 scores = self._aggregate(scores, repeated_query_scoring_mask, self.config.query_aggregation_function, -2) 431 return scores.view(scores.shape[0])
432
[docs] 433 def score( 434 self, 435 query_embeddings: BiEncoderEmbedding, 436 doc_embeddings: BiEncoderEmbedding, 437 num_docs: Sequence[int] | int | None = None, 438 ) -> torch.Tensor: 439 """Compute relevance scores between queries and documents. 440 441 :param query_embeddings: Embeddings and scoring mask for the queries 442 :type query_embeddings: BiEncoderEmbedding 443 :param doc_embeddings: Embeddings and scoring mask for the documents 444 :type doc_embeddings: BiEncoderEmbedding 445 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 446 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 447 sequence contains one value per query specifying the number of documents for that query. If an integer, 448 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 449 the number of documents by the number of queries, defaults to None 450 :type num_docs: Sequence[int] | int | None, optional 451 :return: Relevance scores 452 :rtype: torch.Tensor 453 """ 454 similarity = self.compute_similarity(query_embeddings, doc_embeddings, num_docs) 455 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None: 456 raise ValueError("Scoring masks expected for scoring multi-vector embeddings") 457 return self.aggregate_similarity(similarity, query_embeddings.scoring_mask, doc_embeddings.scoring_mask)