Source code for lightning_ir.models.col

  1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models. Originally proposed in
  2`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
  3<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
  4to work with any transformer backbone model.
  5"""
  6
  7from typing import Literal, Sequence
  8
  9import torch
 10from transformers import BatchEncoding
 11
 12from ..bi_encoder import BiEncoderEmbedding, BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
 13
 14
[docs] 15class ColConfig(MultiVectorBiEncoderConfig): 16 """Configuration class for a Col model.""" 17 18 model_type = "col" 19 """Model type for a Col model.""" 20
[docs] 21 def __init__( 22 self, 23 query_length: int = 32, 24 doc_length: int = 512, 25 similarity_function: Literal["cosine", "dot"] = "dot", 26 normalize: bool = False, 27 add_marker_tokens: bool = False, 28 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 29 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 30 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum", 31 doc_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "max", 32 embedding_dim: int = 128, 33 projection: Literal["linear", "linear_no_bias"] = "linear", 34 query_expansion: bool = False, 35 attend_to_query_expanded_tokens: bool = False, 36 doc_expansion: bool = False, 37 attend_to_doc_expanded_tokens: bool = False, 38 **kwargs, 39 ): 40 """A Col model encodes queries and documents separately and computes a late interaction score between the query 41 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with 42 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a 43 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can 44 be ignored during scoring. 45 46 Args: 47 query_length (int): Maximum query length in number of tokens. Defaults to 32. 48 doc_length (int): Maximum document length in number of tokens. Defaults to 512. 49 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 50 document embeddings. Defaults to "dot". 51 normalize (bool): Whether to normalize query and document embeddings. Defaults to False. 52 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 53 Defaults to False. 54 query_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which query tokens 55 to ignore during scoring. Defaults to None. 56 doc_mask_scoring_tokens (Sequence[str] | Literal["punctuation"] | None): Whether and which document tokens 57 to ignore during scoring. Defaults to None. 58 query_aggregation_function (Literal["sum", "mean", "max", "harmonic_mean"]): How to aggregate 59 similarity scores over query tokens. Defaults to "sum". 60 doc_aggregation_function (Literal["sum", "mean", "max", "harmonic_mean"]): How to aggregate 61 similarity scores over document tokens. Defaults to "max". 62 embedding_dim (int): The output embedding dimension. Defaults to 128. 63 projection (Literal["linear", "linear_no_bias"]): Whether and how to project the output embeddings. 64 Defaults to "linear". If set to "linear_no_bias", the projection layer will not have a bias term. 65 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 66 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 67 tokens. Defaults to False. 68 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 69 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 70 tokens. Defaults to False. 71 """ 72 super().__init__( 73 query_length=query_length, 74 doc_length=doc_length, 75 similarity_function=similarity_function, 76 add_marker_tokens=add_marker_tokens, 77 query_mask_scoring_tokens=query_mask_scoring_tokens, 78 doc_mask_scoring_tokens=doc_mask_scoring_tokens, 79 query_aggregation_function=query_aggregation_function, 80 doc_aggregation_function=doc_aggregation_function, 81 **kwargs, 82 ) 83 self.embedding_dim = embedding_dim 84 self.projection = projection 85 self.query_expansion = query_expansion 86 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 87 self.doc_expansion = doc_expansion 88 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens 89 self.normalize = normalize 90 self.add_marker_tokens = add_marker_tokens
91 92
[docs] 93class ColModel(MultiVectorBiEncoderModel): 94 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options.""" 95 96 config_class = ColConfig 97 """Configuration class for the Col model.""" 98
[docs] 99 def __init__(self, config: ColConfig, *args, **kwargs) -> None: 100 """Initializes a Col model given a :class:`.ColConfig`. 101 102 Args: 103 config (ColConfig): Configuration for the Col model. 104 Raises: 105 ValueError: If the embedding dimension is not specified in the configuration. 106 """ 107 super().__init__(config, *args, **kwargs) 108 if config.embedding_dim is None: 109 raise ValueError("Embedding dimension must be specified in the configuration.") 110 self.projection = torch.nn.Linear( 111 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection 112 )
113
[docs] 114 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 115 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 116 out vectors during scoring. 117 118 Args: 119 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 120 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 121 Returns: 122 torch.Tensor: Scoring mask. 123 """ 124 input_ids = encoding["input_ids"] 125 attention_mask = encoding["attention_mask"] 126 scoring_mask = attention_mask 127 expansion = getattr(self.config, f"{input_type}_expansion") 128 if expansion or scoring_mask is None: 129 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool) 130 scoring_mask = scoring_mask.bool() 131 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids") 132 if mask_scoring_input_ids is not None: 133 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1) 134 scoring_mask = scoring_mask & ~ignore_mask 135 return scoring_mask
136
[docs] 137 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 138 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 139 140 Args: 141 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 142 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 143 Returns: 144 BiEncoderEmbedding: Embeddings and scoring mask. 145 """ 146 embeddings = self._backbone_forward(**encoding).last_hidden_state 147 embeddings = self.projection(embeddings) 148 if self.config.normalize: 149 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 150 scoring_mask = self.scoring_mask(encoding, input_type) 151 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
152 153
[docs] 154class ColTokenizer(BiEncoderTokenizer): 155 """:class:`.LightningIRTokenizer` for Col models.""" 156 157 config_class = ColConfig 158 """Configuration class for the tokenizer.""" 159
[docs] 160 def __init__( 161 self, 162 *args, 163 query_length: int = 32, 164 doc_length: int = 512, 165 add_marker_tokens: bool = False, 166 query_expansion: bool = False, 167 attend_to_query_expanded_tokens: bool = False, 168 doc_expansion: bool = False, 169 attend_to_doc_expanded_tokens: bool = False, 170 **kwargs, 171 ): 172 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens 173 to encoded input sequences and expands queries and documents with mask tokens. 174 175 Args: 176 query_length (int): Maximum query length in number of tokens. Defaults to 32. 177 doc_length (int): Maximum document length in number of tokens. Defaults to 512. 178 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 179 Defaults to False. 180 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 181 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 182 tokens. Defaults to False. 183 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 184 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 185 tokens. Defaults to False. 186 Raises: 187 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 188 """ 189 super().__init__( 190 *args, 191 query_length=query_length, 192 doc_length=doc_length, 193 query_expansion=query_expansion, 194 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens, 195 doc_expansion=doc_expansion, 196 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens, 197 add_marker_tokens=add_marker_tokens, 198 **kwargs, 199 ) 200 self.query_expansion = query_expansion 201 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 202 self.doc_expansion = doc_expansion 203 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
204 205 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 206 """Applies mask expansion to the input encoding.""" 207 input_ids = encoding["input_ids"] 208 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 209 encoding["input_ids"] = input_ids 210 if attend_to_expanded_tokens: 211 encoding["attention_mask"].fill_(1) 212 return encoding 213
[docs] 214 def tokenize_input_sequence( 215 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 216 ) -> BatchEncoding: 217 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 218 219 Args: 220 text (Sequence[str] | str): Input text to tokenize. 221 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 222 Returns: 223 BatchEncoding: Tokenized input sequences. 224 """ 225 expansion = getattr(self, f"{input_type}_expansion") 226 if expansion: 227 kwargs["padding"] = "max_length" 228 encoding = super().tokenize_input_sequence(text, input_type, *args, **kwargs) 229 if expansion: 230 encoding = self._expand(encoding, getattr(self, f"attend_to_{input_type}_expanded_tokens")) 231 return encoding