LightningIRTrainer
- class lightning_ir.main.LightningIRTrainer(*, accelerator: str | Accelerator = 'auto', strategy: str | Strategy = 'auto', devices: list[int] | str | int = 'auto', num_nodes: int = 1, precision: 64 | 32 | 16 | 'transformer-engine' | 'transformer-engine-float16' | '16-true' | '16-mixed' | 'bf16-true' | 'bf16-mixed' | '32-true' | '64-true' | '64' | '32' | '16' | 'bf16' | None = None, logger: Logger | Iterable[Logger] | bool | None = None, callbacks: list[Callback] | Callback | None = None, fast_dev_run: int | bool = False, max_epochs: int | None = None, min_epochs: int | None = None, max_steps: int = -1, min_steps: int | None = None, max_time: str | timedelta | dict[str, int] | None = None, limit_train_batches: int | float | None = None, limit_val_batches: int | float | None = None, limit_test_batches: int | float | None = None, limit_predict_batches: int | float | None = None, overfit_batches: int | float = 0.0, val_check_interval: int | float | None = None, check_val_every_n_epoch: int | None = 1, num_sanity_val_steps: int | None = None, log_every_n_steps: int | None = None, enable_checkpointing: bool | None = None, enable_progress_bar: bool | None = None, enable_model_summary: bool | None = None, accumulate_grad_batches: int = 1, gradient_clip_val: int | float | None = None, gradient_clip_algorithm: str | None = None, deterministic: bool | 'warn' | None = None, benchmark: bool | None = None, inference_mode: bool = True, use_distributed_sampler: bool = True, profiler: Profiler | str | None = None, detect_anomaly: bool = False, barebones: bool = False, plugins: Precision | ClusterEnvironment | CheckpointIO | LayerSync | list[Precision | ClusterEnvironment | CheckpointIO | LayerSync] | None = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: str | Path | None = None, model_registry: str | None = None)[source]
Bases:
TrainerLightning IR Trainer that extends PyTorch Lightning Trainer with information retrieval specific methods.
This trainer inherits all functionality from the PyTorch Lightning Trainer_ and adds specialized methods for information retrieval tasks including document indexing, searching, and re-ranking. It provides a unified interface for both training neural ranking models and performing inference across different IR stages.
The trainer seamlessly integrates with Lightning IR callbacks and supports all standard Lightning features including distributed training, mixed precision, gradient accumulation, and checkpointing.
Examples
Basic usage for fine-tuning and inference:
from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule # Initialize trainer with Lightning configuration trainer = LightningIRTrainer( max_steps=100_000, precision="16-mixed", devices=2, accelerator="gpu" ) # Fine-tune a model module = BiEncoderModule(model_name_or_path="bert-base-uncased") datamodule = LightningIRDataModule(...) trainer.fit(module, datamodule) # Index documents trainer.index(module, datamodule) # Search for relevant documents trainer.search(module, datamodule) # Re-rank retrieved documents trainer.re_rank(module, datamodule)Note
The trainer requires appropriate callbacks to be configured for each IR task: - IndexCallback for indexing operations - SearchCallback for search operations - ReRankCallback for re-ranking operations
Methods
index([model, dataloaders, ckpt_path, ...])Index a collection of documents using a fine-tuned bi-encoder model.
re_rank([model, dataloaders, ckpt_path, ...])Re-rank a set of retrieved documents using bi-encoder or cross-encoder models.
search([model, dataloaders, ckpt_path, ...])Search for relevant documents using a bi-encoder model and pre-built index.
Attributes
- index(model: LightningModule | None = None, dataloaders: Any | LightningDataModule | None = None, ckpt_path: str | Path | None = None, verbose: bool = True, datamodule: LightningDataModule | None = None) List[Mapping[str, float]][source]
Index a collection of documents using a fine-tuned bi-encoder model.
This method performs document indexing by running inference on a document collection and storing the resulting embeddings in an index structure. It requires an IndexCallback to be configured in the trainer to handle the actual indexing process.
- Parameters:
model (LightningModule | None) – The LightningIRModule containing the bi-encoder model to use for encoding documents. If None, uses the model from the datamodule.
dataloaders (Any | LightningDataModule | None) – DataLoader(s) or LightningIRDataModule containing the document collection to index. Should contain DocDataset instances.
ckpt_path (str | Path | None) – Path to a model checkpoint to load before indexing. If None, uses the current model state.
verbose (bool) – Whether to display progress during indexing. Defaults to True.
datamodule (LightningDataModule | None) – LightningIRDataModule instance. Alternative to passing dataloaders directly.
- Returns:
List of dictionaries containing indexing metrics and results.
- Return type:
List[Mapping[str, float]]
Example
from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule from lightning_ir import IndexCallback, TorchDenseIndexConfig, DocDataset # Setup trainer with index callback callback = IndexCallback( index_dir="./index", index_config=TorchDenseIndexConfig() ) trainer = LightningIRTrainer(callbacks=[callback]) # Setup model and data module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder") datamodule = LightningIRDataModule( inference_datasets=[DocDataset("msmarco-passage")] ) # Index the documents trainer.index(module, datamodule)Note
Requires IndexCallback to be configured in trainer callbacks
Only works with bi-encoder models that can encode documents
The index type and configuration are specified in the IndexCallback
- re_rank(model: LightningModule | None = None, dataloaders: Any | LightningDataModule | None = None, ckpt_path: str | Path | None = None, verbose: bool = True, datamodule: LightningDataModule | None = None) List[Mapping[str, float]][source]
Re-rank a set of retrieved documents using bi-encoder or cross-encoder models.
This method performs re-ranking by scoring query-document pairs and reordering them based on relevance scores. Cross-encoders typically provide higher effectiveness for re-ranking tasks compared to bi-encoders. It requires a ReRankCallback to be configured in the trainer to handle saving the re-ranked results.
- Parameters:
model (LightningModule | None) – The LightningIRModule containing the model to use for re-ranking. Can be either BiEncoderModule or CrossEncoderModule. If None, uses the model from the datamodule.
dataloaders (Any | LightningDataModule | None) – DataLoader(s) or LightningIRDataModule containing the query-document pairs to re-rank. Should contain RunDataset instances.
ckpt_path (str | Path | None) – Path to a model checkpoint to load before re-ranking. If None, uses the current model state.
verbose (bool) – Whether to display progress during re-ranking. Defaults to True.
datamodule (LightningDataModule | None) – LightningIRDataModule instance. Alternative to passing dataloaders directly.
- Returns:
- List of dictionaries containing re-ranking metrics and
effectiveness results (if relevance judgments are available).
- Return type:
List[Mapping[str, float]]
Example
from lightning_ir import LightningIRTrainer, CrossEncoderModule, LightningIRDataModule from lightning_ir import ReRankCallback, RunDataset # Setup trainer with re-rank callback rerank_callback = ReRankCallback(results_dir="./reranked_results") trainer = LightningIRTrainer(callbacks=[rerank_callback]) # Setup model and data module = CrossEncoderModule(model_name_or_path="webis/bert-cross-encoder") datamodule = LightningIRDataModule( inference_datasets=[RunDataset("path/to/run/file.txt")] ) # Re-rank the documents results = trainer.re_rank(module, datamodule)Note
Requires ReRankCallback to be configured in trainer callbacks
Input data should be in run file format (query-document pairs with initial scores)
Cross-encoders typically provide better effectiveness than bi-encoders for re-ranking
- search(model: LightningModule | None = None, dataloaders: Any | LightningDataModule | None = None, ckpt_path: str | Path | None = None, verbose: bool = True, datamodule: LightningDataModule | None = None) List[Mapping[str, float]][source]
Search for relevant documents using a bi-encoder model and pre-built index.
This method performs dense or sparse retrieval by encoding queries and searching through a pre-built index to find the most relevant documents. It requires a SearchCallback to be configured in the trainer to handle the search process and optionally a RankCallback to save results.
- Parameters:
model (LightningModule | None) – The LightningIRModule containing the bi-encoder model to use for encoding queries. If None, uses the model from the datamodule.
dataloaders (Any | LightningDataModule | None) – DataLoader(s) or LightningIRDataModule containing the queries to search for. Should contain QueryDataset instances.
ckpt_path (str | Path | None) – Path to a model checkpoint to load before searching. If None, uses the current model state.
verbose (bool) – Whether to display progress during searching. Defaults to True.
datamodule (LightningDataModule | None) – LightningIRDataModule instance. Alternative to passing dataloaders directly.
- Returns:
- List of dictionaries containing search metrics and effectiveness
results (if relevance judgments are available).
- Return type:
List[Mapping[str, float]]
Example
from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule from lightning_ir import SearchCallback, RankCallback, QueryDataset from lightning_ir import TorchDenseSearchConfig # Setup trainer with search and rank callbacks search_callback = SearchCallback( index_dir="./index", search_config=TorchDenseSearchConfig(k=100) ) rank_callback = RankCallback(results_dir="./results") trainer = LightningIRTrainer(callbacks=[search_callback, rank_callback]) # Setup model and data module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder") datamodule = LightningIRDataModule( inference_datasets=[QueryDataset("trec-dl-2019/queries")] ) # Search for relevant documents results = trainer.search(module, datamodule)Note
Requires SearchCallback to be configured in trainer callbacks
Index must be built beforehand using the index() method
Search configuration must match the index configuration used during indexing
Add RankCallback to save search results to disk