Source code for lightning_ir.models.bi_encoders.col

  1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models. Originally proposed in
  2`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
  3<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
  4to work with any transformer backbone model.
  5"""
  6
  7from collections.abc import Sequence
  8from typing import Literal
  9
 10import torch
 11from transformers import BatchEncoding
 12
 13from ...bi_encoder import (
 14    BiEncoderEmbedding,
 15    BiEncoderOutput,
 16    BiEncoderTokenizer,
 17    MultiVectorBiEncoderConfig,
 18    MultiVectorBiEncoderModel,
 19)
 20
 21
[docs] 22class ColConfig(MultiVectorBiEncoderConfig): 23 """Configuration class for a Col model.""" 24 25 model_type = "col" 26 """Model type for a Col model.""" 27
[docs] 28 def __init__( 29 self, 30 query_length: int | None = 32, 31 doc_length: int | None = 512, 32 similarity_function: Literal["cosine", "dot"] = "dot", 33 normalization_strategy: Literal["l2"] | None = None, 34 add_marker_tokens: bool = False, 35 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 36 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 37 query_aggregation_function: Literal["sum", "mean", "max"] = "sum", 38 doc_aggregation_function: Literal["sum", "mean", "max"] = "max", 39 embedding_dim: int = 128, 40 projection: Literal["linear", "linear_no_bias"] = "linear", 41 query_expansion: bool = False, 42 attend_to_query_expanded_tokens: bool = False, 43 doc_expansion: bool = False, 44 attend_to_doc_expanded_tokens: bool = False, 45 k_train: int | None = None, 46 **kwargs, 47 ): 48 """A Col model encodes queries and documents separately and computes a late interaction score between the query 49 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with 50 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a 51 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can 52 be ignored during scoring. 53 54 Args: 55 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 56 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 57 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 58 document embeddings. Defaults to "dot". 59 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings. 60 Defaults to None. 61 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 62 Defaults to False. 63 query_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which query tokens 64 to ignore during scoring. Defaults to None. 65 doc_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which document tokens 66 to ignore during scoring. Defaults to None. 67 query_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate 68 similarity scores over query tokens. Defaults to "sum". 69 doc_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate 70 similarity scores over document tokens. Defaults to "max". 71 embedding_dim (int): The output embedding dimension. Defaults to 128. 72 projection (Literal["linear", "linear_no_bias"]): Whether and how to project the output embeddings. 73 Defaults to "linear". If set to "linear_no_bias", the projection layer will not have a bias term. 74 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 75 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 76 tokens. Defaults to False. 77 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 78 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 79 tokens. Defaults to False. 80 k_train (int | None): Whether to use XTR_'s in-batch token retrieval during training and how many top-k 81 document tokens to use. Defaults to 128. 82 83 .. _XTR: \ 84https://proceedings.neurips.cc/paper_files/paper/2023/file/31d997278ee9069d6721bc194174bb4c-Paper-Conference.pdf 85 """ 86 super().__init__( 87 query_length=query_length, 88 doc_length=doc_length, 89 similarity_function=similarity_function, 90 normalization_strategy=normalization_strategy, 91 add_marker_tokens=add_marker_tokens, 92 query_mask_scoring_tokens=query_mask_scoring_tokens, 93 doc_mask_scoring_tokens=doc_mask_scoring_tokens, 94 query_aggregation_function=query_aggregation_function, 95 doc_aggregation_function=doc_aggregation_function, 96 **kwargs, 97 ) 98 self.embedding_dim = embedding_dim 99 self.projection = projection 100 self.query_expansion = query_expansion 101 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 102 self.doc_expansion = doc_expansion 103 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens 104 self.k_train = k_train
105 106
[docs] 107class ColModel(MultiVectorBiEncoderModel): 108 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options.""" 109 110 config_class = ColConfig 111 """Configuration class for the Col model.""" 112
[docs] 113 def __init__(self, config: ColConfig, *args, **kwargs) -> None: 114 """Initializes a Col model given a :class:`.ColConfig`. 115 116 Args: 117 config (ColConfig): Configuration for the Col model. 118 Raises: 119 ValueError: If the embedding dimension is not specified in the configuration. 120 """ 121 super().__init__(config, *args, **kwargs) 122 if config.embedding_dim is None: 123 raise ValueError("Embedding dimension must be specified in the configuration.") 124 self.projection = torch.nn.Linear( 125 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection 126 )
127
[docs] 128 def score( 129 self, 130 output: BiEncoderOutput, 131 num_docs: Sequence[int] | int | None = None, 132 ) -> BiEncoderOutput: 133 """Compute relevance scores between queries and documents. 134 135 Args: 136 output (BiEncoderOutput): Output containing embeddings and scoring mask. 137 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 138 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 139 number of documents, i.e., the sequence contains one value per query specifying the number of documents 140 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 141 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 142 Returns: 143 BiEncoderOutput: Output containing relevance scores. 144 """ 145 if self.training and self.config.k_train is not None: 146 return self._score_xtr_in_batch(output, num_docs) 147 148 return super().score(output, num_docs)
149
[docs] 150 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 151 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 152 out vectors during scoring. 153 154 Args: 155 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 156 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 157 Returns: 158 torch.Tensor: Scoring mask. 159 """ 160 input_ids = encoding["input_ids"] 161 attention_mask = encoding["attention_mask"] 162 scoring_mask = attention_mask 163 expansion = getattr(self.config, f"{input_type}_expansion") 164 if expansion or scoring_mask is None: 165 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool) 166 scoring_mask = scoring_mask.bool() 167 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids") 168 if mask_scoring_input_ids is not None: 169 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1) 170 scoring_mask = scoring_mask & ~ignore_mask 171 return scoring_mask
172
[docs] 173 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 174 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 175 176 Args: 177 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 178 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 179 Returns: 180 BiEncoderEmbedding: Embeddings and scoring mask. 181 """ 182 embeddings = self._backbone_forward(**encoding).last_hidden_state 183 embeddings = self.projection(embeddings) 184 if self.config.normalization_strategy == "l2": 185 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 186 scoring_mask = self.scoring_mask(encoding, input_type) 187 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
188 189 def _score_xtr_in_batch( 190 self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None 191 ) -> BiEncoderOutput: 192 """XTR in-batch token retrieval scoring. 193 194 Args: 195 output (BiEncoderOutput): Output containing embeddings and scoring mask. 196 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 197 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 198 number of documents, i.e., the sequence contains one value per query specifying the number of documents 199 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 200 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 201 Returns: 202 BiEncoderOutput: Output containing relevance scores. 203 """ 204 query_embeddings = output.query_embeddings 205 doc_embeddings = output.doc_embeddings 206 if query_embeddings is None or doc_embeddings is None: 207 raise ValueError("Both query and document embeddings must be provided for scoring.") 208 similarities = self.compute_similarity(query_embeddings, doc_embeddings, num_docs) 209 210 query_mask = query_embeddings.scoring_mask 211 doc_mask = doc_embeddings.scoring_mask 212 213 num_docs_t = self._parse_num_docs( 214 query_embeddings.embeddings.shape[0], 215 doc_embeddings.embeddings.shape[0], 216 num_docs, 217 query_embeddings.device, 218 ) 219 220 query_mask_expanded = query_mask.repeat_interleave(num_docs_t, dim=0).unsqueeze(-1) 221 doc_mask_expanded = doc_mask.unsqueeze(1) 222 223 similarities = similarities.masked_fill(~doc_mask_expanded, float("-inf")) 224 similarities = similarities.masked_fill(~query_mask_expanded, float("-inf")) 225 226 batch_size = query_embeddings.embeddings.shape[0] 227 q_len = query_embeddings.embeddings.shape[1] 228 doc_len = doc_embeddings.embeddings.shape[1] 229 max_docs = torch.max(num_docs_t) 230 231 sim_list = similarities.split(num_docs_t.tolist(), dim=0) 232 sim_padded = torch.nn.utils.rnn.pad_sequence(sim_list, batch_first=True, padding_value=float("-inf")) 233 234 valid_mask = torch.arange(max_docs, device=num_docs_t.device).unsqueeze(0) < num_docs_t.unsqueeze(1) 235 236 sim_flat = sim_padded.view(batch_size, -1) 237 k_train = min(self.config.k_train, sim_flat.size(-1)) 238 minimum_values = torch.topk(sim_flat, k=k_train, dim=-1).values[:, -1].unsqueeze(-1) 239 240 sim_padded = sim_padded.view(batch_size, -1) 241 sim_padded = sim_padded.masked_fill(sim_padded < minimum_values, 0.0) 242 sim_padded = sim_padded.view(batch_size, max_docs, q_len, doc_len) 243 244 scores = sim_padded.max(dim=-1).values.sum(dim=-1) 245 Z = (sim_padded.max(dim=-1).values > 0).sum(dim=-1).float() 246 Z = Z.clamp(min=1.0) 247 scores = scores / Z 248 249 scores = scores[valid_mask] 250 251 output.scores = scores 252 output.similarity = sim_padded[valid_mask] 253 254 return output
255 256
[docs] 257class ColTokenizer(BiEncoderTokenizer): 258 """:class:`.LightningIRTokenizer` for Col models.""" 259 260 config_class = ColConfig 261 """Configuration class for the tokenizer.""" 262
[docs] 263 def __init__( 264 self, 265 *args, 266 query_length: int | None = 32, 267 doc_length: int | None = 512, 268 add_marker_tokens: bool = False, 269 query_expansion: bool = False, 270 attend_to_query_expanded_tokens: bool = False, 271 doc_expansion: bool = False, 272 attend_to_doc_expanded_tokens: bool = False, 273 **kwargs, 274 ): 275 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens 276 to encoded input sequences and expands queries and documents with mask tokens. 277 278 Args: 279 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 280 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 281 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 282 Defaults to False. 283 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 284 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 285 tokens. Defaults to False. 286 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 287 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 288 tokens. Defaults to False. 289 Raises: 290 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 291 """ 292 super().__init__( 293 *args, 294 query_length=query_length, 295 doc_length=doc_length, 296 query_expansion=query_expansion, 297 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens, 298 doc_expansion=doc_expansion, 299 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens, 300 add_marker_tokens=add_marker_tokens, 301 **kwargs, 302 ) 303 self.query_expansion = query_expansion 304 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 305 self.doc_expansion = doc_expansion 306 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
307 308 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 309 """Applies mask expansion to the input encoding.""" 310 input_ids = encoding["input_ids"] 311 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 312 encoding["input_ids"] = input_ids 313 if attend_to_expanded_tokens: 314 encoding["attention_mask"].fill_(1) 315 return encoding 316
[docs] 317 def tokenize_input_sequence( 318 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 319 ) -> BatchEncoding: 320 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 321 322 Args: 323 text (Sequence[str] | str): Input text to tokenize. 324 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 325 Returns: 326 BatchEncoding: Tokenized input sequences. 327 """ 328 expansion = getattr(self, f"{input_type}_expansion") 329 if expansion: 330 kwargs["padding"] = "max_length" 331 encoding = super().tokenize_input_sequence(text, input_type, *args, **kwargs) 332 if expansion: 333 encoding = self._expand(encoding, getattr(self, f"attend_to_{input_type}_expanded_tokens")) 334 return encoding