Source code for lightning_ir.models.bi_encoders.col

  1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models. Originally proposed in
  2`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
  3<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
  4to work with any transformer backbone model.
  5"""
  6
  7from typing import Literal, Sequence
  8
  9import torch
 10from transformers import BatchEncoding
 11
 12from ...bi_encoder import BiEncoderEmbedding, BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
 13
 14
[docs] 15class ColConfig(MultiVectorBiEncoderConfig): 16 """Configuration class for a Col model.""" 17 18 model_type = "col" 19 """Model type for a Col model.""" 20
[docs] 21 def __init__( 22 self, 23 query_length: int | None = 32, 24 doc_length: int | None = 512, 25 similarity_function: Literal["cosine", "dot"] = "dot", 26 normalization: Literal["l2"] | None = None, 27 add_marker_tokens: bool = False, 28 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 29 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 30 query_aggregation_function: Literal["sum", "mean", "max"] = "sum", 31 doc_aggregation_function: Literal["sum", "mean", "max"] = "max", 32 embedding_dim: int = 128, 33 projection: Literal["linear", "linear_no_bias"] = "linear", 34 query_expansion: bool = False, 35 attend_to_query_expanded_tokens: bool = False, 36 doc_expansion: bool = False, 37 attend_to_doc_expanded_tokens: bool = False, 38 **kwargs, 39 ): 40 """A Col model encodes queries and documents separately and computes a late interaction score between the query 41 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with 42 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a 43 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can 44 be ignored during scoring. 45 46 Args: 47 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 48 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 49 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 50 document embeddings. Defaults to "dot". 51 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None. 52 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 53 Defaults to False. 54 query_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which query tokens 55 to ignore during scoring. Defaults to None. 56 doc_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which document tokens 57 to ignore during scoring. Defaults to None. 58 query_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate 59 similarity scores over query tokens. Defaults to "sum". 60 doc_aggregation_function (Literal["sum", "mean", "max"]): How to aggregate 61 similarity scores over document tokens. Defaults to "max". 62 embedding_dim (int): The output embedding dimension. Defaults to 128. 63 projection (Literal["linear", "linear_no_bias"]): Whether and how to project the output embeddings. 64 Defaults to "linear". If set to "linear_no_bias", the projection layer will not have a bias term. 65 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 66 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 67 tokens. Defaults to False. 68 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 69 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 70 tokens. Defaults to False. 71 """ 72 super().__init__( 73 query_length=query_length, 74 doc_length=doc_length, 75 similarity_function=similarity_function, 76 normalization=normalization, 77 add_marker_tokens=add_marker_tokens, 78 query_mask_scoring_tokens=query_mask_scoring_tokens, 79 doc_mask_scoring_tokens=doc_mask_scoring_tokens, 80 query_aggregation_function=query_aggregation_function, 81 doc_aggregation_function=doc_aggregation_function, 82 **kwargs, 83 ) 84 self.embedding_dim = embedding_dim 85 self.projection = projection 86 self.query_expansion = query_expansion 87 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 88 self.doc_expansion = doc_expansion 89 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
90 91
[docs] 92class ColModel(MultiVectorBiEncoderModel): 93 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options.""" 94 95 config_class = ColConfig 96 """Configuration class for the Col model.""" 97
[docs] 98 def __init__(self, config: ColConfig, *args, **kwargs) -> None: 99 """Initializes a Col model given a :class:`.ColConfig`. 100 101 Args: 102 config (ColConfig): Configuration for the Col model. 103 Raises: 104 ValueError: If the embedding dimension is not specified in the configuration. 105 """ 106 super().__init__(config, *args, **kwargs) 107 if config.embedding_dim is None: 108 raise ValueError("Embedding dimension must be specified in the configuration.") 109 self.projection = torch.nn.Linear( 110 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection 111 )
112
[docs] 113 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 114 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 115 out vectors during scoring. 116 117 Args: 118 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 119 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 120 Returns: 121 torch.Tensor: Scoring mask. 122 """ 123 input_ids = encoding["input_ids"] 124 attention_mask = encoding["attention_mask"] 125 scoring_mask = attention_mask 126 expansion = getattr(self.config, f"{input_type}_expansion") 127 if expansion or scoring_mask is None: 128 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool) 129 scoring_mask = scoring_mask.bool() 130 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids") 131 if mask_scoring_input_ids is not None: 132 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1) 133 scoring_mask = scoring_mask & ~ignore_mask 134 return scoring_mask
135
[docs] 136 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 137 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 138 139 Args: 140 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 141 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 142 Returns: 143 BiEncoderEmbedding: Embeddings and scoring mask. 144 """ 145 embeddings = self._backbone_forward(**encoding).last_hidden_state 146 embeddings = self.projection(embeddings) 147 if self.config.normalization == "l2": 148 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 149 scoring_mask = self.scoring_mask(encoding, input_type) 150 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
151 152
[docs] 153class ColTokenizer(BiEncoderTokenizer): 154 """:class:`.LightningIRTokenizer` for Col models.""" 155 156 config_class = ColConfig 157 """Configuration class for the tokenizer.""" 158
[docs] 159 def __init__( 160 self, 161 *args, 162 query_length: int | None = 32, 163 doc_length: int | None = 512, 164 add_marker_tokens: bool = False, 165 query_expansion: bool = False, 166 attend_to_query_expanded_tokens: bool = False, 167 doc_expansion: bool = False, 168 attend_to_doc_expanded_tokens: bool = False, 169 **kwargs, 170 ): 171 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens 172 to encoded input sequences and expands queries and documents with mask tokens. 173 174 Args: 175 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 176 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 177 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 178 Defaults to False. 179 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 180 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 181 tokens. Defaults to False. 182 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 183 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 184 tokens. Defaults to False. 185 Raises: 186 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 187 """ 188 super().__init__( 189 *args, 190 query_length=query_length, 191 doc_length=doc_length, 192 query_expansion=query_expansion, 193 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens, 194 doc_expansion=doc_expansion, 195 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens, 196 add_marker_tokens=add_marker_tokens, 197 **kwargs, 198 ) 199 self.query_expansion = query_expansion 200 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 201 self.doc_expansion = doc_expansion 202 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
203 204 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 205 """Applies mask expansion to the input encoding.""" 206 input_ids = encoding["input_ids"] 207 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 208 encoding["input_ids"] = input_ids 209 if attend_to_expanded_tokens: 210 encoding["attention_mask"].fill_(1) 211 return encoding 212
[docs] 213 def tokenize_input_sequence( 214 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 215 ) -> BatchEncoding: 216 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 217 218 Args: 219 text (Sequence[str] | str): Input text to tokenize. 220 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 221 Returns: 222 BatchEncoding: Tokenized input sequences. 223 """ 224 expansion = getattr(self, f"{input_type}_expansion") 225 if expansion: 226 kwargs["padding"] = "max_length" 227 encoding = super().tokenize_input_sequence(text, input_type, *args, **kwargs) 228 if expansion: 229 encoding = self._expand(encoding, getattr(self, f"attend_to_{input_type}_expanded_tokens")) 230 return encoding