1"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""
2
3from __future__ import annotations
4
5import csv
6import gc
7import itertools
8from dataclasses import is_dataclass
9from pathlib import Path
10from typing import TYPE_CHECKING, Any, Callable, TypeVar
11
12import pandas as pd
13import torch
14from lightning import Trainer
15from lightning.pytorch.callbacks import Callback, TQDMProgressBar
16
17from ..base.validation_utils import evaluate_run
18from ..data import LightningIRDataModule, RankBatch, SearchBatch
19from ..data.dataset import RUN_HEADER, DocDataset, IRDataset, QueryDataset, RunDataset, _DummyIterableDataset
20from ..data.external_datasets.ir_datasets_utils import register_new_dataset
21from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
22
23if TYPE_CHECKING:
24 from ..base import LightningIRModule, LightningIROutput
25 from ..bi_encoder import BiEncoderModule, BiEncoderOutput
26
27T = TypeVar("T")
28
29
30def _format_large_number(number: float) -> str:
31 suffixes = ["", "K", "M", "B", "T"]
32 suffix_index = 0
33
34 while number >= 1000 and suffix_index < len(suffixes) - 1:
35 number /= 1000.0
36 suffix_index += 1
37
38 formatted_number = f"{number:.2f}"
39
40 suffix = suffixes[suffix_index]
41 if suffix:
42 formatted_number += f" {suffix}"
43 return formatted_number
44
45
46class _GatherMixin:
47 """Mixin to gather dataclasses across all processes"""
48
49 def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
50 if is_dataclass(dataclass):
51 return dataclass.__class__(
52 **{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
53 )
54 return pl_module.all_gather(dataclass)
55
56
57class _IndexDirMixin:
58 """Mixin to get index_dir"""
59
60 index_dir: Path | str | None
61 index_name: str | None
62
63 def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
64 index_dir = self.index_dir
65 if index_dir is None:
66 default_index_dir = Path(pl_module.config.name_or_path)
67 if default_index_dir.exists():
68 index_dir = default_index_dir / "indexes"
69 else:
70 raise ValueError("No index_dir provided and model_name_or_path is not a path")
71 index_dir = Path(index_dir)
72 if self.index_name is None:
73 index_dir = index_dir / dataset.dashed_docs_dataset_id
74 else:
75 index_dir = index_dir / self.index_name
76 return index_dir
77
78
79class _OverwriteMixin:
80 """Mixin to skip datasets (for indexing or searching) if they already exist"""
81
82 _get_save_path: Callable[[LightningIRModule, IRDataset], Path]
83
84 def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
85 overwrite = getattr(self, "overwrite", False)
86 if not overwrite:
87 datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
88 if datamodule is None:
89 raise ValueError("No datamodule found")
90 if datamodule.inference_datasets is None:
91 return
92 inference_datasets = list(datamodule.inference_datasets)
93 for dataset in inference_datasets:
94 save_path = self._get_save_path(pl_module, dataset)
95 if save_path.exists():
96 dataset._SKIP = True
97 trainer.print(f"`{save_path}` already exists. set overwrite=True to overwrite")
98 if (
99 save_path.name.endswith(".run")
100 and dataset.qrels is not None
101 and pl_module.evaluation_metrics is not None
102 ):
103 run = RunDataset._load_csv(save_path)
104 qrels = dataset.qrels.stack(future_stack=True).dropna().astype(int).reset_index()
105 if isinstance(dataset, RunDataset) and dataset.run_path is not None:
106 dataset_id = dataset.run_path.name
107 else:
108 dataset_id = dataset.dataset_id
109 for key, value in evaluate_run(run, qrels, pl_module.evaluation_metrics).items():
110 key = f"{dataset_id}/{key}"
111 pl_module._additional_log_metrics[key] = value
112
113 def _cleanup(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
114 # reset skip flat and additional log metrics
115 datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
116 if datamodule is not None and datamodule.inference_datasets is not None:
117 for dataset in datamodule.inference_datasets:
118 dataset._SKIP = False
119 pl_module._additional_log_metrics = {}
120
121
[docs]
122class IndexCallback(Callback, _GatherMixin, _IndexDirMixin, _OverwriteMixin):
[docs]
123 def __init__(
124 self,
125 index_config: IndexConfig,
126 index_dir: Path | str | None = None,
127 index_name: str | None = None,
128 overwrite: bool = False,
129 verbose: bool = False,
130 ) -> None:
131 """Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`.
132
133 Args:
134 index_config (IndexConfig): Configuration for the indexer.
135 index_dir (Path | str | None): Directory to save index(es) to. If None, indexes will be stored in the
136 model's directory. Defaults to None.
137 index_name (str | None): Name of the index. If None, the dataset's dataset_id or file name will be used.
138 Defaults to None.
139 overwrite (bool): Whether to skip or overwrite already existing indexes. Defaults to False.
140 verbose (bool): Toggle verbose output. Defaults to False.
141 """
142 super().__init__()
143 self.index_config = index_config
144 self.index_dir = index_dir
145 self.index_name = index_name
146 self.overwrite = overwrite
147 self.verbose = verbose
148 self.indexer: Indexer
149
[docs]
150 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
151 """Hook to setup the callback.
152
153 Args:
154 trainer (Trainer): PyTorch Lightning Trainer.
155 pl_module (BiEncoderModule): LightningIR bi-encoder module used for indexing.
156 stage (str): Stage of the trainer, must be "test".
157 Raises:
158 ValueError: If the stage is not "test".
159 """
160 if stage != "test":
161 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
162 self._remove_overwrite_datasets(trainer, pl_module)
163
[docs]
164 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
165 """Hook to cleanup the callback.
166
167 Args:
168 trainer (Trainer): PyTorch Lightning Trainer.
169 pl_module (BiEncoderModule): LightningIR bi-encoder module used for indexing.
170 stage (str): Stage of the trainer.
171 """
172 self._cleanup(trainer, pl_module)
173
174 def _get_save_path(self, pl_module: BiEncoderModule, dataset: IRDataset) -> Path:
175 if not isinstance(dataset, DocDataset):
176 raise ValueError("Expected DocDataset for indexing")
177 return self._get_index_dir(pl_module, dataset)
178
179 def _get_indexer(self, pl_module: BiEncoderModule, dataloader_idx: int) -> Indexer:
180 dataset = pl_module.get_dataset(dataloader_idx)
181 if dataset is None:
182 raise ValueError("No dataset found to index")
183 if not isinstance(dataset, DocDataset):
184 raise ValueError("Expected DocDataset for indexing")
185 index_dir = self._get_save_path(pl_module, dataset)
186
187 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module, self.verbose)
188 return indexer
189
190 def _log_to_pg(self, info: dict[str, Any], trainer: Trainer):
191 pg_callback = trainer.progress_bar_callback
192 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar):
193 return
194 pg = pg_callback.test_progress_bar
195 info = {k: _format_large_number(v) for k, v in info.items()}
196 if pg is not None:
197 pg.set_postfix(info)
198
[docs]
199 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
200 """Hook to test datasets are configured correctly.
201
202 Args:
203 trainer (Trainer): PyTorch Lightning Trainer.
204 pl_module (BiEncoderModule): LightningIR bi-encoder module.
205 Raises:
206 ValueError: If no test_dataloaders are found.
207 ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset`.
208 """
209 dataloaders = trainer.test_dataloaders
210 if dataloaders is None:
211 raise ValueError("No test_dataloaders found")
212 datasets = [dataloader.dataset for dataloader in dataloaders]
213 if not all(isinstance(dataset, (DocDataset, _DummyIterableDataset)) for dataset in datasets):
214 raise ValueError("Expected DocDatasets for indexing")
215
[docs]
216 def on_test_batch_start(
217 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
218 ) -> None:
219 """Hook to setup the indexer between datasets.
220
221 Args:
222 trainer (Trainer): PyTorch Lightning Trainer.
223 pl_module (BiEncoderModule): LightningIR bi-encoder module.
224 batch (Any): Batch of input data.
225 batch_idx (int): Index of batch in the current dataset.
226 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
227 """
228 if not trainer.is_global_zero:
229 return
230 if batch_idx == 0:
231 if hasattr(self, "indexer"):
232 self.indexer.save()
233 self.indexer = self._get_indexer(pl_module, dataloader_idx)
234 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
235
[docs]
236 def on_test_batch_end(
237 self,
238 trainer: Trainer,
239 pl_module: BiEncoderModule,
240 outputs: BiEncoderOutput,
241 batch: Any,
242 batch_idx: int,
243 dataloader_idx: int = 0,
244 ) -> None:
245 """Hook to pass encoded documents to the indexer
246
247 Args:
248 trainer (Trainer): PyTorch Lightning Trainer.
249 pl_module (BiEncoderModule): LightningIR bi-encoder module.
250 outputs (BiEncoderOutput): Encoded documents.
251 batch (Any): Batch of input data.
252 batch_idx (int): Index of batch in the current dataset.
253 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
254 """
255 batch = self._gather(pl_module, batch)
256 outputs = self._gather(pl_module, outputs)
257
258 if not trainer.is_global_zero:
259 return
260
261 self.indexer.add(batch, outputs)
262 self._log_to_pg(
263 {
264 "num_docs": self.indexer.num_docs,
265 "num_embeddings": self.indexer.num_embeddings,
266 },
267 trainer,
268 )
269 # TODO if dataset length cannot be inferred, num_test_batches is inf and no index is saved
270 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
271 assert hasattr(self, "indexer")
272
[docs]
273 def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
274 """Hook to save the index after indexing is done.
275
276 Args:
277 trainer (Trainer): PyTorch Lightning Trainer.
278 pl_module (LightningIRModule): LightningIR module.
279 """
280 if not trainer.is_global_zero:
281 return
282 if hasattr(self, "indexer"):
283 self.indexer.save()
284
285
[docs]
286class RankCallback(Callback, _GatherMixin, _OverwriteMixin):
[docs]
287 def __init__(
288 self, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False
289 ) -> None:
290 """Callback to write run file of ranked documents to disk.
291
292 Args:
293 save_dir (Path | str | None): Directory to save run files to. If None, run files will be saved in the
294 models' directory. Defaults to None.
295 run_name (str | None): Name of the run file. If None, the dataset's dataset_id or file name will be used.
296 Defaults to None.
297 overwrite (bool): Whether to skip or overwrite already existing run files. Defaults to False.
298 """
299 super().__init__()
300 self.save_dir = Path(save_dir) if save_dir is not None else None
301 self.run_name = run_name
302 self.overwrite = overwrite
303 self.run_dfs: list[pd.DataFrame] = []
304
[docs]
305 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
306 """Hook to setup the callback.
307
308 Args:
309 trainer (Trainer): PyTorch Lightning Trainer.
310 pl_module (LightningIRModule): LightningIR module.
311 stage (str): Stage of the trainer, must be "test".
312 Raises:
313 ValueError: If the stage is not "test".
314 ValueError: If no save_dir is provided and model_name_or_path is not a path (the model is not local).
315 """
316 if stage != "test":
317 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
318 if self.save_dir is None:
319 default_save_dir = Path(pl_module.config.name_or_path)
320 if default_save_dir.exists():
321 self.save_dir = default_save_dir / "runs"
322 print(f"Using default save_dir `{self.save_dir}` to save runs")
323 else:
324 raise ValueError("No save_dir provided and model_name_or_path is not a path")
325 self._remove_overwrite_datasets(trainer, pl_module)
326
[docs]
327 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
328 """Hook to cleanup the callback.
329
330 Args:
331 trainer (Trainer): PyTorch Lightning Trainer.
332 pl_module (LightningIRModule): LightningIR bi-encoder module used for indexing.
333 stage (str): Stage of the trainer, must be "test".
334 """
335 self._cleanup(trainer, pl_module)
336
337 def _get_save_path(self, pl_module: LightningIRModule, dataset: IRDataset) -> Path:
338 if self.save_dir is None:
339 raise ValueError("No save_dir found; call setup before using this method")
340 if self.run_name is not None:
341 run_file = self.run_name
342 elif isinstance(dataset, QueryDataset):
343 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
344 elif isinstance(dataset, RunDataset):
345 if dataset.run_path is None:
346 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
347 else:
348 run_file = f"{dataset.run_path.name.split('.')[0]}.run"
349 else:
350 raise ValueError("Expected QueryDataset or RunDataset for ranking")
351 run_file_path = self.save_dir / run_file
352 return run_file_path
353
354 def _rank(self, batch: RankBatch, output: LightningIROutput) -> tuple[torch.Tensor, list[str], list[int]]:
355 scores = output.scores
356 if scores is None:
357 raise ValueError("Expected output to have scores")
358 doc_ids = batch.doc_ids
359 if doc_ids is None:
360 raise ValueError("Expected batch to have doc_ids")
361 scores = scores.view(-1)
362 num_docs = [len(_doc_ids) for _doc_ids in doc_ids]
363 doc_ids = list(itertools.chain.from_iterable(doc_ids))
364 if scores.shape[0] != len(doc_ids):
365 raise ValueError("scores and doc_ids must have the same length")
366 return scores, doc_ids, num_docs
367
368 def _write_run_dfs(self, trainer: Trainer, pl_module: LightningIRModule, dataloader_idx: int):
369 if not trainer.is_global_zero or not self.run_dfs:
370 return
371 dataloaders = trainer.test_dataloaders
372 if dataloaders is None:
373 raise ValueError("No test_dataloaders found")
374 dataset = pl_module.get_dataset(dataloader_idx)
375 if dataset is None:
376 raise ValueError("No dataset found to write run file")
377 if not isinstance(dataset, (QueryDataset, RunDataset)):
378 raise ValueError("Expected QueryDataset or RunDataset for ranking")
379 run_file_path = self._get_save_path(pl_module, dataset)
380 run_file_path.parent.mkdir(parents=True, exist_ok=True)
381 run_df = pd.concat(self.run_dfs, ignore_index=True)
382 run_df.to_csv(run_file_path, header=False, index=False, sep="\t", quoting=csv.QUOTE_NONE)
383
[docs]
384 def on_test_batch_start(
385 self, trainer: Trainer, pl_module: LightningIRModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
386 ) -> None:
387 """Hook to write run file.
388
389 Args:
390 trainer (Trainer): PyTorch Lightning Trainer.
391 pl_module (LightningIRModule): LightningIR module.
392 batch (Any): Batch of input data.
393 batch_idx (int): Index of batch in the current dataset.
394 dataloader_idx (int, optional): Index of the dataloader. Defaults to 0.
395 """
396 if not trainer.is_global_zero:
397 return
398 if batch_idx == 0:
399 if self.run_dfs:
400 self._write_run_dfs(trainer, pl_module, dataloader_idx)
401 self.run_dfs = []
402 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
403
[docs]
404 def on_test_batch_end(
405 self,
406 trainer: Trainer,
407 pl_module: LightningIRModule,
408 outputs: LightningIROutput,
409 batch: Any,
410 batch_idx: int,
411 dataloader_idx: int = 0,
412 ) -> None:
413 """Hook to aggregate and write ranking to file.
414
415 Args:
416 trainer (Trainer): PyTorch Lightning Trainer.
417 pl_module (LightningIRModule): LightningIR module.
418 outputs (LightningIROutput): Scored query documents pairs.
419 batch (Any): Batch of input data.
420 batch_idx (int): Index of batch in the current dataset.
421 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
422 Raises:
423 ValueError: If the batch does not have query_ids.
424 """
425 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
426 batch = self._gather(pl_module, batch)
427 outputs = self._gather(pl_module, outputs)
428 if not trainer.is_global_zero:
429 return
430
431 query_ids = batch.query_ids
432 if query_ids is None:
433 raise ValueError("Expected batch to have query_ids")
434 scores, doc_ids, num_docs = self._rank(batch, outputs)
435 scores = scores.float().cpu().numpy()
436
437 query_ids = list(
438 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs))
439 )
440 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"])
441 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False])
442 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int)
443 run_df["q0"] = 0
444 run_df["system"] = pl_module.model.__class__.__name__
445 run_df = run_df[RUN_HEADER]
446
447 self.run_dfs.append(run_df)
448
[docs]
449 def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
450 """Hook to write remaining run files after ranking is done.
451
452 Args:
453 trainer (Trainer): PyTorch Lightning Trainer.
454 pl_module (LightningIRModule): LightningIR module.
455 """
456 if not trainer.is_global_zero:
457 return
458 if self.run_dfs:
459 self._write_run_dfs(trainer, pl_module, 0)
460
461
[docs]
462class SearchCallback(RankCallback, _IndexDirMixin):
[docs]
463 def __init__(
464 self,
465 search_config: SearchConfig,
466 index_dir: Path | str | None = None,
467 index_name: str | None = None,
468 save_dir: Path | str | None = None,
469 run_name: str | None = None,
470 overwrite: bool = False,
471 use_gpu: bool = True,
472 ) -> None:
473 """Callback to which uses index to retrieve documents efficiently.
474
475 Args:
476 search_config (SearchConfig): Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher`
477 index_dir (Path | str | None): Directory where indexes are stored. Defaults to None.
478 index_name (str | None): Name of the index. If None, the dataset's dataset_id or file name will be used.
479 Defaults to None.
480 save_dir (Path | str | None): Directory to save run files to. If None, run files will be saved in the
481 model's directory. Defaults to None.
482 run_name (str | None): Name of the run file. If None, the dataset's dataset_id or file name will be used.
483 Defaults to None.
484 overwrite (bool): Whether to skip or overwrite already existing run files. Defaults to False.
485 use_gpu (bool): Toggle to use GPU for retrieval. Defaults to True.
486 """
487 super().__init__(save_dir=save_dir, run_name=run_name, overwrite=overwrite)
488 self.search_config = search_config
489 self.index_dir = index_dir
490 self.index_name = index_name
491 self.overwrite = overwrite
492 self.use_gpu = use_gpu
493 self.searcher: Searcher
494
495 def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Searcher:
496 dataloaders = trainer.test_dataloaders
497 if dataloaders is None:
498 raise ValueError("No test_dataloaders found")
499 dataset = dataloaders[dataset_idx].dataset
500
501 index_dir = self._get_index_dir(pl_module, dataset)
502 if hasattr(self, "searcher"):
503 if self.searcher.index_dir == index_dir:
504 return self.searcher
505 # free up memory
506 del self.searcher
507 gc.collect()
508 torch.cuda.empty_cache()
509
510 searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu)
511 return searcher
512
513 def _rank(
514 self, batch: SearchBatch | RankBatch, output: LightningIROutput
515 ) -> tuple[torch.Tensor, list[str], list[int]]:
516 if batch.doc_ids is None:
517 raise ValueError("BiEncoderModule did not return doc_ids when searching")
518 dummy_docs = [[""] * len(ids) for ids in batch.doc_ids]
519 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, batch.doc_ids, batch.qrels)
520 return super()._rank(batch, output)
521
[docs]
522 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
523 """Hook to validate datasets
524
525 Args:
526 trainer (Trainer): PyTorch Lightning Trainer.
527 pl_module (BiEncoderModule): LightningIR bi-encoder module.
528 Raises:
529 ValueError: If no test_dataloaders are found.
530 ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset`.
531 """
532 dataloaders = trainer.test_dataloaders
533 if dataloaders is None:
534 raise ValueError("No test_dataloaders found")
535 datasets = [dataloader.dataset for dataloader in dataloaders]
536 if not all(isinstance(dataset, (QueryDataset, _DummyIterableDataset)) for dataset in datasets):
537 raise ValueError("Expected QueryDatasets for indexing")
538
[docs]
539 def on_test_batch_start(
540 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
541 ) -> None:
542 """Hook to initialize searcher for new datasets.
543
544 Args:
545 trainer (Trainer): PyTorch Lightning Trainer.
546 pl_module (BiEncoderModule): LightningIR bi-encoder module.
547 batch (Any): Batch of input data.
548 batch_idx (int): Index of the batch in the dataset.
549 dataloader_idx (int | None): Index of the dataloader. Defaults to 0.
550 """
551 if batch_idx == 0:
552 self.searcher = self._get_searcher(trainer, pl_module, dataloader_idx)
553 pl_module.searcher = self.searcher
554 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
555
556
[docs]
557class ReRankCallback(RankCallback):
558 pass
559
560
[docs]
561class RegisterLocalDatasetCallback(Callback):
[docs]
562 def __init__(
563 self,
564 dataset_id: str,
565 docs: str | None = None,
566 queries: str | None = None,
567 qrels: str | None = None,
568 docpairs: str | None = None,
569 scoreddocs: str | None = None,
570 qrels_defs: dict[int, str] | None = None,
571 ):
572 """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using
573 ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported:
574
575 - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries
576 - ``.tsv`` or ``.qrels`` for qrels
577 - ``.tsv`` for training n-tuples
578 - ``.tsv`` or ``.run`` for scored documents / run files
579
580 Args:
581 dataset_id (str): Dataset id.
582 docs (str | None): Path to documents file or valid ir_datasets id from which documents should be taken.
583 Defaults to None.
584 queries (str | None): Path to queries file or valid ir_datasets id from which queries should be taken.
585 Defaults to None.
586 qrels (str | None): Path to qrels file or valid ir_datasets id from which qrels will be taken.
587 Defaults to None.
588 docpairs (str | None): Path to training n-tuple file or valid ir_datasets id from which training tuples
589 will be taken. Defaults to None.
590 scoreddocs (str | None): Path to run file or valid ir_datasets id from which scored documents will be taken.
591 Defaults to None.
592 qrels_defs (dict[int, str] | None): Optional dictionary describing the relevance levels of the qrels.
593 Defaults to None.
594 """
595 super().__init__()
596 self.dataset_id = dataset_id
597 self.docs = docs
598 self.queries = queries
599 self.qrels = qrels
600 self.docpairs = docpairs
601 self.scoreddocs = scoreddocs
602 self.qrels_defs = qrels_defs
603
[docs]
604 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
605 """Hook that registers dataset.
606
607 Args:
608 trainer (Trainer): PyTorch Lightning Trainer.
609 pl_module (LightningIRModule): Lightning IR module.
610 stage (str): Stage of the trainer.
611 """
612 register_new_dataset(
613 self.dataset_id,
614 docs=self.docs,
615 queries=self.queries,
616 qrels=self.qrels,
617 docpairs=self.docpairs,
618 scoreddocs=self.scoreddocs,
619 qrels_defs=self.qrels_defs,
620 )