Source code for lightning_ir.models.bi_encoders.splade

  1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models. Originally proposed in
  2`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
  3<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
  4"""
  5
  6import warnings
  7from collections.abc import Sequence
  8from pathlib import Path
  9from typing import Literal, Self
 10
 11import torch
 12from transformers import BatchEncoding
 13
 14from ...bi_encoder import (
 15    BiEncoderEmbedding,
 16    BiEncoderTokenizer,
 17    SingleVectorBiEncoderConfig,
 18    SingleVectorBiEncoderModel,
 19)
 20from ...modeling_utils.embedding_post_processing import Pooler, Sparsifier
 21from ...modeling_utils.lm_head import MODEL_TYPE_TO_LM_HEAD, MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING
 22
 23
[docs] 24class SpladeConfig(SingleVectorBiEncoderConfig): 25 """Configuration class for a SPLADE model.""" 26 27 model_type = "splade" 28 """Model type for a SPLADE model.""" 29
[docs] 30 def __init__( 31 self, 32 query_length: int | None = 32, 33 doc_length: int | None = 512, 34 similarity_function: Literal["cosine", "dot"] = "dot", 35 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = "relu_log", 36 pooling_strategy: Literal["first", "mean", "max", "sum"] = "max", 37 query_weighting: Literal["contextualized", "static"] | None = "contextualized", 38 query_expansion: bool = True, 39 doc_weighting: Literal["contextualized", "static"] | None = "contextualized", 40 doc_expansion: bool = True, 41 **kwargs, 42 ) -> None: 43 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the 44 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained 45 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single 46 embedding for the query and document. 47 48 Args: 49 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 50 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 51 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 52 document embeddings. Defaults to "dot". 53 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 54 function to apply. Defaults to None. 55 pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings. 56 Defaults to "max". 57 query_weighting (Literal["contextualized", "static"] | None): Whether to reweight query embeddings. 58 Defaults to "contextualized". 59 query_expansion (bool): Whether to allow query expansion. Defaults to True. 60 doc_weighting (Literal["contextualized", "static"] | None): Whether to reweight document embeddings. 61 Defaults to "contextualized". 62 doc_expansion (bool): Whether to allow document expansion. Defaults to True. 63 """ 64 super().__init__( 65 query_length=query_length, 66 doc_length=doc_length, 67 similarity_function=similarity_function, 68 sparsification_strategy=sparsification_strategy, 69 pooling_strategy=pooling_strategy, 70 **kwargs, 71 ) 72 if query_expansion and not query_weighting: 73 raise ValueError("If query_expansion is True, query_weighting must also be True.") 74 if doc_expansion and not doc_weighting: 75 raise ValueError("If doc_expansion is True, doc_weighting must also be True.") 76 self.query_weighting = query_weighting 77 self.query_expansion = query_expansion 78 self.doc_weighting = doc_weighting 79 self.doc_expansion = doc_expansion
80 81 @property 82 def embedding_dim(self) -> int: 83 vocab_size = getattr(self, "vocab_size", None) 84 if vocab_size is None: 85 raise ValueError("Unable to determine embedding dimension.") 86 return vocab_size 87 88 @embedding_dim.setter 89 def embedding_dim(self, value: int) -> None: 90 pass
91 92
[docs] 93class SpladeModel(SingleVectorBiEncoderModel): 94 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options.""" 95 96 config_class = SpladeConfig 97 """Configuration class for a SPLADE model.""" 98
[docs] 99 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 100 """Initializes a SPLADE model given a :class:`SpladeConfig`. 101 102 Args: 103 config (SingleVectorBiEncoderConfig): Configuration for the SPLADE model. 104 """ 105 super().__init__(config, *args, **kwargs) 106 # grab language modeling head based on backbone model type 107 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type] 108 self.projection = layer_cls(config) 109 tied_weight_keys = (getattr(self, "_tied_weights_keys", []) or []) + ["projection.decoder.weight"] 110 self._tied_weights_keys = tied_weight_keys 111 self.query_weights = None 112 if config.query_weighting == "static": 113 self.query_weights = torch.nn.Embedding(config.vocab_size, 1) 114 self.doc_weights = None 115 if config.doc_weighting == "static": 116 self.doc_weights = torch.nn.Embedding(config.vocab_size, 1) 117 118 self.pooler = Pooler(config) 119 self.sparsifier = Sparsifier(config)
120
[docs] 121 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 122 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 123 124 Args: 125 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 126 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 127 Returns: 128 BiEncoderEmbedding: Embeddings and scoring mask. 129 """ 130 weighting = getattr(self.config, f"{input_type}_weighting") 131 expansion = getattr(self.config, f"{input_type}_expansion") 132 token_mask = None 133 if weighting is None or weighting == "static" or not expansion: 134 token_mask = torch.zeros( 135 encoding["input_ids"].shape[0], 136 self.config.embedding_dim, 137 device=encoding["input_ids"].device, 138 dtype=torch.float32, 139 ) 140 if weighting == "static": 141 weights = getattr(self, f"{input_type}_weights")(encoding["input_ids"]).squeeze(-1) 142 else: 143 weights = encoding["attention_mask"].to(token_mask) 144 weights = weights.masked_fill(~(encoding["attention_mask"].bool()), 0.0) 145 token_mask = token_mask.scatter(1, encoding["input_ids"], weights) 146 if weighting is None or weighting == "static": 147 # inference free 148 return BiEncoderEmbedding(token_mask[:, None], None, encoding) 149 150 embeddings = self._backbone_forward(**encoding).last_hidden_state 151 embeddings = self.projection(embeddings) 152 embeddings = self.sparsifier(embeddings) 153 embeddings = self.pooler(embeddings, encoding["attention_mask"]) 154 if token_mask is not None: 155 embeddings = embeddings * token_mask[:, None] 156 return BiEncoderEmbedding(embeddings, None, encoding)
157
[docs] 158 @classmethod 159 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self: 160 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps 161 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel. 162 See :class:`LightningIRModelClassFactory` for more details. 163 164.. _transformers.PreTrainedModel.from_pretrained: \ 165 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 166 167 .. ::doctest 168 .. highlight:: python 169 .. code-block:: python 170 171 >>> # Loading using model class and backbone checkpoint 172 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 173 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 174 >>> # Loading using base class and backbone checkpoint 175 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 176 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 177 178 Args: 179 model_name_or_path (str | Path): Name or path of the pretrained model. 180 Returns: 181 Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin. 182 Raises: 183 ValueError: If called on the abstract class :class:`SpladeModel` and no config is passed. 184 """ 185 key_mapping = kwargs.pop("key_mapping", {}) 186 config = cls.config_class 187 # map mlm projection keys 188 model_type = config.backbone_model_type or config.model_type 189 if model_type in MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING: 190 key_mapping.update(MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING[model_type]) 191 if not key_mapping: 192 warnings.warn( 193 f"No mlm key mappings for model_type {model_type} were provided. " 194 "The pre-trained mlm weights will not be loaded correctly.", 195 stacklevel=2, 196 ) 197 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 198 return model
199
[docs] 200 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None: 201 if self.config.projection == "mlm": 202 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
203 # TODO fix this (not super important, only necessary when additional tokens are added to the model) 204 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type] 205 # module = self 206 # for module_name in module_names.split(".")[:-1]: 207 # module = getattr(module, module_name) 208 # setattr(module, module_names.split(".")[-1], new_embeddings) 209 # setattr(module, "bias", new_embeddings.bias) 210
[docs] 211 def get_output_embeddings(self) -> torch.nn.Module | None: 212 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no 213 MLM head is used for projection. 214 215 Returns: 216 torch.nn.Module | None: Output embeddings of the model. 217 """ 218 return self.projection.decoder
219 220
[docs] 221class SpladeTokenizer(BiEncoderTokenizer): 222 """Tokenizer class for SPLADE models.""" 223 224 config_class = SpladeConfig 225 """Configuration class for a SPLADE model.""" 226
[docs] 227 def __init__( 228 self, 229 *args, 230 query_length: int | None = 32, 231 doc_length: int | None = 512, 232 add_marker_tokens: bool = False, 233 query_weighting: Literal["contextualized", "static"] = "contextualized", 234 doc_weighting: Literal["contextualized", "static"] = "contextualized", 235 **kwargs, 236 ): 237 super().__init__( 238 *args, query_length=query_length, doc_length=doc_length, add_marker_tokens=add_marker_tokens, **kwargs 239 ) 240 """Initializes a SPLADE model's tokenizer. Encodes queries and documents separately. Optionally adds 241 marker tokens to encoded input sequences. 242 243 Args: 244 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 245 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 246 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents. 247 Defaults to False. 248 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False. 249 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query 250 tokens. Defaults to False. 251 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False. 252 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document 253 tokens. Defaults to False. 254 query_weighting (Literal["contextualized", "static"]): Whether to apply weighting to query tokens. 255 Defaults to "contextualized". 256 doc_weighting (Literal["contextualized", "static"]): Whether to apply weighting to document tokens. 257 Defaults to "contextualized". 258 Raises: 259 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 260 """ 261 super().__init__( 262 *args, 263 query_length=query_length, 264 doc_length=doc_length, 265 add_marker_tokens=add_marker_tokens, 266 query_weighting=query_weighting, 267 doc_weighting=doc_weighting, 268 **kwargs, 269 ) 270 self.query_weighting = query_weighting 271 self.doc_weighting = doc_weighting
272
[docs] 273 def tokenize_input_sequence( 274 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 275 ) -> BatchEncoding: 276 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 277 278 Args: 279 text (Sequence[str] | str): Input text to tokenize. 280 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 281 Returns: 282 BatchEncoding: Tokenized input sequences. 283 """ 284 post_processer = getattr(self, f"{input_type}_post_processor") 285 kwargs["max_length"] = getattr(self, f"{input_type}_length") 286 weighting = getattr(self, f"{input_type}_weighting") 287 if weighting is None or weighting == "static": 288 kwargs["add_special_tokens"] = False 289 if "padding" not in kwargs: 290 kwargs["truncation"] = True 291 return self._encode(text, *args, post_processor=post_processer, **kwargs)