CrossEncoderModule
- class lightning_ir.cross_encoder.cross_encoder_module.CrossEncoderModule(model_name_or_path: str | None = None, config: CrossEncoderConfig | None = None, model: CrossEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]
Bases:
LightningIRModule
- __init__(model_name_or_path: str | None = None, config: CrossEncoderConfig | None = None, model: CrossEncoderModel | None = None, loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, model_kwargs: Mapping[str, Any] | None = None)[source]
LightningIRModule
for cross-encoder models. It contains aCrossEncoderModel
and aCrossEncoderTokenizer
and implements the training, validation, and testing steps for the model.- Parameters:
model_name_or_path (str | None, optional) – Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
config (CrossEncoderConfig | None, optional) – CrossEncoderConfig to apply when loading from backbone model, defaults to None
model (CrossEncoderModel | None, optional) – Already instantiated CrossEncoderModel, defaults to None
loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional) – Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function, defaults to None
evaluation_metrics (Sequence[str] | None, optional) – Metrics corresponding to ir-measures measure strings to apply during validation or testing, defaults to None
model_kwargs (Mapping[str, Any] | None, optional) – Additional keyword arguments to pass to from_pretrained when loading a model, defaults to None
Methods
__init__
([model_name_or_path, config, ...])LightningIRModule
for cross-encoder models.forward
(batch)Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the backbone model as well as the relevance scores.
Attributes
training
- configure_optimizers() Optimizer
Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR programmatically, the optimizer must be set using
set_optimizer()
.- Raises:
ValueError – If optimizer is not set
- Returns:
Optimizer
- Return type:
torch.optim.Optimizer
- forward(batch: RankBatch | TrainBatch | SearchBatch) CrossEncoderOutput [source]
Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the backbone model as well as the relevance scores.
- Parameters:
batch (RankBatch | TrainBatch | SearchBatch) – Batch of data to run the forward pass on
- Raises:
ValueError – If the batch is a SearchBatch
- Returns:
Output of the model
- Return type:
- get_dataset(dataloader_idx: int) IRDataset | None
Gets the dataset instance from the dataloader index. Returns None if no dataset is found.
- Parameters:
dataloader_idx (int) – Index of the dataloader
- Returns:
Inference dataset
- Return type:
IRDataset | None
- get_dataset_id(dataset: IRDataset) str
Gets the dataset id from the dataloader index for logging.
- Parameters:
dataloader_idx (int) – Index of the dataloader
- Returns:
path to run file, ir-datasets dataset id, or dataloader index
- Return type:
str
- on_save_checkpoint(checkpoint: Dict[str, Any]) None
Saves the model and tokenizer to the trainer’s log directory.
- prepare_input(queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None) Dict[str, BatchEncoding]
Tokenizes queries and documents and returns the tokenized BatchEncoding.
- Parameters:
queries (Sequence[str] | None) – Queries to tokenize
docs (Sequence[str] | None) – Documents to tokenize
num_docs (Sequence[int] | int | None) – Number of documents per query, if None num_docs is inferred by len(docs) // len(queries), defaults to None
- Returns:
Tokenized queries and documents, format depends on the tokenizer
- Return type:
Dict[str, BatchEncoding]
- save_pretrained(save_path: str | Path) None
Saves the model and tokenizer to the save path.
- Parameters:
save_path (str | Path) – Path to save the model and tokenizer
- score(queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) LightningIROutput
Computes relevance scores for queries and documents.
- Parameters:
queries (Sequence[str]) – Queries to score
docs (Sequence[Sequence[str]]) – Documents to score
- Returns:
Model output
- Return type:
- set_optimizer(optimizer: Type[Optimizer], **optimizer_kwargs: Dict[str, Any]) LightningIRModule
Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
- Parameters:
optimizer (Type[torch.optim.Optimizer]) – Torch optimizer class
optimizer_kwargs (Dict[str, Any]) – Arguments to initialize the optimizer
- Returns:
self
- Return type:
- test_step(batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput
Handles the testing step for the model. Passes the batch to the validation step.
- Parameters:
batch (TrainBatch | RankBatch) – Batch of testing data
batch_idx (int) – Index of the batch
dataloader_idx (int, optional) – Index of the dataloader, defaults to 0
- Returns:
Model output
- Return type:
- training_step(batch: TrainBatch, batch_idx: int) Tensor
Handles the training step for the model.
- Parameters:
batch (TrainBatch) – Batch of training data
batch_idx (int) – Index of the batch
- Raises:
ValueError – If no loss functions are set
- Returns:
Sum of the losses weighted by the loss weights
- Return type:
torch.Tensor
- validate(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the evaluation metrics and loss functions.
- Parameters:
output (LightningIROutput) – Model output
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data
- Returns:
Dictionary of evaluation metrics
- Return type:
Dict[str, float]
- validate_loss(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the loss functions.
- Parameters:
output (LightningIROutput) – Model output
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data
- Returns:
Evaluation metrics
- Return type:
Dict[str, float]
- validate_metrics(output: LightningIROutput, batch: TrainBatch | RankBatch | SearchBatch) Dict[str, float]
Validates the model output with the evaluation metrics.
- Parameters:
output (LightningIROutput) – Model output
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data
- Returns:
Evaluation metrics
- Return type:
Dict[str, float]
- validation_step(batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0) LightningIROutput
Handles the validation step for the model.
- Parameters:
batch (TrainBatch | RankBatch | SearchBatch) – Batch of validation or testing data
batch_idx (int) – Index of the batch
dataloader_idx (int, optional) – Index of the dataloader, defaults to 0
- Returns:
Model output
- Return type: