Source code for lightning_ir.models.splade

  1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models. Originally proposed in
  2`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
  3<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
  4"""
  5
  6import warnings
  7from pathlib import Path
  8from typing import Literal, Self
  9
 10import torch
 11from transformers import BatchEncoding
 12
 13from ..bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
 14from ..modeling_utils.mlm_head import (
 15    MODEL_TYPE_TO_KEY_MAPPING,
 16    MODEL_TYPE_TO_LM_HEAD,
 17    MODEL_TYPE_TO_OUTPUT_EMBEDDINGS,
 18    MODEL_TYPE_TO_TIED_WEIGHTS_KEYS,
 19)
 20
 21
[docs] 22class SpladeConfig(SingleVectorBiEncoderConfig): 23 """Configuration class for a SPLADE model.""" 24 25 model_type = "splade" 26 """Model type for a SPLADE model.""" 27
[docs] 28 def __init__( 29 self, 30 query_length: int = 32, 31 doc_length: int = 512, 32 similarity_function: Literal["cosine", "dot"] = "dot", 33 sparsification: Literal["relu", "relu_log"] | None = "relu_log", 34 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max", 35 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max", 36 **kwargs, 37 ) -> None: 38 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the 39 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained 40 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single 41 embedding for the query and document. 42 43 Args: 44 query_length (int): Maximum query length. Defaults to 32. 45 doc_length (int): Maximum document length. Defaults to 512. 46 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 47 document embeddings. Defaults to "dot". 48 sparsification (Literal["relu", "relu_log"] | None): Sparsification function to apply. 49 Defaults to "relu_log". 50 query_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings. 51 Defaults to "max". 52 doc_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for document embeddings. 53 Defaults to "max". 54 """ 55 super().__init__( 56 query_length=query_length, 57 doc_length=doc_length, 58 similarity_function=similarity_function, 59 sparsification=sparsification, 60 query_pooling_strategy=query_pooling_strategy, 61 doc_pooling_strategy=doc_pooling_strategy, 62 **kwargs, 63 )
64 65 @property 66 def embedding_dim(self) -> int: 67 vocab_size = getattr(self, "vocab_size", None) 68 if vocab_size is None: 69 raise ValueError("Unable to determine embedding dimension.") 70 return vocab_size 71 72 @embedding_dim.setter 73 def embedding_dim(self, value: int) -> None: 74 pass
75 76
[docs] 77class SpladeModel(SingleVectorBiEncoderModel): 78 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options.""" 79 80 config_class = SpladeConfig 81 """Configuration class for a SPLADE model.""" 82
[docs] 83 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 84 """Initializes a SPLADE model given a :class:`SpladeConfig`. 85 86 Args: 87 config (SingleVectorBiEncoderConfig): Configuration for the SPLADE model. 88 """ 89 super().__init__(config, *args, **kwargs) 90 # grab language modeling head based on backbone model type 91 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type] 92 self.projection = layer_cls(config) 93 tied_weight_keys = getattr(self, "_tied_weights_keys", []) or [] 94 tied_weight_keys = tied_weight_keys + [ 95 f"projection.{key}" 96 for key in MODEL_TYPE_TO_TIED_WEIGHTS_KEYS[config.backbone_model_type or config.model_type] 97 ] 98 setattr(self, "_tied_weights_keys", tied_weight_keys)
99
[docs] 100 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 101 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 102 103 Args: 104 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 105 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 106 Returns: 107 BiEncoderEmbedding: Embeddings and scoring mask. 108 """ 109 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy") 110 embeddings = self._backbone_forward(**encoding).last_hidden_state 111 embeddings = self.projection(embeddings) 112 embeddings = self.sparsification(embeddings, self.config.sparsification) 113 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy) 114 return BiEncoderEmbedding(embeddings, None, encoding)
115
[docs] 116 @classmethod 117 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self: 118 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps 119 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel. 120 See :class:`LightningIRModelClassFactory` for more details. 121 122.. _transformers.PreTrainedModel.from_pretrained: \ 123 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 124 125 .. ::doctest 126 .. highlight:: python 127 .. code-block:: python 128 129 >>> # Loading using model class and backbone checkpoint 130 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 131 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 132 >>> # Loading using base class and backbone checkpoint 133 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 134 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 135 136 Args: 137 model_name_or_path (str | Path): Name or path of the pretrained model. 138 Returns: 139 Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin. 140 Raises: 141 ValueError: If called on the abstract class :class:`SpladeModel` and no config is passed. 142 """ 143 key_mapping = kwargs.pop("key_mapping", {}) 144 config = cls.config_class 145 # map mlm projection keys 146 model_type = config.backbone_model_type or config.model_type 147 if model_type in MODEL_TYPE_TO_KEY_MAPPING: 148 key_mapping.update(MODEL_TYPE_TO_KEY_MAPPING[model_type]) 149 if not key_mapping: 150 warnings.warn( 151 f"No mlm key mappings for model_type {model_type} were provided. " 152 "The pre-trained mlm weights will not be loaded correctly." 153 ) 154 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 155 return model
156
[docs] 157 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None: 158 if self.config.projection == "mlm": 159 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.")
160 # TODO fix this (not super important, only necessary when additional tokens are added to the model) 161 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type] 162 # module = self 163 # for module_name in module_names.split(".")[:-1]: 164 # module = getattr(module, module_name) 165 # setattr(module, module_names.split(".")[-1], new_embeddings) 166 # setattr(module, "bias", new_embeddings.bias) 167
[docs] 168 def get_output_embeddings(self) -> torch.nn.Module | None: 169 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no 170 MLM head is used for projection. 171 172 Returns: 173 torch.nn.Module | None: Output embeddings of the model. 174 """ 175 module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type] 176 output = self.projection 177 for module_name in module_names.split("."): 178 output = getattr(output, module_name) 179 return output