Source code for lightning_ir.models.col

  1"""Configuration, model, and tokenizer for Col (Contextualized Late Interaction) type models. Originally proposed in
  2`ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT \
  3<https://dl.acm.org/doi/abs/10.1145/3397271.3401075>`_ as the ColBERT model. This implementation generalizes the model
  4to work with any transformer backbone model.
  5"""
  6
  7from typing import Literal, Sequence
  8
  9import torch
 10from transformers import BatchEncoding
 11
 12from ..bi_encoder import BiEncoderEmbedding, BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
 13
 14
[docs] 15class ColConfig(MultiVectorBiEncoderConfig): 16 """Configuration class for a Col model.""" 17 18 model_type = "col" 19 """Model type for a Col model.""" 20
[docs] 21 def __init__( 22 self, 23 query_length: int = 32, 24 doc_length: int = 512, 25 similarity_function: Literal["cosine", "dot"] = "dot", 26 normalize: bool = False, 27 add_marker_tokens: bool = False, 28 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 29 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 30 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum", 31 doc_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "max", 32 embedding_dim: int = 128, 33 projection: Literal["linear", "linear_no_bias"] = "linear", 34 query_expansion: bool = False, 35 attend_to_query_expanded_tokens: bool = False, 36 doc_expansion: bool = False, 37 attend_to_doc_expanded_tokens: bool = False, 38 **kwargs, 39 ): 40 """A Col model encodes queries and documents separately and computes a late interaction score between the query 41 and document embeddings. The aggregation behavior of the late-interaction function can be parameterized with 42 the `aggregation_function` arguments. The dimensionality of the token embeddings is down-projected using a 43 linear layer. Queries and documents can optionally be expanded with mask tokens. Optionally, a set of tokens can 44 be ignored during scoring. 45 46 :param query_length: Maximum query length, defaults to 32 47 :type query_length: int, optional 48 :param doc_length: Maximum document length, defaults to 512 49 :type doc_length: int, optional 50 :param similarity_function: Similarity function to compute scores between query and document embeddings, 51 defaults to "dot" 52 :type similarity_function: Literal['cosine', 'dot'], optional 53 :param normalize: Whether to normalize query and document embeddings, defaults to False 54 :type normalize: bool, optional 55 :param add_marker_tokens: Whether to add extra marker tokens [Q] / [D] to queries / documents, defaults to False 56 :type add_marker_tokens: bool, optional 57 :param query_mask_scoring_tokens: Whether and which query tokens to ignore during scoring, defaults to None 58 :type query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional 59 :param doc_mask_scoring_tokens: Whether and which document tokens to ignore during scoring, defaults to None 60 :type doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional 61 :param query_aggregation_function: How to aggregate similarity scores over query tokens, defaults to "sum" 62 :type query_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional 63 :param doc_aggregation_function: How to aggregate similarity scores over document tokens, defaults to "max" 64 :type doc_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional 65 :param embedding_dim: The output embedding dimension, defaults to 768 66 :type embedding_dim: int, optional 67 :param projection: Whether and how to project the output emeddings, defaults to "linear" 68 :type projection: Literal['linear', 'linear_no_bias', 'mlm'] | None, optional 69 :param query_expansion: Whether to expand queries with mask tokens, defaults to False 70 :type query_expansion: bool, optional 71 :param attend_to_query_expanded_tokens: Whether to allow query tokens to attend to mask tokens, 72 defaults to False 73 :type attend_to_query_expanded_tokens: bool, optional 74 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False 75 :type doc_expansion: bool, optional 76 :param attend_to_doc_expanded_tokens: Whether to allow document tokens to attend to mask tokens, 77 defaults to False 78 :type attend_to_doc_expanded_tokens: bool, optional 79 """ 80 super().__init__( 81 query_length=query_length, 82 doc_length=doc_length, 83 similarity_function=similarity_function, 84 add_marker_tokens=add_marker_tokens, 85 query_mask_scoring_tokens=query_mask_scoring_tokens, 86 doc_mask_scoring_tokens=doc_mask_scoring_tokens, 87 query_aggregation_function=query_aggregation_function, 88 doc_aggregation_function=doc_aggregation_function, 89 **kwargs, 90 ) 91 self.embedding_dim = embedding_dim 92 self.projection = projection 93 self.query_expansion = query_expansion 94 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 95 self.doc_expansion = doc_expansion 96 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens 97 self.normalize = normalize 98 self.add_marker_tokens = add_marker_tokens
99 100
[docs] 101class ColModel(MultiVectorBiEncoderModel): 102 """Multi-vector late-interaction Col model. See :class:`.ColConfig` for configuration options.""" 103 104 config_class = ColConfig 105 """Configuration class for the Col model.""" 106
[docs] 107 def __init__(self, config: ColConfig, *args, **kwargs) -> None: 108 """Initializes a Col model given a :class:`.ColConfig`. 109 110 :param config: Configuration for the Col model 111 :type config: ColConfig 112 """ 113 super().__init__(config, *args, **kwargs) 114 if config.embedding_dim is None: 115 raise ValueError("Embedding dimension must be specified in the configuration.") 116 self.projection = torch.nn.Linear( 117 config.hidden_size, config.embedding_dim, bias="no_bias" not in config.projection 118 )
119
[docs] 120 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 121 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 122 out vectors during scoring. 123 124 :param encoding: Tokenizer encodings for the text sequence 125 :type encoding: BatchEncoding 126 :param input_type: Type of input, either "query" or "doc" 127 :type input_type: Literal["query", "doc"] 128 :return: Scoring mask 129 :rtype: torch.Tensor 130 """ 131 input_ids = encoding["input_ids"] 132 attention_mask = encoding["attention_mask"] 133 scoring_mask = attention_mask 134 expansion = getattr(self.config, f"{input_type}_expansion") 135 if expansion or scoring_mask is None: 136 scoring_mask = torch.ones_like(input_ids, dtype=torch.bool) 137 scoring_mask = scoring_mask.bool() 138 mask_scoring_input_ids = getattr(self, f"{input_type}_mask_scoring_input_ids") 139 if mask_scoring_input_ids is not None: 140 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(input_ids.device)).any(-1) 141 scoring_mask = scoring_mask & ~ignore_mask 142 return scoring_mask
143
[docs] 144 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 145 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 146 147 :param encoding: Tokenizer encodings for the text sequence 148 :type encoding: BatchEncoding 149 :param input_type: Type of input, either "query" or "doc" 150 :type input_type: Literal["query", "doc"] 151 :return: Embeddings and scoring mask 152 :rtype: BiEncoderEmbedding 153 """ 154 embeddings = self._backbone_forward(**encoding).last_hidden_state 155 embeddings = self.projection(embeddings) 156 if self.config.normalize: 157 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 158 scoring_mask = self.scoring_mask(encoding, input_type) 159 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
160 161
[docs] 162class ColTokenizer(BiEncoderTokenizer): 163 """:class:`.LightningIRTokenizer` for Col models.""" 164 165 config_class = ColConfig 166 """Configuration class for the tokenizer.""" 167
[docs] 168 def __init__( 169 self, 170 *args, 171 query_length: int = 32, 172 doc_length: int = 512, 173 add_marker_tokens: bool = False, 174 query_expansion: bool = False, 175 attend_to_query_expanded_tokens: bool = False, 176 doc_expansion: bool = False, 177 attend_to_doc_expanded_tokens: bool = False, 178 **kwargs, 179 ): 180 """Initializes a Col model's tokenizer. Encodes queries and documents separately. Optionally adds marker tokens 181 to encoded input sequences and expands queries and documents with mask tokens. 182 183 :param query_length: Maximum query length in number of tokens, defaults to 32 184 :type query_length: int, optional 185 :param doc_length: Maximum document length in number of tokens, defaults to 512 186 :type doc_length: int, optional 187 :param add_marker_tokens: Whether to add marker tokens to the query and document input sequences, 188 defaults to False 189 :type add_marker_tokens: bool, optional 190 :param query_expansion: Whether to expand queries with mask tokens, defaults to False 191 :type query_expansion: bool, optional 192 :param attend_to_query_expanded_tokens: Whether to let non-expanded query tokens be able to attend to mask 193 expanded query tokens, defaults to False 194 :type attend_to_query_expanded_tokens: bool, optional 195 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False 196 :type doc_expansion: bool, optional 197 :param attend_to_doc_expanded_tokens: Whether to let non-expanded document tokens be able to attend to 198 mask expanded document tokens, defaults to False 199 :type attend_to_doc_expanded_tokens: bool, optional 200 :raises ValueError: If add_marker_tokens is True and a non-supported tokenizer is used 201 """ 202 super().__init__( 203 *args, 204 query_length=query_length, 205 doc_length=doc_length, 206 query_expansion=query_expansion, 207 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens, 208 doc_expansion=doc_expansion, 209 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens, 210 add_marker_tokens=add_marker_tokens, 211 **kwargs, 212 ) 213 self.query_expansion = query_expansion 214 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 215 self.doc_expansion = doc_expansion 216 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
217 218 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 219 """Applies mask expansion to the input encoding.""" 220 input_ids = encoding["input_ids"] 221 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 222 encoding["input_ids"] = input_ids 223 if attend_to_expanded_tokens: 224 encoding["attention_mask"].fill_(1) 225 return encoding 226
[docs] 227 def tokenize_input_sequence( 228 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 229 ) -> BatchEncoding: 230 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 231 232 :param queries: Single string or multiple strings to tokenize 233 :type queries: Sequence[str] | str 234 :return: Tokenized input sequences 235 :rtype: BatchEncoding 236 """ 237 expansion = getattr(self, f"{input_type}_expansion") 238 if expansion: 239 kwargs["padding"] = "max_length" 240 encoding = super().tokenize_input_sequence(text, input_type, *args, **kwargs) 241 if expansion: 242 encoding = self._expand(encoding, getattr(self, f"attend_to_{input_type}_expanded_tokens")) 243 return encoding