1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models. Originally proposed in
2`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
3<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
4to work with any transformer backbone model.
5"""
6
7from typing import Literal, Sequence
8
9import torch
10from transformers import BatchEncoding
11
12from ...bi_encoder import BiEncoderEmbedding, BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
13
14
[docs]
15class ColConfig(MultiVectorBiEncoderConfig):
16 """Configuration class for a Col model."""
17
18 model_type = "col"
19 """Model type for a Col model."""
20
[docs]
21 def __init__(
22 self,
23 query_length: int | None = 32,
24 doc_length: int | None = 512,
25 similarity_function: Literal["cosine", "dot"] = "dot",
26 normalization: Literal["l2"] | None = None,
27 add_marker_tokens: bool = False,
28 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
29 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
30 query_aggregation_function: Literal["sum", "mean", "max"] = "sum",
31 doc_aggregation_function: Literal["sum", "mean", "max"] = "max",
32 embedding_dim: int = 128,
33 projection: Literal["linear", "linear_no_bias"] = "linear",
34 query_expansion: bool = False,
35 attend_to_query_expanded_tokens: bool = False,
36 doc_expansion: bool = False,
37 attend_to_doc_expanded_tokens: bool = False,
38 **kwargs,
39 ):
40 """A Col model encodes queries and documents separately and computes a late interaction score between the query
41 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with
42 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a
43 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can
44 be ignored during scoring.
45
46 Args:
47 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
48 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
49 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
50 document embeddings. Defaults to "dot".
51 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None.
52 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
53 Defaults to False.
54 query_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which query tokens
55 to ignore during scoring. Defaults to None.
56 doc_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which document tokens
57 to ignore during scoring. Defaults to None.
58 query_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate
59 similarity scores over query tokens. Defaults to "sum".
60 doc_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate
61 similarity scores over document tokens. Defaults to "max".
62 embedding_dim (int): The output embedding dimension. Defaults to 128.
63 projection (Literal["linear", "linear_no_bias"]): Whether and how to project the output embeddings.
64 Defaults to "linear". If set to "linear_no_bias", the projection layer will not have a bias term.
65 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
66 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
67 tokens. Defaults to False.
68 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
69 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
70 tokens. Defaults to False.
71 """
72 super().__init__(
73 query_length=query_length,
74 doc_length=doc_length,
75 similarity_function=similarity_function,
76 normalization=normalization,
77 add_marker_tokens=add_marker_tokens,
78 query_mask_scoring_tokens=query_mask_scoring_tokens,
79 doc_mask_scoring_tokens=doc_mask_scoring_tokens,
80 query_aggregation_function=query_aggregation_function,
81 doc_aggregation_function=doc_aggregation_function,
82 **kwargs,
83 )
84 self.embedding_dim = embedding_dim
85 self.projection = projection
86 self.query_expansion = query_expansion
87 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
88 self.doc_expansion = doc_expansion
89 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
90
91
[docs]
92class ColModel(MultiVectorBiEncoderModel):
93 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options."""
94
95 config_class = ColConfig
96 """Configuration class for the Col model."""
97
[docs]
98 def __init__(self, config: ColConfig, *args, **kwargs) -> None:
99 """Initializes a Col model given a :class:`.ColConfig`.
100
101 Args:
102 config (ColConfig): Configuration for the Col model.
103 Raises:
104 ValueError: If the embedding dimension is not specified in the configuration.
105 """
106 super().__init__(config, *args, **kwargs)
107 if config.embedding_dim is None:
108 raise ValueError("Embedding dimension must be specified in the configuration.")
109 self.projection = torch.nn.Linear(
110 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection
111 )
112
[docs]
113 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
114 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
115 out vectors during scoring.
116
117 Args:
118 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
119 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
120 Returns:
121 torch.Tensor: Scoring mask.
122 """
123 input_ids = encoding["input_ids"]
124 attention_mask = encoding["attention_mask"]
125 scoring_mask = attention_mask
126 expansion = getattr(self.config, f"{input_type}_expansion")
127 if expansion or scoring_mask is None:
128 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool)
129 scoring_mask = scoring_mask.bool()
130 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids")
131 if mask_scoring_input_ids is not None:
132 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1)
133 scoring_mask = scoring_mask & ~ignore_mask
134 return scoring_mask
135
[docs]
136 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
137 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
138
139 Args:
140 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
141 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
142 Returns:
143 BiEncoderEmbedding: Embeddings and scoring mask.
144 """
145 embeddings = self._backbone_forward(**encoding).last_hidden_state
146 embeddings = self.projection(embeddings)
147 if self.config.normalization == "l2":
148 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
149 scoring_mask = self.scoring_mask(encoding, input_type)
150 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
151
152
[docs]
153class ColTokenizer(BiEncoderTokenizer):
154 """:class:`.LightningIRTokenizer` for Col models."""
155
156 config_class = ColConfig
157 """Configuration class for the tokenizer."""
158
[docs]
159 def __init__(
160 self,
161 *args,
162 query_length: int | None = 32,
163 doc_length: int | None = 512,
164 add_marker_tokens: bool = False,
165 query_expansion: bool = False,
166 attend_to_query_expanded_tokens: bool = False,
167 doc_expansion: bool = False,
168 attend_to_doc_expanded_tokens: bool = False,
169 **kwargs,
170 ):
171 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens
172 to encoded input sequences and expands queries and documents with mask tokens.
173
174 Args:
175 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
176 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
177 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
178 Defaults to False.
179 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
180 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
181 tokens. Defaults to False.
182 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
183 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
184 tokens. Defaults to False.
185 Raises:
186 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
187 """
188 super().__init__(
189 *args,
190 query_length=query_length,
191 doc_length=doc_length,
192 query_expansion=query_expansion,
193 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens,
194 doc_expansion=doc_expansion,
195 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens,
196 add_marker_tokens=add_marker_tokens,
197 **kwargs,
198 )
199 self.query_expansion = query_expansion
200 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
201 self.doc_expansion = doc_expansion
202 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
203
204 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
205 """Applies mask expansion to the input encoding."""
206 input_ids = encoding["input_ids"]
207 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
208 encoding["input_ids"] = input_ids
209 if attend_to_expanded_tokens:
210 encoding["attention_mask"].fill_(1)
211 return encoding
212