SearchCallback

class lightning_ir.callbacks.callbacks.SearchCallback(search_config: SearchConfig, index_dir: Path | str | None = None, index_name: str | None = None, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False, use_gpu: bool = True)[source]

Bases: RankCallback, _IndexDirMixin

__init__(search_config: SearchConfig, index_dir: Path | str | None = None, index_name: str | None = None, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False, use_gpu: bool = True) None[source]

Callback to which uses index to retrieve documents efficiently.

Parameters:
  • search_config (SearchConfig) – Configuration of the Searcher

  • index_dir (Path | str | None) – Directory where indexes are stored. Defaults to None.

  • index_name (str | None) – Name of the index. If None, the dataset’s dataset_id or file name will be used. Defaults to None.

  • save_dir (Path | str | None) – Directory to save run files to. If None, run files will be saved in the model’s directory. Defaults to None.

  • run_name (str | None) – Name of the run file. If None, the dataset’s dataset_id or file name will be used. Defaults to None.

  • overwrite (bool) – Whether to skip or overwrite already existing run files. Defaults to False.

  • use_gpu (bool) – Toggle to use GPU for retrieval. Defaults to True.

Methods

__init__(search_config[, index_dir, ...])

Callback to which uses index to retrieve documents efficiently.

on_test_batch_start(trainer, pl_module, ...)

Hook to initialize searcher for new datasets.

on_test_start(trainer, pl_module)

Hook to validate datasets

Attributes

index_dir

index_name

on_test_batch_start(trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]

Hook to initialize searcher for new datasets.

Parameters:
  • trainer (Trainer) – PyTorch Lightning Trainer.

  • pl_module (BiEncoderModule) – LightningIR bi-encoder module.

  • batch (Any) – Batch of input data.

  • batch_idx (int) – Index of the batch in the dataset.

  • dataloader_idx (int | None) – Index of the dataloader. Defaults to 0.

on_test_start(trainer: Trainer, pl_module: BiEncoderModule) None[source]

Hook to validate datasets

Parameters:
  • trainer (Trainer) – PyTorch Lightning Trainer.

  • pl_module (BiEncoderModule) – LightningIR bi-encoder module.

Raises:
  • ValueError – If no test_dataloaders are found.

  • ValueError – If not all test datasets are QueryDataset.