1"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""
2
3from __future__ import annotations
4
5import csv
6import gc
7import itertools
8from dataclasses import is_dataclass
9from pathlib import Path
10from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar
11
12import pandas as pd
13import torch
14from lightning import Trainer
15from lightning.pytorch.callbacks import Callback, TQDMProgressBar
16
17from ..base.validation_utils import evaluate_run
18from ..data import LightningIRDataModule, RankBatch, SearchBatch
19from ..data.dataset import RUN_HEADER, DocDataset, IRDataset, QueryDataset, RunDataset, _DummyIterableDataset
20from ..data.external_datasets.ir_datasets_utils import register_new_dataset
21from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
22
23if TYPE_CHECKING:
24 from ..base import LightningIRModule, LightningIROutput
25 from ..bi_encoder import BiEncoderModule, BiEncoderOutput
26
27T = TypeVar("T")
28
29
30def _format_large_number(number: float) -> str:
31 suffixes = ["", "K", "M", "B", "T"]
32 suffix_index = 0
33
34 while number >= 1000 and suffix_index < len(suffixes) - 1:
35 number /= 1000.0
36 suffix_index += 1
37
38 formatted_number = "{:.2f}".format(number)
39
40 suffix = suffixes[suffix_index]
41 if suffix:
42 formatted_number += f" {suffix}"
43 return formatted_number
44
45
46class _GatherMixin:
47 """Mixin to gather dataclasses across all processes"""
48
49 def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
50 if is_dataclass(dataclass):
51 return dataclass.__class__(
52 **{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
53 )
54 return pl_module.all_gather(dataclass)
55
56
57class _IndexDirMixin:
58 """Mixin to get index_dir"""
59
60 index_dir: Path | str | None
61 index_name: str | None
62
63 def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
64 index_dir = self.index_dir
65 if index_dir is None:
66 default_index_dir = Path(pl_module.config.name_or_path)
67 if default_index_dir.exists():
68 index_dir = default_index_dir / "indexes"
69 else:
70 raise ValueError("No index_dir provided and model_name_or_path is not a path")
71 index_dir = Path(index_dir)
72 if self.index_name is None:
73 index_dir = index_dir / dataset.dashed_docs_dataset_id
74 else:
75 index_dir = index_dir / self.index_name
76 return index_dir
77
78
79class _OverwriteMixin:
80 """Mixin to skip datasets (for indexing or searching) if they already exist"""
81
82 _get_save_path: Callable[[LightningIRModule, IRDataset], Path]
83
84 def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
85 overwrite = getattr(self, "overwrite", False)
86 if not overwrite:
87 datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
88 if datamodule is None:
89 raise ValueError("No datamodule found")
90 if datamodule.inference_datasets is None:
91 return
92 inference_datasets = list(datamodule.inference_datasets)
93 for dataset in inference_datasets:
94 save_path = self._get_save_path(pl_module, dataset)
95 if save_path.exists():
96 dataset._SKIP = True
97 trainer.print(f"`{save_path}` already exists. Set overwrite=True to overwrite")
98 if (
99 save_path.name.endswith(".run")
100 and dataset.qrels is not None
101 and pl_module.evaluation_metrics is not None
102 ):
103 run = RunDataset._load_csv(save_path)
104 qrels = dataset.qrels.stack(future_stack=True).dropna().astype(int).reset_index()
105 if isinstance(dataset, RunDataset) and dataset.run_path is not None:
106 dataset_id = dataset.run_path.name
107 else:
108 dataset_id = dataset.dataset_id
109 for key, value in evaluate_run(run, qrels, pl_module.evaluation_metrics).items():
110 key = f"{dataset_id}/{key}"
111 pl_module._additional_log_metrics[key] = value
112
113 def _cleanup(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
114 # reset skip flat and additional log metrics
115 datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
116 if datamodule is not None and datamodule.inference_datasets is not None:
117 for dataset in datamodule.inference_datasets:
118 dataset._SKIP = False
119 pl_module._additional_log_metrics = {}
120
121
[docs]
122class IndexCallback(Callback, _GatherMixin, _IndexDirMixin, _OverwriteMixin):
[docs]
123 def __init__(
124 self,
125 index_config: IndexConfig,
126 index_dir: Path | str | None = None,
127 index_name: str | None = None,
128 overwrite: bool = False,
129 verbose: bool = False,
130 ) -> None:
131 """Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`.
132
133 Args:
134 index_config (IndexConfig): Configuration for the indexer.
135 index_dir (Path | str | None): Directory to save index(es) to. If None, indexes will be stored in the
136 model's directory. Defaults to None.
137 index_name (str | None): Name of the index. If None, the dataset's dataset_id or file name will be used.
138 Defaults to None.
139 overwrite (bool): Whether to skip or overwrite already existing indexes. Defaults to False.
140 verbose (bool): Toggle verbose output. Defaults to False.
141 """
142 super().__init__()
143 self.index_config = index_config
144 self.index_dir = index_dir
145 self.index_name = index_name
146 self.overwrite = overwrite
147 self.verbose = verbose
148 self.indexer: Indexer
149
[docs]
150 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
151 """Hook to setup the callback.
152
153 Args:
154 trainer (Trainer): PyTorch Lightning Trainer.
155 pl_module (BiEncoderModule): LightningIR bi-encoder module used for indexing.
156 stage (str): Stage of the trainer, must be "test".
157 Raises:
158 ValueError: If the stage is not "test".
159 """
160 if stage != "test":
161 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
162 self._remove_overwrite_datasets(trainer, pl_module)
163
[docs]
164 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
165 """Hook to cleanup the callback.
166
167 Args:
168 trainer (Trainer): PyTorch Lightning Trainer.
169 pl_module (BiEncoderModule): LightningIR bi-encoder module used for indexing.
170 stage (str): Stage of the trainer.
171 """
172 self._cleanup(trainer, pl_module)
173
174 def _get_save_path(self, pl_module: BiEncoderModule, dataset: IRDataset) -> Path:
175 if not isinstance(dataset, DocDataset):
176 raise ValueError("Expected DocDataset for indexing")
177 return self._get_index_dir(pl_module, dataset)
178
179 def _get_indexer(self, pl_module: BiEncoderModule, dataloader_idx: int) -> Indexer:
180 dataset = pl_module.get_dataset(dataloader_idx)
181 if dataset is None:
182 raise ValueError("No dataset found to index")
183 if not isinstance(dataset, DocDataset):
184 raise ValueError("Expected DocDataset for indexing")
185 index_dir = self._get_save_path(pl_module, dataset)
186
187 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module, self.verbose)
188 return indexer
189
190 def _log_to_pg(self, info: Dict[str, Any], trainer: Trainer):
191 pg_callback = trainer.progress_bar_callback
192 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar):
193 return
194 pg = pg_callback.test_progress_bar
195 info = {k: _format_large_number(v) for k, v in info.items()}
196 if pg is not None:
197 pg.set_postfix(info)
198
[docs]
199 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
200 """Hook to test datasets are configured correctly.
201
202 Args:
203 trainer (Trainer): PyTorch Lightning Trainer.
204 pl_module (BiEncoderModule): LightningIR bi-encoder module.
205 Raises:
206 ValueError: If no test_dataloaders are found.
207 ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset`.
208 """
209 dataloaders = trainer.test_dataloaders
210 if dataloaders is None:
211 raise ValueError("No test_dataloaders found")
212 datasets = [dataloader.dataset for dataloader in dataloaders]
213 if not all(isinstance(dataset, (DocDataset, _DummyIterableDataset)) for dataset in datasets):
214 raise ValueError("Expected DocDatasets for indexing")
215
[docs]
216 def on_test_batch_start(
217 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
218 ) -> None:
219 """Hook to setup the indexer between datasets.
220
221 Args:
222 trainer (Trainer): PyTorch Lightning Trainer.
223 pl_module (BiEncoderModule): LightningIR bi-encoder module.
224 batch (Any): Batch of input data.
225 batch_idx (int): Index of batch in the current dataset.
226 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
227 """
228 if batch_idx == 0:
229 self.indexer = self._get_indexer(pl_module, dataloader_idx)
230 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
231
[docs]
232 def on_test_batch_end(
233 self,
234 trainer: Trainer,
235 pl_module: BiEncoderModule,
236 outputs: BiEncoderOutput,
237 batch: Any,
238 batch_idx: int,
239 dataloader_idx: int = 0,
240 ) -> None:
241 """Hook to pass encoded documents to the indexer
242
243 Args:
244 trainer (Trainer): PyTorch Lightning Trainer.
245 pl_module (BiEncoderModule): LightningIR bi-encoder module.
246 outputs (BiEncoderOutput): Encoded documents.
247 batch (Any): Batch of input data.
248 batch_idx (int): Index of batch in the current dataset.
249 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
250 """
251 batch = self._gather(pl_module, batch)
252 outputs = self._gather(pl_module, outputs)
253
254 if not trainer.is_global_zero:
255 return
256
257 self.indexer.add(batch, outputs)
258 self._log_to_pg(
259 {
260 "num_docs": self.indexer.num_docs,
261 "num_embeddings": self.indexer.num_embeddings,
262 },
263 trainer,
264 )
265 # TODO if dataset length cannot be inferred, num_test_batches is inf and no index is saved
266 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
267 assert hasattr(self, "indexer")
268 self.indexer.save()
269
270
[docs]
271class RankCallback(Callback, _GatherMixin, _OverwriteMixin):
[docs]
272 def __init__(
273 self, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False
274 ) -> None:
275 """Callback to write run file of ranked documents to disk.
276
277 Args:
278 save_dir (Path | str | None): Directory to save run files to. If None, run files will be saved in the
279 models' directory. Defaults to None.
280 run_name (str | None): Name of the run file. If None, the dataset's dataset_id or file name will be used.
281 Defaults to None.
282 overwrite (bool): Whether to skip or overwrite already existing run files. Defaults to False.
283 """
284 super().__init__()
285 self.save_dir = Path(save_dir) if save_dir is not None else None
286 self.run_name = run_name
287 self.overwrite = overwrite
288 self.run_dfs: List[pd.DataFrame] = []
289
[docs]
290 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
291 """Hook to setup the callback.
292
293 Args:
294 trainer (Trainer): PyTorch Lightning Trainer.
295 pl_module (LightningIRModule): LightningIR module.
296 stage (str): Stage of the trainer, must be "test".
297 Raises:
298 ValueError: If the stage is not "test".
299 ValueError: If no save_dir is provided and model_name_or_path is not a path (the model is not local).
300 """
301 if stage != "test":
302 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
303 if self.save_dir is None:
304 default_save_dir = Path(pl_module.config.name_or_path)
305 if default_save_dir.exists():
306 self.save_dir = default_save_dir / "runs"
307 print(f"Using default save_dir `{self.save_dir}` to save runs")
308 else:
309 raise ValueError("No save_dir provided and model_name_or_path is not a path")
310 self._remove_overwrite_datasets(trainer, pl_module)
311
[docs]
312 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
313 """Hook to cleanup the callback.
314
315 Args:
316 trainer (Trainer): PyTorch Lightning Trainer.
317 pl_module (LightningIRModule): LightningIR bi-encoder module used for indexing.
318 stage (str): Stage of the trainer, must be "test".
319 """
320 self._cleanup(trainer, pl_module)
321
322 def _get_save_path(self, pl_module: LightningIRModule, dataset: IRDataset) -> Path:
323 if self.save_dir is None:
324 raise ValueError("No save_dir found; call setup before using this method")
325 if self.run_name is not None:
326 run_file = self.run_name
327 elif isinstance(dataset, QueryDataset):
328 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
329 elif isinstance(dataset, RunDataset):
330 if dataset.run_path is None:
331 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
332 else:
333 run_file = f"{dataset.run_path.name.split('.')[0]}.run"
334 else:
335 raise ValueError("Expected QueryDataset or RunDataset for ranking")
336 run_file_path = self.save_dir / run_file
337 return run_file_path
338
339 def _rank(self, batch: RankBatch, output: LightningIROutput) -> Tuple[torch.Tensor, List[str], List[int]]:
340 scores = output.scores
341 if scores is None:
342 raise ValueError("Expected output to have scores")
343 doc_ids = batch.doc_ids
344 if doc_ids is None:
345 raise ValueError("Expected batch to have doc_ids")
346 scores = scores.view(-1)
347 num_docs = [len(_doc_ids) for _doc_ids in doc_ids]
348 doc_ids = list(itertools.chain.from_iterable(doc_ids))
349 if scores.shape[0] != len(doc_ids):
350 raise ValueError("scores and doc_ids must have the same length")
351 return scores, doc_ids, num_docs
352
353 def _write_run_dfs(self, trainer: Trainer, pl_module: LightningIRModule, dataloader_idx: int):
354 if not trainer.is_global_zero or not self.run_dfs:
355 return
356 dataloaders = trainer.test_dataloaders
357 if dataloaders is None:
358 raise ValueError("No test_dataloaders found")
359 dataset = pl_module.get_dataset(dataloader_idx)
360 if dataset is None:
361 raise ValueError("No dataset found to write run file")
362 if not isinstance(dataset, (QueryDataset, RunDataset)):
363 raise ValueError("Expected QueryDataset or RunDataset for ranking")
364 run_file_path = self._get_save_path(pl_module, dataset)
365 run_file_path.parent.mkdir(parents=True, exist_ok=True)
366 run_df = pd.concat(self.run_dfs, ignore_index=True)
367 run_df.to_csv(run_file_path, header=False, index=False, sep="\t", quoting=csv.QUOTE_NONE)
368
[docs]
369 def on_test_batch_end(
370 self,
371 trainer: Trainer,
372 pl_module: LightningIRModule,
373 outputs: LightningIROutput,
374 batch: Any,
375 batch_idx: int,
376 dataloader_idx: int = 0,
377 ) -> None:
378 """Hook to aggregate and write ranking to file.
379
380 Args:
381 trainer (Trainer): PyTorch Lightning Trainer.
382 pl_module (LightningIRModule): LightningIR module.
383 outputs (LightningIROutput): Scored query documents pairs.
384 batch (Any): Batch of input data.
385 batch_idx (int): Index of batch in the current dataset.
386 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
387 Raises:
388 ValueError: If the batch does not have query_ids.
389 """
390 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
391 batch = self._gather(pl_module, batch)
392 outputs = self._gather(pl_module, outputs)
393 if not trainer.is_global_zero:
394 return
395
396 query_ids = batch.query_ids
397 if query_ids is None:
398 raise ValueError("Expected batch to have query_ids")
399 scores, doc_ids, num_docs = self._rank(batch, outputs)
400 scores = scores.float().cpu().numpy()
401
402 query_ids = list(
403 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs))
404 )
405 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"])
406 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False])
407 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int)
408 run_df["q0"] = 0
409 run_df["system"] = pl_module.model.__class__.__name__
410 run_df = run_df[RUN_HEADER]
411
412 self.run_dfs.append(run_df)
413
414 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
415 self._write_run_dfs(trainer, pl_module, dataloader_idx)
416 self.run_dfs = []
417
418
[docs]
419class SearchCallback(RankCallback, _IndexDirMixin):
[docs]
420 def __init__(
421 self,
422 search_config: SearchConfig,
423 index_dir: Path | str | None = None,
424 index_name: str | None = None,
425 save_dir: Path | str | None = None,
426 run_name: str | None = None,
427 overwrite: bool = False,
428 use_gpu: bool = True,
429 ) -> None:
430 """Callback to which uses index to retrieve documents efficiently.
431
432 Args:
433 search_config (SearchConfig): Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher`
434 index_dir (Path | str | None): Directory where indexes are stored. Defaults to None.
435 index_name (str | None): Name of the index. If None, the dataset's dataset_id or file name will be used.
436 Defaults to None.
437 save_dir (Path | str | None): Directory to save run files to. If None, run files will be saved in the
438 model's directory. Defaults to None.
439 run_name (str | None): Name of the run file. If None, the dataset's dataset_id or file name will be used.
440 Defaults to None.
441 overwrite (bool): Whether to skip or overwrite already existing run files. Defaults to False.
442 use_gpu (bool): Toggle to use GPU for retrieval. Defaults to True.
443 """
444 super().__init__(save_dir=save_dir, run_name=run_name, overwrite=overwrite)
445 self.search_config = search_config
446 self.index_dir = index_dir
447 self.index_name = index_name
448 self.overwrite = overwrite
449 self.use_gpu = use_gpu
450 self.searcher: Searcher
451
452 def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Searcher:
453 dataloaders = trainer.test_dataloaders
454 if dataloaders is None:
455 raise ValueError("No test_dataloaders found")
456 dataset = dataloaders[dataset_idx].dataset
457
458 index_dir = self._get_index_dir(pl_module, dataset)
459 if hasattr(self, "searcher"):
460 if self.searcher.index_dir == index_dir:
461 return self.searcher
462 # free up memory
463 del self.searcher
464 gc.collect()
465 torch.cuda.empty_cache()
466
467 searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu)
468 return searcher
469
470 def _rank(
471 self, batch: SearchBatch | RankBatch, output: LightningIROutput
472 ) -> Tuple[torch.Tensor, List[str], List[int]]:
473 if batch.doc_ids is None:
474 raise ValueError("BiEncoderModule did not return doc_ids when searching")
475 dummy_docs = [[""] * len(ids) for ids in batch.doc_ids]
476 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, batch.doc_ids, batch.qrels)
477 return super()._rank(batch, output)
478
[docs]
479 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
480 """Hook to validate datasets
481
482 Args:
483 trainer (Trainer): PyTorch Lightning Trainer.
484 pl_module (BiEncoderModule): LightningIR bi-encoder module.
485 Raises:
486 ValueError: If no test_dataloaders are found.
487 ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset`.
488 """
489 dataloaders = trainer.test_dataloaders
490 if dataloaders is None:
491 raise ValueError("No test_dataloaders found")
492 datasets = [dataloader.dataset for dataloader in dataloaders]
493 if not all(isinstance(dataset, (QueryDataset, _DummyIterableDataset)) for dataset in datasets):
494 raise ValueError("Expected QueryDatasets for indexing")
495
[docs]
496 def on_test_batch_start(
497 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
498 ) -> None:
499 """Hook to initialize searcher for new datasets.
500
501 Args:
502 trainer (Trainer): PyTorch Lightning Trainer.
503 pl_module (BiEncoderModule): LightningIR bi-encoder module.
504 batch (Any): Batch of input data.
505 batch_idx (int): Index of the batch in the dataset.
506 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
507 """
508 if batch_idx == 0:
509 self.searcher = self._get_searcher(trainer, pl_module, dataloader_idx)
510 pl_module.searcher = self.searcher
511 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
512
513
[docs]
514class ReRankCallback(RankCallback):
515 pass
516
517
[docs]
518class RegisterLocalDatasetCallback(Callback):
519
[docs]
520 def __init__(
521 self,
522 dataset_id: str,
523 docs: str | None = None,
524 queries: str | None = None,
525 qrels: str | None = None,
526 docpairs: str | None = None,
527 scoreddocs: str | None = None,
528 qrels_defs: Dict[int, str] | None = None,
529 ):
530 """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using
531 ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported:
532
533 - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries
534 - ``.tsv`` or ``.qrels`` for qrels
535 - ``.tsv`` for training n-tuples
536 - ``.tsv`` or ``.run`` for scored documents / run files
537
538 Args:
539 dataset_id (str): Dataset id.
540 docs (str | None): Path to documents file or valid ir_datasets id from which documents should be taken.
541 Defaults to None.
542 queries (str | None): Path to queries file or valid ir_datasets id from which queries should be taken.
543 Defaults to None.
544 qrels (str | None): Path to qrels file or valid ir_datasets id from which qrels will be taken.
545 Defaults to None.
546 docpairs (str | None): Path to training n-tuple file or valid ir_datasets id from which training tuples
547 will be taken. Defaults to None.
548 scoreddocs (str | None): Path to run file or valid ir_datasets id from which scored documents will be taken.
549 Defaults to None.
550 qrels_defs (Dict[int, str] | None): Optional dictionary describing the relevance levels of the qrels.
551 Defaults to None.
552 """
553 super().__init__()
554 self.dataset_id = dataset_id
555 self.docs = docs
556 self.queries = queries
557 self.qrels = qrels
558 self.docpairs = docpairs
559 self.scoreddocs = scoreddocs
560 self.qrels_defs = qrels_defs
561
[docs]
562 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
563 """Hook that registers dataset.
564
565 Args:
566 trainer (Trainer): PyTorch Lightning Trainer.
567 pl_module (LightningIRModule): Lightning IR module.
568 stage (str): Stage of the trainer.
569 """
570 register_new_dataset(
571 self.dataset_id,
572 docs=self.docs,
573 queries=self.queries,
574 qrels=self.qrels,
575 docpairs=self.docpairs,
576 scoreddocs=self.scoreddocs,
577 qrels_defs=self.qrels_defs,
578 )