SearchCallback
- class lightning_ir.callbacks.callbacks.SearchCallback(search_config: SearchConfig, index_dir: Path | str | None = None, index_name: str | None = None, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False, use_gpu: bool = True)[source]
Bases:
RankCallback
,_IndexDirMixin
- __init__(search_config: SearchConfig, index_dir: Path | str | None = None, index_name: str | None = None, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False, use_gpu: bool = True) None [source]
Callback to which uses index to retrieve documents efficiently.
- Parameters:
search_config (SearchConfig) – Configuration of the
Searcher
index_dir (Path | str | None, optional) – Directory where indexes are stored, defaults to None
index_name (str | None, optional) – Name of the index. If None, the dataset’s dataset_id or file name will be used, defaults to None
save_dir (Path | str | None, optional) – Directory to save run files to. If None, run files are saved in the model’s directory, defaults to None
run_name (str | None, optional) – Name of the run file. If None, the dataset’s dataset_id or file name will be used, defaults to None
overwrite (bool, optional) – Whether to skip or overwrite already existing run files, defaults to False
use_gpu (bool, optional) – Toggle to use gpu for retrieval, defaults to True
Methods
__init__
(search_config[, index_dir, ...])Callback to which uses index to retrieve documents efficiently.
on_test_batch_start
(trainer, pl_module, ...)Hook to initialize searcher for new datasets.
on_test_start
(trainer, pl_module)Hook to validate datasets
Attributes
index_dir
index_name
- on_test_batch_end(trainer: Trainer, pl_module: LightningIRModule, outputs: LightningIROutput, batch: Any, batch_idx: int, dataloader_idx: int = 0) None
Hook to aggregate and write ranking to file.
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (LightningIRModule) – LightningIR Module
outputs (LightningIROutput) – Scored query documents pairs
batch (Any) – Batch of input data
batch_idx (int) – Index of batch in the current dataset
dataloader_idx (int, optional) – Index of the dataloader, defaults to 0
- on_test_batch_start(trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None [source]
Hook to initialize searcher for new datasets.
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (BiEncoderModule) – LightningIR BiEncoderModule
batch (Any) – Batch of input data
batch_idx (int) – Index of batch in dataset
dataloader_idx (int, optional) – Index of the dataloader, defaults to 0
- on_test_start(trainer: Trainer, pl_module: BiEncoderModule) None [source]
Hook to validate datasets
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (BiEncoderModule) – LightningIR BiEncoderModule
- Raises:
ValueError – If no test_dataloaders are found
ValueError – If not all datasets are
QueryDataset
- setup(trainer: Trainer, pl_module: LightningIRModule, stage: str) None
Hook to setup the callback.
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (LightningIRModule) – LightningIR module
stage (str) – Stage of the trainer, must be “test”
- Raises:
ValueError – If the stage is not “test”
ValueError – If no save_dir is provided and model_name_or_path is not a path (the model is not local)
- teardown(trainer: Trainer, pl_module: BiEncoderModule, stage: str) None
Hook to cleanup the callback.
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (BiEncoderModule) – LightningIR bi-encoder module used for indexing
stage (str) – Stage of the trainer, must be “test”