Source code for lightning_ir.models.bi_encoders.coil

  1"""Configuration, model, and embedding for COIL (Contextualized Inverted List) type models. Originally proposed in
  2`COIL: Revisit Exact Lexical Match in Information Retrieval with Contextualized Inverted List \
  3<https://arxiv.org/abs/2104.07186>`_."""
  4
  5from dataclasses import dataclass
  6from typing import Literal, Sequence
  7
  8import torch
  9from transformers import BatchEncoding
 10
 11from ...bi_encoder import (
 12    BiEncoderEmbedding,
 13    BiEncoderOutput,
 14    MultiVectorBiEncoderConfig,
 15    MultiVectorBiEncoderModel,
 16    SingleVectorBiEncoderConfig,
 17    SingleVectorBiEncoderModel,
 18)
 19
 20
[docs] 21@dataclass 22class CoilEmbedding(BiEncoderEmbedding): 23 """Dataclass containing embeddings and the encoding for COIL models.""" 24 25 embeddings: torch.Tensor 26 """Raw embeddings of the COIL model. Should not be used directly for scoring.""" 27 token_embeddings: torch.Tensor | None = None 28 """Token embeddings of the COIL model.""" 29 cls_embeddings: torch.Tensor | None = None 30 """Separate [CLS] token embeddings."""
31 32
[docs] 33@dataclass 34class CoilOutput(BiEncoderOutput): 35 """Dataclass containing the output of a COIL model.""" 36 37 query_embeddings: CoilEmbedding | None = None 38 """Query embeddings generated by the model.""" 39 doc_embeddings: CoilEmbedding | None = None 40 """Document embeddings generated by the model."""
41 42
[docs] 43class CoilConfig(MultiVectorBiEncoderConfig): 44 """Configuration class for COIL models.""" 45 46 model_type = "coil" 47 """Model type for COIL models.""" 48
[docs] 49 def __init__( 50 self, 51 query_length: int | None = 32, 52 doc_length: int | None = 512, 53 similarity_function: Literal["cosine", "dot"] = "dot", 54 add_marker_tokens: bool = False, 55 token_embedding_dim: int = 32, 56 cls_embedding_dim: int = 768, 57 projection: Literal["linear", "linear_no_bias"] = "linear", 58 **kwargs, 59 ) -> None: 60 """A COIL model encodes queries and documents separately, and computes a similarity score using the maximum 61 similarity ... 62 63 Args: 64 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 65 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 66 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 67 document embeddings. Defaults to "dot". 68 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 69 Defaults to False. 70 token_embedding_dim (int | None): The output embedding dimension for tokens. Defaults to 32. 71 cls_embedding_dim (int | None): The output embedding dimension for the [CLS] token. Defaults to 768. 72 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings. 73 Defaults to "linear". 74 """ 75 super().__init__( 76 query_length=query_length, 77 doc_length=doc_length, 78 similarity_function=similarity_function, 79 add_marker_tokens=add_marker_tokens, 80 **kwargs, 81 ) 82 self.projection = projection 83 self.token_embedding_dim = token_embedding_dim 84 self.cls_embedding_dim = cls_embedding_dim
85 86
[docs] 87class CoilModel(MultiVectorBiEncoderModel): 88 """Multi-vector COIL model. See :class:`.CoilConfig` for configuration options.""" 89 90 config_class = CoilConfig 91 """Configuration class for COIL models.""" 92
[docs] 93 def __init__(self, config: CoilConfig, *args, **kwargs) -> None: 94 """Initializes a COIL model given a :class:`.CoilConfig` configuration. 95 96 Args: 97 config (CoilConfig): Configuration for the COIL model. 98 """ 99 super().__init__(config, *args, **kwargs) 100 self.config: CoilConfig 101 self.token_projection = torch.nn.Linear( 102 self.config.hidden_size, 103 self.config.token_embedding_dim, 104 bias="no_bias" not in self.config.projection, 105 ) 106 self.cls_projection = torch.nn.Linear( 107 self.config.hidden_size, 108 self.config.cls_embedding_dim, 109 bias="no_bias" not in self.config.projection, 110 )
111
[docs] 112 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> CoilEmbedding: 113 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 114 115 Args: 116 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 117 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 118 Returns: 119 BiEncoderEmbedding: Embeddings and scoring mask. 120 """ 121 embeddings = self._backbone_forward(**encoding).last_hidden_state 122 123 cls_embeddings = self.cls_projection(embeddings[:, [0]]) 124 token_embeddings = self.token_projection(embeddings[:, 1:]) 125 126 scoring_mask = self.scoring_mask(encoding, input_type) 127 return CoilEmbedding( 128 embeddings, scoring_mask, encoding, cls_embeddings=cls_embeddings, token_embeddings=token_embeddings 129 )
130
[docs] 131 def score( 132 self, 133 output: CoilOutput, 134 num_docs: Sequence[int] | int | None = None, 135 ) -> CoilOutput: 136 """Compute relevance scores between queries and documents. 137 138 Args: 139 query_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the queries. 140 doc_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the documents. 141 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 142 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 143 number of documents, i.e., the sequence contains one value per query specifying the number of documents 144 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 145 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 146 Returns: 147 torch.Tensor: Relevance scores.""" 148 query_embeddings = output.query_embeddings 149 doc_embeddings = output.doc_embeddings 150 if query_embeddings is None or doc_embeddings is None: 151 raise ValueError("Query and document embeddings must be provided for scoring.") 152 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None: 153 raise ValueError("Scoring masks expected for scoring multi-vector embeddings") 154 if ( 155 query_embeddings.cls_embeddings is None 156 or doc_embeddings.cls_embeddings is None 157 or query_embeddings.token_embeddings is None 158 or doc_embeddings.token_embeddings is None 159 ): 160 raise ValueError("COIL embeddings must contain cls_embeddings and token_embeddings") 161 162 cls_scores = self.compute_similarity( 163 BiEncoderEmbedding(query_embeddings.cls_embeddings), 164 BiEncoderEmbedding(doc_embeddings.cls_embeddings), 165 num_docs, 166 ).view(-1) 167 168 token_similarity = self.compute_similarity( 169 BiEncoderEmbedding(query_embeddings.token_embeddings), 170 BiEncoderEmbedding(doc_embeddings.token_embeddings), 171 num_docs, 172 ) 173 num_docs_t = self._parse_num_docs( 174 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device 175 ) 176 query = query_embeddings.encoding.input_ids.repeat_interleave(num_docs_t, 0)[:, 1:] 177 docs = doc_embeddings.encoding.input_ids[:, 1:] 178 mask = (query[:, :, None] == docs[:, None, :]).to(token_similarity) 179 token_similarity = token_similarity * mask 180 token_scores = self.aggregate_similarity( 181 token_similarity, query_embeddings.scoring_mask[:, 1:], doc_embeddings.scoring_mask[:, 1:], num_docs 182 ) 183 184 output.scores = cls_scores + token_scores 185 return output
186 187
[docs] 188class UniCoilConfig(SingleVectorBiEncoderConfig): 189 """Configuration class for UniCOIL models.""" 190 191 model_type = "unicoil" 192 """Model type for UniCOIL models.""" 193
[docs] 194 def __init__( 195 self, 196 query_length: int | None = 32, 197 doc_length: int | None = 512, 198 similarity_function: Literal["cosine", "dot"] = "dot", 199 projection: Literal["linear", "linear_no_bias"] = "linear", 200 **kwargs, 201 ) -> None: 202 """A UniCOIL model encodes queries and documents separately, and computes a similarity score using the maximum 203 similarity of token embeddings between query and document. 204 205 Args: 206 query_length (int | None): Maximum query length in number of tokens. Defaults to 32. 207 doc_length (int | None): Maximum document length in number of tokens. Defaults to 512. 208 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 209 document embeddings. Defaults to "dot". 210 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings. 211 Defaults to "linear". 212 """ 213 super().__init__( 214 query_length=query_length, 215 doc_length=doc_length, 216 similarity_function=similarity_function, 217 **kwargs, 218 ) 219 self.projection = projection
220 221 @property 222 def embedding_dim(self) -> int: 223 vocab_size = getattr(self, "vocab_size", None) 224 if vocab_size is None: 225 raise ValueError("Unable to determine embedding dimension.") 226 return vocab_size 227 228 @embedding_dim.setter 229 def embedding_dim(self, value: int) -> None: 230 pass
231 232
[docs] 233class UniCoilModel(SingleVectorBiEncoderModel): 234 """Single-vector UniCOIL model. See :class:`.UniCoilConfig` for configuration options.""" 235 236 config_class = UniCoilConfig 237 """Configuration class for UniCOIL models.""" 238
[docs] 239 def __init__(self, config: UniCoilConfig, *args, **kwargs) -> None: 240 """Initializes a UniCOIL model given a :class:`.UniCoilConfig` configuration. 241 242 Args: 243 config (UniCoilConfig): Configuration for the UniCOIL model. 244 """ 245 super().__init__(config, *args, **kwargs) 246 self.config: UniCoilConfig 247 self.token_projection = torch.nn.Linear( 248 self.config.hidden_size, 1, bias="no_bias" not in self.config.projection 249 )
250
[docs] 251 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 252 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 253 254 Args: 255 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 256 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 257 Returns: 258 BiEncoderEmbedding: Embeddings and scoring mask. 259 """ 260 contextualized_embeddings = self._backbone_forward(**encoding).last_hidden_state 261 262 token_weights = self.token_projection(contextualized_embeddings).squeeze(-1) 263 if encoding["attention_mask"] is not None: 264 token_weights = token_weights.masked_fill(~(encoding["attention_mask"].bool()), 0) 265 token_weights = torch.relu(token_weights) 266 embeddings = torch.zeros(encoding.input_ids.shape[0], self.config.vocab_size, device=token_weights.device) 267 embeddings = embeddings.scatter(1, encoding.input_ids, token_weights) 268 return BiEncoderEmbedding(embeddings[:, None], None, encoding)