SpladeModel
- class lightning_ir.models.bi_encoders.splade.SpladeModel(config: SingleVectorBiEncoderConfig, *args, **kwargs)[source]
Bases:
SingleVectorBiEncoderModelSparse lexical SPLADE model. See
SpladeConfigfor configuration options.- __init__(config: SingleVectorBiEncoderConfig, *args, **kwargs) None[source]
Initializes a SPLADE model given a
SpladeConfig.- Parameters:
config (SingleVectorBiEncoderConfig) – Configuration for the SPLADE model.
Methods
__init__(config, *args, **kwargs)Initializes a SPLADE model given a
SpladeConfig.encode(encoding, input_type)Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
from_pretrained(model_name_or_path, *args, ...)Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps
Returns the output embeddings of the model for tieing the input and output embeddings.
set_output_embeddings(new_embeddings)Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
Attributes
training- config_class
Configuration class for a SPLADE model.
alias of
SpladeConfig
- encode(encoding: BatchEncoding, input_type: 'query' | 'doc') BiEncoderEmbedding[source]
Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
- Parameters:
encoding (BatchEncoding) – Tokenizer encodings for the text sequence.
input_type (Literal["query", "doc"]) – Type of input, either “query” or “doc”.
- Returns:
Embeddings and scoring mask.
- Return type:
- classmethod from_pretrained(model_name_or_path: str | Path, *args, **kwargs) Self[source]
- Loads a pretrained model and handles mapping the MLM head weights to the projection head weights. Wraps
the transformers.PreTrainedModel.from_pretrained method to return a derived LightningIRModel. See
LightningIRModelClassFactoryfor more details.
>>> # Loading using model class and backbone checkpoint >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> >>> # Loading using base class and backbone checkpoint >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>- Args:
model_name_or_path (str | Path): Name or path of the pretrained model.
- Returns:
Self: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.
- Raises:
ValueError: If called on the abstract class
SpladeModeland no config is passed.
- get_output_embeddings() Module | None[source]
Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no MLM head is used for projection.
- Returns:
Output embeddings of the model.
- Return type:
torch.nn.Module | None
- set_output_embeddings(new_embeddings: Module) None[source]
Sets the model’s output embedding, defaulting to setting new_embeddings to lm_head.