BiEncoderModule
- class lightning_ir.bi_encoder.bi_encoder_module.BiEncoderModule(model_name_or_path: str | None = None, config: BiEncoderConfig | None = None, model: BiEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, index_dir: Path | None = None, search_config: SearchConfig | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]
Bases:
LightningIRModule- __init__(model_name_or_path: str | None = None, config: BiEncoderConfig | None = None, model: BiEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, index_dir: Path | None = None, search_config: SearchConfig | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]
LightningIRModulefor bi-encoder models. It contains aBiEncoderModeland aBiEncoderTokenizerand implements the training, validation, and testing steps for the model.- Parameters:
model_name_or_path (str | None) – Name or path of backbone model or fine-tuned Lightning IR model. Defaults to None.
config (BiEncoderConfig | None) – BiEncoderConfig to apply when loading from backbone model. Defaults to None.
model (BiEncoderModel | None) – Already instantiated BiEncoderModel. Defaults to None.
loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None) – Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function Defaults to None.
evaluation_metrics (Sequence[str] | None) – Metrics corresponding to ir-measures measure strings to apply during validation or testing. Defaults to None.
index_dir (Path | None) – Path to an index used for retrieval. Defaults to None.
search_config (SearchConfig | None) – Configuration to use during retrieval. Defaults to None.
model_kwargs (Mapping[str, Any] | None) – Additional keyword arguments to pass to from_pretrained when loading a model. Defaults to None.
Methods
__init__([model_name_or_path, config, ...])LightningIRModulefor bi-encoder models.forward(batch)Runs a forward pass of the model on a batch of data.
Called at the beginning of testing.
score(queries, docs)Computes relevance scores for queries and documents.
validation_step(batch, batch_idx[, ...])Handles the validation step for the model.
Attributes
Searcher used for retrieval if index_dir and search_config are set.
training- configure_optimizers() Optimizer
Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR programmatically, the optimizer must be set using
set_optimizer().- Returns:
The optimizer set for the model.
- Return type:
torch.optim.Optimizer
- Raises:
ValueError – If optimizer is not set. Call set_optimizer.
- forward(batch: RankBatch | IndexBatch | SearchBatch) BiEncoderOutput[source]
Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is the similarity between the two embeddings. If the batch is an
IndexBatch, only document embeddings are comuputed. If the batch is aSearchBatch, only query embeddings are computed and the model will additionally retrieve documents ifsearcheris set.- Parameters:
batch (RankBatch | IndexBatch | SearchBatch) – Input batch containing queries and/or documents.
- Returns:
Output of the model.
- Return type:
- Raises:
ValueError – If the input batch contains neither queries nor documents.
- get_dataset(dataloader_idx: int) IRDataset | None
Gets the dataset instance from the dataloader index. Returns None if no dataset is found.
- Parameters:
dataloader_idx (int) – Index of the dataloader.
- Returns:
Inference dataset or None if no dataset is found.
- Return type:
IRDataset | None
- get_dataset_id(dataset: IRDataset) str
Gets the dataset id from the dataloader index for logging.
- Parameters:
dataset (IRDataset) – Dataset instance.
- Returns:
Path to run file, ir-datasets dataset id, or dataloader index.
- Return type:
str
- on_save_checkpoint(checkpoint: Dict[str, Any]) None
Saves the model and tokenizer to the trainer’s log directory.
- on_test_start() None[source]
Called at the beginning of testing. Initializes the searcher if index_dir and search_config are set.
- prepare_input(queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None) Dict[str, BatchEncoding]
Tokenizes queries and documents and returns the tokenized BatchEncoding.
- Parameters:
queries (Sequence[str] | None) – Queries to tokenize.
docs (Sequence[str] | None) – Documents to tokenize.
num_docs (Sequence[int] | int | None) – Number of documents per query, if None num_docs is inferred by len(docs) // len(queries). Defaults to None.
- Returns:
Tokenized queries and documents, format depends on the tokenizer.
- Return type:
Dict[str, BatchEncoding]
- save_pretrained(save_path: str | Path) None
Saves the model and tokenizer to the save path.
- Parameters:
save_path (str | Path) – Path to save the model and tokenizer.
- score(queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) BiEncoderOutput[source]
Computes relevance scores for queries and documents.
- Parameters:
queries (Sequence[str] | str) – Queries to score.
docs (Sequence[Sequence[str]] | Sequence[str]) – Documents to score.
- Returns:
Output of the model.
- Return type:
- property searcher: Searcher | None
Searcher used for retrieval if index_dir and search_config are set.
- Returns:
Searcher class.
- Return type:
- set_optimizer(optimizer: Type[Optimizer], **optimizer_kwargs: Dict[str, Any]) LightningIRModule
Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
- Parameters:
optimizer (Type[torch.optim.Optimizer]) – Torch optimizer class.
optimizer_kwargs (Dict[str, Any]) – Arguments to initialize the optimizer.
- Returns:
Self with the optimizer set.
- Return type:
- test_step(batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput
Handles the testing step for the model. Passes the batch to the validation step.
- Parameters:
batch (TrainBatch | RankBatch) – Batch of testing data.
batch_idx (int) – Index of the batch.
dataloader_idx (int, optional) – Index of the dataloader. Defaults to 0.
- Returns:
Model output.
- Return type:
- training_step(batch: TrainBatch, batch_idx: int) Tensor
Handles the training step for the model.
- Parameters:
batch (TrainBatch) – Batch of training data.
batch_idx (int) – Index of the batch.
- Returns:
Sum of the losses weighted by the loss weights.
- Return type:
torch.Tensor
- Raises:
ValueError – If no loss functions are set.
- validate(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the evaluation metrics and loss functions.
- Parameters:
output (LightningIROutput) – Model output.
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data.
- Returns:
Dictionary of evaluation metrics.
- Return type:
Dict[str, float]
- validate_loss(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the loss functions.
- Parameters:
output (LightningIROutput) – Model output.
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data.
- Returns:
Dictionary of evaluation metrics.
- Return type:
Dict[str, float]
- validate_metrics(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the evaluation metrics.
- Parameters:
output (LightningIROutput) – Model output.
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data.
- Returns:
Dictionary of evaluation metrics.
- Return type:
Dict[str, float]
- Raises:
ValueError – If query_ids or doc_ids are not set in the batch.
- validation_step(batch: TrainBatch | IndexBatch | SearchBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) BiEncoderOutput[source]
Handles the validation step for the model.
- Parameters:
batch (TrainBatch | IndexBatch | SearchBatch | RankBatch) – Batch of validation or testing data.
batch_idx (int) – Index of the batch.
dataloader_idx (int, optional) – Index of the dataloader. Defaults to 0.
- Returns:
Output of the model.
- Return type: