LightningIRModel
- class lightning_ir.base.model.LightningIRModel(config: LightningIRConfig, *args, **kwargs)[source]
Bases:
LightningIRAdapterMixin,PreTrainedModelBase class for Lightning IR models. Derived classes implement the forward method for handling query and document embeddings. It acts as mixin for a transformers.PreTrainedModel backbone model.
- config_class
Configuration class for the model.
- Type:
Type[LightningIRConfig]
- ALLOW_SUB_BATCHING
Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure correctness.
- Type:
bool
- __init__(config: LightningIRConfig, *args, **kwargs) None[source]
Initializes the model.
- Parameters:
config (LightningIRConfig) – Configuration class for the model
Methods
__init__(config, *args, **kwargs)Initializes the model.
forward(*args, **kwargs)Forward method of the model.
from_pretrained(model_name_or_path, *args[, ...])Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method to return a
pooling(embeddings, attention_mask, ...)Helper method to apply pooling to the embeddings.
sparsification(embeddings[, ...])Helper method to apply sparsification to the embeddings.
Attributes
Flag to allow mini batches of documents for a single query.
training- ALLOW_SUB_BATCHING = True
Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure correctness.
- config_class
Configuration class for the model.
alias of
LightningIRConfig
- forward(*args, **kwargs) LightningIROutput[source]
Forward method of the model. Must be implemented by the derived class.
- classmethod from_pretrained(model_name_or_path: str | Path, *args, BackboneModel: Type[PreTrainedModel] | None = None, **kwargs) Self[source]
- Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method to return a
derived LightningIRModel. See
LightningIRModelClassFactoryfor more details.
>>> # Loading using model class and backbone checkpoint >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> >>> # Loading using base class and backbone checkpoint >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>- Args:
model_name_or_path (str | Path): Name or path of the pretrained model. BackboneModel (Type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone
instead of the default AutoModel. Defaults to None.
- Raises:
ValueError: If called on the abstract class LightningIRModel and no config is passed.
- Returns:
LightningIRModel: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.
- pooling(embeddings: Tensor, attention_mask: Tensor | None, pooling_strategy: 'first' | 'mean' | 'max' | 'sum' | None) Tensor[source]
Helper method to apply pooling to the embeddings.
- Parameters:
embeddings (torch.Tensor) – Query or document embeddings
attention_mask (torch.Tensor | None) – Query or document attention mask
pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | None) – The pooling strategy. No pooling is applied if None.
- Returns:
(Optionally) pooled embeddings.
- Return type:
torch.Tensor
- Raises:
ValueError – If an unknown pooling strategy is passed.
- sparsification(embeddings: Tensor, sparsification_strategy: 'relu' | 'relu_log' | 'relu_2xlog' | None = None) Tensor[source]
Helper method to apply sparsification to the embeddings.
- Parameters:
embeddings (torch.Tensor) – Query or document embeddings
sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None) – The sparsification strategy. No sparsification is applied if None. Defaults to None.
- Returns:
(Optionally) sparsified embeddings.
- Return type:
torch.Tensor
- Raises:
ValueError – If an unknown sparsification strategy is passed.