Source code for lightning_ir.models.bi_encoders.splade

  1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models.
  2
  3SPLADE (SParse Lexical AnD Expansion) is a efficient retrieval model that bridges the gap between traditional
  4keyword search and deep neural understanding. It uses a language model to analyze a document and assign importance
  5scores to words across the entire vocabulary, expanding the text with highly relevant terms that were not originally
  6present. This process creates a sparse, high-dimensional vector that can be stored and searched using inverted indices
  7just like traditional search engines.
  8
  9Originally proposed in
 10`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
 11<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
 12"""
 13
 14import warnings
 15from collections.abc import Sequence
 16from pathlib import Path
 17from typing import Literal, Self
 18
 19import torch
 20from transformers import BatchEncoding
 21
 22from ...bi_encoder import (
 23    BiEncoderEmbedding,
 24    BiEncoderTokenizer,
 25    SingleVectorBiEncoderConfig,
 26    SingleVectorBiEncoderModel,
 27)
 28from ...modeling_utils.embedding_post_processing import Pooler, Sparsifier
 29from ...modeling_utils.lm_head import MODEL_TYPE_TO_LM_HEAD, MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING
 30
 31
[docs] 32class SpladeConfig(SingleVectorBiEncoderConfig): 33 """Configuration class for a SPLADE model.""" 34 35 model_type = "splade" 36 """Model type for a SPLADE model.""" 37
[docs] 38 def __init__( 39 self, 40 query_length: int | None = 32, 41 doc_length: int | None = 512, 42 similarity_function: Literal["cosine", "dot"] = "dot", 43 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = "relu_log", 44 pooling_strategy: Literal["first", "mean", "max", "sum"] = "max", 45 query_weighting: Literal["contextualized", "static"] | None = "contextualized", 46 query_expansion: bool = True, 47 doc_weighting: Literal["contextualized", "static"] | None = "contextualized", 48 doc_expansion: bool = True, 49 **kwargs, 50 ) -> None: 51 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the 52 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained 53 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single 54 embedding for the query and document. 55 56 Args: 57 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 58 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 59 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 60 document embeddings. Defaults to "dot". 61 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 62 function to apply. Defaults to None. 63 pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings. 64 Defaults to "max". 65 query_weighting (Literal["contextualized", "static"] | None): Whether to reweight query embeddings. 66 Defaults to "contextualized". 67 query_expansion (bool): Whether to allow query expansion. Defaults to True. 68 doc_weighting (Literal["contextualized", "static"] | None): Whether to reweight document embeddings. 69 Defaults to "contextualized". 70 doc_expansion (bool): Whether to allow document expansion. Defaults to True. 71 """ 72 super().__init__( 73 query_length=query_length, 74 doc_length=doc_length, 75 similarity_function=similarity_function, 76 sparsification_strategy=sparsification_strategy, 77 pooling_strategy=pooling_strategy, 78 **kwargs, 79 ) 80 if query_expansion and not query_weighting: 81 raise ValueError("If query_expansion is True, query_weighting must also be True.") 82 if doc_expansion and not doc_weighting: 83 raise ValueError("If doc_expansion is True, doc_weighting must also be True.") 84 self.query_weighting = query_weighting 85 self.query_expansion = query_expansion 86 self.doc_weighting = doc_weighting 87 self.doc_expansion = doc_expansion
88 89 @property 90 def embedding_dim(self) -> int: 91 vocab_size = getattr(self, "vocab_size", None) 92 if vocab_size is None: 93 raise ValueError("Unable to determine embedding dimension.") 94 return vocab_size 95 96 @embedding_dim.setter 97 def embedding_dim(self, value: int) -> None: 98 pass
99 100
[docs] 101class SpladeModel(SingleVectorBiEncoderModel): 102 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options.""" 103 104 config_class = SpladeConfig 105 """Configuration class for a SPLADE model.""" 106
[docs] 107 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 108 """Initializes a SPLADE model given a :class:`SpladeConfig`. 109 110 Args: 111 config (SingleVectorBiEncoderConfig): Configuration for the SPLADE model. 112 """ 113 super().__init__(config, *args, **kwargs) 114 # grab language modeling head based on backbone model type 115 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type] 116 self.projection = layer_cls(config) 117 tied_weight_keys = (getattr(self, "_tied_weights_keys", []) or []) + ["projection.decoder.weight"] 118 self._tied_weights_keys = tied_weight_keys 119 self.query_weights = None 120 if config.query_weighting == "static": 121 self.query_weights = torch.nn.Embedding(config.vocab_size, 1) 122 self.doc_weights = None 123 if config.doc_weighting == "static": 124 self.doc_weights = torch.nn.Embedding(config.vocab_size, 1) 125 126 self.pooler = Pooler(config) 127 self.sparsifier = Sparsifier(config)
128
[docs] 129 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 130 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 131 132 Args: 133 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 134 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 135 Returns: 136 BiEncoderEmbedding: Embeddings and scoring mask. 137 """ 138 weighting = getattr(self.config, f"{input_type}_weighting") 139 expansion = getattr(self.config, f"{input_type}_expansion") 140 token_mask = None 141 if weighting is None or weighting == "static" or not expansion: 142 token_mask = torch.zeros( 143 encoding["input_ids"].shape[0], 144 self.config.embedding_dim, 145 device=encoding["input_ids"].device, 146 dtype=torch.float32, 147 ) 148 if weighting == "static": 149 weights = getattr(self, f"{input_type}_weights")(encoding["input_ids"]).squeeze(-1) 150 else: 151 weights = encoding["attention_mask"].to(token_mask) 152 weights = weights.masked_fill(~(encoding["attention_mask"].bool()), 0.0) 153 token_mask = token_mask.scatter(1, encoding["input_ids"], weights) 154 if weighting is None or weighting == "static": 155 # inference free 156 return BiEncoderEmbedding(token_mask[:, None], None, encoding) 157 158 embeddings = self._backbone_forward(**encoding).last_hidden_state 159 embeddings = self.projection(embeddings) 160 embeddings = self.sparsifier(embeddings) 161 embeddings = self.pooler(embeddings, encoding["attention_mask"]) 162 if token_mask is not None: 163 embeddings = embeddings * token_mask[:, None] 164 return BiEncoderEmbedding(embeddings, None, encoding)
165
[docs] 166 @classmethod 167 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self: 168 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps 169 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel. 170 See :class:`LightningIRModelClassFactory` for more details. 171 172.. _transformers.PreTrainedModel.from_pretrained: \ 173 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 174 175 .. ::doctest 176 .. highlight:: python 177 .. code-block:: python 178 179 >>> # Loading using model class and backbone checkpoint 180 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 181 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 182 >>> # Loading using base class and backbone checkpoint 183 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 184 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 185 186 Args: 187 model_name_or_path (str | Path): Name or path of the pretrained model. 188 Returns: 189 Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin. 190 Raises: 191 ValueError: If called on the abstract class :class:`SpladeModel` and no config is passed. 192 """ 193 key_mapping = kwargs.pop("key_mapping", {}) 194 config = cls.config_class 195 # map mlm projection keys 196 model_type = config.backbone_model_type or config.model_type 197 if model_type in MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING: 198 key_mapping.update(MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING[model_type]) 199 if not key_mapping: 200 warnings.warn( 201 f"No mlm key mappings for model_type {model_type} were provided. " 202 "The pre-trained mlm weights will not be loaded correctly.", 203 stacklevel=2, 204 ) 205 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 206 return model
207
[docs] 208 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None: 209 if self.config.projection == "mlm": 210 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
211 # TODO fix this (not super important, only necessary when additional tokens are added to the model) 212 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type] 213 # module = self 214 # for module_name in module_names.split(".")[:-1]: 215 # module = getattr(module, module_name) 216 # setattr(module, module_names.split(".")[-1], new_embeddings) 217 # setattr(module, "bias", new_embeddings.bias) 218
[docs] 219 def get_output_embeddings(self) -> torch.nn.Module | None: 220 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no 221 MLM head is used for projection. 222 223 Returns: 224 torch.nn.Module | None: Output embeddings of the model. 225 """ 226 return self.projection.decoder
227 228
[docs] 229class SpladeTokenizer(BiEncoderTokenizer): 230 """Tokenizer class for SPLADE models.""" 231 232 config_class = SpladeConfig 233 """Configuration class for a SPLADE model.""" 234
[docs] 235 def __init__( 236 self, 237 *args, 238 query_length: int | None = 32, 239 doc_length: int | None = 512, 240 add_marker_tokens: bool = False, 241 query_weighting: Literal["contextualized", "static"] = "contextualized", 242 doc_weighting: Literal["contextualized", "static"] = "contextualized", 243 **kwargs, 244 ): 245 super().__init__( 246 *args, query_length=query_length, doc_length=doc_length, add_marker_tokens=add_marker_tokens, **kwargs 247 ) 248 """Initializes a SPLADE model's tokenizer. Encodes queries and documents separately. Optionally adds 249 marker tokens to encoded input sequences. 250 251 Args: 252 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 253 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 254 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 255 Defaults to False. 256 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 257 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 258 tokens. Defaults to False. 259 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 260 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 261 tokens. Defaults to False. 262 query_weighting (Literal["contextualized", "static"]): Whether to apply weighting to query tokens. 263 Defaults to "contextualized". 264 doc_weighting (Literal["contextualized", "static"]): Whether to apply weighting to document tokens. 265 Defaults to "contextualized". 266 Raises: 267 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 268 """ 269 super().__init__( 270 *args, 271 query_length=query_length, 272 doc_length=doc_length, 273 add_marker_tokens=add_marker_tokens, 274 query_weighting=query_weighting, 275 doc_weighting=doc_weighting, 276 **kwargs, 277 ) 278 self.query_weighting = query_weighting 279 self.doc_weighting = doc_weighting
280
[docs] 281 def tokenize_input_sequence( 282 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 283 ) -> BatchEncoding: 284 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 285 286 Args: 287 text (Sequence[str] | str): Input text to tokenize. 288 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 289 Returns: 290 BatchEncoding: Tokenized input sequences. 291 """ 292 post_processer = getattr(self, f"{input_type}_post_processor") 293 kwargs["max_length"] = getattr(self, f"{input_type}_length") 294 weighting = getattr(self, f"{input_type}_weighting") 295 if weighting is None or weighting == "static": 296 kwargs["add_special_tokens"] = False 297 if "padding" not in kwargs: 298 kwargs["truncation"] = True 299 return self._encode(text, *args, post_processor=post_processer, **kwargs)