1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models. Originally proposed in
2`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
3<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
4to work with any transformer backbone model.
5"""
6
7from collections.abc import Sequence
8from typing import Literal
9
10import torch
11from transformers import BatchEncoding
12
13from ...bi_encoder import (
14 BiEncoderEmbedding,
15 BiEncoderOutput,
16 BiEncoderTokenizer,
17 MultiVectorBiEncoderConfig,
18 MultiVectorBiEncoderModel,
19)
20
21
[docs]
22class ColConfig(MultiVectorBiEncoderConfig):
23 """Configuration class for a Col model."""
24
25 model_type = "col"
26 """Model type for a Col model."""
27
[docs]
28 def __init__(
29 self,
30 query_length: int | None = 32,
31 doc_length: int | None = 512,
32 similarity_function: Literal["cosine", "dot"] = "dot",
33 normalization_strategy: Literal["l2"] | None = None,
34 add_marker_tokens: bool = False,
35 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
36 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
37 query_aggregation_function: Literal["sum", "mean", "max"] = "sum",
38 doc_aggregation_function: Literal["sum", "mean", "max"] = "max",
39 embedding_dim: int = 128,
40 projection: Literal["linear", "linear_no_bias"] = "linear",
41 query_expansion: bool = False,
42 attend_to_query_expanded_tokens: bool = False,
43 doc_expansion: bool = False,
44 attend_to_doc_expanded_tokens: bool = False,
45 k_train: int | None = None,
46 **kwargs,
47 ):
48 """A Col model encodes queries and documents separately and computes a late interaction score between the query
49 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with
50 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a
51 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can
52 be ignored during scoring.
53
54 Args:
55 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
56 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
57 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
58 document embeddings. Defaults to "dot".
59 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings.
60 Defaults to None.
61 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
62 Defaults to False.
63 query_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which query tokens
64 to ignore during scoring. Defaults to None.
65 doc_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which document tokens
66 to ignore during scoring. Defaults to None.
67 query_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate
68 similarity scores over query tokens. Defaults to "sum".
69 doc_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate
70 similarity scores over document tokens. Defaults to "max".
71 embedding_dim (int): The output embedding dimension. Defaults to 128.
72 projection (Literal["linear", "linear_no_bias"]): Whether and how to project the output embeddings.
73 Defaults to "linear". If set to "linear_no_bias", the projection layer will not have a bias term.
74 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
75 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
76 tokens. Defaults to False.
77 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
78 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
79 tokens. Defaults to False.
80 k_train (int | None): Whether to use XTR_'s in-batch token retrieval during training and how many top-k
81 document tokens to use. Defaults to 128.
82
83 .. _XTR: \
84https://proceedings.neurips.cc/paper_files/paper/2023/file/31d997278ee9069d6721bc194174bb4c-Paper-Conference.pdf
85 """
86 super().__init__(
87 query_length=query_length,
88 doc_length=doc_length,
89 similarity_function=similarity_function,
90 normalization_strategy=normalization_strategy,
91 add_marker_tokens=add_marker_tokens,
92 query_mask_scoring_tokens=query_mask_scoring_tokens,
93 doc_mask_scoring_tokens=doc_mask_scoring_tokens,
94 query_aggregation_function=query_aggregation_function,
95 doc_aggregation_function=doc_aggregation_function,
96 **kwargs,
97 )
98 self.embedding_dim = embedding_dim
99 self.projection = projection
100 self.query_expansion = query_expansion
101 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
102 self.doc_expansion = doc_expansion
103 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
104 self.k_train = k_train
105
106
[docs]
107class ColModel(MultiVectorBiEncoderModel):
108 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options."""
109
110 config_class = ColConfig
111 """Configuration class for the Col model."""
112
[docs]
113 def __init__(self, config: ColConfig, *args, **kwargs) -> None:
114 """Initializes a Col model given a :class:`.ColConfig`.
115
116 Args:
117 config (ColConfig): Configuration for the Col model.
118 Raises:
119 ValueError: If the embedding dimension is not specified in the configuration.
120 """
121 super().__init__(config, *args, **kwargs)
122 if config.embedding_dim is None:
123 raise ValueError("Embedding dimension must be specified in the configuration.")
124 self.projection = torch.nn.Linear(
125 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection
126 )
127
[docs]
128 def score(
129 self,
130 output: BiEncoderOutput,
131 num_docs: Sequence[int] | int | None = None,
132 ) -> BiEncoderOutput:
133 """Compute relevance scores between queries and documents.
134
135 Args:
136 output (BiEncoderOutput): Output containing embeddings and scoring mask.
137 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
138 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
139 number of documents, i.e., the sequence contains one value per query specifying the number of documents
140 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
141 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
142 Returns:
143 BiEncoderOutput: Output containing relevance scores.
144 """
145 if self.training and self.config.k_train is not None:
146 return self._score_xtr_in_batch(output, num_docs)
147
148 return super().score(output, num_docs)
149
[docs]
150 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
151 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
152 out vectors during scoring.
153
154 Args:
155 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
156 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
157 Returns:
158 torch.Tensor: Scoring mask.
159 """
160 input_ids = encoding["input_ids"]
161 attention_mask = encoding["attention_mask"]
162 scoring_mask = attention_mask
163 expansion = getattr(self.config, f"{input_type}_expansion")
164 if expansion or scoring_mask is None:
165 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool)
166 scoring_mask = scoring_mask.bool()
167 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids")
168 if mask_scoring_input_ids is not None:
169 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1)
170 scoring_mask = scoring_mask & ~ignore_mask
171 return scoring_mask
172
[docs]
173 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
174 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
175
176 Args:
177 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
178 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
179 Returns:
180 BiEncoderEmbedding: Embeddings and scoring mask.
181 """
182 embeddings = self._backbone_forward(**encoding).last_hidden_state
183 embeddings = self.projection(embeddings)
184 if self.config.normalization_strategy == "l2":
185 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
186 scoring_mask = self.scoring_mask(encoding, input_type)
187 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
188
189 def _score_xtr_in_batch(
190 self, output: BiEncoderOutput, num_docs: Sequence[int] | int | None = None
191 ) -> BiEncoderOutput:
192 """XTR in-batch token retrieval scoring.
193
194 Args:
195 output (BiEncoderOutput): Output containing embeddings and scoring mask.
196 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
197 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
198 number of documents, i.e., the sequence contains one value per query specifying the number of documents
199 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
200 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
201 Returns:
202 BiEncoderOutput: Output containing relevance scores.
203 """
204 query_embeddings = output.query_embeddings
205 doc_embeddings = output.doc_embeddings
206 if query_embeddings is None or doc_embeddings is None:
207 raise ValueError("Both query and document embeddings must be provided for scoring.")
208 similarities = self.compute_similarity(query_embeddings, doc_embeddings, num_docs)
209
210 query_mask = query_embeddings.scoring_mask
211 doc_mask = doc_embeddings.scoring_mask
212
213 num_docs_t = self._parse_num_docs(
214 query_embeddings.embeddings.shape[0],
215 doc_embeddings.embeddings.shape[0],
216 num_docs,
217 query_embeddings.device,
218 )
219
220 query_mask_expanded = query_mask.repeat_interleave(num_docs_t, dim=0).unsqueeze(-1)
221 doc_mask_expanded = doc_mask.unsqueeze(1)
222
223 similarities = similarities.masked_fill(~doc_mask_expanded, float("-inf"))
224 similarities = similarities.masked_fill(~query_mask_expanded, float("-inf"))
225
226 batch_size = query_embeddings.embeddings.shape[0]
227 q_len = query_embeddings.embeddings.shape[1]
228 doc_len = doc_embeddings.embeddings.shape[1]
229 max_docs = torch.max(num_docs_t)
230
231 sim_list = similarities.split(num_docs_t.tolist(), dim=0)
232 sim_padded = torch.nn.utils.rnn.pad_sequence(sim_list, batch_first=True, padding_value=float("-inf"))
233
234 valid_mask = torch.arange(max_docs, device=num_docs_t.device).unsqueeze(0) < num_docs_t.unsqueeze(1)
235
236 sim_flat = sim_padded.view(batch_size, -1)
237 k_train = min(self.config.k_train, sim_flat.size(-1))
238 minimum_values = torch.topk(sim_flat, k=k_train, dim=-1).values[:, -1].unsqueeze(-1)
239
240 sim_padded = sim_padded.view(batch_size, -1)
241 sim_padded = sim_padded.masked_fill(sim_padded < minimum_values, 0.0)
242 sim_padded = sim_padded.view(batch_size, max_docs, q_len, doc_len)
243
244 scores = sim_padded.max(dim=-1).values.sum(dim=-1)
245 Z = (sim_padded.max(dim=-1).values > 0).sum(dim=-1).float()
246 Z = Z.clamp(min=1.0)
247 scores = scores / Z
248
249 scores = scores[valid_mask]
250
251 output.scores = scores
252 output.similarity = sim_padded[valid_mask]
253
254 return output
255
256
[docs]
257class ColTokenizer(BiEncoderTokenizer):
258 """:class:`.LightningIRTokenizer` for Col models."""
259
260 config_class = ColConfig
261 """Configuration class for the tokenizer."""
262
[docs]
263 def __init__(
264 self,
265 *args,
266 query_length: int | None = 32,
267 doc_length: int | None = 512,
268 add_marker_tokens: bool = False,
269 query_expansion: bool = False,
270 attend_to_query_expanded_tokens: bool = False,
271 doc_expansion: bool = False,
272 attend_to_doc_expanded_tokens: bool = False,
273 **kwargs,
274 ):
275 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens
276 to encoded input sequences and expands queries and documents with mask tokens.
277
278 Args:
279 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
280 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
281 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
282 Defaults to False.
283 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
284 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
285 tokens. Defaults to False.
286 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
287 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
288 tokens. Defaults to False.
289 Raises:
290 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
291 """
292 super().__init__(
293 *args,
294 query_length=query_length,
295 doc_length=doc_length,
296 query_expansion=query_expansion,
297 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens,
298 doc_expansion=doc_expansion,
299 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens,
300 add_marker_tokens=add_marker_tokens,
301 **kwargs,
302 )
303 self.query_expansion = query_expansion
304 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
305 self.doc_expansion = doc_expansion
306 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
307
308 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
309 """Applies mask expansion to the input encoding."""
310 input_ids = encoding["input_ids"]
311 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
312 encoding["input_ids"] = input_ids
313 if attend_to_expanded_tokens:
314 encoding["attention_mask"].fill_(1)
315 return encoding
316