1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models.
2
3Col models implement a multi-vector late-interaction retrieval approach where queries and documents
4are encoded separately into multiple token-level embeddings. Relevance is computed through element-wise
5similarity matching between query and document token embeddings, aggregated to produce a final relevance score.
6This approach enables fine-grained matching while maintaining computational efficiency compared to dense
7cross-encoders.
8
9Originally proposed in
10`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
11<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
12to work with any transformer backbone model.
13
14Additionally supports XTR (conteXtualized Token Retrieval) an information retrieval model that dramatically speeds up
15multi-vector architectures like ColBERT by rethinking how documents are scored. In traditional multi-vector models, the
16system finds candidate documents using a few matching tokens but then must computationally gather every single word
17vector from those candidates to calculate a final score. XTR simplifies this by training the model to prioritize
18retrieving the most important document tokens right away, allowing it to calculate the final relevance score using only
19those initially retrieved tokens. This entirely eliminates the expensive gathering step, making the search process up
20to a thousand times cheaper and faster while still achieving state-of-the-art accuracy.
21
22Usage with XTR:
23 To enable XTR in-batch token retrieval during training, set the `k_train` parameter in the configuration
24 to the desired number of top-k document tokens to retrieve (e.g., 128):
25
26 >>> from lightning_ir.models.bi_encoders import ColConfig, ColModel
27 >>> config = ColConfig(k_train=128)
28 >>> model = ColModel(config) # doctest: +SKIP
29
30 During training, the model will automatically use XTR scoring when `k_train` is set and the model is in
31 training mode. This enables efficient computation by retrieving only the most relevant document tokens
32 rather than processing all tokens.
33
34Originally proposed in
35`Rethinking the Role of Token Retrieval in Multi-Vector Retrieval \
36<https://arxiv.org/abs/2304.01982>`_.
37"""
38
39from collections.abc import Sequence
40from typing import Literal
41
42import torch
43from transformers import BatchEncoding
44
45from ...bi_encoder import (
46 BiEncoderEmbedding,
47 BiEncoderOutput,
48 BiEncoderTokenizer,
49 MultiVectorBiEncoderConfig,
50 MultiVectorBiEncoderModel,
51)
52
53
[docs]
54class ColConfig(MultiVectorBiEncoderConfig):
55 """Configuration class for a Col model."""
56
57 model_type = "col"
58 """Model type for a Col model."""
59
[docs]
60 def __init__(
61 self,
62 query_length: int | None = 32,
63 doc_length: int | None = 512,
64 similarity_function: Literal["cosine", "dot"] = "dot",
65 normalization_strategy: Literal["l2"] | None = None,
66 add_marker_tokens: bool = False,
67 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
68 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
69 query_aggregation_function: Literal["sum", "mean", "max"] = "sum",
70 doc_aggregation_function: Literal["sum", "mean", "max"] = "max",
71 embedding_dim: int = 128,
72 projection: Literal["linear", "linear_no_bias"] = "linear",
73 query_expansion: bool = False,
74 attend_to_query_expanded_tokens: bool = False,
75 doc_expansion: bool = False,
76 attend_to_doc_expanded_tokens: bool = False,
77 k_train: int | None = None,
78 **kwargs,
79 ):
80 """A Col model encodes queries and documents separately and computes a late interaction score between the query
81 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with
82 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a
83 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can
84 be ignored during scoring.
85
86 Args:
87 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
88 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
89 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
90 document embeddings. Defaults to "dot".
91 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings.
92 Defaults to None.
93 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
94 Defaults to False.
95 query_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which query tokens
96 to ignore during scoring. Defaults to None.
97 doc_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which document tokens
98 to ignore during scoring. Defaults to None.
99 query_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate
100 similarity scores over query tokens. Defaults to "sum".
101 doc_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate
102 similarity scores over document tokens. Defaults to "max".
103 embedding_dim (int): The output embedding dimension. Defaults to 128.
104 projection (Literal["linear", "linear_no_bias"]): Whether and how to project the output embeddings.
105 Defaults to "linear". If set to "linear_no_bias", the projection layer will not have a bias term.
106 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
107 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
108 tokens. Defaults to False.
109 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
110 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
111 tokens. Defaults to False.
112 k_train (int | None): Whether to use XTR_'s in-batch token retrieval during training and how many top-k
113 document tokens to use. Defaults to 128.
114
115 .. _XTR: \
116https://proceedings.neurips.cc/paper_files/paper/2023/file/31d997278ee9069d6721bc194174bb4c-Paper-Conference.pdf
117 """
118 super().__init__(
119 query_length=query_length,
120 doc_length=doc_length,
121 similarity_function=similarity_function,
122 normalization_strategy=normalization_strategy,
123 add_marker_tokens=add_marker_tokens,
124 query_mask_scoring_tokens=query_mask_scoring_tokens,
125 doc_mask_scoring_tokens=doc_mask_scoring_tokens,
126 query_aggregation_function=query_aggregation_function,
127 doc_aggregation_function=doc_aggregation_function,
128 **kwargs,
129 )
130 self.embedding_dim = embedding_dim
131 self.projection = projection
132 self.query_expansion = query_expansion
133 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
134 self.doc_expansion = doc_expansion
135 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
136 self.k_train = k_train
137
138
[docs]
139class ColModel(MultiVectorBiEncoderModel):
140 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options."""
141
142 config_class = ColConfig
143 """Configuration class for the Col model."""
144
[docs]
145 def __init__(self, config: ColConfig, *args, **kwargs) -> None:
146 """Initializes a Col model given a :class:`.ColConfig`.
147
148 Args:
149 config (ColConfig): Configuration for the Col model.
150 Raises:
151 ValueError: If the embedding dimension is not specified in the configuration.
152 """
153 super().__init__(config, *args, **kwargs)
154 if config.embedding_dim is None:
155 raise ValueError("Embedding dimension must be specified in the configuration.")
156 self.projection = torch.nn.Linear(
157 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection
158 )
159
[docs]
160 def score(
161 self,
162 output: BiEncoderOutput,
163 num_docs: Sequence[int] | int | None = None,
164 ) -> BiEncoderOutput:
165 """Compute relevance scores between queries and documents.
166
167 Args:
168 output (BiEncoderOutput): Output containing embeddings and scoring mask.
169 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
170 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
171 number of documents, i.e., the sequence contains one value per query specifying the number of documents
172 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
173 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
174 Returns:
175 BiEncoderOutput: Output containing relevance scores.
176 """
177 if self.training and self.config.k_train is not None:
178 return self._score_xtr_in_batch(output, num_docs)
179
180 return super().score(output, num_docs)
181
[docs]
182 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
183 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
184 out vectors during scoring.
185
186 Args:
187 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
188 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
189 Returns:
190 torch.Tensor: Scoring mask.
191 """
192 input_ids = encoding["input_ids"]
193 attention_mask = encoding["attention_mask"]
194 scoring_mask = attention_mask
195 expansion = getattr(self.config, f"{input_type}_expansion")
196 if expansion or scoring_mask is None:
197 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool)
198 scoring_mask = scoring_mask.bool()
199 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids")
200 if mask_scoring_input_ids is not None:
201 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1)
202 scoring_mask = scoring_mask & ~ignore_mask
203 return scoring_mask
204
[docs]
205 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
206 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
207
208 Args:
209 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
210 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
211 Returns:
212 BiEncoderEmbedding: Embeddings and scoring mask.
213 """
214 embeddings = self._backbone_forward(**encoding).last_hidden_state
215 embeddings = self.projection(embeddings)
216 if self.config.normalization_strategy == "l2":
217 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
218 scoring_mask = self.scoring_mask(encoding, input_type)
219 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
220
221 def _score_xtr_in_batch(
222 self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None
223 ) -> BiEncoderOutput:
224 """XTR in-batch token retrieval scoring.
225
226 Args:
227 output (BiEncoderOutput): Output containing embeddings and scoring mask.
228 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
229 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
230 number of documents, i.e., the sequence contains one value per query specifying the number of documents
231 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
232 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
233 Returns:
234 BiEncoderOutput: Output containing relevance scores.
235 """
236 query_embeddings = output.query_embeddings
237 doc_embeddings = output.doc_embeddings
238 if query_embeddings is None or doc_embeddings is None:
239 raise ValueError("Both query and document embeddings must be provided for scoring.")
240 similarities = self.compute_similarity(query_embeddings, doc_embeddings, num_docs)
241
242 query_mask = query_embeddings.scoring_mask
243 doc_mask = doc_embeddings.scoring_mask
244
245 num_docs_t = self._parse_num_docs(
246 query_embeddings.embeddings.shape[0],
247 doc_embeddings.embeddings.shape[0],
248 num_docs,
249 query_embeddings.device,
250 )
251
252 query_mask_expanded = query_mask.repeat_interleave(num_docs_t, dim=0).unsqueeze(-1)
253 doc_mask_expanded = doc_mask.unsqueeze(1)
254
255 similarities = similarities.masked_fill(~doc_mask_expanded, float("-inf"))
256 similarities = similarities.masked_fill(~query_mask_expanded, float("-inf"))
257
258 batch_size = query_embeddings.embeddings.shape[0]
259 q_len = query_embeddings.embeddings.shape[1]
260 doc_len = doc_embeddings.embeddings.shape[1]
261 max_docs = torch.max(num_docs_t)
262
263 sim_list = similarities.split(num_docs_t.tolist(), dim=0)
264 sim_padded = torch.nn.utils.rnn.pad_sequence(sim_list, batch_first=True, padding_value=float("-inf"))
265
266 valid_mask = torch.arange(max_docs, device=num_docs_t.device).unsqueeze(0) < num_docs_t.unsqueeze(1)
267
268 sim_flat = sim_padded.view(batch_size, -1)
269 k_train = min(self.config.k_train, sim_flat.size(-1))
270 minimum_values = torch.topk(sim_flat, k=k_train, dim=-1).values[:, -1].unsqueeze(-1)
271
272 sim_padded = sim_padded.view(batch_size, -1)
273 sim_padded = sim_padded.masked_fill(sim_padded < minimum_values, 0.0)
274 sim_padded = sim_padded.view(batch_size, max_docs, q_len, doc_len)
275
276 scores = sim_padded.max(dim=-1).values.sum(dim=-1)
277 Z = (sim_padded.max(dim=-1).values > 0).sum(dim=-1).float()
278 Z = Z.clamp(min=1.0)
279 scores = scores / Z
280
281 scores = scores[valid_mask]
282
283 output.scores = scores
284 output.similarity = sim_padded[valid_mask]
285
286 return output
287
288
[docs]
289class ColTokenizer(BiEncoderTokenizer):
290 """:class:`.LightningIRTokenizer` for Col models."""
291
292 config_class = ColConfig
293 """Configuration class for the tokenizer."""
294
[docs]
295 def __init__(
296 self,
297 *args,
298 query_length: int | None = 32,
299 doc_length: int | None = 512,
300 add_marker_tokens: bool = False,
301 query_expansion: bool = False,
302 attend_to_query_expanded_tokens: bool = False,
303 doc_expansion: bool = False,
304 attend_to_doc_expanded_tokens: bool = False,
305 **kwargs,
306 ):
307 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens
308 to encoded input sequences and expands queries and documents with mask tokens.
309
310 Args:
311 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
312 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
313 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
314 Defaults to False.
315 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
316 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
317 tokens. Defaults to False.
318 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
319 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
320 tokens. Defaults to False.
321 Raises:
322 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
323 """
324 super().__init__(
325 *args,
326 query_length=query_length,
327 doc_length=doc_length,
328 query_expansion=query_expansion,
329 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens,
330 doc_expansion=doc_expansion,
331 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens,
332 add_marker_tokens=add_marker_tokens,
333 **kwargs,
334 )
335 self.query_expansion = query_expansion
336 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
337 self.doc_expansion = doc_expansion
338 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
339
340 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
341 """Applies mask expansion to the input encoding."""
342 input_ids = encoding["input_ids"]
343 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
344 encoding["input_ids"] = input_ids
345 if attend_to_expanded_tokens:
346 encoding["attention_mask"].fill_(1)
347 return encoding
348