1"""Configuration, model, and embedding for COIL (Contextualized Inverted List) type models. Originally proposed in
2`COIL: Revisit Exact Lexical Match in Information Retrieval with Contextualized Inverted List \
3<https://arxiv.org/abs/2104.07186>`_."""
4
5from dataclasses import dataclass
6from typing import Literal, Sequence
7
8import torch
9from transformers import BatchEncoding
10
11from ..bi_encoder import BiEncoderEmbedding, BiEncoderOutput, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
12
13
[docs]
14@dataclass
15class CoilEmbedding(BiEncoderEmbedding):
16 """Dataclass containing embeddings and the encoding for COIL models."""
17
18 embeddings: torch.Tensor
19 """Raw embeddings of the COIL model. Should not be used directly for scoring."""
20 token_embeddings: torch.Tensor | None = None
21 """Token embeddings of the COIL model."""
22 cls_embeddings: torch.Tensor | None = None
23 """Separate [CLS] token embeddings."""
24
25
[docs]
26@dataclass
27class CoilOutput(BiEncoderOutput):
28 """Dataclass containing the output of a COIL model."""
29
30 query_embeddings: CoilEmbedding | None = None
31 """Query embeddings generated by the model."""
32 doc_embeddings: CoilEmbedding | None = None
33 """Document embeddings generated by the model."""
34
35
[docs]
36class CoilConfig(MultiVectorBiEncoderConfig):
37 """Configuration class for COIL models."""
38
39 model_type = "coil"
40 """Model type for COIL models."""
41
[docs]
42 def __init__(
43 self,
44 query_length: int = 32,
45 doc_length: int = 512,
46 similarity_function: Literal["cosine", "dot"] = "dot",
47 normalize: bool = False,
48 add_marker_tokens: bool = False,
49 token_embedding_dim: int = 32,
50 cls_embedding_dim: int = 768,
51 projection: Literal["linear", "linear_no_bias"] = "linear",
52 **kwargs,
53 ) -> None:
54 """A COIL model encodes queries and documents separately, and computes a similarity score using the maximum
55 similarity ...
56
57 Args:
58 query_length (int, optional): Maximum query length in number of tokens. Defaults to 32.
59 doc_length (int, optional): Maximum document length in number of tokens. Defaults to 512.
60 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
61 document embeddings. Defaults to "dot".
62 normalize (bool): Whether to normalize query and document embeddings. Defaults to False.
63 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
64 Defaults to False.
65 token_embedding_dim (int, optional): The output embedding dimension for tokens. Defaults to 32.
66 cls_embedding_dim (int, optional): The output embedding dimension for the [CLS] token. Defaults to 768.
67 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings.
68 Defaults to "linear".
69 """
70 super().__init__(
71 query_length=query_length,
72 doc_length=doc_length,
73 similarity_function=similarity_function,
74 normalize=normalize,
75 add_marker_tokens=add_marker_tokens,
76 **kwargs,
77 )
78 self.projection = projection
79 self.token_embedding_dim = token_embedding_dim
80 self.cls_embedding_dim = cls_embedding_dim
81
82
[docs]
83class CoilModel(MultiVectorBiEncoderModel):
84 """Multi-vector COIL model. See :class:`.CoilConfig` for configuration options."""
85
86 config_class = CoilConfig
87 """Configuration class for COIL models."""
88
[docs]
89 def __init__(self, config: CoilConfig, *args, **kwargs) -> None:
90 """Initializes a COIL model given a :class:`.CoilConfig` configuration.
91
92 Args:
93 config (CoilConfig): Configuration for the COIL model.
94 """
95 super().__init__(config, *args, **kwargs)
96 self.config: CoilConfig
97 self.token_projection = torch.nn.Linear(
98 self.config.hidden_size,
99 self.config.token_embedding_dim,
100 bias="no_bias" not in self.config.projection,
101 )
102 self.cls_projection = torch.nn.Linear(
103 self.config.hidden_size,
104 self.config.cls_embedding_dim,
105 bias="no_bias" not in self.config.projection,
106 )
107
[docs]
108 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> CoilEmbedding:
109 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
110
111 Args:
112 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
113 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
114 Returns:
115 BiEncoderEmbedding: Embeddings and scoring mask.
116 """
117 embeddings = self._backbone_forward(**encoding).last_hidden_state
118
119 cls_embeddings = self.cls_projection(embeddings[:, [0]])
120 token_embeddings = self.token_projection(embeddings[:, 1:])
121
122 if self.config.normalize:
123 cls_embeddings = torch.nn.functional.normalize(cls_embeddings, dim=-1)
124 token_embeddings = torch.nn.functional.normalize(token_embeddings, dim=-1)
125
126 scoring_mask = self.scoring_mask(encoding, input_type)
127 return CoilEmbedding(
128 embeddings, scoring_mask, encoding, cls_embeddings=cls_embeddings, token_embeddings=token_embeddings
129 )
130
[docs]
131 def score(
132 self,
133 output: CoilOutput,
134 num_docs: Sequence[int] | int | None = None,
135 ) -> CoilOutput:
136 """Compute relevance scores between queries and documents.
137
138 Args:
139 query_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the queries.
140 doc_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the documents.
141 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
142 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
143 number of documents, i.e., the sequence contains one value per query specifying the number of documents
144 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
145 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
146 Returns:
147 torch.Tensor: Relevance scores."""
148 query_embeddings = output.query_embeddings
149 doc_embeddings = output.doc_embeddings
150 if query_embeddings is None or doc_embeddings is None:
151 raise ValueError("Query and document embeddings must be provided for scoring.")
152 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None:
153 raise ValueError("Scoring masks expected for scoring multi-vector embeddings")
154 if (
155 query_embeddings.cls_embeddings is None
156 or doc_embeddings.cls_embeddings is None
157 or query_embeddings.token_embeddings is None
158 or doc_embeddings.token_embeddings is None
159 ):
160 raise ValueError("COIL embeddings must contain cls_embeddings and token_embeddings")
161
162 cls_scores = self.compute_similarity(
163 BiEncoderEmbedding(query_embeddings.cls_embeddings),
164 BiEncoderEmbedding(doc_embeddings.cls_embeddings),
165 num_docs,
166 ).view(-1)
167
168 token_similarity = self.compute_similarity(
169 BiEncoderEmbedding(query_embeddings.token_embeddings),
170 BiEncoderEmbedding(doc_embeddings.token_embeddings),
171 num_docs,
172 )
173 num_docs_t = self._parse_num_docs(
174 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device
175 )
176 query = query_embeddings.encoding.input_ids.repeat_interleave(num_docs_t, 0)[:, 1:]
177 docs = doc_embeddings.encoding.input_ids[:, 1:]
178 mask = (query[:, :, None] == docs[:, None, :]).to(token_similarity)
179 token_similarity = token_similarity * mask
180 token_scores = self.aggregate_similarity(
181 token_similarity, query_embeddings.scoring_mask[:, 1:], doc_embeddings.scoring_mask[:, 1:], num_docs
182 )
183
184 output.scores = cls_scores + token_scores
185 return output