1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models. Originally proposed in
2`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
3<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
4"""
5
6import warnings
7from collections.abc import Sequence
8from pathlib import Path
9from typing import Literal, Self
10
11import torch
12from transformers import BatchEncoding
13
14from ...bi_encoder import (
15 BiEncoderEmbedding,
16 BiEncoderTokenizer,
17 SingleVectorBiEncoderConfig,
18 SingleVectorBiEncoderModel,
19)
20from ...modeling_utils.embedding_post_processing import Pooler, Sparsifier
21from ...modeling_utils.lm_head import MODEL_TYPE_TO_LM_HEAD, MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING
22
23
[docs]
24class SpladeConfig(SingleVectorBiEncoderConfig):
25 """Configuration class for a SPLADE model."""
26
27 model_type = "splade"
28 """Model type for a SPLADE model."""
29
[docs]
30 def __init__(
31 self,
32 query_length: int | None = 32,
33 doc_length: int | None = 512,
34 similarity_function: Literal["cosine", "dot"] = "dot",
35 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = "relu_log",
36 pooling_strategy: Literal["first", "mean", "max", "sum"] = "max",
37 query_weighting: Literal["contextualized", "static"] | None = "contextualized",
38 query_expansion: bool = True,
39 doc_weighting: Literal["contextualized", "static"] | None = "contextualized",
40 doc_expansion: bool = True,
41 **kwargs,
42 ) -> None:
43 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the
44 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained
45 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single
46 embedding for the query and document.
47
48 Args:
49 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
50 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
51 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
52 document embeddings. Defaults to "dot".
53 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
54 function to apply. Defaults to None.
55 pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings.
56 Defaults to "max".
57 query_weighting (Literal["contextualized", "static"] | None): Whether to reweight query embeddings.
58 Defaults to "contextualized".
59 query_expansion (bool): Whether to allow query expansion. Defaults to True.
60 doc_weighting (Literal["contextualized", "static"] | None): Whether to reweight document embeddings.
61 Defaults to "contextualized".
62 doc_expansion (bool): Whether to allow document expansion. Defaults to True.
63 """
64 super().__init__(
65 query_length=query_length,
66 doc_length=doc_length,
67 similarity_function=similarity_function,
68 sparsification_strategy=sparsification_strategy,
69 pooling_strategy=pooling_strategy,
70 **kwargs,
71 )
72 if query_expansion and not query_weighting:
73 raise ValueError("If query_expansion is True, query_weighting must also be True.")
74 if doc_expansion and not doc_weighting:
75 raise ValueError("If doc_expansion is True, doc_weighting must also be True.")
76 self.query_weighting = query_weighting
77 self.query_expansion = query_expansion
78 self.doc_weighting = doc_weighting
79 self.doc_expansion = doc_expansion
80
81 @property
82 def embedding_dim(self) -> int:
83 vocab_size = getattr(self, "vocab_size", None)
84 if vocab_size is None:
85 raise ValueError("Unable to determine embedding dimension.")
86 return vocab_size
87
88 @embedding_dim.setter
89 def embedding_dim(self, value: int) -> None:
90 pass
91
92
[docs]
93class SpladeModel(SingleVectorBiEncoderModel):
94 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options."""
95
96 config_class = SpladeConfig
97 """Configuration class for a SPLADE model."""
98
[docs]
99 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
100 """Initializes a SPLADE model given a :class:`SpladeConfig`.
101
102 Args:
103 config (SingleVectorBiEncoderConfig): Configuration for the SPLADE model.
104 """
105 super().__init__(config, *args, **kwargs)
106 # grab language modeling head based on backbone model type
107 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type]
108 self.projection = layer_cls(config)
109 tied_weight_keys = (getattr(self, "_tied_weights_keys", []) or []) + ["projection.decoder.weight"]
110 self._tied_weights_keys = tied_weight_keys
111 self.query_weights = None
112 if config.query_weighting == "static":
113 self.query_weights = torch.nn.Embedding(config.vocab_size, 1)
114 self.doc_weights = None
115 if config.doc_weighting == "static":
116 self.doc_weights = torch.nn.Embedding(config.vocab_size, 1)
117
118 self.pooler = Pooler(config)
119 self.sparsifier = Sparsifier(config)
120
[docs]
121 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
122 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
123
124 Args:
125 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
126 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
127 Returns:
128 BiEncoderEmbedding: Embeddings and scoring mask.
129 """
130 weighting = getattr(self.config, f"{input_type}_weighting")
131 expansion = getattr(self.config, f"{input_type}_expansion")
132 token_mask = None
133 if weighting is None or weighting == "static" or not expansion:
134 token_mask = torch.zeros(
135 encoding["input_ids"].shape[0],
136 self.config.embedding_dim,
137 device=encoding["input_ids"].device,
138 dtype=torch.float32,
139 )
140 if weighting == "static":
141 weights = getattr(self, f"{input_type}_weights")(encoding["input_ids"]).squeeze(-1)
142 else:
143 weights = encoding["attention_mask"].to(token_mask)
144 weights = weights.masked_fill(~(encoding["attention_mask"].bool()), 0.0)
145 token_mask = token_mask.scatter(1, encoding["input_ids"], weights)
146 if weighting is None or weighting == "static":
147 # inference free
148 return BiEncoderEmbedding(token_mask[:, None], None, encoding)
149
150 embeddings = self._backbone_forward(**encoding).last_hidden_state
151 embeddings = self.projection(embeddings)
152 embeddings = self.sparsifier(embeddings)
153 embeddings = self.pooler(embeddings, encoding["attention_mask"])
154 if token_mask is not None:
155 embeddings = embeddings * token_mask[:, None]
156 return BiEncoderEmbedding(embeddings, None, encoding)
157
[docs]
158 @classmethod
159 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self:
160 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps
161 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel.
162 See :class:`LightningIRModelClassFactory` for more details.
163
164.. _transformers.PreTrainedModel.from_pretrained: \
165 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
166
167 .. ::doctest
168 .. highlight:: python
169 .. code-block:: python
170
171 >>> # Loading using model class and backbone checkpoint
172 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
173 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
174 >>> # Loading using base class and backbone checkpoint
175 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
176 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
177
178 Args:
179 model_name_or_path (str | Path): Name or path of the pretrained model.
180 Returns:
181 Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.
182 Raises:
183 ValueError: If called on the abstract class :class:`SpladeModel` and no config is passed.
184 """
185 key_mapping = kwargs.pop("key_mapping", {})
186 config = cls.config_class
187 # map mlm projection keys
188 model_type = config.backbone_model_type or config.model_type
189 if model_type in MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING:
190 key_mapping.update(MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING[model_type])
191 if not key_mapping:
192 warnings.warn(
193 f"No mlm key mappings for model_type {model_type} were provided. "
194 "The pre-trained mlm weights will not be loaded correctly.",
195 stacklevel=2,
196 )
197 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
198 return model
199
[docs]
200 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
201 if self.config.projection == "mlm":
202 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
203 # TODO fix this (not super important, only necessary when additional tokens are added to the model)
204 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type]
205 # module = self
206 # for module_name in module_names.split(".")[:-1]:
207 # module = getattr(module, module_name)
208 # setattr(module, module_names.split(".")[-1], new_embeddings)
209 # setattr(module, "bias", new_embeddings.bias)
210
[docs]
211 def get_output_embeddings(self) -> torch.nn.Module | None:
212 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no
213 MLM head is used for projection.
214
215 Returns:
216 torch.nn.Module | None: Output embeddings of the model.
217 """
218 return self.projection.decoder
219
220
[docs]
221class SpladeTokenizer(BiEncoderTokenizer):
222 """Tokenizer class for SPLADE models."""
223
224 config_class = SpladeConfig
225 """Configuration class for a SPLADE model."""
226
[docs]
227 def __init__(
228 self,
229 *args,
230 query_length: int | None = 32,
231 doc_length: int | None = 512,
232 add_marker_tokens: bool = False,
233 query_weighting: Literal["contextualized", "static"] = "contextualized",
234 doc_weighting: Literal["contextualized", "static"] = "contextualized",
235 **kwargs,
236 ):
237 super().__init__(
238 *args, query_length=query_length, doc_length=doc_length, add_marker_tokens=add_marker_tokens, **kwargs
239 )
240 """Initializes a SPLADE model's tokenizer. Encodes queries and documents separately. Optionally adds
241 marker tokens to encoded input sequences.
242
243 Args:
244 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
245 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
246 add_marker_tokens (bool): Whether to add extra marker tokens [Q] / [D] to queries / documents.
247 Defaults to False.
248 query_expansion (bool): Whether to expand queries with mask tokens. Defaults to False.
249 attend_to_query_expanded_tokens (bool): Whether to allow query tokens to attend to mask expanded query
250 tokens. Defaults to False.
251 doc_expansion (bool): Whether to expand documents with mask tokens. Defaults to False.
252 attend_to_doc_expanded_tokens (bool): Whether to allow document tokens to attend to mask expanded document
253 tokens. Defaults to False.
254 query_weighting (Literal["contextualized", "static"]): Whether to apply weighting to query tokens.
255 Defaults to "contextualized".
256 doc_weighting (Literal["contextualized", "static"]): Whether to apply weighting to document tokens.
257 Defaults to "contextualized".
258 Raises:
259 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
260 """
261 super().__init__(
262 *args,
263 query_length=query_length,
264 doc_length=doc_length,
265 add_marker_tokens=add_marker_tokens,
266 query_weighting=query_weighting,
267 doc_weighting=doc_weighting,
268 **kwargs,
269 )
270 self.query_weighting = query_weighting
271 self.doc_weighting = doc_weighting
272