1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models. Originally proposed in
2`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
3<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
4"""
5
6import warnings
7from pathlib import Path
8from typing import Literal, Self
9
10import torch
11from transformers import BatchEncoding
12
13from ..bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
14from ..modeling_utils.mlm_head import (
15 MODEL_TYPE_TO_KEY_MAPPING,
16 MODEL_TYPE_TO_LM_HEAD,
17 MODEL_TYPE_TO_OUTPUT_EMBEDDINGS,
18 MODEL_TYPE_TO_TIED_WEIGHTS_KEYS,
19)
20
21
[docs]
22class SpladeConfig(SingleVectorBiEncoderConfig):
23 """Configuration class for a SPLADE model."""
24
25 model_type = "splade"
26 """Model type for a SPLADE model."""
27
[docs]
28 def __init__(
29 self,
30 query_length: int = 32,
31 doc_length: int = 512,
32 similarity_function: Literal["cosine", "dot"] = "dot",
33 sparsification: Literal["relu", "relu_log"] | None = "relu_log",
34 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max",
35 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max",
36 **kwargs,
37 ) -> None:
38 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the
39 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained
40 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single
41 embedding for the query and document.
42
43 :param query_length: Maximum query length, defaults to 32
44 :type query_length: int, optional
45 :param doc_length: Maximum document length, defaults to 512
46 :type doc_length: int, optional
47 :param similarity_function: Similarity function to compute scores between query and document embeddings,
48 defaults to "dot"
49 :type similarity_function: Literal['cosine', 'dot'], optional
50 :param sparsification: Whether and which sparsification function to apply, defaults to None
51 :type sparsification: Literal['relu', 'relu_log'] | None, optional
52 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "max"
53 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional
54 :param doc_pooling_strategy: Whether and how to pool document token embeddings, defaults to "max"
55 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional
56 """
57 super().__init__(
58 query_length=query_length,
59 doc_length=doc_length,
60 similarity_function=similarity_function,
61 sparsification=sparsification,
62 query_pooling_strategy=query_pooling_strategy,
63 doc_pooling_strategy=doc_pooling_strategy,
64 **kwargs,
65 )
66
67 @property
68 def embedding_dim(self) -> int:
69 vocab_size = getattr(self, "vocab_size", None)
70 if vocab_size is None:
71 raise ValueError("Unable to determine embedding dimension.")
72 return vocab_size
73
74 @embedding_dim.setter
75 def embedding_dim(self, value: int) -> None:
76 pass
77
78
[docs]
79class SpladeModel(SingleVectorBiEncoderModel):
80 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options."""
81
82 config_class = SpladeConfig
83 """Configuration class for a SPLADE model."""
84
[docs]
85 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
86 """Initializes a SPLADE model given a :class:`SpladeConfig`.
87
88 :param config: Configuration for the SPLADE model
89 :type config: SingleVectorBiEncoderConfig
90 """
91 super().__init__(config, *args, **kwargs)
92 # grab language modeling head based on backbone model type
93 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type]
94 self.projection = layer_cls(config)
95 tied_weight_keys = getattr(self, "_tied_weights_keys", []) or []
96 tied_weight_keys = tied_weight_keys + [
97 f"projection.{key}"
98 for key in MODEL_TYPE_TO_TIED_WEIGHTS_KEYS[config.backbone_model_type or config.model_type]
99 ]
100 setattr(self, "_tied_weights_keys", tied_weight_keys)
101
[docs]
102 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
103 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
104
105 :param encoding: Tokenizer encodings for the text sequence
106 :type encoding: BatchEncoding
107 :param input_type: Type of input, either "query" or "doc"
108 :type input_type: Literal["query", "doc"]
109 :return: Embeddings and scoring mask
110 :rtype: BiEncoderEmbedding
111 """
112 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy")
113 embeddings = self._backbone_forward(**encoding).last_hidden_state
114 embeddings = self.projection(embeddings)
115 embeddings = self.sparsification(embeddings, self.config.sparsification)
116 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy)
117 return BiEncoderEmbedding(embeddings, None, encoding)
118
[docs]
119 @classmethod
120 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self:
121 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps
122 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel.
123 See :class:`LightningIRModelClassFactory` for more details.
124
125.. _transformers.PreTrainedModel.from_pretrained: \
126 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
127
128 :param model_name_or_path: Name or path of the pretrained model
129 :type model_name_or_path: str | Path
130 :raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed
131 :return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin
132 :rtype: LightningIRModel
133
134 .. ::doctest
135 .. highlight:: python
136 .. code-block:: python
137
138 >>> # Loading using model class and backbone checkpoint
139 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
140 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
141 >>> # Loading using base class and backbone checkpoint
142 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
143 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
144 """
145 key_mapping = kwargs.pop("key_mapping", {})
146 config = cls.config_class
147 # map mlm projection keys
148 model_type = config.backbone_model_type or config.model_type
149 if model_type in MODEL_TYPE_TO_KEY_MAPPING:
150 key_mapping.update(MODEL_TYPE_TO_KEY_MAPPING[model_type])
151 if not key_mapping:
152 warnings.warn(
153 f"No mlm key mappings for model_type {model_type} were provided. "
154 "The pre-trained mlm weights will not be loaded correctly."
155 )
156 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
157 return model
158
159 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
160 if self.config.projection == "mlm":
161 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
162 # TODO fix this (not super important, only necessary when additional tokens are added to the model)
163 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type]
164 # module = self
165 # for module_name in module_names.split(".")[:-1]:
166 # module = getattr(module, module_name)
167 # setattr(module, module_names.split(".")[-1], new_embeddings)
168 # setattr(module, "bias", new_embeddings.bias)
169
[docs]
170 def get_output_embeddings(self) -> torch.nn.Module | None:
171 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no
172 MLM head is used for projection.
173
174 :return: Output embeddings of the model
175 :rtype: torch.nn.Module | None
176 """
177 module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type]
178 output = self.projection
179 for module_name in module_names.split("."):
180 output = getattr(output, module_name)
181 return output