1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models. Originally proposed in
2`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
3<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
4to work with any transformer backbone model.
5"""
6
7from typing import Literal, Sequence
8
9import torch
10from transformers import BatchEncoding
11
12from ..bi_encoder import BiEncoderEmbedding, BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
13
14
[docs]
15class ColConfig(MultiVectorBiEncoderConfig):
16 """Configuration class for a Col model."""
17
18 model_type = "col"
19 """Model type for a Col model."""
20
[docs]
21 def __init__(
22 self,
23 query_length: int = 32,
24 doc_length: int = 512,
25 similarity_function: Literal["cosine", "dot"] = "dot",
26 normalize: bool = False,
27 add_marker_tokens: bool = False,
28 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
29 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
30 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum",
31 doc_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "max",
32 embedding_dim: int = 128,
33 projection: Literal["linear", "linear_no_bias"] = "linear",
34 query_expansion: bool = False,
35 attend_to_query_expanded_tokens: bool = False,
36 doc_expansion: bool = False,
37 attend_to_doc_expanded_tokens: bool = False,
38 **kwargs,
39 ):
40 """A Col model encodes queries and documents separately and computes a late interaction score between the query
41 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with
42 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a
43 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can
44 be ignored during scoring.
45
46 Args:
47 query_length (int): Maximum query length in number of tokens. Defaults to 32.
48 doc_length (int): Maximum document length in number of tokens. Defaults to 512.
49 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
50 document embeddings. Defaults to "dot".
51 normalize (bool): Whether to normalize query and document embeddings. Defaults to False.
52 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
53 Defaults to False.
54 query_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which query tokens
55 to ignore during scoring. Defaults to None.
56 doc_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which document tokens
57 to ignore during scoring. Defaults to None.
58 query_aggregation_function (Literal["sum", "mean", "max", "harmonic_mean"]): How to aggregate
59 similarity scores over query tokens. Defaults to "sum".
60 doc_aggregation_function (Literal["sum", "mean", "max", "harmonic_mean"]): How to aggregate
61 similarity scores over document tokens. Defaults to "max".
62 embedding_dim (int): The output embedding dimension. Defaults to 128.
63 projection (Literal["linear", "linear_no_bias"]): Whether and how to project the output embeddings.
64 Defaults to "linear". If set to "linear_no_bias", the projection layer will not have a bias term.
65 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
66 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
67 tokens. Defaults to False.
68 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
69 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
70 tokens. Defaults to False.
71 """
72 super().__init__(
73 query_length=query_length,
74 doc_length=doc_length,
75 similarity_function=similarity_function,
76 add_marker_tokens=add_marker_tokens,
77 query_mask_scoring_tokens=query_mask_scoring_tokens,
78 doc_mask_scoring_tokens=doc_mask_scoring_tokens,
79 query_aggregation_function=query_aggregation_function,
80 doc_aggregation_function=doc_aggregation_function,
81 **kwargs,
82 )
83 self.embedding_dim = embedding_dim
84 self.projection = projection
85 self.query_expansion = query_expansion
86 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
87 self.doc_expansion = doc_expansion
88 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
89 self.normalize = normalize
90 self.add_marker_tokens = add_marker_tokens
91
92
[docs]
93class ColModel(MultiVectorBiEncoderModel):
94 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options."""
95
96 config_class = ColConfig
97 """Configuration class for the Col model."""
98
[docs]
99 def __init__(self, config: ColConfig, *args, **kwargs) -> None:
100 """Initializes a Col model given a :class:`.ColConfig`.
101
102 Args:
103 config (ColConfig): Configuration for the Col model.
104 Raises:
105 ValueError: If the embedding dimension is not specified in the configuration.
106 """
107 super().__init__(config, *args, **kwargs)
108 if config.embedding_dim is None:
109 raise ValueError("Embedding dimension must be specified in the configuration.")
110 self.projection = torch.nn.Linear(
111 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection
112 )
113
[docs]
114 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
115 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
116 out vectors during scoring.
117
118 Args:
119 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
120 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
121 Returns:
122 torch.Tensor: Scoring mask.
123 """
124 input_ids = encoding["input_ids"]
125 attention_mask = encoding["attention_mask"]
126 scoring_mask = attention_mask
127 expansion = getattr(self.config, f"{input_type}_expansion")
128 if expansion or scoring_mask is None:
129 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool)
130 scoring_mask = scoring_mask.bool()
131 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids")
132 if mask_scoring_input_ids is not None:
133 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1)
134 scoring_mask = scoring_mask & ~ignore_mask
135 return scoring_mask
136
[docs]
137 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
138 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
139
140 Args:
141 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
142 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
143 Returns:
144 BiEncoderEmbedding: Embeddings and scoring mask.
145 """
146 embeddings = self._backbone_forward(**encoding).last_hidden_state
147 embeddings = self.projection(embeddings)
148 if self.config.normalize:
149 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
150 scoring_mask = self.scoring_mask(encoding, input_type)
151 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
152
153
[docs]
154class ColTokenizer(BiEncoderTokenizer):
155 """:class:`.LightningIRTokenizer` for Col models."""
156
157 config_class = ColConfig
158 """Configuration class for the tokenizer."""
159
[docs]
160 def __init__(
161 self,
162 *args,
163 query_length: int = 32,
164 doc_length: int = 512,
165 add_marker_tokens: bool = False,
166 query_expansion: bool = False,
167 attend_to_query_expanded_tokens: bool = False,
168 doc_expansion: bool = False,
169 attend_to_doc_expanded_tokens: bool = False,
170 **kwargs,
171 ):
172 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens
173 to encoded input sequences and expands queries and documents with mask tokens.
174
175 Args:
176 query_length (int): Maximum query length in number of tokens. Defaults to 32.
177 doc_length (int): Maximum document length in number of tokens. Defaults to 512.
178 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
179 Defaults to False.
180 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
181 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
182 tokens. Defaults to False.
183 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
184 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
185 tokens. Defaults to False.
186 Raises:
187 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
188 """
189 super().__init__(
190 *args,
191 query_length=query_length,
192 doc_length=doc_length,
193 query_expansion=query_expansion,
194 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens,
195 doc_expansion=doc_expansion,
196 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens,
197 add_marker_tokens=add_marker_tokens,
198 **kwargs,
199 )
200 self.query_expansion = query_expansion
201 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
202 self.doc_expansion = doc_expansion
203 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
204
205 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
206 """Applies mask expansion to the input encoding."""
207 input_ids = encoding["input_ids"]
208 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
209 encoding["input_ids"] = input_ids
210 if attend_to_expanded_tokens:
211 encoding["attention_mask"].fill_(1)
212 return encoding
213