1"""Configuration, model, and embedding for COIL (Contextualized Inverted List) type models. Originally proposed in
2`COIL: Revisit Exact Lexical Match in Information Retrieval with Contextualized Inverted List \
3<https://arxiv.org/abs/2104.07186>`_."""
4
5from dataclasses import dataclass
6from typing import Literal, Sequence
7
8import torch
9from transformers import BatchEncoding
10
11from ...bi_encoder import (
12 BiEncoderEmbedding,
13 BiEncoderOutput,
14 MultiVectorBiEncoderConfig,
15 MultiVectorBiEncoderModel,
16 SingleVectorBiEncoderConfig,
17 SingleVectorBiEncoderModel,
18)
19
20
[docs]
21@dataclass
22class CoilEmbedding(BiEncoderEmbedding):
23 """Dataclass containing embeddings and the encoding for COIL models."""
24
25 embeddings: torch.Tensor
26 """Raw embeddings of the COIL model. Should not be used directly for scoring."""
27 token_embeddings: torch.Tensor | None = None
28 """Token embeddings of the COIL model."""
29 cls_embeddings: torch.Tensor | None = None
30 """Separate [CLS] token embeddings."""
31
32
[docs]
33@dataclass
34class CoilOutput(BiEncoderOutput):
35 """Dataclass containing the output of a COIL model."""
36
37 query_embeddings: CoilEmbedding | None = None
38 """Query embeddings generated by the model."""
39 doc_embeddings: CoilEmbedding | None = None
40 """Document embeddings generated by the model."""
41
42
[docs]
43class CoilConfig(MultiVectorBiEncoderConfig):
44 """Configuration class for COIL models."""
45
46 model_type = "coil"
47 """Model type for COIL models."""
48
[docs]
49 def __init__(
50 self,
51 query_length: int | None = 32,
52 doc_length: int | None = 512,
53 similarity_function: Literal["cosine", "dot"] = "dot",
54 add_marker_tokens: bool = False,
55 token_embedding_dim: int = 32,
56 cls_embedding_dim: int = 768,
57 projection: Literal["linear", "linear_no_bias"] = "linear",
58 **kwargs,
59 ) -> None:
60 """A COIL model encodes queries and documents separately, and computes a similarity score using the maximum
61 similarity ...
62
63 Args:
64 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
65 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
66 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
67 document embeddings. Defaults to "dot".
68 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
69 Defaults to False.
70 token_embedding_dim (int | None): The output embedding dimension for tokens. Defaults to 32.
71 cls_embedding_dim (int | None): The output embedding dimension for the [CLS] token. Defaults to 768.
72 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings.
73 Defaults to "linear".
74 """
75 super().__init__(
76 query_length=query_length,
77 doc_length=doc_length,
78 similarity_function=similarity_function,
79 add_marker_tokens=add_marker_tokens,
80 **kwargs,
81 )
82 self.projection = projection
83 self.token_embedding_dim = token_embedding_dim
84 self.cls_embedding_dim = cls_embedding_dim
85
86
[docs]
87class CoilModel(MultiVectorBiEncoderModel):
88 """Multi-vector COIL model. See :class:`.CoilConfig` for configuration options."""
89
90 config_class = CoilConfig
91 """Configuration class for COIL models."""
92
[docs]
93 def __init__(self, config: CoilConfig, *args, **kwargs) -> None:
94 """Initializes a COIL model given a :class:`.CoilConfig` configuration.
95
96 Args:
97 config (CoilConfig): Configuration for the COIL model.
98 """
99 super().__init__(config, *args, **kwargs)
100 self.config: CoilConfig
101 self.token_projection = torch.nn.Linear(
102 self.config.hidden_size,
103 self.config.token_embedding_dim,
104 bias="no_bias" not in self.config.projection,
105 )
106 self.cls_projection = torch.nn.Linear(
107 self.config.hidden_size,
108 self.config.cls_embedding_dim,
109 bias="no_bias" not in self.config.projection,
110 )
111
[docs]
112 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> CoilEmbedding:
113 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
114
115 Args:
116 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
117 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
118 Returns:
119 BiEncoderEmbedding: Embeddings and scoring mask.
120 """
121 embeddings = self._backbone_forward(**encoding).last_hidden_state
122
123 cls_embeddings = self.cls_projection(embeddings[:, [0]])
124 token_embeddings = self.token_projection(embeddings[:, 1:])
125
126 scoring_mask = self.scoring_mask(encoding, input_type)
127 return CoilEmbedding(
128 embeddings, scoring_mask, encoding, cls_embeddings=cls_embeddings, token_embeddings=token_embeddings
129 )
130
[docs]
131 def score(
132 self,
133 output: CoilOutput,
134 num_docs: Sequence[int] | int | None = None,
135 ) -> CoilOutput:
136 """Compute relevance scores between queries and documents.
137
138 Args:
139 query_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the queries.
140 doc_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the documents.
141 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
142 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
143 number of documents, i.e., the sequence contains one value per query specifying the number of documents
144 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
145 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
146 Returns:
147 torch.Tensor: Relevance scores."""
148 query_embeddings = output.query_embeddings
149 doc_embeddings = output.doc_embeddings
150 if query_embeddings is None or doc_embeddings is None:
151 raise ValueError("Query and document embeddings must be provided for scoring.")
152 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None:
153 raise ValueError("Scoring masks expected for scoring multi-vector embeddings")
154 if (
155 query_embeddings.cls_embeddings is None
156 or doc_embeddings.cls_embeddings is None
157 or query_embeddings.token_embeddings is None
158 or doc_embeddings.token_embeddings is None
159 ):
160 raise ValueError("COIL embeddings must contain cls_embeddings and token_embeddings")
161
162 cls_scores = self.compute_similarity(
163 BiEncoderEmbedding(query_embeddings.cls_embeddings),
164 BiEncoderEmbedding(doc_embeddings.cls_embeddings),
165 num_docs,
166 ).view(-1)
167
168 token_similarity = self.compute_similarity(
169 BiEncoderEmbedding(query_embeddings.token_embeddings),
170 BiEncoderEmbedding(doc_embeddings.token_embeddings),
171 num_docs,
172 )
173 num_docs_t = self._parse_num_docs(
174 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device
175 )
176 query = query_embeddings.encoding.input_ids.repeat_interleave(num_docs_t, 0)[:, 1:]
177 docs = doc_embeddings.encoding.input_ids[:, 1:]
178 mask = (query[:, :, None] == docs[:, None, :]).to(token_similarity)
179 token_similarity = token_similarity * mask
180 token_scores = self.aggregate_similarity(
181 token_similarity, query_embeddings.scoring_mask[:, 1:], doc_embeddings.scoring_mask[:, 1:], num_docs
182 )
183
184 output.scores = cls_scores + token_scores
185 return output
186
187
[docs]
188class UniCoilConfig(SingleVectorBiEncoderConfig):
189 """Configuration class for UniCOIL models."""
190
191 model_type = "unicoil"
192 """Model type for UniCOIL models."""
193
[docs]
194 def __init__(
195 self,
196 query_length: int | None = 32,
197 doc_length: int | None = 512,
198 similarity_function: Literal["cosine", "dot"] = "dot",
199 projection: Literal["linear", "linear_no_bias"] = "linear",
200 **kwargs,
201 ) -> None:
202 """A UniCOIL model encodes queries and documents separately, and computes a similarity score using the maximum
203 similarity of token embeddings between query and document.
204
205 Args:
206 query_length (int | None): Maximum query length in number of tokens. Defaults to 32.
207 doc_length (int | None): Maximum document length in number of tokens. Defaults to 512.
208 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
209 document embeddings. Defaults to "dot".
210 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings.
211 Defaults to "linear".
212 """
213 super().__init__(
214 query_length=query_length,
215 doc_length=doc_length,
216 similarity_function=similarity_function,
217 **kwargs,
218 )
219 self.projection = projection
220
221 @property
222 def embedding_dim(self) -> int:
223 vocab_size = getattr(self, "vocab_size", None)
224 if vocab_size is None:
225 raise ValueError("Unable to determine embedding dimension.")
226 return vocab_size
227
228 @embedding_dim.setter
229 def embedding_dim(self, value: int) -> None:
230 pass
231
232
[docs]
233class UniCoilModel(SingleVectorBiEncoderModel):
234 """Single-vector UniCOIL model. See :class:`.UniCoilConfig` for configuration options."""
235
236 config_class = UniCoilConfig
237 """Configuration class for UniCOIL models."""
238
[docs]
239 def __init__(self, config: UniCoilConfig, *args, **kwargs) -> None:
240 """Initializes a UniCOIL model given a :class:`.UniCoilConfig` configuration.
241
242 Args:
243 config (UniCoilConfig): Configuration for the UniCOIL model.
244 """
245 super().__init__(config, *args, **kwargs)
246 self.config: UniCoilConfig
247 self.token_projection = torch.nn.Linear(
248 self.config.hidden_size, 1, bias="no_bias" not in self.config.projection
249 )
250
[docs]
251 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
252 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
253
254 Args:
255 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
256 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
257 Returns:
258 BiEncoderEmbedding: Embeddings and scoring mask.
259 """
260 contextualized_embeddings = self._backbone_forward(**encoding).last_hidden_state
261
262 token_weights = self.token_projection(contextualized_embeddings).squeeze(-1)
263 if encoding["attention_mask"] is not None:
264 token_weights = token_weights.masked_fill(~(encoding["attention_mask"].bool()), 0)
265 token_weights = torch.relu(token_weights)
266 embeddings = torch.zeros(encoding.input_ids.shape[0], self.config.vocab_size, device=token_weights.device)
267 embeddings = embeddings.scatter(1, encoding.input_ids, token_weights)
268 return BiEncoderEmbedding(embeddings[:, None], None, encoding)