1"""Configuration, model, and embedding for COIL (Contextualized Inverted list) type models. Originally proposed in
2`COIL: Revisit Exact Lexical Match in Information Retrieval with Contextualized Inverted list \
3<https://arxiv.org/abs/2104.07186>`_."""
4
5from collections.abc import Sequence
6from dataclasses import dataclass
7from typing import Literal
8
9import torch
10from transformers import BatchEncoding
11
12from ...bi_encoder import (
13 BiEncoderEmbedding,
14 BiEncoderOutput,
15 MultiVectorBiEncoderConfig,
16 MultiVectorBiEncoderModel,
17 SingleVectorBiEncoderConfig,
18 SingleVectorBiEncoderModel,
19)
20
21
[docs]
22@dataclass
23class CoilEmbedding(BiEncoderEmbedding):
24 """Dataclass containing embeddings and the encoding for COIL models."""
25
26 embeddings: torch.Tensor
27 """Raw embeddings of the COIL model. Should not be used directly for scoring."""
28 token_embeddings: torch.Tensor | None = None
29 """Token embeddings of the COIL model."""
30 cls_embeddings: torch.Tensor | None = None
31 """Separate [CLS] token embeddings."""
32
33
[docs]
34@dataclass
35class CoilOutput(BiEncoderOutput):
36 """Dataclass containing the output of a COIL model."""
37
38 query_embeddings: CoilEmbedding | None = None
39 """Query embeddings generated by the model."""
40 doc_embeddings: CoilEmbedding | None = None
41 """Document embeddings generated by the model."""
42
43
[docs]
44class CoilConfig(MultiVectorBiEncoderConfig):
45 """Configuration class for COIL models."""
46
47 model_type = "coil"
48 """Model type for COIL models."""
49
[docs]
50 def __init__(
51 self,
52 query_length: int | None = 32,
53 doc_length: int | None = 512,
54 similarity_function: Literal["cosine", "dot"] = "dot",
55 add_marker_tokens: bool = False,
56 token_embedding_dim: int = 32,
57 cls_embedding_dim: int = 768,
58 projection: Literal["linear", "linear_no_bias"] = "linear",
59 **kwargs,
60 ) -> None:
61 """A COIL model encodes queries and documents separately, and computes a similarity score using the maximum
62 similarity ...
63
64 Args:
65 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
66 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
67 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
68 document embeddings. Defaults to "dot".
69 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
70 Defaults to False.
71 token_embedding_dim (int | None): The output embedding dimension for tokens. Defaults to 32.
72 cls_embedding_dim (int | None): The output embedding dimension for the [CLS] token. Defaults to 768.
73 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings.
74 Defaults to "linear".
75 """
76 super().__init__(
77 query_length=query_length,
78 doc_length=doc_length,
79 similarity_function=similarity_function,
80 add_marker_tokens=add_marker_tokens,
81 **kwargs,
82 )
83 self.projection = projection
84 self.token_embedding_dim = token_embedding_dim
85 self.cls_embedding_dim = cls_embedding_dim
86
87
[docs]
88class CoilModel(MultiVectorBiEncoderModel):
89 """Multi-vector COIL model. See :class:`.CoilConfig` for configuration options."""
90
91 config_class = CoilConfig
92 """Configuration class for COIL models."""
93
[docs]
94 def __init__(self, config: CoilConfig, *args, **kwargs) -> None:
95 """Initializes a COIL model given a :class:`.CoilConfig` configuration.
96
97 Args:
98 config (CoilConfig): Configuration for the COIL model.
99 """
100 super().__init__(config, *args, **kwargs)
101 self.config: CoilConfig
102 self.token_projection = torch.nn.Linear(
103 self.config.hidden_size,
104 self.config.token_embedding_dim,
105 bias="no_bias" not in self.config.projection,
106 )
107 self.cls_projection = torch.nn.Linear(
108 self.config.hidden_size,
109 self.config.cls_embedding_dim,
110 bias="no_bias" not in self.config.projection,
111 )
112
[docs]
113 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> CoilEmbedding:
114 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
115
116 Args:
117 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
118 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
119 Returns:
120 BiEncoderEmbedding: Embeddings and scoring mask.
121 """
122 embeddings = self._backbone_forward(**encoding).last_hidden_state
123
124 cls_embeddings = self.cls_projection(embeddings[:, [0]])
125 token_embeddings = self.token_projection(embeddings[:, 1:])
126
127 scoring_mask = self.scoring_mask(encoding, input_type)
128 return CoilEmbedding(
129 embeddings, scoring_mask, encoding, cls_embeddings=cls_embeddings, token_embeddings=token_embeddings
130 )
131
[docs]
132 def score(
133 self,
134 output: CoilOutput,
135 num_docs: Sequence[int] | int | None = None,
136 ) -> CoilOutput:
137 """Compute relevance scores between queries and documents.
138
139 Args:
140 query_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the queries.
141 doc_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the documents.
142 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
143 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
144 number of documents, i.e., the sequence contains one value per query specifying the number of documents
145 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
146 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
147 Returns:
148 torch.Tensor: Relevance scores."""
149 query_embeddings = output.query_embeddings
150 doc_embeddings = output.doc_embeddings
151 if query_embeddings is None or doc_embeddings is None:
152 raise ValueError("Query and document embeddings must be provided for scoring.")
153 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None:
154 raise ValueError("Scoring masks expected for scoring multi-vector embeddings")
155 if (
156 query_embeddings.cls_embeddings is None
157 or doc_embeddings.cls_embeddings is None
158 or query_embeddings.token_embeddings is None
159 or doc_embeddings.token_embeddings is None
160 ):
161 raise ValueError("COIL embeddings must contain cls_embeddings and token_embeddings")
162
163 cls_scores = self.compute_similarity(
164 BiEncoderEmbedding(query_embeddings.cls_embeddings),
165 BiEncoderEmbedding(doc_embeddings.cls_embeddings),
166 num_docs,
167 ).view(-1)
168
169 token_similarity = self.compute_similarity(
170 BiEncoderEmbedding(query_embeddings.token_embeddings),
171 BiEncoderEmbedding(doc_embeddings.token_embeddings),
172 num_docs,
173 )
174 num_docs_t = self._parse_num_docs(
175 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device
176 )
177 query = query_embeddings.encoding.input_ids.repeat_interleave(num_docs_t, 0)[:, 1:]
178 docs = doc_embeddings.encoding.input_ids[:, 1:]
179 mask = (query[:, :, None] == docs[:, None, :]).to(token_similarity)
180 token_similarity = token_similarity * mask
181 token_scores = self.aggregate_similarity(
182 token_similarity, query_embeddings.scoring_mask[:, 1:], doc_embeddings.scoring_mask[:, 1:], num_docs
183 )
184
185 output.scores = cls_scores + token_scores
186 return output
187
188
[docs]
189class UniCoilConfig(SingleVectorBiEncoderConfig):
190 """Configuration class for UniCOIL models."""
191
192 model_type = "unicoil"
193 """Model type for UniCOIL models."""
194
[docs]
195 def __init__(
196 self,
197 query_length: int | None = 32,
198 doc_length: int | None = 512,
199 similarity_function: Literal["cosine", "dot"] = "dot",
200 projection: Literal["linear", "linear_no_bias"] = "linear",
201 **kwargs,
202 ) -> None:
203 """A UniCOIL model encodes queries and documents separately, and computes a similarity score using the maximum
204 similarity of token embeddings between query and document.
205
206 Args:
207 query_length (int | None): Maximum query length in number of tokens. Defaults to 32.
208 doc_length (int | None): Maximum document length in number of tokens. Defaults to 512.
209 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
210 document embeddings. Defaults to "dot".
211 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings.
212 Defaults to "linear".
213 """
214 super().__init__(
215 query_length=query_length,
216 doc_length=doc_length,
217 similarity_function=similarity_function,
218 **kwargs,
219 )
220 self.projection = projection
221
222 @property
223 def embedding_dim(self) -> int:
224 vocab_size = getattr(self, "vocab_size", None)
225 if vocab_size is None:
226 raise ValueError("Unable to determine embedding dimension.")
227 return vocab_size
228
229 @embedding_dim.setter
230 def embedding_dim(self, value: int) -> None:
231 pass
232
233
[docs]
234class UniCoilModel(SingleVectorBiEncoderModel):
235 """Single-vector UniCOIL model. See :class:`.UniCoilConfig` for configuration options."""
236
237 config_class = UniCoilConfig
238 """Configuration class for UniCOIL models."""
239
[docs]
240 def __init__(self, config: UniCoilConfig, *args, **kwargs) -> None:
241 """Initializes a UniCOIL model given a :class:`.UniCoilConfig` configuration.
242
243 Args:
244 config (UniCoilConfig): Configuration for the UniCOIL model.
245 """
246 super().__init__(config, *args, **kwargs)
247 self.config: UniCoilConfig
248 self.token_projection = torch.nn.Linear(
249 self.config.hidden_size, 1, bias="no_bias" not in self.config.projection
250 )
251
[docs]
252 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
253 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
254
255 Args:
256 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
257 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
258 Returns:
259 BiEncoderEmbedding: Embeddings and scoring mask.
260 """
261 contextualized_embeddings = self._backbone_forward(**encoding).last_hidden_state
262
263 token_weights = self.token_projection(contextualized_embeddings).squeeze(-1)
264 if encoding["attention_mask"] is not None:
265 token_weights = token_weights.masked_fill(~(encoding["attention_mask"].bool()), 0)
266 token_weights = torch.relu(token_weights)
267 embeddings = torch.zeros(encoding.input_ids.shape[0], self.config.vocab_size, device=token_weights.device)
268 embeddings = embeddings.scatter(1, encoding.input_ids, token_weights)
269 return BiEncoderEmbedding(embeddings[:, None], None, encoding)