Source code for lightning_ir.models.bi_encoders.coil

  1"""Configuration, model, and embedding for COIL (Contextualized Inverted list) type models.
  2
  3COIL is an information retrieval model that combines the speed of traditional keyword search with the deep
  4understanding of neural networks. It generates context-aware vector representations for every word in a document and
  5stores them in a standard inverted index. During a search, the model only calculates similarity scores for the exact
  6words that appear in both the query and the document.
  7
  8Originally proposed in
  9`COIL: Revisit Exact Lexical Match in Information Retrieval with Contextualized Inverted list \
 10<https://arxiv.org/abs/2104.07186>`_."""
 11
 12from collections.abc import Sequence
 13from dataclasses import dataclass
 14from typing import Literal
 15
 16import torch
 17from transformers import BatchEncoding
 18
 19from ...bi_encoder import (
 20    BiEncoderEmbedding,
 21    BiEncoderOutput,
 22    MultiVectorBiEncoderConfig,
 23    MultiVectorBiEncoderModel,
 24    SingleVectorBiEncoderConfig,
 25    SingleVectorBiEncoderModel,
 26)
 27
 28
[docs] 29@dataclass 30class CoilEmbedding(BiEncoderEmbedding): 31 """Dataclass containing embeddings and the encoding for COIL models.""" 32 33 embeddings: torch.Tensor 34 """Raw embeddings of the COIL model. Should not be used directly for scoring.""" 35 token_embeddings: torch.Tensor | None = None 36 """Token embeddings of the COIL model.""" 37 cls_embeddings: torch.Tensor | None = None 38 """Separate [CLS] token embeddings."""
39 40
[docs] 41@dataclass 42class CoilOutput(BiEncoderOutput): 43 """Dataclass containing the output of a COIL model.""" 44 45 query_embeddings: CoilEmbedding | None = None 46 """Query embeddings generated by the model.""" 47 doc_embeddings: CoilEmbedding | None = None 48 """Document embeddings generated by the model."""
49 50
[docs] 51class CoilConfig(MultiVectorBiEncoderConfig): 52 """Configuration class for COIL models.""" 53 54 model_type = "coil" 55 """Model type for COIL models.""" 56
[docs] 57 def __init__( 58 self, 59 query_length: int | None = 32, 60 doc_length: int | None = 512, 61 similarity_function: Literal["cosine", "dot"] = "dot", 62 add_marker_tokens: bool = False, 63 token_embedding_dim: int = 32, 64 cls_embedding_dim: int = 768, 65 projection: Literal["linear", "linear_no_bias"] = "linear", 66 **kwargs, 67 ) -> None: 68 """A COIL model encodes queries and documents separately, and computes a similarity score using the maximum 69 similarity ... 70 71 Args: 72 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 73 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 74 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 75 document embeddings. Defaults to "dot". 76 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 77 Defaults to False. 78 token_embedding_dim (int | None): The output embedding dimension for tokens. Defaults to 32. 79 cls_embedding_dim (int | None): The output embedding dimension for the [CLS] token. Defaults to 768. 80 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings. 81 Defaults to "linear". 82 """ 83 super().__init__( 84 query_length=query_length, 85 doc_length=doc_length, 86 similarity_function=similarity_function, 87 add_marker_tokens=add_marker_tokens, 88 **kwargs, 89 ) 90 self.projection = projection 91 self.token_embedding_dim = token_embedding_dim 92 self.cls_embedding_dim = cls_embedding_dim
93 94
[docs] 95class CoilModel(MultiVectorBiEncoderModel): 96 """Multi-vector COIL model. See :class:`.CoilConfig` for configuration options.""" 97 98 config_class = CoilConfig 99 """Configuration class for COIL models.""" 100
[docs] 101 def __init__(self, config: CoilConfig, *args, **kwargs) -> None: 102 """Initializes a COIL model given a :class:`.CoilConfig` configuration. 103 104 Args: 105 config (CoilConfig): Configuration for the COIL model. 106 """ 107 super().__init__(config, *args, **kwargs) 108 self.config: CoilConfig 109 self.token_projection = torch.nn.Linear( 110 self.config.hidden_size, 111 self.config.token_embedding_dim, 112 bias="no_bias" not in self.config.projection, 113 ) 114 self.cls_projection = torch.nn.Linear( 115 self.config.hidden_size, 116 self.config.cls_embedding_dim, 117 bias="no_bias" not in self.config.projection, 118 )
119
[docs] 120 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> CoilEmbedding: 121 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 122 123 Args: 124 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 125 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 126 Returns: 127 BiEncoderEmbedding: Embeddings and scoring mask. 128 """ 129 embeddings = self._backbone_forward(**encoding).last_hidden_state 130 131 cls_embeddings = self.cls_projection(embeddings[:, [0]]) 132 token_embeddings = self.token_projection(embeddings[:, 1:]) 133 134 scoring_mask = self.scoring_mask(encoding, input_type) 135 return CoilEmbedding( 136 embeddings, 137 scoring_mask, 138 encoding, 139 cls_embeddings=cls_embeddings, 140 token_embeddings=token_embeddings, 141 )
142
[docs] 143 def score( 144 self, 145 output: CoilOutput, 146 num_docs: Sequence[int] | int | None = None, 147 ) -> CoilOutput: 148 """Compute relevance scores between queries and documents. 149 150 Args: 151 query_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the queries. 152 doc_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the documents. 153 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 154 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 155 number of documents, i.e., the sequence contains one value per query specifying the number of documents 156 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 157 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 158 Returns: 159 torch.Tensor: Relevance scores.""" 160 query_embeddings = output.query_embeddings 161 doc_embeddings = output.doc_embeddings 162 if query_embeddings is None or doc_embeddings is None: 163 raise ValueError("Query and document embeddings must be provided for scoring.") 164 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None: 165 raise ValueError("Scoring masks expected for scoring multi-vector embeddings") 166 if ( 167 query_embeddings.cls_embeddings is None 168 or doc_embeddings.cls_embeddings is None 169 or query_embeddings.token_embeddings is None 170 or doc_embeddings.token_embeddings is None 171 ): 172 raise ValueError("COIL embeddings must contain cls_embeddings and token_embeddings") 173 174 cls_scores = self.compute_similarity( 175 BiEncoderEmbedding(query_embeddings.cls_embeddings), 176 BiEncoderEmbedding(doc_embeddings.cls_embeddings), 177 num_docs, 178 ).view(-1) 179 180 token_similarity = self.compute_similarity( 181 BiEncoderEmbedding(query_embeddings.token_embeddings), 182 BiEncoderEmbedding(doc_embeddings.token_embeddings), 183 num_docs, 184 ) 185 num_docs_t = self._parse_num_docs( 186 query_embeddings.embeddings.shape[0], 187 doc_embeddings.embeddings.shape[0], 188 num_docs, 189 query_embeddings.device, 190 ) 191 query = query_embeddings.encoding.input_ids.repeat_interleave(num_docs_t, 0)[:, 1:] 192 docs = doc_embeddings.encoding.input_ids[:, 1:] 193 mask = (query[:, :, None] == docs[:, None, :]).to(token_similarity) 194 token_similarity = token_similarity * mask 195 token_scores = self.aggregate_similarity( 196 token_similarity, 197 query_embeddings.scoring_mask[:, 1:], 198 doc_embeddings.scoring_mask[:, 1:], 199 num_docs, 200 ) 201 202 output.scores = cls_scores + token_scores 203 return output
204 205
[docs] 206class UniCoilConfig(SingleVectorBiEncoderConfig): 207 """Configuration class for UniCOIL models.""" 208 209 model_type = "unicoil" 210 """Model type for UniCOIL models.""" 211
[docs] 212 def __init__( 213 self, 214 query_length: int | None = 32, 215 doc_length: int | None = 512, 216 similarity_function: Literal["cosine", "dot"] = "dot", 217 projection: Literal["linear", "linear_no_bias"] = "linear", 218 **kwargs, 219 ) -> None: 220 """A UniCOIL model encodes queries and documents separately, and computes a similarity score using the maximum 221 similarity of token embeddings between query and document. 222 223 Args: 224 query_length (int | None): Maximum query length in number of tokens. Defaults to 32. 225 doc_length (int | None): Maximum document length in number of tokens. Defaults to 512. 226 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 227 document embeddings. Defaults to "dot". 228 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings. 229 Defaults to "linear". 230 """ 231 super().__init__( 232 query_length=query_length, 233 doc_length=doc_length, 234 similarity_function=similarity_function, 235 **kwargs, 236 ) 237 self.projection = projection
238 239 @property 240 def embedding_dim(self) -> int: 241 vocab_size = getattr(self, "vocab_size", None) 242 if vocab_size is None: 243 raise ValueError("Unable to determine embedding dimension.") 244 return vocab_size 245 246 @embedding_dim.setter 247 def embedding_dim(self, value: int) -> None: 248 pass
249 250
[docs] 251class UniCoilModel(SingleVectorBiEncoderModel): 252 """Single-vector UniCOIL model. See :class:`.UniCoilConfig` for configuration options.""" 253 254 config_class = UniCoilConfig 255 """Configuration class for UniCOIL models.""" 256
[docs] 257 def __init__(self, config: UniCoilConfig, *args, **kwargs) -> None: 258 """Initializes a UniCOIL model given a :class:`.UniCoilConfig` configuration. 259 260 Args: 261 config (UniCoilConfig): Configuration for the UniCOIL model. 262 """ 263 super().__init__(config, *args, **kwargs) 264 self.config: UniCoilConfig 265 self.token_projection = torch.nn.Linear( 266 self.config.hidden_size, 1, bias="no_bias" not in self.config.projection 267 )
268
[docs] 269 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 270 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 271 272 Args: 273 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 274 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 275 Returns: 276 BiEncoderEmbedding: Embeddings and scoring mask. 277 """ 278 contextualized_embeddings = self._backbone_forward(**encoding).last_hidden_state 279 280 token_weights = self.token_projection(contextualized_embeddings).squeeze(-1) 281 if encoding["attention_mask"] is not None: 282 token_weights = token_weights.masked_fill(~(encoding["attention_mask"].bool()), 0) 283 token_weights = torch.relu(token_weights) 284 embeddings = torch.zeros( 285 encoding.input_ids.shape[0], 286 self.config.vocab_size, 287 device=token_weights.device, 288 ) 289 embeddings = embeddings.scatter(1, encoding.input_ids, token_weights) 290 return BiEncoderEmbedding(embeddings[:, None], None, encoding)