BiEncoderModule

class lightning_ir.bi_encoder.bi_encoder_module.BiEncoderModule(model_name_or_path: str | None = None, config: BiEncoderConfig | None = None, model: BiEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, index_dir: Path | None = None, search_config: SearchConfig | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]

Bases: LightningIRModule

__init__(model_name_or_path: str | None = None, config: BiEncoderConfig | None = None, model: BiEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, index_dir: Path | None = None, search_config: SearchConfig | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]

LightningIRModule for bi-encoder models. It contains a BiEncoderModel and a BiEncoderTokenizer and implements the training, validation, and testing steps for the model.

Parameters:
  • model_name_or_path (str | None) – Name or path of backbone model or fine-tuned Lightning IR model. Defaults to None.

  • config (BiEncoderConfig | None) – BiEncoderConfig to apply when loading from backbone model. Defaults to None.

  • model (BiEncoderModel | None) – Already instantiated BiEncoderModel. Defaults to None.

  • loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None) – Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function Defaults to None.

  • evaluation_metrics (Sequence[str] | None) – Metrics corresponding to ir-measures measure strings to apply during validation or testing. Defaults to None.

  • index_dir (Path | None) – Path to an index used for retrieval. Defaults to None.

  • search_config (SearchConfig | None) – Configuration to use during retrieval. Defaults to None.

  • model_kwargs (Mapping[str, Any] | None) – Additional keyword arguments to pass to from_pretrained when loading a model. Defaults to None.

Methods

__init__([model_name_or_path, config, ...])

LightningIRModule for bi-encoder models.

forward(batch)

Runs a forward pass of the model on a batch of data.

on_test_start()

Called at the beginning of testing.

score(queries, docs)

Computes relevance scores for queries and documents.

validation_step(batch, batch_idx[, ...])

Handles the validation step for the model.

Attributes

searcher

Searcher used for retrieval if index_dir and search_config are set.

training

configure_optimizers() Optimizer

Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR programmatically, the optimizer must be set using set_optimizer().

Returns:

The optimizer set for the model.

Return type:

torch.optim.Optimizer

Raises:

ValueError – If optimizer is not set. Call set_optimizer.

forward(batch: RankBatch | IndexBatch | SearchBatch) BiEncoderOutput[source]

Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is the similarity between the two embeddings. If the batch is an IndexBatch, only document embeddings are comuputed. If the batch is a SearchBatch, only query embeddings are computed and the model will additionally retrieve documents if searcher is set.

Parameters:

batch (RankBatch | IndexBatch | SearchBatch) – Input batch containing queries and/or documents.

Returns:

Output of the model.

Return type:

BiEncoderOutput

Raises:

ValueError – If the input batch contains neither queries nor documents.

get_dataset(dataloader_idx: int) IRDataset | None

Gets the dataset instance from the dataloader index. Returns None if no dataset is found.

Parameters:

dataloader_idx (int) – Index of the dataloader.

Returns:

Inference dataset or None if no dataset is found.

Return type:

IRDataset | None

get_dataset_id(dataset: IRDataset) str

Gets the dataset id from the dataloader index for logging.

Parameters:

dataset (IRDataset) – Dataset instance.

Returns:

Path to run file, ir-datasets dataset id, or dataloader index.

Return type:

str

on_save_checkpoint(checkpoint: Dict[str, Any]) None

Saves the model and tokenizer to the trainer’s log directory.

on_test_end() None

Prints the accumulated metrics for each dataloader.

on_test_start() None[source]

Called at the beginning of testing. Initializes the searcher if index_dir and search_config are set.

on_train_start() None

Called at the beginning of training after sanity check.

on_validation_end() None

Prints the validation results for each dataloader.

on_validation_start() None

Called at the beginning of validation.

prepare_input(queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None) Dict[str, BatchEncoding]

Tokenizes queries and documents and returns the tokenized BatchEncoding.

Parameters:
  • queries (Sequence[str] | None) – Queries to tokenize.

  • docs (Sequence[str] | None) – Documents to tokenize.

  • num_docs (Sequence[int] | int | None) – Number of documents per query, if None num_docs is inferred by len(docs) // len(queries). Defaults to None.

Returns:

Tokenized queries and documents, format depends on the tokenizer.

Return type:

Dict[str, BatchEncoding]

save_pretrained(save_path: str | Path) None

Saves the model and tokenizer to the save path.

Parameters:

save_path (str | Path) – Path to save the model and tokenizer.

score(queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) BiEncoderOutput[source]

Computes relevance scores for queries and documents.

Parameters:
  • queries (Sequence[str] | str) – Queries to score.

  • docs (Sequence[Sequence[str]] | Sequence[str]) – Documents to score.

Returns:

Output of the model.

Return type:

BiEncoderOutput

property searcher: Searcher | None

Searcher used for retrieval if index_dir and search_config are set.

Returns:

Searcher class.

Return type:

Searcher

set_optimizer(optimizer: Type[Optimizer], **optimizer_kwargs: Dict[str, Any]) LightningIRModule

Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.

Parameters:
  • optimizer (Type[torch.optim.Optimizer]) – Torch optimizer class.

  • optimizer_kwargs (Dict[str, Any]) – Arguments to initialize the optimizer.

Returns:

Self with the optimizer set.

Return type:

LightningIRModule

test_step(batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput

Handles the testing step for the model. Passes the batch to the validation step.

Parameters:
  • batch (TrainBatch | RankBatch) – Batch of testing data.

  • batch_idx (int) – Index of the batch.

  • dataloader_idx (int, optional) – Index of the dataloader. Defaults to 0.

Returns:

Model output.

Return type:

LightningIROutput

training_step(batch: TrainBatch, batch_idx: int) Tensor

Handles the training step for the model.

Parameters:
  • batch (TrainBatch) – Batch of training data.

  • batch_idx (int) – Index of the batch.

Returns:

Sum of the losses weighted by the loss weights.

Return type:

torch.Tensor

Raises:

ValueError – If no loss functions are set.

validate(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]

Validates the model output with the evaluation metrics and loss functions.

Parameters:
Returns:

Dictionary of evaluation metrics.

Return type:

Dict[str, float]

validate_loss(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]

Validates the model output with the loss functions.

Parameters:
Returns:

Dictionary of evaluation metrics.

Return type:

Dict[str, float]

validate_metrics(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]

Validates the model output with the evaluation metrics.

Parameters:
Returns:

Dictionary of evaluation metrics.

Return type:

Dict[str, float]

Raises:

ValueError – If query_ids or doc_ids are not set in the batch.

validation_step(batch: TrainBatch | IndexBatch | SearchBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) BiEncoderOutput[source]

Handles the validation step for the model.

Parameters:
  • batch (TrainBatch | IndexBatch | SearchBatch | RankBatch) – Batch of validation or testing data.

  • batch_idx (int) – Index of the batch.

  • dataloader_idx (int, optional) – Index of the dataloader. Defaults to 0.

Returns:

Output of the model.

Return type:

BiEncoderOutput