1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models. Originally proposed in
2`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
3<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
4"""
5
6import warnings
7from pathlib import Path
8from typing import Literal, Self
9
10import torch
11from transformers import BatchEncoding
12
13from ..bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
14from ..modeling_utils.mlm_head import (
15 MODEL_TYPE_TO_KEY_MAPPING,
16 MODEL_TYPE_TO_LM_HEAD,
17 MODEL_TYPE_TO_OUTPUT_EMBEDDINGS,
18 MODEL_TYPE_TO_TIED_WEIGHTS_KEYS,
19)
20
21
[docs]
22class SpladeConfig(SingleVectorBiEncoderConfig):
23 """Configuration class for a SPLADE model."""
24
25 model_type = "splade"
26 """Model type for a SPLADE model."""
27
[docs]
28 def __init__(
29 self,
30 query_length: int = 32,
31 doc_length: int = 512,
32 similarity_function: Literal["cosine", "dot"] = "dot",
33 sparsification: Literal["relu", "relu_log"] | None = "relu_log",
34 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max",
35 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max",
36 **kwargs,
37 ) -> None:
38 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the
39 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained
40 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single
41 embedding for the query and document.
42
43 Args:
44 query_length (int): Maximum query length. Defaults to 32.
45 doc_length (int): Maximum document length. Defaults to 512.
46 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
47 document embeddings. Defaults to "dot".
48 sparsification (Literal["relu", "relu_log"] | None): Sparsification function to apply.
49 Defaults to "relu_log".
50 query_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings.
51 Defaults to "max".
52 doc_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for document embeddings.
53 Defaults to "max".
54 """
55 super().__init__(
56 query_length=query_length,
57 doc_length=doc_length,
58 similarity_function=similarity_function,
59 sparsification=sparsification,
60 query_pooling_strategy=query_pooling_strategy,
61 doc_pooling_strategy=doc_pooling_strategy,
62 **kwargs,
63 )
64
65 @property
66 def embedding_dim(self) -> int:
67 vocab_size = getattr(self, "vocab_size", None)
68 if vocab_size is None:
69 raise ValueError("Unable to determine embedding dimension.")
70 return vocab_size
71
72 @embedding_dim.setter
73 def embedding_dim(self, value: int) -> None:
74 pass
75
76
[docs]
77class SpladeModel(SingleVectorBiEncoderModel):
78 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options."""
79
80 config_class = SpladeConfig
81 """Configuration class for a SPLADE model."""
82
[docs]
83 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
84 """Initializes a SPLADE model given a :class:`SpladeConfig`.
85
86 Args:
87 config (SingleVectorBiEncoderConfig): Configuration for the SPLADE model.
88 """
89 super().__init__(config, *args, **kwargs)
90 # grab language modeling head based on backbone model type
91 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type]
92 self.projection = layer_cls(config)
93 tied_weight_keys = getattr(self, "_tied_weights_keys", []) or []
94 tied_weight_keys = tied_weight_keys + [
95 f"projection.{key}"
96 for key in MODEL_TYPE_TO_TIED_WEIGHTS_KEYS[config.backbone_model_type or config.model_type]
97 ]
98 setattr(self, "_tied_weights_keys", tied_weight_keys)
99
[docs]
100 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
101 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
102
103 Args:
104 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
105 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
106 Returns:
107 BiEncoderEmbedding: Embeddings and scoring mask.
108 """
109 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy")
110 embeddings = self._backbone_forward(**encoding).last_hidden_state
111 embeddings = self.projection(embeddings)
112 embeddings = self.sparsification(embeddings, self.config.sparsification)
113 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy)
114 return BiEncoderEmbedding(embeddings, None, encoding)
115
[docs]
116 @classmethod
117 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self:
118 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps
119 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel.
120 See :class:`LightningIRModelClassFactory` for more details.
121
122.. _transformers.PreTrainedModel.from_pretrained: \
123 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
124
125 .. ::doctest
126 .. highlight:: python
127 .. code-block:: python
128
129 >>> # Loading using model class and backbone checkpoint
130 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
131 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
132 >>> # Loading using base class and backbone checkpoint
133 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
134 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
135
136 Args:
137 model_name_or_path (str | Path): Name or path of the pretrained model.
138 Returns:
139 Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.
140 Raises:
141 ValueError: If called on the abstract class :class:`SpladeModel` and no config is passed.
142 """
143 key_mapping = kwargs.pop("key_mapping", {})
144 config = cls.config_class
145 # map mlm projection keys
146 model_type = config.backbone_model_type or config.model_type
147 if model_type in MODEL_TYPE_TO_KEY_MAPPING:
148 key_mapping.update(MODEL_TYPE_TO_KEY_MAPPING[model_type])
149 if not key_mapping:
150 warnings.warn(
151 f"No mlm key mappings for model_type {model_type} were provided. "
152 "The pre-trained mlm weights will not be loaded correctly."
153 )
154 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
155 return model
156
[docs]
157 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
158 if self.config.projection == "mlm":
159 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
160 # TODO fix this (not super important, only necessary when additional tokens are added to the model)
161 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type]
162 # module = self
163 # for module_name in module_names.split(".")[:-1]:
164 # module = getattr(module, module_name)
165 # setattr(module, module_names.split(".")[-1], new_embeddings)
166 # setattr(module, "bias", new_embeddings.bias)
167
[docs]
168 def get_output_embeddings(self) -> torch.nn.Module | None:
169 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no
170 MLM head is used for projection.
171
172 Returns:
173 torch.nn.Module | None: Output embeddings of the model.
174 """
175 module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type]
176 output = self.projection
177 for module_name in module_names.split("."):
178 output = getattr(output, module_name)
179 return output