Source code for lightning_ir.bi_encoder.bi_encoder_model

  1"""
  2Model module for bi-encoder models.
  3
  4This module defines the model class used to implement bi-encoder models.
  5"""
  6
  7from abc import ABC, abstractmethod
  8from dataclasses import dataclass
  9from string import punctuation
 10from typing import Iterable, List, Literal, Self, Sequence, Tuple, Type, overload
 11
 12import torch
 13from transformers import BatchEncoding
 14
 15from ..base import LightningIRModel, LightningIROutput
 16from ..modeling_utils.batching import _batch_elementwise_scoring
 17from .bi_encoder_config import BiEncoderConfig, MultiVectorBiEncoderConfig, SingleVectorBiEncoderConfig
 18
 19
[docs] 20@dataclass 21class BiEncoderEmbedding: 22 """Dataclass containing embeddings and the encoding for bi-encoder models.""" 23 24 embeddings: torch.Tensor 25 """Embedding tensor generated by a bi-encoder model of shape [batch_size x seq_len x hidden_size]. The sequence 26 length varies depending on the pooling strategy and the hidden size varies depending on the projection settings.""" 27 scoring_mask: torch.Tensor | None = None 28 """Mask tensor designating which vectors should be ignored during scoring.""" 29 encoding: BatchEncoding | None = None 30 """Tokenizer encodings used to generate the embeddings.""" 31 ids: List[str] | None = None 32 """List of ids for the embeddings, e.g., query or document ids.""" 33 34 @overload 35 def to(self, device: torch.device, /) -> Self: ... 36 37 @overload 38 def to(self, other: Self, /) -> Self: ... 39
[docs] 40 def to(self, device) -> Self: 41 """Moves the embeddings to the specified device. 42 43 Args: 44 device (torch.device | BiEncoderEmbedding): Device to move the embeddings to or another instance to move to 45 the same device. 46 Returns: 47 Self: The instance with embeddings moved to the specified device. 48 """ 49 if isinstance(device, BiEncoderEmbedding): 50 device = device.device 51 self.embeddings = self.embeddings.to(device) 52 self.scoring_mask = self.scoring_mask.to(device) if self.scoring_mask is not None else None 53 self.encoding = self.encoding.to(device) if self.encoding is not None else None 54 return self
55 56 @property 57 def device(self) -> torch.device: 58 """Returns the device of the embeddings. 59 60 Returns: 61 torch.device: Device of the embeddings. 62 Raises: 63 ValueError: If the embeddings and scoring_mask are not on the same device. 64 """ 65 return self.embeddings.device 66
[docs] 67 def items(self) -> Iterable[Tuple[str, torch.Tensor]]: 68 """Iterates over the embeddings attributes and their values like `dict.items()`. 69 70 Yields: 71 Tuple[str, torch.Tensor]: Tuple of attribute name and its value. 72 """ 73 for field in self.__dataclass_fields__: 74 yield field, getattr(self, field)
75 76
[docs] 77@dataclass 78class BiEncoderOutput(LightningIROutput): 79 """Dataclass containing the output of a bi-encoder model.""" 80 81 query_embeddings: BiEncoderEmbedding | None = None 82 """Query embeddings generated by the model.""" 83 doc_embeddings: BiEncoderEmbedding | None = None 84 """Document embeddings generated by the model.""" 85 similarity: torch.Tensor | None = None 86 """Similarity scores between query and document embeddings."""
87 88
[docs] 89class BiEncoderModel(LightningIRModel, ABC): 90 """A bi-encoder model that encodes queries and documents separately and computes a relevance score between them. 91 See :class:`.BiEncoderConfig` for configuration options.""" 92 93 config_class: Type[BiEncoderConfig] = BiEncoderConfig 94 """Configuration class for the bi-encoder model.""" 95
[docs] 96 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None: 97 """Initializes a bi-encoder model given a :class:`.BiEncoderConfig`. 98 99 Args: 100 config (BiEncoderConfig): Configuration for the bi-encoder model. 101 Raises: 102 ValueError: If the similarity function is not supported. 103 """ 104 super().__init__(config, *args, **kwargs) 105 self.config: BiEncoderConfig 106 107 if self.config.similarity_function == "cosine": 108 self.similarity_function = self._cosine_similarity 109 elif self.config.similarity_function == "l2": 110 self.similarity_function = self._l2_similarity 111 elif self.config.similarity_function == "dot": 112 self.similarity_function = self._dot_similarity 113 else: 114 raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
115
[docs] 116 def forward( 117 self, 118 query_encoding: BatchEncoding | None, 119 doc_encoding: BatchEncoding | None, 120 num_docs: Sequence[int] | int | None = None, 121 ) -> BiEncoderOutput: 122 """Embeds queries and/or documents and computes relevance scores between them if both are provided. 123 124 Args: 125 query_encoding (BatchEncoding | None): Tokenizer encodings for the queries. Defaults to None. 126 doc_encoding (BatchEncoding | None): Tokenizer encodings for the documents. Defaults to None. 127 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 128 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 129 number of documents, i.e., the sequence contains one value per query specifying the number of documents 130 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 131 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 132 Returns: 133 BiEncoderOutput: Output of the model containing query and document embeddings and relevance scores. 134 """ 135 query_embeddings = None 136 if query_encoding is not None: 137 query_embeddings = self.encode_query(query_encoding) 138 doc_embeddings = None 139 if doc_encoding is not None: 140 doc_embeddings = self.encode_doc(doc_encoding) 141 output = BiEncoderOutput(query_embeddings=query_embeddings, doc_embeddings=doc_embeddings) 142 if doc_embeddings is not None and query_embeddings is not None: 143 output = self.score(output, num_docs) 144 return output
145
[docs] 146 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 147 """Encodes tokenized queries. 148 149 Args: 150 encoding (BatchEncoding): Tokenizer encodings for the queries. 151 Returns: 152 BiEncoderEmbedding: Query embeddings and scoring mask. 153 """ 154 return self.encode(encoding=encoding, input_type="query")
155
[docs] 156 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 157 """Encodes tokenized documents. 158 159 Args: 160 encoding (BatchEncoding): Tokenizer encodings for the documents. 161 Returns: 162 BiEncoderEmbedding: Document embeddings and scoring mask. 163 """ 164 return self.encode(encoding=encoding, input_type="doc")
165 166 def _parse_num_docs( 167 self, query_shape: int, doc_shape: int, num_docs: int | Sequence[int] | None, device: torch.device | None = None 168 ) -> torch.Tensor: 169 """Helper function to parse the number of documents per query.""" 170 if isinstance(num_docs, int): 171 num_docs = [num_docs] * query_shape 172 if isinstance(num_docs, list): 173 if sum(num_docs) != doc_shape or len(num_docs) != query_shape: 174 raise ValueError("Num docs does not match doc embeddings") 175 if num_docs is None: 176 if doc_shape % query_shape != 0: 177 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided") 178 num_docs = [doc_shape // query_shape] * query_shape 179 return torch.tensor(num_docs, device=device) 180
[docs] 181 def compute_similarity( 182 self, 183 query_embeddings: BiEncoderEmbedding, 184 doc_embeddings: BiEncoderEmbedding, 185 num_docs: Sequence[int] | int | None = None, 186 ) -> torch.Tensor: 187 """Computes the similarity score between all query and document embedding vector pairs. 188 189 Args: 190 query_embeddings (BiEncoderEmbedding): Embeddings of the queries. 191 doc_embeddings (BiEncoderEmbedding): Embeddings of the documents. 192 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 193 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the 194 number of documents, i.e., the sequence contains one value per query specifying the number of documents 195 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 196 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 197 Returns: 198 torch.Tensor: Similarity scores between all query and document embedding vector pairs. 199 """ 200 num_docs_t = self._parse_num_docs( 201 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device 202 ) 203 query_emb = query_embeddings.embeddings.repeat_interleave(num_docs_t, dim=0).unsqueeze(2) 204 doc_emb = doc_embeddings.embeddings.unsqueeze(1) 205 similarity = self.similarity_function(query_emb, doc_emb) 206 return similarity
207 208 @staticmethod 209 @_batch_elementwise_scoring 210 @torch.autocast(device_type="cuda", enabled=False) 211 def _cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 212 return torch.nn.functional.cosine_similarity(x, y, dim=-1) 213 214 @staticmethod 215 @_batch_elementwise_scoring 216 @torch.autocast(device_type="cuda", enabled=False) 217 def _l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 218 return -1 * torch.cdist(x, y).squeeze(-2) 219 220 @staticmethod 221 @_batch_elementwise_scoring 222 @torch.autocast(device_type="cuda", enabled=False) 223 def _dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 224 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2) 225
[docs] 226 @abstractmethod 227 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 228 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 229 230 Args: 231 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 232 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 233 Returns: 234 BiEncoderEmbedding: Embeddings and scoring mask for the text sequence. 235 Raises: 236 NotImplementedError: If the method is not implemented. 237 """ 238 raise NotImplementedError
239
[docs] 240 @abstractmethod 241 def score(self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None) -> BiEncoderOutput: 242 """Compute relevance scores between queries and documents. 243 244 Args: 245 output (BiEncoderOutput): Output containing embeddings and scoring mask. 246 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 247 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 248 number of documents, i.e., the sequence contains one value per query specifying the number of documents 249 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 250 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 251 Returns: 252 torch.Tensor: Relevance scores. 253 Raises: 254 NotImplementedError: If the method is not implemented. 255 """ 256 raise NotImplementedError
257 258
[docs] 259class SingleVectorBiEncoderModel(BiEncoderModel): 260 """A bi-encoder model that encodes queries and documents separately, pools the contextualized embeddings into a 261 single vector, and computes a relevance score based on the similarities between the two vectors. See 262 :class:`.SingleVectorBiEncoderConfig` for configuration options.""" 263 264 config_class: Type[SingleVectorBiEncoderConfig] = SingleVectorBiEncoderConfig 265 """Configuration class for the single-vector bi-encoder model.""" 266
[docs] 267 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 268 """Initializes a single-vector bi-encoder model given a :class:`.SingleVectorBiEncoderConfig`. 269 270 Args: 271 config (SingleVectorBiEncoderConfig): Configuration for the single-vector bi-encoder model. 272 """ 273 super().__init__(config, *args, **kwargs)
274
[docs] 275 def score( 276 self, 277 output: BiEncoderOutput, 278 num_docs: Sequence[int] | int | None = None, 279 ) -> BiEncoderOutput: 280 """Compute relevance scores between queries and documents. 281 282 Args: 283 output (BiEncoderOutput): Output containing embeddings and scoring mask. 284 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 285 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 286 number of documents, i.e., the sequence contains one value per query specifying the number of documents 287 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 288 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 289 Returns: 290 BiEncoderOutput: Output containing relevance scores. 291 Raises: 292 ValueError: If query or document embeddings are not provided in the output. 293 """ 294 if output.query_embeddings is None or output.doc_embeddings is None: 295 raise ValueError("Expected query and document embeddings in BiEncoderOutput") 296 similarity = self.compute_similarity(output.query_embeddings, output.doc_embeddings, num_docs) 297 output.scores = similarity.view(-1) 298 output.similarity = similarity 299 return output
300 301
[docs] 302class MultiVectorBiEncoderModel(BiEncoderModel): 303 304 config_class: Type[MultiVectorBiEncoderConfig] = MultiVectorBiEncoderConfig 305 """Configuration class for the single-vector bi-encoder model.""" 306 307 supports_retrieval_models = [] 308
[docs] 309 def __init__(self, config: MultiVectorBiEncoderConfig, *args, **kwargs) -> None: 310 """Initializes a multi-vector bi-encoder model given a :class:`.MultiVectorBiEncoderConfig`. 311 312 Args: 313 config (MultiVectorBiEncoderConfig): Configuration for the multi-vector bi-encoder model 314 Raises: 315 ValueError: If mask scoring tokens are specified in the configuration but the tokenizer is not available 316 ValueError: If the specified mask scoring tokens are not in the tokenizer vocab 317 """ 318 super().__init__(config, *args, **kwargs) 319 320 self.query_mask_scoring_input_ids: torch.Tensor | None = None 321 self.doc_mask_scoring_input_ids: torch.Tensor | None = None 322 323 # Adds the mask scoring input ids to the model if they are specified in the configuration. 324 for sequence in ("query", "doc"): 325 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens") 326 if mask_scoring_tokens is None: 327 continue 328 if mask_scoring_tokens == "punctuation": 329 mask_scoring_tokens = list(punctuation) 330 try: 331 from transformers import AutoTokenizer 332 333 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path) 334 except OSError: 335 raise ValueError("Can't use token scoring masking if the checkpoint does not have a tokenizer.") 336 mask_scoring_input_ids = [] 337 for token in mask_scoring_tokens: 338 if token not in tokenizer.vocab: 339 raise ValueError(f"Token {token} not in tokenizer vocab") 340 mask_scoring_input_ids.append(tokenizer.vocab[token]) 341 setattr( 342 self, 343 f"{sequence}_mask_scoring_input_ids", 344 torch.tensor(mask_scoring_input_ids, dtype=torch.long), 345 )
346 347 def _expand_mask(self, shape: torch.Size, mask: torch.Tensor, dim: int) -> torch.Tensor: 348 """Helper function to expand the mask to the shape of the similarity scores.""" 349 if mask.ndim == len(shape): 350 return mask 351 if mask.ndim > len(shape): 352 raise ValueError("Mask has too many dimensions") 353 fill_values = len(shape) - mask.ndim + 1 354 new_shape = [*mask.shape[:-1]] + [1] * fill_values 355 new_shape[dim] = mask.shape[-1] 356 return mask.view(*new_shape) 357 358 def _aggregate( 359 self, 360 scores: torch.Tensor, 361 mask: torch.Tensor, 362 aggregation_function: Literal["max", "sum", "mean"], 363 dim: int, 364 ) -> torch.Tensor: 365 """Helper function to aggregate similarity scores over query and document embeddings.""" 366 mask = self._expand_mask(scores.shape, mask, dim) 367 if aggregation_function == "max": 368 scores = scores.masked_fill(~mask, float("-inf")) 369 return scores.amax(dim, keepdim=True) 370 scores = scores.masked_fill(~mask, 0) 371 if aggregation_function == "sum": 372 return scores.sum(dim, keepdim=True) 373 num_non_masked = mask.sum(dim, keepdim=True) 374 if aggregation_function == "mean": 375 return scores.masked_fill(num_non_masked == 0, 0).sum(dim, keepdim=True) / num_non_masked.clamp(min=1) 376 raise ValueError(f"Unknown aggregation {aggregation_function}") 377
[docs] 378 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 379 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 380 out vectors during scoring. 381 382 Args: 383 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 384 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 385 Returns: 386 torch.Tensor: Scoring mask. 387 """ 388 input_ids = encoding["input_ids"] 389 attention_mask = encoding["attention_mask"] 390 scoring_mask = attention_mask 391 if scoring_mask is None: 392 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool) 393 scoring_mask = scoring_mask.bool() 394 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids") 395 if mask_scoring_input_ids is not None: 396 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1) 397 scoring_mask = scoring_mask & ~ignore_mask 398 return scoring_mask
399
[docs] 400 def aggregate_similarity( 401 self, 402 similarity: torch.Tensor, 403 query_scoring_mask: torch.Tensor, 404 doc_scoring_mask: torch.Tensor, 405 num_docs: int | Sequence[int] | None = None, 406 ) -> torch.Tensor: 407 """Aggregates the matrix of query-document similarities into a single score based on the configured aggregation 408 strategy. 409 410 Args: 411 similarity (torch.Tensor): Query--document similarity matrix. 412 query_scoring_mask (torch.Tensor): Which query vectors should be masked out during scoring. 413 doc_scoring_mask (torch.Tensor): Which document vectors should be masked out during scoring. 414 Returns: 415 torch.Tensor: Aggregated similarity scores. 416 """ 417 scores = similarity 418 if scores.shape[-1] > 1: 419 scores = self._aggregate(scores, doc_scoring_mask, self.config.doc_aggregation_function, -1) 420 if scores.shape[-2] > 1: 421 num_docs_t = self._parse_num_docs( 422 query_scoring_mask.shape[0], doc_scoring_mask.shape[0], num_docs, similarity.device 423 ) 424 repeated_query_scoring_mask = query_scoring_mask.repeat_interleave(num_docs_t, dim=0) 425 scores = self._aggregate(scores, repeated_query_scoring_mask, self.config.query_aggregation_function, -2) 426 return scores.view(scores.shape[0])
427
[docs] 428 def score(self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None) -> BiEncoderOutput: 429 """Compute relevance scores between queries and documents. 430 431 Args: 432 output (BiEncoderOutput): Output containing embeddings and scoring mask. 433 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 434 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 435 number of documents, i.e., the sequence contains one value per query specifying the number of documents 436 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 437 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 438 Returns: 439 BiEncoderOutput: Output containing relevance scores. 440 Raises: 441 ValueError: If query or document embeddings are not provided in the output. 442 ValueError: If scoring masks are not provided for the embeddings. 443 """ 444 query_embeddings = output.query_embeddings 445 doc_embeddings = output.doc_embeddings 446 if query_embeddings is None or doc_embeddings is None: 447 raise ValueError("Embeddings expected for scoring multi-vector embeddings") 448 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None: 449 raise ValueError("Scoring masks expected for scoring multi-vector embeddings") 450 451 similarity = self.compute_similarity(query_embeddings, doc_embeddings, num_docs) 452 scores = self.aggregate_similarity( 453 similarity, query_embeddings.scoring_mask, doc_embeddings.scoring_mask, num_docs 454 ) 455 output.scores = scores 456 output.similarity = similarity 457 return output