Source code for lightning_ir.bi_encoder.bi_encoder_model

  1"""
  2Model module for bi-encoder models.
  3
  4This module defines the model class used to implement bi-encoder models.
  5"""
  6
  7from abc import ABC, abstractmethod
  8from collections.abc import Iterable, Sequence
  9from dataclasses import dataclass
 10from string import punctuation
 11from typing import Literal, Self, overload
 12
 13import torch
 14from transformers import BatchEncoding
 15
 16from ..base import LightningIRModel, LightningIROutput
 17from ..modeling_utils.batching import _batch_elementwise_scoring
 18from .bi_encoder_config import BiEncoderConfig, MultiVectorBiEncoderConfig, SingleVectorBiEncoderConfig
 19
 20
[docs] 21@dataclass 22class BiEncoderEmbedding: 23 """Dataclass containing embeddings and the encoding for bi-encoder models.""" 24 25 embeddings: torch.Tensor 26 """Embedding tensor generated by a bi-encoder model of shape [batch_size x seq_len x hidden_size]. The sequence 27 length varies depending on the pooling strategy and the hidden size varies depending on the projection settings.""" 28 scoring_mask: torch.Tensor | None = None 29 """Mask tensor designating which vectors should be ignored during scoring.""" 30 encoding: BatchEncoding | None = None 31 """Tokenizer encodings used to generate the embeddings.""" 32 ids: list[str] | None = None 33 """list of ids for the embeddings, e.g., query or document ids.""" 34 35 @overload 36 def to(self, device: torch.device, /) -> Self: ... 37 38 @overload 39 def to(self, other: Self, /) -> Self: ... 40
[docs] 41 def to(self, device) -> Self: 42 """Moves the embeddings to the specified device. 43 44 Args: 45 device (torch.device | BiEncoderEmbedding): Device to move the embeddings to or another instance to move to 46 the same device. 47 Returns: 48 Self: The instance with embeddings moved to the specified device. 49 """ 50 if isinstance(device, BiEncoderEmbedding): 51 device = device.device 52 self.embeddings = self.embeddings.to(device) 53 self.scoring_mask = self.scoring_mask.to(device) if self.scoring_mask is not None else None 54 self.encoding = self.encoding.to(device) if self.encoding is not None else None 55 return self
56 57 @property 58 def device(self) -> torch.device: 59 """Returns the device of the embeddings. 60 61 Returns: 62 torch.device: Device of the embeddings. 63 Raises: 64 ValueError: If the embeddings and scoring_mask are not on the same device. 65 """ 66 return self.embeddings.device 67
[docs] 68 def items(self) -> Iterable[tuple[str, torch.Tensor]]: 69 """Iterates over the embeddings attributes and their values like `dict.items()`. 70 71 Yields: 72 tuple[str, torch.Tensor]: tuple of attribute name and its value. 73 """ 74 for field in self.__dataclass_fields__: 75 yield field, getattr(self, field)
76 77
[docs] 78@dataclass 79class BiEncoderOutput(LightningIROutput): 80 """Dataclass containing the output of a bi-encoder model.""" 81 82 query_embeddings: BiEncoderEmbedding | None = None 83 """Query embeddings generated by the model.""" 84 doc_embeddings: BiEncoderEmbedding | None = None 85 """Document embeddings generated by the model.""" 86 similarity: torch.Tensor | None = None 87 """Similarity scores between query and document embeddings."""
88 89
[docs] 90class BiEncoderModel(LightningIRModel, ABC): 91 """A bi-encoder model that encodes queries and documents separately and computes a relevance score between them. 92 See :class:`.BiEncoderConfig` for configuration options.""" 93 94 config_class: type[BiEncoderConfig] = BiEncoderConfig 95 """Configuration class for the bi-encoder model.""" 96
[docs] 97 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None: 98 """Initializes a bi-encoder model given a :class:`.BiEncoderConfig`. 99 100 Args: 101 config (BiEncoderConfig): Configuration for the bi-encoder model. 102 Raises: 103 ValueError: If the similarity function is not supported. 104 """ 105 super().__init__(config, *args, **kwargs) 106 self.config: BiEncoderConfig 107 108 if self.config.similarity_function == "cosine": 109 self.similarity_function = self._cosine_similarity 110 elif self.config.similarity_function == "l2": 111 self.similarity_function = self._l2_similarity 112 elif self.config.similarity_function == "dot": 113 self.similarity_function = self._dot_similarity 114 else: 115 raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
116
[docs] 117 def forward( 118 self, 119 query_encoding: BatchEncoding | None, 120 doc_encoding: BatchEncoding | None, 121 num_docs: Sequence[int] | int | None = None, 122 ) -> BiEncoderOutput: 123 """Embeds queries and/or documents and computes relevance scores between them if both are provided. 124 125 Args: 126 query_encoding (BatchEncoding | None): Tokenizer encodings for the queries. Defaults to None. 127 doc_encoding (BatchEncoding | None): Tokenizer encodings for the documents. Defaults to None. 128 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 129 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 130 number of documents, i.e., the sequence contains one value per query specifying the number of documents 131 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 132 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 133 Returns: 134 BiEncoderOutput: Output of the model containing query and document embeddings and relevance scores. 135 """ 136 query_embeddings = None 137 if query_encoding is not None: 138 query_embeddings = self.encode_query(query_encoding) 139 doc_embeddings = None 140 if doc_encoding is not None: 141 doc_embeddings = self.encode_doc(doc_encoding) 142 output = BiEncoderOutput(query_embeddings=query_embeddings, doc_embeddings=doc_embeddings) 143 if doc_embeddings is not None and query_embeddings is not None: 144 output = self.score(output, num_docs) 145 return output
146
[docs] 147 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 148 """Encodes tokenized queries. 149 150 Args: 151 encoding (BatchEncoding): Tokenizer encodings for the queries. 152 Returns: 153 BiEncoderEmbedding: Query embeddings and scoring mask. 154 """ 155 return self.encode(encoding=encoding, input_type="query")
156
[docs] 157 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 158 """Encodes tokenized documents. 159 160 Args: 161 encoding (BatchEncoding): Tokenizer encodings for the documents. 162 Returns: 163 BiEncoderEmbedding: Document embeddings and scoring mask. 164 """ 165 return self.encode(encoding=encoding, input_type="doc")
166 167 def _parse_num_docs( 168 self, query_shape: int, doc_shape: int, num_docs: int | Sequence[int] | None, device: torch.device | None = None 169 ) -> torch.Tensor: 170 """Helper function to parse the number of documents per query.""" 171 if isinstance(num_docs, int): 172 num_docs = [num_docs] * query_shape 173 if isinstance(num_docs, list): 174 if sum(num_docs) != doc_shape or len(num_docs) != query_shape: 175 raise ValueError("Num docs does not match doc embeddings") 176 if num_docs is None: 177 if doc_shape % query_shape != 0: 178 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided") 179 num_docs = [doc_shape // query_shape] * query_shape 180 return torch.tensor(num_docs, device=device) 181
[docs] 182 def compute_similarity( 183 self, 184 query_embeddings: BiEncoderEmbedding, 185 doc_embeddings: BiEncoderEmbedding, 186 num_docs: Sequence[int] | int | None = None, 187 ) -> torch.Tensor: 188 """Computes the similarity score between all query and document embedding vector pairs. 189 190 Args: 191 query_embeddings (BiEncoderEmbedding): Embeddings of the queries. 192 doc_embeddings (BiEncoderEmbedding): Embeddings of the documents. 193 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 194 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the 195 number of documents, i.e., the sequence contains one value per query specifying the number of documents 196 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 197 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 198 Returns: 199 torch.Tensor: Similarity scores between all query and document embedding vector pairs. 200 """ 201 num_docs_t = self._parse_num_docs( 202 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device 203 ) 204 query_emb = query_embeddings.embeddings.repeat_interleave(num_docs_t, dim=0).unsqueeze(2) 205 doc_emb = doc_embeddings.embeddings.unsqueeze(1) 206 similarity = self.similarity_function(query_emb, doc_emb) 207 return similarity
208 209 @staticmethod 210 @_batch_elementwise_scoring 211 @torch.autocast(device_type="cuda", enabled=False) 212 def _cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 213 return torch.nn.functional.cosine_similarity(x, y, dim=-1) 214 215 @staticmethod 216 @_batch_elementwise_scoring 217 @torch.autocast(device_type="cuda", enabled=False) 218 def _l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 219 return -1 * torch.cdist(x, y).squeeze(-2) 220 221 @staticmethod 222 @_batch_elementwise_scoring 223 @torch.autocast(device_type="cuda", enabled=False) 224 def _dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 225 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2) 226
[docs] 227 @abstractmethod 228 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 229 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 230 231 Args: 232 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 233 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 234 Returns: 235 BiEncoderEmbedding: Embeddings and scoring mask for the text sequence. 236 Raises: 237 NotImplementedError: If the method is not implemented. 238 """ 239 raise NotImplementedError
240
[docs] 241 @abstractmethod 242 def score(self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None) -> BiEncoderOutput: 243 """Compute relevance scores between queries and documents. 244 245 Args: 246 output (BiEncoderOutput): Output containing embeddings and scoring mask. 247 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 248 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 249 number of documents, i.e., the sequence contains one value per query specifying the number of documents 250 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 251 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 252 Returns: 253 torch.Tensor: Relevance scores. 254 Raises: 255 NotImplementedError: If the method is not implemented. 256 """ 257 raise NotImplementedError
258 259
[docs] 260class SingleVectorBiEncoderModel(BiEncoderModel): 261 """A bi-encoder model that encodes queries and documents separately, pools the contextualized embeddings into a 262 single vector, and computes a relevance score based on the similarities between the two vectors. See 263 :class:`.SingleVectorBiEncoderConfig` for configuration options.""" 264 265 config_class: type[SingleVectorBiEncoderConfig] = SingleVectorBiEncoderConfig 266 """Configuration class for the single-vector bi-encoder model.""" 267
[docs] 268 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 269 """Initializes a single-vector bi-encoder model given a :class:`.SingleVectorBiEncoderConfig`. 270 271 Args: 272 config (SingleVectorBiEncoderConfig): Configuration for the single-vector bi-encoder model. 273 """ 274 super().__init__(config, *args, **kwargs)
275
[docs] 276 def score( 277 self, 278 output: BiEncoderOutput, 279 num_docs: Sequence[int] | int | None = None, 280 ) -> BiEncoderOutput: 281 """Compute relevance scores between queries and documents. 282 283 Args: 284 output (BiEncoderOutput): Output containing embeddings and scoring mask. 285 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 286 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 287 number of documents, i.e., the sequence contains one value per query specifying the number of documents 288 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 289 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 290 Returns: 291 BiEncoderOutput: Output containing relevance scores. 292 Raises: 293 ValueError: If query or document embeddings are not provided in the output. 294 """ 295 if output.query_embeddings is None or output.doc_embeddings is None: 296 raise ValueError("Expected query and document embeddings in BiEncoderOutput") 297 similarity = self.compute_similarity(output.query_embeddings, output.doc_embeddings, num_docs) 298 output.scores = similarity.view(-1) 299 output.similarity = similarity 300 return output
301 302
[docs] 303class MultiVectorBiEncoderModel(BiEncoderModel): 304 config_class: type[MultiVectorBiEncoderConfig] = MultiVectorBiEncoderConfig 305 """Configuration class for the single-vector bi-encoder model.""" 306 307 supports_retrieval_models = [] 308
[docs] 309 def __init__(self, config: MultiVectorBiEncoderConfig, *args, **kwargs) -> None: 310 """Initializes a multi-vector bi-encoder model given a :class:`.MultiVectorBiEncoderConfig`. 311 312 Args: 313 config (MultiVectorBiEncoderConfig): Configuration for the multi-vector bi-encoder model 314 Raises: 315 ValueError: If mask scoring tokens are specified in the configuration but the tokenizer is not available 316 ValueError: If the specified mask scoring tokens are not in the tokenizer vocab 317 """ 318 super().__init__(config, *args, **kwargs) 319 320 self.query_mask_scoring_input_ids: torch.Tensor | None = None 321 self.doc_mask_scoring_input_ids: torch.Tensor | None = None 322 323 # Adds the mask scoring input ids to the model if they are specified in the configuration. 324 for sequence in ("query", "doc"): 325 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens") 326 if mask_scoring_tokens is None: 327 continue 328 if mask_scoring_tokens == "punctuation": 329 mask_scoring_tokens = list(punctuation) 330 try: 331 from transformers import AutoTokenizer 332 333 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path) 334 except OSError as err: 335 raise ValueError( 336 "Can't use token scoring masking if the checkpoint does not have a tokenizer." 337 ) from err 338 mask_scoring_input_ids = [] 339 for token in mask_scoring_tokens: 340 if token not in tokenizer.vocab: 341 raise ValueError(f"Token {token} not in tokenizer vocab") 342 mask_scoring_input_ids.append(tokenizer.vocab[token]) 343 setattr( 344 self, 345 f"{sequence}_mask_scoring_input_ids", 346 torch.tensor(mask_scoring_input_ids, dtype=torch.long), 347 )
348 349 def _expand_mask(self, shape: torch.Size, mask: torch.Tensor, dim: int) -> torch.Tensor: 350 """Helper function to expand the mask to the shape of the similarity scores.""" 351 if mask.ndim == len(shape): 352 return mask 353 if mask.ndim > len(shape): 354 raise ValueError("Mask has too many dimensions") 355 fill_values = len(shape) - mask.ndim + 1 356 new_shape = [*mask.shape[:-1]] + [1] * fill_values 357 new_shape[dim] = mask.shape[-1] 358 return mask.view(*new_shape) 359 360 def _aggregate( 361 self, 362 scores: torch.Tensor, 363 mask: torch.Tensor, 364 aggregation_function: Literal["max", "sum", "mean"], 365 dim: int, 366 ) -> torch.Tensor: 367 """Helper function to aggregate similarity scores over query and document embeddings.""" 368 mask = self._expand_mask(scores.shape, mask, dim) 369 if aggregation_function == "max": 370 scores = scores.masked_fill(~mask, float("-inf")) 371 return scores.amax(dim, keepdim=True) 372 scores = scores.masked_fill(~mask, 0) 373 if aggregation_function == "sum": 374 return scores.sum(dim, keepdim=True) 375 num_non_masked = mask.sum(dim, keepdim=True) 376 if aggregation_function == "mean": 377 return scores.masked_fill(num_non_masked == 0, 0).sum(dim, keepdim=True) / num_non_masked.clamp(min=1) 378 raise ValueError(f"Unknown aggregation {aggregation_function}") 379
[docs] 380 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 381 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 382 out vectors during scoring. 383 384 Args: 385 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 386 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 387 Returns: 388 torch.Tensor: Scoring mask. 389 """ 390 input_ids = encoding["input_ids"] 391 attention_mask = encoding["attention_mask"] 392 scoring_mask = attention_mask 393 if scoring_mask is None: 394 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool) 395 scoring_mask = scoring_mask.bool() 396 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids") 397 if mask_scoring_input_ids is not None: 398 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1) 399 scoring_mask = scoring_mask & ~ignore_mask 400 return scoring_mask
401
[docs] 402 def aggregate_similarity( 403 self, 404 similarity: torch.Tensor, 405 query_scoring_mask: torch.Tensor, 406 doc_scoring_mask: torch.Tensor, 407 num_docs: int | Sequence[int] | None = None, 408 ) -> torch.Tensor: 409 """Aggregates the matrix of query-document similarities into a single score based on the configured aggregation 410 strategy. 411 412 Args: 413 similarity (torch.Tensor): Query--document similarity matrix. 414 query_scoring_mask (torch.Tensor): Which query vectors should be masked out during scoring. 415 doc_scoring_mask (torch.Tensor): Which document vectors should be masked out during scoring. 416 Returns: 417 torch.Tensor: Aggregated similarity scores. 418 """ 419 scores = similarity 420 if scores.shape[-1] > 1: 421 scores = self._aggregate(scores, doc_scoring_mask, self.config.doc_aggregation_function, -1) 422 if scores.shape[-2] > 1: 423 num_docs_t = self._parse_num_docs( 424 query_scoring_mask.shape[0], doc_scoring_mask.shape[0], num_docs, similarity.device 425 ) 426 repeated_query_scoring_mask = query_scoring_mask.repeat_interleave(num_docs_t, dim=0) 427 scores = self._aggregate(scores, repeated_query_scoring_mask, self.config.query_aggregation_function, -2) 428 return scores.view(scores.shape[0])
429
[docs] 430 def score(self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None) -> BiEncoderOutput: 431 """Compute relevance scores between queries and documents. 432 433 Args: 434 output (BiEncoderOutput): Output containing embeddings and scoring mask. 435 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 436 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 437 number of documents, i.e., the sequence contains one value per query specifying the number of documents 438 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 439 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 440 Returns: 441 BiEncoderOutput: Output containing relevance scores. 442 Raises: 443 ValueError: If query or document embeddings are not provided in the output. 444 ValueError: If scoring masks are not provided for the embeddings. 445 """ 446 query_embeddings = output.query_embeddings 447 doc_embeddings = output.doc_embeddings 448 if query_embeddings is None or doc_embeddings is None: 449 raise ValueError("Embeddings expected for scoring multi-vector embeddings") 450 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None: 451 raise ValueError("Scoring masks expected for scoring multi-vector embeddings") 452 453 similarity = self.compute_similarity(query_embeddings, doc_embeddings, num_docs) 454 scores = self.aggregate_similarity( 455 similarity, query_embeddings.scoring_mask, doc_embeddings.scoring_mask, num_docs 456 ) 457 output.scores = scores 458 output.similarity = similarity 459 return output