1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models. Originally proposed in
2`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
3<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
4to work with any transformer backbone model.
5"""
6
7from typing import Literal, Sequence
8
9import torch
10from transformers import BatchEncoding
11
12from ..bi_encoder import BiEncoderEmbedding, BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
13
14
[docs]
15class ColConfig(MultiVectorBiEncoderConfig):
16 """Configuration class for a Col model."""
17
18 model_type = "col"
19 """Model type for a Col model."""
20
[docs]
21 def __init__(
22 self,
23 query_length: int = 32,
24 doc_length: int = 512,
25 similarity_function: Literal["cosine", "dot"] = "dot",
26 normalize: bool = False,
27 add_marker_tokens: bool = False,
28 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
29 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
30 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum",
31 doc_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "max",
32 embedding_dim: int = 128,
33 projection: Literal["linear", "linear_no_bias"] = "linear",
34 query_expansion: bool = False,
35 attend_to_query_expanded_tokens: bool = False,
36 doc_expansion: bool = False,
37 attend_to_doc_expanded_tokens: bool = False,
38 **kwargs,
39 ):
40 """A Col model encodes queries and documents separately and computes a late interaction score between the query
41 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with
42 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a
43 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can
44 be ignored during scoring.
45
46 :param query_length: Maximum query length, defaults to 32
47 :type query_length: int, optional
48 :param doc_length: Maximum document length, defaults to 512
49 :type doc_length: int, optional
50 :param similarity_function: Similarity function to compute scores between query and document embeddings,
51 defaults to "dot"
52 :type similarity_function: Literal['cosine', 'dot'], optional
53 :param normalize: Whether to normalize query and document embeddings, defaults to False
54 :type normalize: bool, optional
55 :param add_marker_tokens: Whether to add extra marker tokens [Q] / [D] to queries / documents, defaults to False
56 :type add_marker_tokens: bool, optional
57 :param query_mask_scoring_tokens: Whether and which query tokens to ignore during scoring, defaults to None
58 :type query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional
59 :param doc_mask_scoring_tokens: Whether and which document tokens to ignore during scoring, defaults to None
60 :type doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional
61 :param query_aggregation_function: How to aggregate similarity scores over query tokens, defaults to "sum"
62 :type query_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional
63 :param doc_aggregation_function: How to aggregate similarity scores over document tokens, defaults to "max"
64 :type doc_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional
65 :param embedding_dim: The output embedding dimension, defaults to 768
66 :type embedding_dim: int, optional
67 :param projection: Whether and how to project the output emeddings, defaults to "linear"
68 :type projection: Literal['linear', 'linear_no_bias', 'mlm'] | None, optional
69 :param query_expansion: Whether to expand queries with mask tokens, defaults to False
70 :type query_expansion: bool, optional
71 :param attend_to_query_expanded_tokens: Whether to allow query tokens to attend to mask tokens,
72 defaults to False
73 :type attend_to_query_expanded_tokens: bool, optional
74 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False
75 :type doc_expansion: bool, optional
76 :param attend_to_doc_expanded_tokens: Whether to allow document tokens to attend to mask tokens,
77 defaults to False
78 :type attend_to_doc_expanded_tokens: bool, optional
79 """
80 super().__init__(
81 query_length=query_length,
82 doc_length=doc_length,
83 similarity_function=similarity_function,
84 add_marker_tokens=add_marker_tokens,
85 query_mask_scoring_tokens=query_mask_scoring_tokens,
86 doc_mask_scoring_tokens=doc_mask_scoring_tokens,
87 query_aggregation_function=query_aggregation_function,
88 doc_aggregation_function=doc_aggregation_function,
89 **kwargs,
90 )
91 self.embedding_dim = embedding_dim
92 self.projection = projection
93 self.query_expansion = query_expansion
94 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
95 self.doc_expansion = doc_expansion
96 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
97 self.normalize = normalize
98 self.add_marker_tokens = add_marker_tokens
99
100
[docs]
101class ColModel(MultiVectorBiEncoderModel):
102 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options."""
103
104 config_class = ColConfig
105 """Configuration class for the Col model."""
106
[docs]
107 def __init__(self, config: ColConfig, *args, **kwargs) -> None:
108 """Initializes a Col model given a :class:`.ColConfig`.
109
110 :param config: Configuration for the Col model
111 :type config: ColConfig
112 """
113 super().__init__(config, *args, **kwargs)
114 if config.embedding_dim is None:
115 raise ValueError("Embedding dimension must be specified in the configuration.")
116 self.projection = torch.nn.Linear(
117 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection
118 )
119
[docs]
120 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
121 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
122 out vectors during scoring.
123
124 :param encoding: Tokenizer encodings for the text sequence
125 :type encoding: BatchEncoding
126 :param input_type: Type of input, either "query" or "doc"
127 :type input_type: Literal["query", "doc"]
128 :return: Scoring mask
129 :rtype: torch.Tensor
130 """
131 input_ids = encoding["input_ids"]
132 attention_mask = encoding["attention_mask"]
133 scoring_mask = attention_mask
134 expansion = getattr(self.config, f"{input_type}_expansion")
135 if expansion or scoring_mask is None:
136 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool)
137 scoring_mask = scoring_mask.bool()
138 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids")
139 if mask_scoring_input_ids is not None:
140 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1)
141 scoring_mask = scoring_mask & ~ignore_mask
142 return scoring_mask
143
[docs]
144 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
145 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
146
147 :param encoding: Tokenizer encodings for the text sequence
148 :type encoding: BatchEncoding
149 :param input_type: Type of input, either "query" or "doc"
150 :type input_type: Literal["query", "doc"]
151 :return: Embeddings and scoring mask
152 :rtype: BiEncoderEmbedding
153 """
154 embeddings = self._backbone_forward(**encoding).last_hidden_state
155 embeddings = self.projection(embeddings)
156 if self.config.normalize:
157 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
158 scoring_mask = self.scoring_mask(encoding, input_type)
159 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
160
161
[docs]
162class ColTokenizer(BiEncoderTokenizer):
163 """:class:`.LightningIRTokenizer` for Col models."""
164
165 config_class = ColConfig
166 """Configuration class for the tokenizer."""
167
[docs]
168 def __init__(
169 self,
170 *args,
171 query_length: int = 32,
172 doc_length: int = 512,
173 add_marker_tokens: bool = False,
174 query_expansion: bool = False,
175 attend_to_query_expanded_tokens: bool = False,
176 doc_expansion: bool = False,
177 attend_to_doc_expanded_tokens: bool = False,
178 **kwargs,
179 ):
180 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens
181 to encoded input sequences and expands queries and documents with mask tokens.
182
183 :param query_length: Maximum query length in number of tokens, defaults to 32
184 :type query_length: int, optional
185 :param doc_length: Maximum document length in number of tokens, defaults to 512
186 :type doc_length: int, optional
187 :param add_marker_tokens: Whether to add marker tokens to the query and document input sequences,
188 defaults to False
189 :type add_marker_tokens: bool, optional
190 :param query_expansion: Whether to expand queries with mask tokens, defaults to False
191 :type query_expansion: bool, optional
192 :param attend_to_query_expanded_tokens: Whether to let non-expanded query tokens be able to attend to mask
193 expanded query tokens, defaults to False
194 :type attend_to_query_expanded_tokens: bool, optional
195 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False
196 :type doc_expansion: bool, optional
197 :param attend_to_doc_expanded_tokens: Whether to let non-expanded document tokens be able to attend to
198 mask expanded document tokens, defaults to False
199 :type attend_to_doc_expanded_tokens: bool, optional
200 :raises ValueError: If add_marker_tokens is True and a non-supported tokenizer is used
201 """
202 super().__init__(
203 *args,
204 query_length=query_length,
205 doc_length=doc_length,
206 query_expansion=query_expansion,
207 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens,
208 doc_expansion=doc_expansion,
209 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens,
210 add_marker_tokens=add_marker_tokens,
211 **kwargs,
212 )
213 self.query_expansion = query_expansion
214 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
215 self.doc_expansion = doc_expansion
216 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
217
218 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
219 """Applies mask expansion to the input encoding."""
220 input_ids = encoding["input_ids"]
221 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
222 encoding["input_ids"] = input_ids
223 if attend_to_expanded_tokens:
224 encoding["attention_mask"].fill_(1)
225 return encoding
226