CrossEncoderModule
- class lightning_ir.cross_encoder.cross_encoder_module.CrossEncoderModule(model_name_or_path: str | None = None, config: CrossEncoderConfig | None = None, model: CrossEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]
Bases:
LightningIRModule- __init__(model_name_or_path: str | None = None, config: CrossEncoderConfig | None = None, model: CrossEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]
LightningIRModulefor cross-encoder models. It contains aCrossEncoderModeland aCrossEncoderTokenizerand implements the training, validation, and testing steps for the model.- Parameters:
model_name_or_path (str | None) – Name or path of backbone model or fine-tuned Lightning IR model. Defaults to None.
config (CrossEncoderConfig | None) – CrossEncoderConfig to apply when loading from backbone model. Defaults to None.
model (CrossEncoderModel | None) – Already instantiated CrossEncoderModel. Defaults to None.
loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None) – Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function. Defaults to None.
evaluation_metrics (Sequence[str] | None) – Metrics corresponding to ir-measures measure strings to apply during validation or testing. Defaults to None.
model_kwargs (Mapping[str, Any] | None) – Additional keyword arguments to pass to from_pretrained when loading a model. Defaults to None.
Methods
__init__([model_name_or_path, config, ...])LightningIRModulefor cross-encoder models.forward(batch)Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the backbone model as well as the relevance scores.
Attributes
training- configure_optimizers() Optimizer
Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR programmatically, the optimizer must be set using
set_optimizer().- Returns:
The optimizer set for the model.
- Return type:
torch.optim.Optimizer
- Raises:
ValueError – If optimizer is not set. Call set_optimizer.
- forward(batch: RankBatch | TrainBatch | SearchBatch) CrossEncoderOutput[source]
Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the backbone model as well as the relevance scores.
- Parameters:
batch (RankBatch | TrainBatch | SearchBatch) – Batch of data to run the forward pass on.
- Returns:
Output of the model.
- Return type:
- Raises:
ValueError – If the batch is a SearchBatch.
- get_dataset(dataloader_idx: int) IRDataset | None
Gets the dataset instance from the dataloader index. Returns None if no dataset is found.
- Parameters:
dataloader_idx (int) – Index of the dataloader.
- Returns:
Inference dataset or None if no dataset is found.
- Return type:
IRDataset | None
- get_dataset_id(dataset: IRDataset) str
Gets the dataset id from the dataloader index for logging.
- Parameters:
dataset (IRDataset) – Dataset instance.
- Returns:
Path to run file, ir-datasets dataset id, or dataloader index.
- Return type:
str
- on_save_checkpoint(checkpoint: Dict[str, Any]) None
Saves the model and tokenizer to the trainer’s log directory.
- prepare_input(queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None) Dict[str, BatchEncoding]
Tokenizes queries and documents and returns the tokenized BatchEncoding.
- Parameters:
queries (Sequence[str] | None) – Queries to tokenize.
docs (Sequence[str] | None) – Documents to tokenize.
num_docs (Sequence[int] | int | None) – Number of documents per query, if None num_docs is inferred by len(docs) // len(queries). Defaults to None.
- Returns:
Tokenized queries and documents, format depends on the tokenizer.
- Return type:
Dict[str, BatchEncoding]
- save_pretrained(save_path: str | Path) None
Saves the model and tokenizer to the save path.
- Parameters:
save_path (str | Path) – Path to save the model and tokenizer.
- score(queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) LightningIROutput
Computes relevance scores for queries and documents.
- Parameters:
queries (Sequence[str] | str) – Queries to score.
docs (Sequence[Sequence[str]] | Sequence[str]) – Documents to score.
- Returns:
Model output containing the scores.
- Return type:
- set_optimizer(optimizer: Type[Optimizer], **optimizer_kwargs: Dict[str, Any]) LightningIRModule
Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
- Parameters:
optimizer (Type[torch.optim.Optimizer]) – Torch optimizer class.
optimizer_kwargs (Dict[str, Any]) – Arguments to initialize the optimizer.
- Returns:
Self with the optimizer set.
- Return type:
- test_step(batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput
Handles the testing step for the model. Passes the batch to the validation step.
- Parameters:
batch (TrainBatch | RankBatch) – Batch of testing data.
batch_idx (int) – Index of the batch.
dataloader_idx (int, optional) – Index of the dataloader. Defaults to 0.
- Returns:
Model output.
- Return type:
- training_step(batch: TrainBatch, batch_idx: int) Tensor
Handles the training step for the model.
- Parameters:
batch (TrainBatch) – Batch of training data.
batch_idx (int) – Index of the batch.
- Returns:
Sum of the losses weighted by the loss weights.
- Return type:
torch.Tensor
- Raises:
ValueError – If no loss functions are set.
- validate(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the evaluation metrics and loss functions.
- Parameters:
output (LightningIROutput) – Model output.
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data.
- Returns:
Dictionary of evaluation metrics.
- Return type:
Dict[str, float]
- validate_loss(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the loss functions.
- Parameters:
output (LightningIROutput) – Model output.
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data.
- Returns:
Dictionary of evaluation metrics.
- Return type:
Dict[str, float]
- validate_metrics(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the evaluation metrics.
- Parameters:
output (LightningIROutput) – Model output.
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data.
- Returns:
Dictionary of evaluation metrics.
- Return type:
Dict[str, float]
- Raises:
ValueError – If query_ids or doc_ids are not set in the batch.
- validation_step(batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput
Handles the validation step for the model.
- Parameters:
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data.
batch_idx (int) – Index of the batch.
dataloader_idx (int, optional) – Index of the dataloader. Defaults to 0.
- Returns:
Model output.
- Return type: