LightningIRModule
- class lightning_ir.base.module.LightningIRModule(model_name_or_path: str | None = None, config: LightningIRConfig | None = None, model: LightningIRModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]
Bases:
LightningModule
LightningIRModule base class. It dervies from a LightningModule. LightningIRModules contain a LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the model. Derived classes must implement the forward method for the model.
- __init__(model_name_or_path: str | None = None, config: LightningIRConfig | None = None, model: LightningIRModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]
Initializes the LightningIRModule.
- Parameters:
model_name_or_path (str | None, optional) – Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
config (LightningIRConfig | None, optional) – LightningIRConfig to apply when loading from backbone model, defaults to None
model (LightningIRModel | None, optional) – Already instantiated Lightning IR model, defaults to None
loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional) – Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function, defaults to None
evaluation_metrics (Sequence[str] | None, optional) – Metrics corresponding to ir-measures measure strings to apply during validation or testing, defaults to None
model_kwargs (Mapping[str, Any] | None, optional) – Additional keyword arguments to pass to from_pretrained when loading a model, defaults to None
- Raises:
ValueError – If both model and model_name_or_path are provided
ValueError – If neither model nor model_name_or_path are provided
Methods
__init__
([model_name_or_path, config, ...])Initializes the LightningIRModule.
Configures the optizmizer for fine-tuning.
forward
(batch)Handles the forward pass of the model.
get_dataset
(dataloader_idx)Gets the dataset instance from the dataloader index.
get_dataset_id
(dataset)Gets the dataset id from the dataloader index for logging.
on_save_checkpoint
(checkpoint)Saves the model and tokenizer to the trainer's log directory.
Prints the accumulated metrics for each dataloader.
Called at the beginning of testing.
Called at the beginning of training after sanity check.
Prints the validation results for each dataloader.
Called at the beginning of validation.
prepare_input
(queries, docs, num_docs)Tokenizes queries and documents and returns the tokenized BatchEncoding.
save_pretrained
(save_path)Saves the model and tokenizer to the save path.
score
(queries, docs)Computes relevance scores for queries and documents.
set_optimizer
(optimizer, **optimizer_kwargs)Sets the optimizer for the model.
test_step
(batch, batch_idx[, dataloader_idx])Handles the testing step for the model.
training_step
(batch, batch_idx)Handles the training step for the model.
validate
(output, batch)Validates the model output with the evaluation metrics and loss functions.
validate_loss
(output, batch)Validates the model output with the loss functions.
validate_metrics
(output, batch)Validates the model output with the evaluation metrics.
validation_step
(batch, batch_idx[, ...])Handles the validation step for the model.
Attributes
training
- configure_optimizers() Optimizer [source]
Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR programmatically, the optimizer must be set using
set_optimizer()
.- Raises:
ValueError – If optimizer is not set
- Returns:
Optimizer
- Return type:
torch.optim.Optimizer
- forward(batch: TrainBatch | RankBatch | SearchBatch) LightningIROutput [source]
Handles the forward pass of the model.
- Parameters:
batch (TrainBatch | RankBatch) – Batch of training or ranking data
- Raises:
NotImplementedError – Must be implemented by derived class
- Returns:
Model output
- Return type:
- get_dataset(dataloader_idx: int) IRDataset | None [source]
Gets the dataset instance from the dataloader index. Returns None if no dataset is found.
- Parameters:
dataloader_idx (int) – Index of the dataloader
- Returns:
Inference dataset
- Return type:
IRDataset | None
- get_dataset_id(dataset: IRDataset) str [source]
Gets the dataset id from the dataloader index for logging.
- Parameters:
dataloader_idx (int) – Index of the dataloader
- Returns:
path to run file, ir-datasets dataset id, or dataloader index
- Return type:
str
- on_save_checkpoint(checkpoint: Dict[str, Any]) None [source]
Saves the model and tokenizer to the trainer’s log directory.
- on_test_end() None [source]
Prints the accumulated metrics for each dataloader.
- on_test_start() None [source]
Called at the beginning of testing.
- on_train_start() None [source]
Called at the beginning of training after sanity check.
- on_validation_end() None [source]
Prints the validation results for each dataloader.
- on_validation_start() None [source]
Called at the beginning of validation.
- prepare_input(queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None) Dict[str, BatchEncoding] [source]
Tokenizes queries and documents and returns the tokenized BatchEncoding.
- Parameters:
queries (Sequence[str] | None) – Queries to tokenize
docs (Sequence[str] | None) – Documents to tokenize
num_docs (Sequence[int] | int | None) – Number of documents per query, if None num_docs is inferred by len(docs) // len(queries), defaults to None
- Returns:
Tokenized queries and documents, format depends on the tokenizer
- Return type:
Dict[str, BatchEncoding]
- save_pretrained(save_path: str | Path) None [source]
Saves the model and tokenizer to the save path.
- Parameters:
save_path (str | Path) – Path to save the model and tokenizer
- score(queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) LightningIROutput [source]
Computes relevance scores for queries and documents.
- Parameters:
queries (Sequence[str]) – Queries to score
docs (Sequence[Sequence[str]]) – Documents to score
- Returns:
Model output
- Return type:
- set_optimizer(optimizer: Type[Optimizer], **optimizer_kwargs: Dict[str, Any]) LightningIRModule [source]
Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
- Parameters:
optimizer (Type[torch.optim.Optimizer]) – Torch optimizer class
optimizer_kwargs (Dict[str, Any]) – Arguments to initialize the optimizer
- Returns:
self
- Return type:
- test_step(batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput [source]
Handles the testing step for the model. Passes the batch to the validation step.
- Parameters:
batch (TrainBatch | RankBatch) – Batch of testing data
batch_idx (int) – Index of the batch
dataloader_idx (int, optional) – Index of the dataloader, defaults to 0
- Returns:
Model output
- Return type:
- training_step(batch: TrainBatch, batch_idx: int) Tensor [source]
Handles the training step for the model.
- Parameters:
batch (TrainBatch) – Batch of training data
batch_idx (int) – Index of the batch
- Raises:
ValueError – If no loss functions are set
- Returns:
Sum of the losses weighted by the loss weights
- Return type:
torch.Tensor
- validate(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float] [source]
Validates the model output with the evaluation metrics and loss functions.
- Parameters:
output (LightningIROutput) – Model output
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data
- Returns:
Dictionary of evaluation metrics
- Return type:
Dict[str, float]
- validate_loss(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float] [source]
Validates the model output with the loss functions.
- Parameters:
output (LightningIROutput) – Model output
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data
- Returns:
Evaluation metrics
- Return type:
Dict[str, float]
- validate_metrics(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float] [source]
Validates the model output with the evaluation metrics.
- Parameters:
output (LightningIROutput) – Model output
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data
- Returns:
Evaluation metrics
- Return type:
Dict[str, float]
- validation_step(batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput [source]
Handles the validation step for the model.
- Parameters:
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data
batch_idx (int) – Index of the batch
dataloader_idx (int, optional) – Index of the dataloader, defaults to 0
- Returns:
Model output
- Return type: