1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models.
2
3SPLADE (SParse Lexical AnD Expansion) is a efficient retrieval model that bridges the gap between traditional
4keyword search and deep neural understanding. It uses a language model to analyze a document and assign importance
5scores to words across the entire vocabulary, expanding the text with highly relevant terms that were not originally
6present. This process creates a sparse, high-dimensional vector that can be stored and searched using inverted indices
7just like traditional search engines.
8
9Originally proposed in
10`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
11<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
12"""
13
14import warnings
15from collections.abc import Sequence
16from pathlib import Path
17from typing import Literal, Self
18
19import torch
20from transformers import BatchEncoding
21
22from ...bi_encoder import (
23 BiEncoderEmbedding,
24 BiEncoderTokenizer,
25 SingleVectorBiEncoderConfig,
26 SingleVectorBiEncoderModel,
27)
28from ...modeling_utils.embedding_post_processing import Pooler, Sparsifier
29from ...modeling_utils.lm_head import MODEL_TYPE_TO_LM_HEAD, MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING
30
31
[docs]
32class SpladeConfig(SingleVectorBiEncoderConfig):
33 """Configuration class for a SPLADE model."""
34
35 model_type = "splade"
36 """Model type for a SPLADE model."""
37
[docs]
38 def __init__(
39 self,
40 query_length: int | None = 32,
41 doc_length: int | None = 512,
42 similarity_function: Literal["cosine", "dot"] = "dot",
43 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = "relu_log",
44 pooling_strategy: Literal["first", "mean", "max", "sum"] = "max",
45 query_weighting: Literal["contextualized", "static"] | None = "contextualized",
46 query_expansion: bool = True,
47 doc_weighting: Literal["contextualized", "static"] | None = "contextualized",
48 doc_expansion: bool = True,
49 **kwargs,
50 ) -> None:
51 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the
52 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained
53 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single
54 embedding for the query and document.
55
56 Args:
57 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
58 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
59 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
60 document embeddings. Defaults to "dot".
61 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
62 function to apply. Defaults to None.
63 pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings.
64 Defaults to "max".
65 query_weighting (Literal["contextualized", "static"] | None): Whether to reweight query embeddings.
66 Defaults to "contextualized".
67 query_expansion (bool): Whether to allow query expansion. Defaults to True.
68 doc_weighting (Literal["contextualized", "static"] | None): Whether to reweight document embeddings.
69 Defaults to "contextualized".
70 doc_expansion (bool): Whether to allow document expansion. Defaults to True.
71 """
72 super().__init__(
73 query_length=query_length,
74 doc_length=doc_length,
75 similarity_function=similarity_function,
76 sparsification_strategy=sparsification_strategy,
77 pooling_strategy=pooling_strategy,
78 **kwargs,
79 )
80 if query_expansion and not query_weighting:
81 raise ValueError("If query_expansion is True, query_weighting must also be True.")
82 if doc_expansion and not doc_weighting:
83 raise ValueError("If doc_expansion is True, doc_weighting must also be True.")
84 self.query_weighting = query_weighting
85 self.query_expansion = query_expansion
86 self.doc_weighting = doc_weighting
87 self.doc_expansion = doc_expansion
88
89 @property
90 def embedding_dim(self) -> int:
91 vocab_size = getattr(self, "vocab_size", None)
92 if vocab_size is None:
93 raise ValueError("Unable to determine embedding dimension.")
94 return vocab_size
95
96 @embedding_dim.setter
97 def embedding_dim(self, value: int) -> None:
98 pass
99
100
[docs]
101class SpladeModel(SingleVectorBiEncoderModel):
102 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options."""
103
104 config_class = SpladeConfig
105 """Configuration class for a SPLADE model."""
106
[docs]
107 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
108 """Initializes a SPLADE model given a :class:`SpladeConfig`.
109
110 Args:
111 config (SingleVectorBiEncoderConfig): Configuration for the SPLADE model.
112 """
113 super().__init__(config, *args, **kwargs)
114 # grab language modeling head based on backbone model type
115 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type]
116 self.projection = layer_cls(config)
117 tied_weight_keys = (getattr(self, "_tied_weights_keys", []) or []) + ["projection.decoder.weight"]
118 self._tied_weights_keys = tied_weight_keys
119 self.query_weights = None
120 if config.query_weighting == "static":
121 self.query_weights = torch.nn.Embedding(config.vocab_size, 1)
122 self.doc_weights = None
123 if config.doc_weighting == "static":
124 self.doc_weights = torch.nn.Embedding(config.vocab_size, 1)
125
126 self.pooler = Pooler(config)
127 self.sparsifier = Sparsifier(config)
128
[docs]
129 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
130 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
131
132 Args:
133 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
134 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
135 Returns:
136 BiEncoderEmbedding: Embeddings and scoring mask.
137 """
138 weighting = getattr(self.config, f"{input_type}_weighting")
139 expansion = getattr(self.config, f"{input_type}_expansion")
140 token_mask = None
141 if weighting is None or weighting == "static" or not expansion:
142 token_mask = torch.zeros(
143 encoding["input_ids"].shape[0],
144 self.config.embedding_dim,
145 device=encoding["input_ids"].device,
146 dtype=torch.float32,
147 )
148 if weighting == "static":
149 weights = getattr(self, f"{input_type}_weights")(encoding["input_ids"]).squeeze(-1)
150 else:
151 weights = encoding["attention_mask"].to(token_mask)
152 weights = weights.masked_fill(~(encoding["attention_mask"].bool()), 0.0)
153 token_mask = token_mask.scatter(1, encoding["input_ids"], weights)
154 if weighting is None or weighting == "static":
155 # inference free
156 return BiEncoderEmbedding(token_mask[:, None], None, encoding)
157
158 embeddings = self._backbone_forward(**encoding).last_hidden_state
159 embeddings = self.projection(embeddings)
160 embeddings = self.sparsifier(embeddings)
161 embeddings = self.pooler(embeddings, encoding["attention_mask"])
162 if token_mask is not None:
163 embeddings = embeddings * token_mask[:, None]
164 return BiEncoderEmbedding(embeddings, None, encoding)
165
[docs]
166 @classmethod
167 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self:
168 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps
169 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel.
170 See :class:`LightningIRModelClassFactory` for more details.
171
172.. _transformers.PreTrainedModel.from_pretrained: \
173 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
174
175 .. ::doctest
176 .. highlight:: python
177 .. code-block:: python
178
179 >>> # Loading using model class and backbone checkpoint
180 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
181 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
182 >>> # Loading using base class and backbone checkpoint
183 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
184 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
185
186 Args:
187 model_name_or_path (str | Path): Name or path of the pretrained model.
188 Returns:
189 Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.
190 Raises:
191 ValueError: If called on the abstract class :class:`SpladeModel` and no config is passed.
192 """
193 key_mapping = kwargs.pop("key_mapping", {})
194 config = cls.config_class
195 # map mlm projection keys
196 model_type = config.backbone_model_type or config.model_type
197 if model_type in MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING:
198 key_mapping.update(MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING[model_type])
199 if not key_mapping:
200 warnings.warn(
201 f"No mlm key mappings for model_type {model_type} were provided. "
202 "The pre-trained mlm weights will not be loaded correctly.",
203 stacklevel=2,
204 )
205 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
206 return model
207
[docs]
208 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
209 if self.config.projection == "mlm":
210 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
211 # TODO fix this (not super important, only necessary when additional tokens are added to the model)
212 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type]
213 # module = self
214 # for module_name in module_names.split(".")[:-1]:
215 # module = getattr(module, module_name)
216 # setattr(module, module_names.split(".")[-1], new_embeddings)
217 # setattr(module, "bias", new_embeddings.bias)
218
[docs]
219 def get_output_embeddings(self) -> torch.nn.Module | None:
220 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no
221 MLM head is used for projection.
222
223 Returns:
224 torch.nn.Module | None: Output embeddings of the model.
225 """
226 return self.projection.decoder
227
228
[docs]
229class SpladeTokenizer(BiEncoderTokenizer):
230 """Tokenizer class for SPLADE models."""
231
232 config_class = SpladeConfig
233 """Configuration class for a SPLADE model."""
234
[docs]
235 def __init__(
236 self,
237 *args,
238 query_length: int | None = 32,
239 doc_length: int | None = 512,
240 add_marker_tokens: bool = False,
241 query_weighting: Literal["contextualized", "static"] = "contextualized",
242 doc_weighting: Literal["contextualized", "static"] = "contextualized",
243 **kwargs,
244 ):
245 super().__init__(
246 *args, query_length=query_length, doc_length=doc_length, add_marker_tokens=add_marker_tokens, **kwargs
247 )
248 """Initializes a SPLADE model's tokenizer. Encodes queries and documents separately. Optionally adds
249 marker tokens to encoded input sequences.
250
251 Args:
252 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
253 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
254 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
255 Defaults to False.
256 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
257 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
258 tokens. Defaults to False.
259 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
260 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
261 tokens. Defaults to False.
262 query_weighting (Literal["contextualized", "static"]): Whether to apply weighting to query tokens.
263 Defaults to "contextualized".
264 doc_weighting (Literal["contextualized", "static"]): Whether to apply weighting to document tokens.
265 Defaults to "contextualized".
266 Raises:
267 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
268 """
269 super().__init__(
270 *args,
271 query_length=query_length,
272 doc_length=doc_length,
273 add_marker_tokens=add_marker_tokens,
274 query_weighting=query_weighting,
275 doc_weighting=doc_weighting,
276 **kwargs,
277 )
278 self.query_weighting = query_weighting
279 self.doc_weighting = doc_weighting
280