1"""Configuration, model, and embedding for COIL (Contextualized Inverted list) type models.
2
3COIL is an information retrieval model that combines the speed of traditional keyword search with the deep
4understanding of neural networks. It generates context-aware vector representations for every word in a document and
5stores them in a standard inverted index. During a search, the model only calculates similarity scores for the exact
6words that appear in both the query and the document.
7
8Originally proposed in
9`COIL: Revisit Exact Lexical Match in Information Retrieval with Contextualized Inverted list \
10<https://arxiv.org/abs/2104.07186>`_."""
11
12from collections.abc import Sequence
13from dataclasses import dataclass
14from typing import Literal
15
16import torch
17from transformers import BatchEncoding
18
19from ...bi_encoder import (
20 BiEncoderEmbedding,
21 BiEncoderOutput,
22 MultiVectorBiEncoderConfig,
23 MultiVectorBiEncoderModel,
24 SingleVectorBiEncoderConfig,
25 SingleVectorBiEncoderModel,
26)
27
28
[docs]
29@dataclass
30class CoilEmbedding(BiEncoderEmbedding):
31 """Dataclass containing embeddings and the encoding for COIL models."""
32
33 embeddings: torch.Tensor
34 """Raw embeddings of the COIL model. Should not be used directly for scoring."""
35 token_embeddings: torch.Tensor | None = None
36 """Token embeddings of the COIL model."""
37 cls_embeddings: torch.Tensor | None = None
38 """Separate [CLS] token embeddings."""
39
40
[docs]
41@dataclass
42class CoilOutput(BiEncoderOutput):
43 """Dataclass containing the output of a COIL model."""
44
45 query_embeddings: CoilEmbedding | None = None
46 """Query embeddings generated by the model."""
47 doc_embeddings: CoilEmbedding | None = None
48 """Document embeddings generated by the model."""
49
50
[docs]
51class CoilConfig(MultiVectorBiEncoderConfig):
52 """Configuration class for COIL models."""
53
54 model_type = "coil"
55 """Model type for COIL models."""
56
[docs]
57 def __init__(
58 self,
59 query_length: int | None = 32,
60 doc_length: int | None = 512,
61 similarity_function: Literal["cosine", "dot"] = "dot",
62 add_marker_tokens: bool = False,
63 token_embedding_dim: int = 32,
64 cls_embedding_dim: int = 768,
65 projection: Literal["linear", "linear_no_bias"] = "linear",
66 **kwargs,
67 ) -> None:
68 """A COIL model encodes queries and documents separately, and computes a similarity score using the maximum
69 similarity ...
70
71 Args:
72 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
73 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
74 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
75 document embeddings. Defaults to "dot".
76 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
77 Defaults to False.
78 token_embedding_dim (int | None): The output embedding dimension for tokens. Defaults to 32.
79 cls_embedding_dim (int | None): The output embedding dimension for the [CLS] token. Defaults to 768.
80 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings.
81 Defaults to "linear".
82 """
83 super().__init__(
84 query_length=query_length,
85 doc_length=doc_length,
86 similarity_function=similarity_function,
87 add_marker_tokens=add_marker_tokens,
88 **kwargs,
89 )
90 self.projection = projection
91 self.token_embedding_dim = token_embedding_dim
92 self.cls_embedding_dim = cls_embedding_dim
93
94
[docs]
95class CoilModel(MultiVectorBiEncoderModel):
96 """Multi-vector COIL model. See :class:`.CoilConfig` for configuration options."""
97
98 config_class = CoilConfig
99 """Configuration class for COIL models."""
100
[docs]
101 def __init__(self, config: CoilConfig, *args, **kwargs) -> None:
102 """Initializes a COIL model given a :class:`.CoilConfig` configuration.
103
104 Args:
105 config (CoilConfig): Configuration for the COIL model.
106 """
107 super().__init__(config, *args, **kwargs)
108 self.config: CoilConfig
109 self.token_projection = torch.nn.Linear(
110 self.config.hidden_size,
111 self.config.token_embedding_dim,
112 bias="no_bias" not in self.config.projection,
113 )
114 self.cls_projection = torch.nn.Linear(
115 self.config.hidden_size,
116 self.config.cls_embedding_dim,
117 bias="no_bias" not in self.config.projection,
118 )
119
[docs]
120 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> CoilEmbedding:
121 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
122
123 Args:
124 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
125 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
126 Returns:
127 BiEncoderEmbedding: Embeddings and scoring mask.
128 """
129 embeddings = self._backbone_forward(**encoding).last_hidden_state
130
131 cls_embeddings = self.cls_projection(embeddings[:, [0]])
132 token_embeddings = self.token_projection(embeddings[:, 1:])
133
134 scoring_mask = self.scoring_mask(encoding, input_type)
135 return CoilEmbedding(
136 embeddings,
137 scoring_mask,
138 encoding,
139 cls_embeddings=cls_embeddings,
140 token_embeddings=token_embeddings,
141 )
142
[docs]
143 def score(
144 self,
145 output: CoilOutput,
146 num_docs: Sequence[int] | int | None = None,
147 ) -> CoilOutput:
148 """Compute relevance scores between queries and documents.
149
150 Args:
151 query_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the queries.
152 doc_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the documents.
153 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
154 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
155 number of documents, i.e., the sequence contains one value per query specifying the number of documents
156 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
157 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
158 Returns:
159 torch.Tensor: Relevance scores."""
160 query_embeddings = output.query_embeddings
161 doc_embeddings = output.doc_embeddings
162 if query_embeddings is None or doc_embeddings is None:
163 raise ValueError("Query and document embeddings must be provided for scoring.")
164 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None:
165 raise ValueError("Scoring masks expected for scoring multi-vector embeddings")
166 if (
167 query_embeddings.cls_embeddings is None
168 or doc_embeddings.cls_embeddings is None
169 or query_embeddings.token_embeddings is None
170 or doc_embeddings.token_embeddings is None
171 ):
172 raise ValueError("COIL embeddings must contain cls_embeddings and token_embeddings")
173
174 cls_scores = self.compute_similarity(
175 BiEncoderEmbedding(query_embeddings.cls_embeddings),
176 BiEncoderEmbedding(doc_embeddings.cls_embeddings),
177 num_docs,
178 ).view(-1)
179
180 token_similarity = self.compute_similarity(
181 BiEncoderEmbedding(query_embeddings.token_embeddings),
182 BiEncoderEmbedding(doc_embeddings.token_embeddings),
183 num_docs,
184 )
185 num_docs_t = self._parse_num_docs(
186 query_embeddings.embeddings.shape[0],
187 doc_embeddings.embeddings.shape[0],
188 num_docs,
189 query_embeddings.device,
190 )
191 query = query_embeddings.encoding.input_ids.repeat_interleave(num_docs_t, 0)[:, 1:]
192 docs = doc_embeddings.encoding.input_ids[:, 1:]
193 mask = (query[:, :, None] == docs[:, None, :]).to(token_similarity)
194 token_similarity = token_similarity * mask
195 token_scores = self.aggregate_similarity(
196 token_similarity,
197 query_embeddings.scoring_mask[:, 1:],
198 doc_embeddings.scoring_mask[:, 1:],
199 num_docs,
200 )
201
202 output.scores = cls_scores + token_scores
203 return output
204
205
[docs]
206class UniCoilConfig(SingleVectorBiEncoderConfig):
207 """Configuration class for UniCOIL models."""
208
209 model_type = "unicoil"
210 """Model type for UniCOIL models."""
211
[docs]
212 def __init__(
213 self,
214 query_length: int | None = 32,
215 doc_length: int | None = 512,
216 similarity_function: Literal["cosine", "dot"] = "dot",
217 projection: Literal["linear", "linear_no_bias"] = "linear",
218 **kwargs,
219 ) -> None:
220 """A UniCOIL model encodes queries and documents separately, and computes a similarity score using the maximum
221 similarity of token embeddings between query and document.
222
223 Args:
224 query_length (int | None): Maximum query length in number of tokens. Defaults to 32.
225 doc_length (int | None): Maximum document length in number of tokens. Defaults to 512.
226 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
227 document embeddings. Defaults to "dot".
228 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings.
229 Defaults to "linear".
230 """
231 super().__init__(
232 query_length=query_length,
233 doc_length=doc_length,
234 similarity_function=similarity_function,
235 **kwargs,
236 )
237 self.projection = projection
238
239 @property
240 def embedding_dim(self) -> int:
241 vocab_size = getattr(self, "vocab_size", None)
242 if vocab_size is None:
243 raise ValueError("Unable to determine embedding dimension.")
244 return vocab_size
245
246 @embedding_dim.setter
247 def embedding_dim(self, value: int) -> None:
248 pass
249
250
[docs]
251class UniCoilModel(SingleVectorBiEncoderModel):
252 """Single-vector UniCOIL model. See :class:`.UniCoilConfig` for configuration options."""
253
254 config_class = UniCoilConfig
255 """Configuration class for UniCOIL models."""
256
[docs]
257 def __init__(self, config: UniCoilConfig, *args, **kwargs) -> None:
258 """Initializes a UniCOIL model given a :class:`.UniCoilConfig` configuration.
259
260 Args:
261 config (UniCoilConfig): Configuration for the UniCOIL model.
262 """
263 super().__init__(config, *args, **kwargs)
264 self.config: UniCoilConfig
265 self.token_projection = torch.nn.Linear(
266 self.config.hidden_size, 1, bias="no_bias" not in self.config.projection
267 )
268
[docs]
269 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
270 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
271
272 Args:
273 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
274 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
275 Returns:
276 BiEncoderEmbedding: Embeddings and scoring mask.
277 """
278 contextualized_embeddings = self._backbone_forward(**encoding).last_hidden_state
279
280 token_weights = self.token_projection(contextualized_embeddings).squeeze(-1)
281 if encoding["attention_mask"] is not None:
282 token_weights = token_weights.masked_fill(~(encoding["attention_mask"].bool()), 0)
283 token_weights = torch.relu(token_weights)
284 embeddings = torch.zeros(
285 encoding.input_ids.shape[0],
286 self.config.vocab_size,
287 device=token_weights.device,
288 )
289 embeddings = embeddings.scatter(1, encoding.input_ids, token_weights)
290 return BiEncoderEmbedding(embeddings[:, None], None, encoding)