LightningIRModel

class lightning_ir.base.model.LightningIRModel(config: LightningIRConfig, *args, **kwargs)[source]

Bases: LightningIRAdapterMixin, PreTrainedModel

Base class for Lightning IR models. Derived classes implement the forward method for handling query and document embeddings. It acts as mixin for a transformers.PreTrainedModel backbone model.

config_class

Configuration class for the model.

Type:

type[LightningIRConfig]

ALLOW_SUB_BATCHING

Flag to allow mini batches of documents for a single query. set to false for listwise models to ensure correctness.

Type:

bool

__init__(config: LightningIRConfig, *args, **kwargs) None[source]

Initializes the model.

Parameters:

config (LightningIRConfig) – Configuration class for the model

Methods

__init__(config, *args, **kwargs)

Initializes the model.

forward(*args, **kwargs)

Forward method of the model.

from_pretrained(model_name_or_path, *args[, ...])

Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method to return a

Attributes

ALLOW_SUB_BATCHING

Flag to allow mini batches of documents for a single query.

training

ALLOW_SUB_BATCHING = True

Flag to allow mini batches of documents for a single query. set to false for listwise models to ensure correctness.

config_class

Configuration class for the model.

alias of LightningIRConfig

forward(*args, **kwargs) LightningIROutput[source]

Forward method of the model. Must be implemented by the derived class.

classmethod from_pretrained(model_name_or_path: str | Path, *args, BackboneModel: type[PreTrainedModel] | None = None, **kwargs) Self[source]
Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method to return a

derived LightningIRModel. See LightningIRModelClassFactory for more details.

>>> # Loading using model class and backbone checkpoint
>>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
<class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
>>> # Loading using base class and backbone checkpoint
>>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
<class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
Args:

model_name_or_path (str | Path): Name or path of the pretrained model. BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone

instead of the default AutoModel. Defaults to None.

Raises:

ValueError: If called on the abstract class LightningIRModel and no config is passed.

Returns:

LightningIRModel: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.