BiEncoderModule

class lightning_ir.bi_encoder.bi_encoder_module.BiEncoderModule(model_name_or_path: str | None = None, config: BiEncoderConfig | None = None, model: BiEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, index_dir: Path | None = None, search_config: SearchConfig | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]

Bases: LightningIRModule

__init__(model_name_or_path: str | None = None, config: BiEncoderConfig | None = None, model: BiEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, index_dir: Path | None = None, search_config: SearchConfig | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]

LightningIRModule for bi-encoder models. It contains a BiEncoderModel and a BiEncoderTokenizer and implements the training, validation, and testing steps for the model.

Parameters:
  • model_name_or_path (str | None, optional) – Name or path of backbone model or fine-tuned Lightning IR model, defaults to None

  • config (BiEncoderConfig | None, optional) – BiEncoderConfig to apply when loading from backbone model, defaults to None

  • model (BiEncoderModel | None, optional) – Already instantiated BiEncoderModel, defaults to None

  • loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional) – Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function, defaults to None

  • evaluation_metrics (Sequence[str] | None, optional) – Metrics corresponding to ir-measures measure strings to apply during validation or testing, defaults to None

  • index_dir (Path | None, optional) – Path to an index used for retrieval, defaults to None

  • search_config (SearchConfig | None, optional) – Configuration to use during retrieval, defaults to None

  • model_kwargs (Mapping[str, Any] | None, optional) – Additional keyword arguments to pass to from_pretrained when loading a model, defaults to None

Methods

__init__([model_name_or_path, config, ...])

LightningIRModule for bi-encoder models.

forward(batch)

Runs a forward pass of the model on a batch of data.

on_test_start()

Called at the beginning of testing.

score(queries, docs)

Computes relevance scores for queries and documents.

validation_step(batch, batch_idx[, ...])

Handles the validation step for the model.

Attributes

searcher

Searcher used for retrieval if index_dir and search_config are set.

training

configure_optimizers() Optimizer

Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR programmatically, the optimizer must be set using set_optimizer().

Raises:

ValueError – If optimizer is not set

Returns:

Optimizer

Return type:

torch.optim.Optimizer

forward(batch: RankBatch | IndexBatch | SearchBatch) BiEncoderOutput[source]

Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is the similarity between the two embeddings. If the batch is an IndexBatch, only document embeddings are comuputed. If the batch is a SearchBatch, only query embeddings are computed and the model will additionally retrieve documents if searcher is set.

Parameters:

batch (RankBatch | IndexBatch | SearchBatch) – Input batch containg

Raises:

ValueError – If the input batch contains neither queries nor documents

Returns:

Output of the model

Return type:

BiEncoderOutput

get_dataset(dataloader_idx: int) IRDataset | None

Gets the dataset instance from the dataloader index. Returns None if no dataset is found.

Parameters:

dataloader_idx (int) – Index of the dataloader

Returns:

Inference dataset

Return type:

IRDataset | None

get_dataset_id(dataset: IRDataset) str

Gets the dataset id from the dataloader index for logging.

Parameters:

dataloader_idx (int) – Index of the dataloader

Returns:

path to run file, ir-datasets dataset id, or dataloader index

Return type:

str

on_save_checkpoint(checkpoint: Dict[str, Any]) None

Saves the model and tokenizer to the trainer’s log directory.

on_test_end() None

Prints the accumulated metrics for each dataloader.

on_test_start() None[source]

Called at the beginning of testing. Initializes the searcher if index_dir and search_config are set.

on_train_start() None

Called at the beginning of training after sanity check.

on_validation_end() None

Prints the validation results for each dataloader.

on_validation_start() None

Called at the beginning of validation.

prepare_input(queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None) Dict[str, BatchEncoding]

Tokenizes queries and documents and returns the tokenized BatchEncoding.

Parameters:
  • queries (Sequence[str] | None) – Queries to tokenize

  • docs (Sequence[str] | None) – Documents to tokenize

  • num_docs (Sequence[int] | int | None) – Number of documents per query, if None num_docs is inferred by len(docs) // len(queries), defaults to None

Returns:

Tokenized queries and documents, format depends on the tokenizer

Return type:

Dict[str, BatchEncoding]

save_pretrained(save_path: str | Path) None

Saves the model and tokenizer to the save path.

Parameters:

save_path (str | Path) – Path to save the model and tokenizer

score(queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) BiEncoderOutput[source]

Computes relevance scores for queries and documents.

Parameters:
  • queries (Sequence[str]) – Queries to score

  • docs (Sequence[Sequence[str]]) – Documents to score

Returns:

Model output

Return type:

BiEncoderOutput

property searcher: Searcher | None

Searcher used for retrieval if index_dir and search_config are set.

Returns:

Searcher class

Return type:

Searcher | None

set_optimizer(optimizer: Type[Optimizer], **optimizer_kwargs: Dict[str, Any]) LightningIRModule

Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.

Parameters:
  • optimizer (Type[torch.optim.Optimizer]) – Torch optimizer class

  • optimizer_kwargs (Dict[str, Any]) – Arguments to initialize the optimizer

Returns:

self

Return type:

LightningIRModule

test_step(batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput

Handles the testing step for the model. Passes the batch to the validation step.

Parameters:
  • batch (TrainBatch | RankBatch) – Batch of testing data

  • batch_idx (int) – Index of the batch

  • dataloader_idx (int, optional) – Index of the dataloader, defaults to 0

Returns:

Model output

Return type:

LightningIROutput

training_step(batch: TrainBatch, batch_idx: int) Tensor

Handles the training step for the model.

Parameters:
  • batch (TrainBatch) – Batch of training data

  • batch_idx (int) – Index of the batch

Raises:

ValueError – If no loss functions are set

Returns:

Sum of the losses weighted by the loss weights

Return type:

torch.Tensor

validate(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]

Validates the model output with the evaluation metrics and loss functions.

Parameters:
Returns:

Dictionary of evaluation metrics

Return type:

Dict[str, float]

validate_loss(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]

Validates the model output with the loss functions.

Parameters:
Returns:

Evaluation metrics

Return type:

Dict[str, float]

validate_metrics(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]

Validates the model output with the evaluation metrics.

Parameters:
Returns:

Evaluation metrics

Return type:

Dict[str, float]

validation_step(batch: TrainBatch | IndexBatch | SearchBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) BiEncoderOutput[source]

Handles the validation step for the model.

Parameters:
  • batch (TrainBatch | IndexBatch | SearchBatch | RankBatch) – Batch of validation or testing data

  • batch_idx (int) – Index of the batch

  • dataloader_idx (int, optional) – Index of the dataloader, defaults to 0

Returns:

Model output

Return type:

BiEncoderOutput