Source code for lightning_ir.models.coil

  1"""Configuration, model, and embedding for COIL (Contextualized Inverted List) type models. Originally proposed in
  2`COIL: Revisit Exact Lexical Match in Information Retrieval with Contextualized Inverted List \
  3<https://arxiv.org/abs/2104.07186>`_."""
  4
  5from dataclasses import dataclass
  6from typing import Literal, Sequence
  7
  8import torch
  9from transformers import BatchEncoding
 10
 11from ..bi_encoder import BiEncoderEmbedding, BiEncoderOutput, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
 12
 13
[docs] 14@dataclass 15class CoilEmbedding(BiEncoderEmbedding): 16 """Dataclass containing embeddings and the encoding for COIL models.""" 17 18 embeddings: torch.Tensor 19 """Raw embeddings of the COIL model. Should not be used directly for scoring.""" 20 token_embeddings: torch.Tensor | None = None 21 """Token embeddings of the COIL model.""" 22 cls_embeddings: torch.Tensor | None = None 23 """Separate [CLS] token embeddings."""
24 25
[docs] 26@dataclass 27class CoilOutput(BiEncoderOutput): 28 """Dataclass containing the output of a COIL model.""" 29 30 query_embeddings: CoilEmbedding | None = None 31 """Query embeddings generated by the model.""" 32 doc_embeddings: CoilEmbedding | None = None 33 """Document embeddings generated by the model."""
34 35
[docs] 36class CoilConfig(MultiVectorBiEncoderConfig): 37 """Configuration class for COIL models.""" 38 39 model_type = "coil" 40 """Model type for COIL models.""" 41
[docs] 42 def __init__( 43 self, 44 query_length: int = 32, 45 doc_length: int = 512, 46 similarity_function: Literal["cosine", "dot"] = "dot", 47 normalize: bool = False, 48 add_marker_tokens: bool = False, 49 token_embedding_dim: int = 32, 50 cls_embedding_dim: int = 768, 51 projection: Literal["linear", "linear_no_bias"] = "linear", 52 **kwargs, 53 ) -> None: 54 """A COIL model encodes queries and documents separately, and computes a similarity score using the maximum 55 similarity ... 56 57 Args: 58 query_length (int, optional): Maximum query length in number of tokens. Defaults to 32. 59 doc_length (int, optional): Maximum document length in number of tokens. Defaults to 512. 60 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 61 document embeddings. Defaults to "dot". 62 normalize (bool): Whether to normalize query and document embeddings. Defaults to False. 63 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 64 Defaults to False. 65 token_embedding_dim (int, optional): The output embedding dimension for tokens. Defaults to 32. 66 cls_embedding_dim (int, optional): The output embedding dimension for the [CLS] token. Defaults to 768. 67 projection (Literal["linear", "linear_no_bias"], optional): Whether and how to project the embeddings. 68 Defaults to "linear". 69 """ 70 super().__init__( 71 query_length=query_length, 72 doc_length=doc_length, 73 similarity_function=similarity_function, 74 normalize=normalize, 75 add_marker_tokens=add_marker_tokens, 76 **kwargs, 77 ) 78 self.projection = projection 79 self.token_embedding_dim = token_embedding_dim 80 self.cls_embedding_dim = cls_embedding_dim
81 82
[docs] 83class CoilModel(MultiVectorBiEncoderModel): 84 """Multi-vector COIL model. See :class:`.CoilConfig` for configuration options.""" 85 86 config_class = CoilConfig 87 """Configuration class for COIL models.""" 88
[docs] 89 def __init__(self, config: CoilConfig, *args, **kwargs) -> None: 90 """Initializes a COIL model given a :class:`.CoilConfig` configuration. 91 92 Args: 93 config (CoilConfig): Configuration for the COIL model. 94 """ 95 super().__init__(config, *args, **kwargs) 96 self.config: CoilConfig 97 self.token_projection = torch.nn.Linear( 98 self.config.hidden_size, 99 self.config.token_embedding_dim, 100 bias="no_bias" not in self.config.projection, 101 ) 102 self.cls_projection = torch.nn.Linear( 103 self.config.hidden_size, 104 self.config.cls_embedding_dim, 105 bias="no_bias" not in self.config.projection, 106 )
107
[docs] 108 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> CoilEmbedding: 109 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 110 111 Args: 112 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 113 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 114 Returns: 115 BiEncoderEmbedding: Embeddings and scoring mask. 116 """ 117 embeddings = self._backbone_forward(**encoding).last_hidden_state 118 119 cls_embeddings = self.cls_projection(embeddings[:, [0]]) 120 token_embeddings = self.token_projection(embeddings[:, 1:]) 121 122 if self.config.normalize: 123 cls_embeddings = torch.nn.functional.normalize(cls_embeddings, dim=-1) 124 token_embeddings = torch.nn.functional.normalize(token_embeddings, dim=-1) 125 126 scoring_mask = self.scoring_mask(encoding, input_type) 127 return CoilEmbedding( 128 embeddings, scoring_mask, encoding, cls_embeddings=cls_embeddings, token_embeddings=token_embeddings 129 )
130
[docs] 131 def score( 132 self, 133 output: CoilOutput, 134 num_docs: Sequence[int] | int | None = None, 135 ) -> CoilOutput: 136 """Compute relevance scores between queries and documents. 137 138 Args: 139 query_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the queries. 140 doc_embeddings (CoilEmbedding): CLS embeddings, token embeddings, and scoring mask for the documents. 141 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 142 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 143 number of documents, i.e., the sequence contains one value per query specifying the number of documents 144 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 145 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 146 Returns: 147 torch.Tensor: Relevance scores.""" 148 query_embeddings = output.query_embeddings 149 doc_embeddings = output.doc_embeddings 150 if query_embeddings is None or doc_embeddings is None: 151 raise ValueError("Query and document embeddings must be provided for scoring.") 152 if query_embeddings.scoring_mask is None or doc_embeddings.scoring_mask is None: 153 raise ValueError("Scoring masks expected for scoring multi-vector embeddings") 154 if ( 155 query_embeddings.cls_embeddings is None 156 or doc_embeddings.cls_embeddings is None 157 or query_embeddings.token_embeddings is None 158 or doc_embeddings.token_embeddings is None 159 ): 160 raise ValueError("COIL embeddings must contain cls_embeddings and token_embeddings") 161 162 cls_scores = self.compute_similarity( 163 BiEncoderEmbedding(query_embeddings.cls_embeddings), 164 BiEncoderEmbedding(doc_embeddings.cls_embeddings), 165 num_docs, 166 ).view(-1) 167 168 token_similarity = self.compute_similarity( 169 BiEncoderEmbedding(query_embeddings.token_embeddings), 170 BiEncoderEmbedding(doc_embeddings.token_embeddings), 171 num_docs, 172 ) 173 num_docs_t = self._parse_num_docs( 174 query_embeddings.embeddings.shape[0], doc_embeddings.embeddings.shape[0], num_docs, query_embeddings.device 175 ) 176 query = query_embeddings.encoding.input_ids.repeat_interleave(num_docs_t, 0)[:, 1:] 177 docs = doc_embeddings.encoding.input_ids[:, 1:] 178 mask = (query[:, :, None] == docs[:, None, :]).to(token_similarity) 179 token_similarity = token_similarity * mask 180 token_scores = self.aggregate_similarity( 181 token_similarity, query_embeddings.scoring_mask[:, 1:], doc_embeddings.scoring_mask[:, 1:], num_docs 182 ) 183 184 output.scores = cls_scores + token_scores 185 return output