Source code for lightning_ir.models.bi_encoders.col

  1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models.
  2
  3Col models implement a multi-vector late-interaction retrieval approach where queries and documents
  4are encoded separately into multiple token-level embeddings. Relevance is computed through element-wise
  5similarity matching between query and document token embeddings, aggregated to produce a final relevance score.
  6This approach enables fine-grained matching while maintaining computational efficiency compared to dense
  7cross-encoders.
  8
  9Originally proposed in
 10`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
 11<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
 12to work with any transformer backbone model.
 13
 14Additionally supports XTR (conteXtualized Token Retrieval) an information retrieval model that dramatically speeds up
 15multi-vector architectures like ColBERT by rethinking how documents are scored. In traditional multi-vector models, the
 16system finds candidate documents using a few matching tokens but then must computationally gather every single word
 17vector from those candidates to calculate a final score. XTR simplifies this by training the model to prioritize
 18retrieving the most important document tokens right away, allowing it to calculate the final relevance score using only
 19those initially retrieved tokens. This entirely eliminates the expensive gathering step, making the search process up
 20to a thousand times cheaper and faster while still achieving state-of-the-art accuracy.
 21
 22Usage with XTR:
 23    To enable XTR in-batch token retrieval during training, set the `k_train` parameter in the configuration
 24    to the desired number of top-k document tokens to retrieve (e.g., 128):
 25
 26    >>> from lightning_ir.models.bi_encoders import ColConfig, ColModel
 27    >>> config = ColConfig(k_train=128)
 28    >>> model = ColModel(config) # doctest: +SKIP
 29
 30    During training, the model will automatically use XTR scoring when `k_train` is set and the model is in
 31    training mode. This enables efficient computation by retrieving only the most relevant document tokens
 32    rather than processing all tokens.
 33
 34Originally proposed in
 35`Rethinking the Role of Token Retrieval in Multi-Vector Retrieval \
 36<https://arxiv.org/abs/2304.01982>`_.
 37"""
 38
 39from collections.abc import Sequence
 40from typing import Literal
 41
 42import torch
 43from transformers import BatchEncoding
 44
 45from ...bi_encoder import (
 46    BiEncoderEmbedding,
 47    BiEncoderOutput,
 48    BiEncoderTokenizer,
 49    MultiVectorBiEncoderConfig,
 50    MultiVectorBiEncoderModel,
 51)
 52
 53
[docs] 54class ColConfig(MultiVectorBiEncoderConfig): 55 """Configuration class for a Col model.""" 56 57 model_type = "col" 58 """Model type for a Col model.""" 59
[docs] 60 def __init__( 61 self, 62 query_length: int | None = 32, 63 doc_length: int | None = 512, 64 similarity_function: Literal["cosine", "dot"] = "dot", 65 normalization_strategy: Literal["l2"] | None = None, 66 add_marker_tokens: bool = False, 67 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 68 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 69 query_aggregation_function: Literal["sum", "mean", "max"] = "sum", 70 doc_aggregation_function: Literal["sum", "mean", "max"] = "max", 71 embedding_dim: int = 128, 72 projection: Literal["linear", "linear_no_bias"] = "linear", 73 query_expansion: bool = False, 74 attend_to_query_expanded_tokens: bool = False, 75 doc_expansion: bool = False, 76 attend_to_doc_expanded_tokens: bool = False, 77 k_train: int | None = None, 78 **kwargs, 79 ): 80 """A Col model encodes queries and documents separately and computes a late interaction score between the query 81 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with 82 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a 83 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can 84 be ignored during scoring. 85 86 Args: 87 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 88 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 89 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 90 document embeddings. Defaults to "dot". 91 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings. 92 Defaults to None. 93 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 94 Defaults to False. 95 query_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which query tokens 96 to ignore during scoring. Defaults to None. 97 doc_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which document tokens 98 to ignore during scoring. Defaults to None. 99 query_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate 100 similarity scores over query tokens. Defaults to "sum". 101 doc_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate 102 similarity scores over document tokens. Defaults to "max". 103 embedding_dim (int): The output embedding dimension. Defaults to 128. 104 projection (Literal["linear", "linear_no_bias"]): Whether and how to project the output embeddings. 105 Defaults to "linear". If set to "linear_no_bias", the projection layer will not have a bias term. 106 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 107 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 108 tokens. Defaults to False. 109 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 110 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 111 tokens. Defaults to False. 112 k_train (int | None): Whether to use XTR_'s in-batch token retrieval during training and how many top-k 113 document tokens to use. Defaults to 128. 114 115 .. _XTR: \ 116https://proceedings.neurips.cc/paper_files/paper/2023/file/31d997278ee9069d6721bc194174bb4c-Paper-Conference.pdf 117 """ 118 super().__init__( 119 query_length=query_length, 120 doc_length=doc_length, 121 similarity_function=similarity_function, 122 normalization_strategy=normalization_strategy, 123 add_marker_tokens=add_marker_tokens, 124 query_mask_scoring_tokens=query_mask_scoring_tokens, 125 doc_mask_scoring_tokens=doc_mask_scoring_tokens, 126 query_aggregation_function=query_aggregation_function, 127 doc_aggregation_function=doc_aggregation_function, 128 **kwargs, 129 ) 130 self.embedding_dim = embedding_dim 131 self.projection = projection 132 self.query_expansion = query_expansion 133 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 134 self.doc_expansion = doc_expansion 135 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens 136 self.k_train = k_train
137 138
[docs] 139class ColModel(MultiVectorBiEncoderModel): 140 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options.""" 141 142 config_class = ColConfig 143 """Configuration class for the Col model.""" 144
[docs] 145 def __init__(self, config: ColConfig, *args, **kwargs) -> None: 146 """Initializes a Col model given a :class:`.ColConfig`. 147 148 Args: 149 config (ColConfig): Configuration for the Col model. 150 Raises: 151 ValueError: If the embedding dimension is not specified in the configuration. 152 """ 153 super().__init__(config, *args, **kwargs) 154 if config.embedding_dim is None: 155 raise ValueError("Embedding dimension must be specified in the configuration.") 156 self.projection = torch.nn.Linear( 157 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection 158 )
159
[docs] 160 def score( 161 self, 162 output: BiEncoderOutput, 163 num_docs: Sequence[int] | int | None = None, 164 ) -> BiEncoderOutput: 165 """Compute relevance scores between queries and documents. 166 167 Args: 168 output (BiEncoderOutput): Output containing embeddings and scoring mask. 169 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 170 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 171 number of documents, i.e., the sequence contains one value per query specifying the number of documents 172 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 173 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 174 Returns: 175 BiEncoderOutput: Output containing relevance scores. 176 """ 177 if self.training and self.config.k_train is not None: 178 return self._score_xtr_in_batch(output, num_docs) 179 180 return super().score(output, num_docs)
181
[docs] 182 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 183 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 184 out vectors during scoring. 185 186 Args: 187 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 188 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 189 Returns: 190 torch.Tensor: Scoring mask. 191 """ 192 input_ids = encoding["input_ids"] 193 attention_mask = encoding["attention_mask"] 194 scoring_mask = attention_mask 195 expansion = getattr(self.config, f"{input_type}_expansion") 196 if expansion or scoring_mask is None: 197 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool) 198 scoring_mask = scoring_mask.bool() 199 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids") 200 if mask_scoring_input_ids is not None: 201 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1) 202 scoring_mask = scoring_mask & ~ignore_mask 203 return scoring_mask
204
[docs] 205 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 206 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 207 208 Args: 209 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 210 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 211 Returns: 212 BiEncoderEmbedding: Embeddings and scoring mask. 213 """ 214 embeddings = self._backbone_forward(**encoding).last_hidden_state 215 embeddings = self.projection(embeddings) 216 if self.config.normalization_strategy == "l2": 217 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 218 scoring_mask = self.scoring_mask(encoding, input_type) 219 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
220 221 def _score_xtr_in_batch( 222 self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None 223 ) -> BiEncoderOutput: 224 """XTR in-batch token retrieval scoring. 225 226 Args: 227 output (BiEncoderOutput): Output containing embeddings and scoring mask. 228 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 229 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 230 number of documents, i.e., the sequence contains one value per query specifying the number of documents 231 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 232 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 233 Returns: 234 BiEncoderOutput: Output containing relevance scores. 235 """ 236 query_embeddings = output.query_embeddings 237 doc_embeddings = output.doc_embeddings 238 if query_embeddings is None or doc_embeddings is None: 239 raise ValueError("Both query and document embeddings must be provided for scoring.") 240 similarities = self.compute_similarity(query_embeddings, doc_embeddings, num_docs) 241 242 query_mask = query_embeddings.scoring_mask 243 doc_mask = doc_embeddings.scoring_mask 244 245 num_docs_t = self._parse_num_docs( 246 query_embeddings.embeddings.shape[0], 247 doc_embeddings.embeddings.shape[0], 248 num_docs, 249 query_embeddings.device, 250 ) 251 252 query_mask_expanded = query_mask.repeat_interleave(num_docs_t, dim=0).unsqueeze(-1) 253 doc_mask_expanded = doc_mask.unsqueeze(1) 254 255 similarities = similarities.masked_fill(~doc_mask_expanded, float("-inf")) 256 similarities = similarities.masked_fill(~query_mask_expanded, float("-inf")) 257 258 batch_size = query_embeddings.embeddings.shape[0] 259 q_len = query_embeddings.embeddings.shape[1] 260 doc_len = doc_embeddings.embeddings.shape[1] 261 max_docs = torch.max(num_docs_t) 262 263 sim_list = similarities.split(num_docs_t.tolist(), dim=0) 264 sim_padded = torch.nn.utils.rnn.pad_sequence(sim_list, batch_first=True, padding_value=float("-inf")) 265 266 valid_mask = torch.arange(max_docs, device=num_docs_t.device).unsqueeze(0) < num_docs_t.unsqueeze(1) 267 268 sim_flat = sim_padded.view(batch_size, -1) 269 k_train = min(self.config.k_train, sim_flat.size(-1)) 270 minimum_values = torch.topk(sim_flat, k=k_train, dim=-1).values[:, -1].unsqueeze(-1) 271 272 sim_padded = sim_padded.view(batch_size, -1) 273 sim_padded = sim_padded.masked_fill(sim_padded < minimum_values, 0.0) 274 sim_padded = sim_padded.view(batch_size, max_docs, q_len, doc_len) 275 276 scores = sim_padded.max(dim=-1).values.sum(dim=-1) 277 Z = (sim_padded.max(dim=-1).values > 0).sum(dim=-1).float() 278 Z = Z.clamp(min=1.0) 279 scores = scores / Z 280 281 scores = scores[valid_mask] 282 283 output.scores = scores 284 output.similarity = sim_padded[valid_mask] 285 286 return output
287 288
[docs] 289class ColTokenizer(BiEncoderTokenizer): 290 """:class:`.LightningIRTokenizer` for Col models.""" 291 292 config_class = ColConfig 293 """Configuration class for the tokenizer.""" 294
[docs] 295 def __init__( 296 self, 297 *args, 298 query_length: int | None = 32, 299 doc_length: int | None = 512, 300 add_marker_tokens: bool = False, 301 query_expansion: bool = False, 302 attend_to_query_expanded_tokens: bool = False, 303 doc_expansion: bool = False, 304 attend_to_doc_expanded_tokens: bool = False, 305 **kwargs, 306 ): 307 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens 308 to encoded input sequences and expands queries and documents with mask tokens. 309 310 Args: 311 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 312 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 313 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 314 Defaults to False. 315 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 316 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 317 tokens. Defaults to False. 318 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 319 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 320 tokens. Defaults to False. 321 Raises: 322 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 323 """ 324 super().__init__( 325 *args, 326 query_length=query_length, 327 doc_length=doc_length, 328 query_expansion=query_expansion, 329 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens, 330 doc_expansion=doc_expansion, 331 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens, 332 add_marker_tokens=add_marker_tokens, 333 **kwargs, 334 ) 335 self.query_expansion = query_expansion 336 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 337 self.doc_expansion = doc_expansion 338 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
339 340 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 341 """Applies mask expansion to the input encoding.""" 342 input_ids = encoding["input_ids"] 343 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 344 encoding["input_ids"] = input_ids 345 if attend_to_expanded_tokens: 346 encoding["attention_mask"].fill_(1) 347 return encoding 348
[docs] 349 def tokenize_input_sequence( 350 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 351 ) -> BatchEncoding: 352 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 353 354 Args: 355 text (Sequence[str] | str): Input text to tokenize. 356 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 357 Returns: 358 BatchEncoding: Tokenized input sequences. 359 """ 360 expansion = getattr(self, f"{input_type}_expansion") 361 if expansion: 362 kwargs["padding"] = "max_length" 363 encoding = super().tokenize_input_sequence(text, input_type, *args, **kwargs) 364 if expansion: 365 encoding = self._expand(encoding, getattr(self, f"attend_to_{input_type}_expanded_tokens")) 366 return encoding