Source code for lightning_ir.models.splade

  1"""Configuration and model for SPLADE (SParse Lexical AnD Expansion) type models. Originally proposed in
  2`SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking
  3<https://dl.acm.org/doi/abs/10.1145/3404835.3463098>`_.
  4"""
  5
  6import warnings
  7from pathlib import Path
  8from typing import Literal, Self
  9
 10import torch
 11from transformers import BatchEncoding
 12
 13from ..bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
 14from ..modeling_utils.mlm_head import (
 15    MODEL_TYPE_TO_KEY_MAPPING,
 16    MODEL_TYPE_TO_LM_HEAD,
 17    MODEL_TYPE_TO_OUTPUT_EMBEDDINGS,
 18    MODEL_TYPE_TO_TIED_WEIGHTS_KEYS,
 19)
 20
 21
[docs] 22class SpladeConfig(SingleVectorBiEncoderConfig): 23 """Configuration class for a SPLADE model.""" 24 25 model_type = "splade" 26 """Model type for a SPLADE model.""" 27
[docs] 28 def __init__( 29 self, 30 query_length: int = 32, 31 doc_length: int = 512, 32 similarity_function: Literal["cosine", "dot"] = "dot", 33 sparsification: Literal["relu", "relu_log"] | None = "relu_log", 34 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max", 35 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "max", 36 **kwargs, 37 ) -> None: 38 """A SPLADE model encodes queries and documents separately. Before computing the similarity score, the 39 contextualized token embeddings are projected into a logit distribution over the vocabulary using a pre-trained 40 masked language model (MLM) head. The logit distribution is then sparsified and aggregated to obtain a single 41 embedding for the query and document. 42 43 :param query_length: Maximum query length, defaults to 32 44 :type query_length: int, optional 45 :param doc_length: Maximum document length, defaults to 512 46 :type doc_length: int, optional 47 :param similarity_function: Similarity function to compute scores between query and document embeddings, 48 defaults to "dot" 49 :type similarity_function: Literal['cosine', 'dot'], optional 50 :param sparsification: Whether and which sparsification function to apply, defaults to None 51 :type sparsification: Literal['relu', 'relu_log'] | None, optional 52 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "max" 53 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional 54 :param doc_pooling_strategy: Whether and how to pool document token embeddings, defaults to "max" 55 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional 56 """ 57 super().__init__( 58 query_length=query_length, 59 doc_length=doc_length, 60 similarity_function=similarity_function, 61 sparsification=sparsification, 62 query_pooling_strategy=query_pooling_strategy, 63 doc_pooling_strategy=doc_pooling_strategy, 64 **kwargs, 65 )
66 67 @property 68 def embedding_dim(self) -> int: 69 vocab_size = getattr(self, "vocab_size", None) 70 if vocab_size is None: 71 raise ValueError("Unable to determine embedding dimension.") 72 return vocab_size 73 74 @embedding_dim.setter 75 def embedding_dim(self, value: int) -> None: 76 pass
77 78
[docs] 79class SpladeModel(SingleVectorBiEncoderModel): 80 """Sparse lexical SPLADE model. See :class:`SpladeConfig` for configuration options.""" 81 82 config_class = SpladeConfig 83 """Configuration class for a SPLADE model.""" 84
[docs] 85 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 86 """Initializes a SPLADE model given a :class:`SpladeConfig`. 87 88 :param config: Configuration for the SPLADE model 89 :type config: SingleVectorBiEncoderConfig 90 """ 91 super().__init__(config, *args, **kwargs) 92 # grab language modeling head based on backbone model type 93 layer_cls = MODEL_TYPE_TO_LM_HEAD[config.backbone_model_type or config.model_type] 94 self.projection = layer_cls(config) 95 tied_weight_keys = getattr(self, "_tied_weights_keys", []) or [] 96 tied_weight_keys = tied_weight_keys + [ 97 f"projection.{key}" 98 for key in MODEL_TYPE_TO_TIED_WEIGHTS_KEYS[config.backbone_model_type or config.model_type] 99 ] 100 setattr(self, "_tied_weights_keys", tied_weight_keys)
101
[docs] 102 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 103 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 104 105 :param encoding: Tokenizer encodings for the text sequence 106 :type encoding: BatchEncoding 107 :param input_type: Type of input, either "query" or "doc" 108 :type input_type: Literal["query", "doc"] 109 :return: Embeddings and scoring mask 110 :rtype: BiEncoderEmbedding 111 """ 112 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy") 113 embeddings = self._backbone_forward(**encoding).last_hidden_state 114 embeddings = self.projection(embeddings) 115 embeddings = self.sparsification(embeddings, self.config.sparsification) 116 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy) 117 return BiEncoderEmbedding(embeddings, None, encoding)
118
[docs] 119 @classmethod 120 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self: 121 """Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps 122 the transformers.PreTrainedModel.from_pretrained_ method to return a derived LightningIRModel. 123 See :class:`LightningIRModelClassFactory` for more details. 124 125.. _transformers.PreTrainedModel.from_pretrained: \ 126 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 127 128 :param model_name_or_path: Name or path of the pretrained model 129 :type model_name_or_path: str | Path 130 :raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed 131 :return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin 132 :rtype: LightningIRModel 133 134 .. ::doctest 135 .. highlight:: python 136 .. code-block:: python 137 138 >>> # Loading using model class and backbone checkpoint 139 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 140 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 141 >>> # Loading using base class and backbone checkpoint 142 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 143 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 144 """ 145 key_mapping = kwargs.pop("key_mapping", {}) 146 config = cls.config_class 147 # map mlm projection keys 148 model_type = config.backbone_model_type or config.model_type 149 if model_type in MODEL_TYPE_TO_KEY_MAPPING: 150 key_mapping.update(MODEL_TYPE_TO_KEY_MAPPING[model_type]) 151 if not key_mapping: 152 warnings.warn( 153 f"No mlm key mappings for model_type {model_type} were provided. " 154 "The pre-trained mlm weights will not be loaded correctly." 155 ) 156 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 157 return model
158 159 def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None: 160 if self.config.projection == "mlm": 161 raise NotImplementedError("Setting output embeddings is not supported for models with MLM projection.") 162 # TODO fix this (not super important, only necessary when additional tokens are added to the model) 163 # module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type] 164 # module = self 165 # for module_name in module_names.split(".")[:-1]: 166 # module = getattr(module, module_name) 167 # setattr(module, module_names.split(".")[-1], new_embeddings) 168 # setattr(module, "bias", new_embeddings.bias) 169
[docs] 170 def get_output_embeddings(self) -> torch.nn.Module | None: 171 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no 172 MLM head is used for projection. 173 174 :return: Output embeddings of the model 175 :rtype: torch.nn.Module | None 176 """ 177 module_names = MODEL_TYPE_TO_OUTPUT_EMBEDDINGS[self.config.backbone_model_type or self.config.model_type] 178 output = self.projection 179 for module_name in module_names.split("."): 180 output = getattr(output, module_name) 181 return output