Source code for lightning_ir.models.bi_encoders.splade

  1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models. Originally proposed in
  2`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
  3<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
  4"""
  5
  6import warnings
  7from pathlib import Path
  8from typing import Literal, Self, Sequence
  9
 10import torch
 11from transformers import BatchEncoding
 12
 13from ...bi_encoder import (
 14    BiEncoderEmbedding,
 15    BiEncoderTokenizer,
 16    SingleVectorBiEncoderConfig,
 17    SingleVectorBiEncoderModel,
 18)
 19from ...modeling_utils.mlm_head import (
 20    MODEL_TYPE_TO_LM_HEAD,
 21    MODEL_TYPE_TO_OUTPUT_EMBEDDINGS,
 22    MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING,
 23    MODEL_TYPE_TO_TIED_WEIGHTS_KEYS,
 24)
 25
 26
[docs] 27class SpladeConfig(SingleVectorBiEncoderConfig): 28 """Configuration class for a SPLADE model.""" 29 30 model_type = "splade" 31 """Model type for a SPLADE model.""" 32
[docs] 33 def __init__( 34 self, 35 query_length: int | None = 32, 36 doc_length: int | None = 512, 37 similarity_function: Literal["cosine", "dot"] = "dot", 38 sparsification: Literal["relu", "relu_log", "relu_2xlog"] | None = "relu_log", 39 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max", 40 query_weighting: Literal["contextualized", "static"] | None = "contextualized", 41 query_expansion: bool = True, 42 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max", 43 doc_weighting: Literal["contextualized", "static"] | None = "contextualized", 44 doc_expansion: bool = True, 45 **kwargs, 46 ) -> None: 47 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the 48 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained 49 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single 50 embedding for the query and document. 51 52 Args: 53 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 54 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 55 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 56 document embeddings. Defaults to "dot". 57 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 58 function to apply. Defaults to None. 59 query_weighting (Literal["contextualized", "static"] | None): Whether to reweight query embeddings. 60 Defaults to "contextualized". 61 query_expansion (bool): Whether to allow query expansion. Defaults to True. 62 query_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings. 63 Defaults to "max". 64 doc_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for document embeddings. 65 Defaults to "max". 66 doc_weighting (Literal["contextualized", "static"] | None): Whether to reweight document embeddings. 67 Defaults to "contextualized". 68 doc_expansion (bool): Whether to allow document expansion. Defaults to True. 69 """ 70 super().__init__( 71 query_length=query_length, 72 doc_length=doc_length, 73 similarity_function=similarity_function, 74 sparsification=sparsification, 75 query_pooling_strategy=query_pooling_strategy, 76 doc_pooling_strategy=doc_pooling_strategy, 77 **kwargs, 78 ) 79 if query_expansion and not query_weighting: 80 raise ValueError("If query_expansion is True, query_weighting must also be True.") 81 if doc_expansion and not doc_weighting: 82 raise ValueError("If doc_expansion is True, doc_weighting must also be True.") 83 self.query_weighting = query_weighting 84 self.query_expansion = query_expansion 85 self.doc_weighting = doc_weighting 86 self.doc_expansion = doc_expansion
87 88 @property 89 def embedding_dim(self) -> int: 90 vocab_size = getattr(self, "vocab_size", None) 91 if vocab_size is None: 92 raise ValueError("Unable to determine embedding dimension.") 93 return vocab_size 94 95 @embedding_dim.setter 96 def embedding_dim(self, value: int) -> None: 97 pass
98 99
[docs] 100class SpladeModel(SingleVectorBiEncoderModel): 101 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options.""" 102 103 config_class = SpladeConfig 104 """Configuration class for a SPLADE model.""" 105
[docs] 106 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 107 """Initializes a SPLADE model given a :class:`SpladeConfig`. 108 109 Args: 110 config (SingleVectorBiEncoderConfig): Configuration for the SPLADE model. 111 """ 112 super().__init__(config, *args, **kwargs) 113 # grab language modeling head based on backbone model type 114 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type] 115 self.projection = layer_cls(config) 116 tied_weight_keys = getattr(self, "_tied_weights_keys", []) or [] 117 tied_weight_keys = tied_weight_keys + [ 118 f"projection.{key}" 119 for key in MODEL_TYPE_TO_TIED_WEIGHTS_KEYS[config.backbone_model_type or config.model_type] 120 ] 121 setattr(self, "_tied_weights_keys", tied_weight_keys) 122 self.query_weights = None 123 if config.query_weighting == "static": 124 self.query_weights = torch.nn.Embedding(config.vocab_size, 1) 125 self.doc_weights = None 126 if config.doc_weighting == "static": 127 self.doc_weights = torch.nn.Embedding(config.vocab_size, 1)
128
[docs] 129 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 130 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 131 132 Args: 133 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 134 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 135 Returns: 136 BiEncoderEmbedding: Embeddings and scoring mask. 137 """ 138 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy") 139 weighting = getattr(self.config, f"{input_type}_weighting") 140 expansion = getattr(self.config, f"{input_type}_expansion") 141 token_mask = None 142 if weighting is None or weighting == "static" or not expansion: 143 token_mask = torch.zeros( 144 encoding["input_ids"].shape[0], 145 self.config.embedding_dim, 146 device=encoding["input_ids"].device, 147 dtype=torch.float32, 148 ) 149 if weighting == "static": 150 weights = getattr(self, f"{input_type}_weights")(encoding["input_ids"]).squeeze(-1) 151 else: 152 weights = encoding["attention_mask"].to(token_mask) 153 weights = weights.masked_fill(~(encoding["attention_mask"].bool()), 0.0) 154 token_mask = token_mask.scatter(1, encoding["input_ids"], weights) 155 if weighting is None or weighting == "static": 156 # inference free 157 return BiEncoderEmbedding(token_mask[:, None], None, encoding) 158 159 embeddings = self._backbone_forward(**encoding).last_hidden_state 160 embeddings = self.projection(embeddings) 161 embeddings = self.sparsification(embeddings, self.config.sparsification) 162 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy) 163 if token_mask is not None: 164 embeddings = embeddings * token_mask[:, None] 165 return BiEncoderEmbedding(embeddings, None, encoding)
166
[docs] 167 @classmethod 168 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self: 169 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps 170 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel. 171 See :class:`LightningIRModelClassFactory` for more details. 172 173.. _transformers.PreTrainedModel.from_pretrained: \ 174 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 175 176 .. ::doctest 177 .. highlight:: python 178 .. code-block:: python 179 180 >>> # Loading using model class and backbone checkpoint 181 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 182 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 183 >>> # Loading using base class and backbone checkpoint 184 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 185 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 186 187 Args: 188 model_name_or_path (str | Path): Name or path of the pretrained model. 189 Returns: 190 Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin. 191 Raises: 192 ValueError: If called on the abstract class :class:`SpladeModel` and no config is passed. 193 """ 194 key_mapping = kwargs.pop("key_mapping", {}) 195 config = cls.config_class 196 # map mlm projection keys 197 model_type = config.backbone_model_type or config.model_type 198 if model_type in MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING: 199 key_mapping.update(MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING[model_type]) 200 if not key_mapping: 201 warnings.warn( 202 f"No mlm key mappings for model_type {model_type} were provided. " 203 "The pre-trained mlm weights will not be loaded correctly." 204 ) 205 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 206 return model
207
[docs] 208 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None: 209 if self.config.projection == "mlm": 210 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
211 # TODO fix this (not super important, only necessary when additional tokens are added to the model) 212 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type] 213 # module = self 214 # for module_name in module_names.split(".")[:-1]: 215 # module = getattr(module, module_name) 216 # setattr(module, module_names.split(".")[-1], new_embeddings) 217 # setattr(module, "bias", new_embeddings.bias) 218
[docs] 219 def get_output_embeddings(self) -> torch.nn.Module | None: 220 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no 221 MLM head is used for projection. 222 223 Returns: 224 torch.nn.Module | None: Output embeddings of the model. 225 """ 226 module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type] 227 output = self.projection 228 for module_name in module_names.split("."): 229 output = getattr(output, module_name) 230 return output
231 232
[docs] 233class SpladeTokenizer(BiEncoderTokenizer): 234 """Tokenizer class for SPLADE models.""" 235 236 config_class = SpladeConfig 237 """Configuration class for a SPLADE model.""" 238
[docs] 239 def __init__( 240 self, 241 *args, 242 query_length: int | None = 32, 243 doc_length: int | None = 512, 244 add_marker_tokens: bool = False, 245 query_weighting: Literal["contextualized", "static"] = "contextualized", 246 doc_weighting: Literal["contextualized", "static"] = "contextualized", 247 **kwargs, 248 ): 249 super().__init__( 250 *args, query_length=query_length, doc_length=doc_length, add_marker_tokens=add_marker_tokens, **kwargs 251 ) 252 """Initializes a SPLADE model's tokenizer. Encodes queries and documents separately. Optionally adds 253 marker tokens to encoded input sequences. 254 255 Args: 256 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 257 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 258 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 259 Defaults to False. 260 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 261 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 262 tokens. Defaults to False. 263 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 264 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 265 tokens. Defaults to False. 266 query_weighting (Literal["contextualized", "static"]): Whether to apply weighting to query tokens. 267 Defaults to "contextualized". 268 doc_weighting (Literal["contextualized", "static"]): Whether to apply weighting to document tokens. 269 Defaults to "contextualized". 270 Raises: 271 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 272 """ 273 super().__init__( 274 *args, 275 query_length=query_length, 276 doc_length=doc_length, 277 add_marker_tokens=add_marker_tokens, 278 query_weighting=query_weighting, 279 doc_weighting=doc_weighting, 280 **kwargs, 281 ) 282 self.query_weighting = query_weighting 283 self.doc_weighting = doc_weighting
284
[docs] 285 def tokenize_input_sequence( 286 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 287 ) -> BatchEncoding: 288 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 289 290 Args: 291 text (Sequence[str] | str): Input text to tokenize. 292 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 293 Returns: 294 BatchEncoding: Tokenized input sequences. 295 """ 296 post_processer = getattr(self, f"{input_type}_post_processor") 297 kwargs["max_length"] = getattr(self, f"{input_type}_length") 298 weighting = getattr(self, f"{input_type}_weighting") 299 if weighting is None or weighting == "static": 300 kwargs["add_special_tokens"] = False 301 if "padding" not in kwargs: 302 kwargs["truncation"] = True 303 return self._encode(text, *args, post_processor=post_processer, **kwargs)