1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models. Originally proposed in
2`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
3<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
4"""
5
6import warnings
7from pathlib import Path
8from typing import Literal, Self, Sequence
9
10import torch
11from transformers import BatchEncoding
12
13from ...bi_encoder import (
14 BiEncoderEmbedding,
15 BiEncoderTokenizer,
16 SingleVectorBiEncoderConfig,
17 SingleVectorBiEncoderModel,
18)
19from ...modeling_utils.mlm_head import (
20 MODEL_TYPE_TO_LM_HEAD,
21 MODEL_TYPE_TO_OUTPUT_EMBEDDINGS,
22 MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING,
23 MODEL_TYPE_TO_TIED_WEIGHTS_KEYS,
24)
25
26
[docs]
27class SpladeConfig(SingleVectorBiEncoderConfig):
28 """Configuration class for a SPLADE model."""
29
30 model_type = "splade"
31 """Model type for a SPLADE model."""
32
[docs]
33 def __init__(
34 self,
35 query_length: int | None = 32,
36 doc_length: int | None = 512,
37 similarity_function: Literal["cosine", "dot"] = "dot",
38 sparsification: Literal["relu", "relu_log", "relu_2xlog"] | None = "relu_log",
39 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max",
40 query_weighting: Literal["contextualized", "static"] | None = "contextualized",
41 query_expansion: bool = True,
42 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max",
43 doc_weighting: Literal["contextualized", "static"] | None = "contextualized",
44 doc_expansion: bool = True,
45 **kwargs,
46 ) -> None:
47 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the
48 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained
49 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single
50 embedding for the query and document.
51
52 Args:
53 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
54 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
55 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
56 document embeddings. Defaults to "dot".
57 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
58 function to apply. Defaults to None.
59 query_weighting (Literal["contextualized", "static"] | None): Whether to reweight query embeddings.
60 Defaults to "contextualized".
61 query_expansion (bool): Whether to allow query expansion. Defaults to True.
62 query_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings.
63 Defaults to "max".
64 doc_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for document embeddings.
65 Defaults to "max".
66 doc_weighting (Literal["contextualized", "static"] | None): Whether to reweight document embeddings.
67 Defaults to "contextualized".
68 doc_expansion (bool): Whether to allow document expansion. Defaults to True.
69 """
70 super().__init__(
71 query_length=query_length,
72 doc_length=doc_length,
73 similarity_function=similarity_function,
74 sparsification=sparsification,
75 query_pooling_strategy=query_pooling_strategy,
76 doc_pooling_strategy=doc_pooling_strategy,
77 **kwargs,
78 )
79 if query_expansion and not query_weighting:
80 raise ValueError("If query_expansion is True, query_weighting must also be True.")
81 if doc_expansion and not doc_weighting:
82 raise ValueError("If doc_expansion is True, doc_weighting must also be True.")
83 self.query_weighting = query_weighting
84 self.query_expansion = query_expansion
85 self.doc_weighting = doc_weighting
86 self.doc_expansion = doc_expansion
87
88 @property
89 def embedding_dim(self) -> int:
90 vocab_size = getattr(self, "vocab_size", None)
91 if vocab_size is None:
92 raise ValueError("Unable to determine embedding dimension.")
93 return vocab_size
94
95 @embedding_dim.setter
96 def embedding_dim(self, value: int) -> None:
97 pass
98
99
[docs]
100class SpladeModel(SingleVectorBiEncoderModel):
101 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options."""
102
103 config_class = SpladeConfig
104 """Configuration class for a SPLADE model."""
105
[docs]
106 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
107 """Initializes a SPLADE model given a :class:`SpladeConfig`.
108
109 Args:
110 config (SingleVectorBiEncoderConfig): Configuration for the SPLADE model.
111 """
112 super().__init__(config, *args, **kwargs)
113 # grab language modeling head based on backbone model type
114 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type]
115 self.projection = layer_cls(config)
116 tied_weight_keys = getattr(self, "_tied_weights_keys", []) or []
117 tied_weight_keys = tied_weight_keys + [
118 f"projection.{key}"
119 for key in MODEL_TYPE_TO_TIED_WEIGHTS_KEYS[config.backbone_model_type or config.model_type]
120 ]
121 setattr(self, "_tied_weights_keys", tied_weight_keys)
122 self.query_weights = None
123 if config.query_weighting == "static":
124 self.query_weights = torch.nn.Embedding(config.vocab_size, 1)
125 self.doc_weights = None
126 if config.doc_weighting == "static":
127 self.doc_weights = torch.nn.Embedding(config.vocab_size, 1)
128
[docs]
129 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
130 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
131
132 Args:
133 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
134 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
135 Returns:
136 BiEncoderEmbedding: Embeddings and scoring mask.
137 """
138 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy")
139 weighting = getattr(self.config, f"{input_type}_weighting")
140 expansion = getattr(self.config, f"{input_type}_expansion")
141 token_mask = None
142 if weighting is None or weighting == "static" or not expansion:
143 token_mask = torch.zeros(
144 encoding["input_ids"].shape[0],
145 self.config.embedding_dim,
146 device=encoding["input_ids"].device,
147 dtype=torch.float32,
148 )
149 if weighting == "static":
150 weights = getattr(self, f"{input_type}_weights")(encoding["input_ids"]).squeeze(-1)
151 else:
152 weights = encoding["attention_mask"].to(token_mask)
153 weights = weights.masked_fill(~(encoding["attention_mask"].bool()), 0.0)
154 token_mask = token_mask.scatter(1, encoding["input_ids"], weights)
155 if weighting is None or weighting == "static":
156 # inference free
157 return BiEncoderEmbedding(token_mask[:, None], None, encoding)
158
159 embeddings = self._backbone_forward(**encoding).last_hidden_state
160 embeddings = self.projection(embeddings)
161 embeddings = self.sparsification(embeddings, self.config.sparsification)
162 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy)
163 if token_mask is not None:
164 embeddings = embeddings * token_mask[:, None]
165 return BiEncoderEmbedding(embeddings, None, encoding)
166
[docs]
167 @classmethod
168 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self:
169 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps
170 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel.
171 See :class:`LightningIRModelClassFactory` for more details.
172
173.. _transformers.PreTrainedModel.from_pretrained: \
174 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
175
176 .. ::doctest
177 .. highlight:: python
178 .. code-block:: python
179
180 >>> # Loading using model class and backbone checkpoint
181 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
182 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
183 >>> # Loading using base class and backbone checkpoint
184 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
185 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
186
187 Args:
188 model_name_or_path (str | Path): Name or path of the pretrained model.
189 Returns:
190 Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.
191 Raises:
192 ValueError: If called on the abstract class :class:`SpladeModel` and no config is passed.
193 """
194 key_mapping = kwargs.pop("key_mapping", {})
195 config = cls.config_class
196 # map mlm projection keys
197 model_type = config.backbone_model_type or config.model_type
198 if model_type in MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING:
199 key_mapping.update(MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING[model_type])
200 if not key_mapping:
201 warnings.warn(
202 f"No mlm key mappings for model_type {model_type} were provided. "
203 "The pre-trained mlm weights will not be loaded correctly."
204 )
205 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
206 return model
207
[docs]
208 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
209 if self.config.projection == "mlm":
210 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
211 # TODO fix this (not super important, only necessary when additional tokens are added to the model)
212 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type]
213 # module = self
214 # for module_name in module_names.split(".")[:-1]:
215 # module = getattr(module, module_name)
216 # setattr(module, module_names.split(".")[-1], new_embeddings)
217 # setattr(module, "bias", new_embeddings.bias)
218
[docs]
219 def get_output_embeddings(self) -> torch.nn.Module | None:
220 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no
221 MLM head is used for projection.
222
223 Returns:
224 torch.nn.Module | None: Output embeddings of the model.
225 """
226 module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type]
227 output = self.projection
228 for module_name in module_names.split("."):
229 output = getattr(output, module_name)
230 return output
231
232
[docs]
233class SpladeTokenizer(BiEncoderTokenizer):
234 """Tokenizer class for SPLADE models."""
235
236 config_class = SpladeConfig
237 """Configuration class for a SPLADE model."""
238
[docs]
239 def __init__(
240 self,
241 *args,
242 query_length: int | None = 32,
243 doc_length: int | None = 512,
244 add_marker_tokens: bool = False,
245 query_weighting: Literal["contextualized", "static"] = "contextualized",
246 doc_weighting: Literal["contextualized", "static"] = "contextualized",
247 **kwargs,
248 ):
249 super().__init__(
250 *args, query_length=query_length, doc_length=doc_length, add_marker_tokens=add_marker_tokens, **kwargs
251 )
252 """Initializes a SPLADE model's tokenizer. Encodes queries and documents separately. Optionally adds
253 marker tokens to encoded input sequences.
254
255 Args:
256 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
257 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
258 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
259 Defaults to False.
260 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
261 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
262 tokens. Defaults to False.
263 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
264 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
265 tokens. Defaults to False.
266 query_weighting (Literal["contextualized", "static"]): Whether to apply weighting to query tokens.
267 Defaults to "contextualized".
268 doc_weighting (Literal["contextualized", "static"]): Whether to apply weighting to document tokens.
269 Defaults to "contextualized".
270 Raises:
271 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
272 """
273 super().__init__(
274 *args,
275 query_length=query_length,
276 doc_length=doc_length,
277 add_marker_tokens=add_marker_tokens,
278 query_weighting=query_weighting,
279 doc_weighting=doc_weighting,
280 **kwargs,
281 )
282 self.query_weighting = query_weighting
283 self.doc_weighting = doc_weighting
284