1"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""
2
3from __future__ import annotations
4
5import csv
6import gc
7import itertools
8from dataclasses import is_dataclass
9from pathlib import Path
10from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar
11
12import pandas as pd
13import torch
14from lightning import Trainer
15from lightning.pytorch.callbacks import Callback, TQDMProgressBar
16
17from ..base.validation_utils import evaluate_run
18from ..data import LightningIRDataModule, RankBatch, SearchBatch
19from ..data.dataset import RUN_HEADER, DocDataset, IRDataset, QueryDataset, RunDataset, _DummyIterableDataset
20from ..data.external_datasets.ir_datasets_utils import register_new_dataset
21from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
22
23if TYPE_CHECKING:
24 from ..base import LightningIRModule, LightningIROutput
25 from ..bi_encoder import BiEncoderModule, BiEncoderOutput
26
27T = TypeVar("T")
28
29
30def _format_large_number(number: float) -> str:
31 suffixes = ["", "K", "M", "B", "T"]
32 suffix_index = 0
33
34 while number >= 1000 and suffix_index < len(suffixes) - 1:
35 number /= 1000.0
36 suffix_index += 1
37
38 formatted_number = "{:.2f}".format(number)
39
40 suffix = suffixes[suffix_index]
41 if suffix:
42 formatted_number += f" {suffix}"
43 return formatted_number
44
45
46class _GatherMixin:
47 """Mixin to gather dataclasses across all processes"""
48
49 def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
50 if is_dataclass(dataclass):
51 return dataclass.__class__(
52 **{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
53 )
54 return pl_module.all_gather(dataclass)
55
56
57class _IndexDirMixin:
58 """Mixin to get index_dir"""
59
60 index_dir: Path | str | None
61 index_name: str | None
62
63 def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
64 index_dir = self.index_dir
65 if index_dir is None:
66 default_index_dir = Path(pl_module.config.name_or_path)
67 if default_index_dir.exists():
68 index_dir = default_index_dir / "indexes"
69 else:
70 raise ValueError("No index_dir provided and model_name_or_path is not a path")
71 index_dir = Path(index_dir)
72 if self.index_name is None:
73 index_dir = index_dir / dataset.docs_dataset_id
74 else:
75 index_dir = index_dir / self.index_name
76 return index_dir
77
78
79class _OverwriteMixin:
80 """Mixin to skip datasets (for indexing or searching) if they already exist"""
81
82 _get_save_path: Callable[[LightningIRModule, IRDataset], Path]
83
84 def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
85 overwrite = getattr(self, "overwrite", False)
86 if not overwrite:
87 datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
88 if datamodule is None:
89 raise ValueError("No datamodule found")
90 if datamodule.inference_datasets is None:
91 return
92 inference_datasets = list(datamodule.inference_datasets)
93 for dataset in inference_datasets:
94 save_path = self._get_save_path(pl_module, dataset)
95 if save_path.exists():
96 dataset._SKIP = True
97 trainer.print(f"`{save_path}` already exists. Set overwrite=True to overwrite")
98 if (
99 save_path.name.endswith(".run")
100 and dataset.qrels is not None
101 and pl_module.evaluation_metrics is not None
102 ):
103 run = RunDataset._load_csv(save_path)
104 qrels = dataset.qrels.stack(future_stack=True).dropna().astype(int).reset_index()
105 if isinstance(dataset, RunDataset) and dataset.run_path is not None:
106 dataset_id = dataset.run_path.name
107 else:
108 dataset_id = dataset.dataset_id
109 for key, value in evaluate_run(run, qrels, pl_module.evaluation_metrics).items():
110 key = f"{dataset_id}/{key}"
111 pl_module._additional_log_metrics[key] = value
112
113 def _cleanup(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
114 # reset skip flat and additional log metrics
115 datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
116 if datamodule is not None and datamodule.inference_datasets is not None:
117 for dataset in datamodule.inference_datasets:
118 dataset._SKIP = False
119 pl_module._additional_log_metrics = {}
120
121
[docs]
122class IndexCallback(Callback, _GatherMixin, _IndexDirMixin, _OverwriteMixin):
[docs]
123 def __init__(
124 self,
125 index_config: IndexConfig,
126 index_dir: Path | str | None = None,
127 index_name: str | None = None,
128 overwrite: bool = False,
129 verbose: bool = False,
130 ) -> None:
131 """Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`.
132
133 :param index_config: Configuration for the indexer
134 :type index_config: IndexConfig
135 :param index_dir: Directory to save index(es) to. If None, indexes will be stored in the model's directory,
136 defaults to None
137 :type index_dir: Path | str | None, optional
138 :param index_name: Name of the index. If None, the dataset's dataset_id or file name will be used,
139 defaults to None
140 :type index_name: str | None, optional
141 :param overwrite: Whether to skip or overwrite already existing indexes, defaults to False
142 :type overwrite: bool, optional
143 :param verbose: Toggle verbose output, defaults to False
144 :type verbose: bool, optional
145 """
146 super().__init__()
147 self.index_config = index_config
148 self.index_dir = index_dir
149 self.index_name = index_name
150 self.overwrite = overwrite
151 self.verbose = verbose
152 self.indexer: Indexer
153
[docs]
154 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
155 """Hook to setup the callback.
156
157 :param trainer: PyTorch Lightning Trainer
158 :type trainer: Trainer
159 :param pl_module: LightningIR bi-encoder module used for indexing
160 :type pl_module: BiEncoderModule
161 :param stage: Stage of the trainer, must be "test"
162 :type stage: str
163 :raises ValueError: If the stage is not "test"
164 """
165 if stage != "test":
166 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
167 self._remove_overwrite_datasets(trainer, pl_module)
168
[docs]
169 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
170 """Hook to cleanup the callback.
171
172 :param trainer: PyTorch Lightning Trainer
173 :type trainer: Trainer
174 :param pl_module: LightningIR bi-encoder module used for indexing
175 :type pl_module: BiEncoderModule
176 :param stage: Stage of the trainer, must be "test"
177 :type stage: str
178 """
179 self._cleanup(trainer, pl_module)
180
181 def _get_save_path(self, pl_module: BiEncoderModule, dataset: IRDataset) -> Path:
182 if not isinstance(dataset, DocDataset):
183 raise ValueError("Expected DocDataset for indexing")
184 return self._get_index_dir(pl_module, dataset)
185
186 def _get_indexer(self, pl_module: BiEncoderModule, dataloader_idx: int) -> Indexer:
187 dataset = pl_module.get_dataset(dataloader_idx)
188 if dataset is None:
189 raise ValueError("No dataset found to index")
190 if not isinstance(dataset, DocDataset):
191 raise ValueError("Expected DocDataset for indexing")
192 index_dir = self._get_save_path(pl_module, dataset)
193
194 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module, self.verbose)
195 return indexer
196
197 def _log_to_pg(self, info: Dict[str, Any], trainer: Trainer):
198 pg_callback = trainer.progress_bar_callback
199 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar):
200 return
201 pg = pg_callback.test_progress_bar
202 info = {k: _format_large_number(v) for k, v in info.items()}
203 if pg is not None:
204 pg.set_postfix(info)
205
[docs]
206 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
207 """Hook to test datasets are configured correctly.
208
209 :param trainer: PyTorch Lightning Trainer
210 :type trainer: Trainer
211 :param pl_module: LightningIR BiEncoderModule
212 :type pl_module: BiEncoderModule
213 :raises ValueError: If no test_dataloaders are found
214 :raises ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset`
215 """
216 dataloaders = trainer.test_dataloaders
217 if dataloaders is None:
218 raise ValueError("No test_dataloaders found")
219 datasets = [dataloader.dataset for dataloader in dataloaders]
220 if not all(isinstance(dataset, (DocDataset, _DummyIterableDataset)) for dataset in datasets):
221 raise ValueError("Expected DocDatasets for indexing")
222
[docs]
223 def on_test_batch_start(
224 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
225 ) -> None:
226 """Hook to setup the indexer between datasets.
227
228 :param trainer: PyTorch Lightning Trainer
229 :type trainer: Trainer
230 :param pl_module: LightningIR bi-encoder module
231 :type pl_module: BiEncoderModule
232 :param batch: Batch of input data
233 :type batch: Any
234 :param batch_idx: Index of batch in the current dataset
235 :type batch_idx: int
236 :param dataloader_idx: Index of the dataloader, defaults to 0
237 :type dataloader_idx: int, optional
238 """
239 if batch_idx == 0:
240 self.indexer = self._get_indexer(pl_module, dataloader_idx)
241 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
242
[docs]
243 def on_test_batch_end(
244 self,
245 trainer: Trainer,
246 pl_module: BiEncoderModule,
247 outputs: BiEncoderOutput,
248 batch: Any,
249 batch_idx: int,
250 dataloader_idx: int = 0,
251 ) -> None:
252 """Hook to pass encoded documents to the indexer
253
254 :param trainer: PyTorch Lightning Trainer
255 :type trainer: Trainer
256 :param pl_module: LightningIR bi-encoder module
257 :type pl_module: BiEncoderModule
258 :param outputs: Encoded documents
259 :type outputs: BiEncoderOutput
260 :param batch: Batch of input data
261 :type batch: Any
262 :param batch_idx: Index of batch in the current dataset
263 :type batch_idx: int
264 :param dataloader_idx: Index of the dataloader, defaults to 0
265 :type dataloader_idx: int, optional
266 """
267 batch = self._gather(pl_module, batch)
268 outputs = self._gather(pl_module, outputs)
269
270 if not trainer.is_global_zero:
271 return
272
273 self.indexer.add(batch, outputs)
274 self._log_to_pg(
275 {
276 "num_docs": self.indexer.num_docs,
277 "num_embeddings": self.indexer.num_embeddings,
278 },
279 trainer,
280 )
281 # TODO if dataset length cannot be inferred, num_test_batches is inf and no index is saved
282 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
283 assert hasattr(self, "indexer")
284 self.indexer.save()
285
286
[docs]
287class RankCallback(Callback, _GatherMixin, _OverwriteMixin):
[docs]
288 def __init__(
289 self, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False
290 ) -> None:
291 """Callback to write run file of ranked documents to disk.
292
293 :param save_dir: Directory to save run files to. If None, run files will be saved in the models' directory,
294 defaults to None
295 :type save_dir: Path | str | None, optional
296 :param run_name: Name of the run file. If None, the dataset's dataset_id or file name will be used,
297 defaults to None
298 :type run_name: str | None, optional
299 :param overwrite: Whether to skip or overwrite already existing run files, defaults to False
300 :type overwrite: bool, optional
301 """
302 super().__init__()
303 self.save_dir = Path(save_dir) if save_dir is not None else None
304 self.run_name = run_name
305 self.overwrite = overwrite
306 self.run_dfs: List[pd.DataFrame] = []
307
[docs]
308 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
309 """Hook to setup the callback.
310
311 :param trainer: PyTorch Lightning Trainer
312 :type trainer: Trainer
313 :param pl_module: LightningIR module
314 :type pl_module: LightningIRModule
315 :param stage: Stage of the trainer, must be "test"
316 :type stage: str
317 :raises ValueError: If the stage is not "test"
318 :raises ValueError: If no save_dir is provided and model_name_or_path is not a path (the model is not local)
319 """
320 if stage != "test":
321 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
322 if self.save_dir is None:
323 default_save_dir = Path(pl_module.config.name_or_path)
324 if default_save_dir.exists():
325 self.save_dir = default_save_dir / "runs"
326 print(f"Using default save_dir `{self.save_dir}` to save runs")
327 else:
328 raise ValueError("No save_dir provided and model_name_or_path is not a path")
329 self._remove_overwrite_datasets(trainer, pl_module)
330
[docs]
331 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
332 """Hook to cleanup the callback.
333
334 :param trainer: PyTorch Lightning Trainer
335 :type trainer: Trainer
336 :param pl_module: LightningIR bi-encoder module used for indexing
337 :type pl_module: BiEncoderModule
338 :param stage: Stage of the trainer, must be "test"
339 :type stage: str
340 """
341 self._cleanup(trainer, pl_module)
342
343 def _get_save_path(self, pl_module: LightningIRModule, dataset: IRDataset) -> Path:
344 if self.save_dir is None:
345 raise ValueError("No save_dir found; call setup before using this method")
346 if self.run_name is not None:
347 run_file = self.run_name
348 elif isinstance(dataset, QueryDataset):
349 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
350 elif isinstance(dataset, RunDataset):
351 if dataset.run_path is None:
352 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
353 else:
354 run_file = f"{dataset.run_path.name.split('.')[0]}.run"
355 else:
356 raise ValueError("Expected QueryDataset or RunDataset for ranking")
357 run_file_path = self.save_dir / run_file
358 return run_file_path
359
360 def _rank(self, batch: RankBatch, output: LightningIROutput) -> Tuple[torch.Tensor, List[str], List[int]]:
361 scores = output.scores
362 if scores is None:
363 raise ValueError("Expected output to have scores")
364 doc_ids = batch.doc_ids
365 if doc_ids is None:
366 raise ValueError("Expected batch to have doc_ids")
367 scores = scores.view(-1)
368 num_docs = [len(_doc_ids) for _doc_ids in doc_ids]
369 doc_ids = list(itertools.chain.from_iterable(doc_ids))
370 if scores.shape[0] != len(doc_ids):
371 raise ValueError("scores and doc_ids must have the same length")
372 return scores, doc_ids, num_docs
373
374 def _write_run_dfs(self, trainer: Trainer, pl_module: LightningIRModule, dataloader_idx: int):
375 if not trainer.is_global_zero or not self.run_dfs:
376 return
377 dataloaders = trainer.test_dataloaders
378 if dataloaders is None:
379 raise ValueError("No test_dataloaders found")
380 dataset = pl_module.get_dataset(dataloader_idx)
381 if dataset is None:
382 raise ValueError("No dataset found to write run file")
383 if not isinstance(dataset, (QueryDataset, RunDataset)):
384 raise ValueError("Expected QueryDataset or RunDataset for ranking")
385 run_file_path = self._get_save_path(pl_module, dataset)
386 run_file_path.parent.mkdir(parents=True, exist_ok=True)
387 run_df = pd.concat(self.run_dfs, ignore_index=True)
388 run_df.to_csv(run_file_path, header=False, index=False, sep="\t", quoting=csv.QUOTE_NONE)
389
[docs]
390 def on_test_batch_end(
391 self,
392 trainer: Trainer,
393 pl_module: LightningIRModule,
394 outputs: LightningIROutput,
395 batch: Any,
396 batch_idx: int,
397 dataloader_idx: int = 0,
398 ) -> None:
399 """Hook to aggregate and write ranking to file.
400
401 :param trainer: PyTorch Lightning Trainer
402 :type trainer: Trainer
403 :param pl_module: LightningIR Module
404 :type pl_module: LightningIRModule
405 :param outputs: Scored query documents pairs
406 :type outputs: LightningIROutput
407 :param batch: Batch of input data
408 :type batch: Any
409 :param batch_idx: Index of batch in the current dataset
410 :type batch_idx: int
411 :param dataloader_idx: Index of the dataloader, defaults to 0
412 :type dataloader_idx: int, optional
413 """
414 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
415 batch = self._gather(pl_module, batch)
416 outputs = self._gather(pl_module, outputs)
417 if not trainer.is_global_zero:
418 return
419
420 query_ids = batch.query_ids
421 if query_ids is None:
422 raise ValueError("Expected batch to have query_ids")
423 scores, doc_ids, num_docs = self._rank(batch, outputs)
424 scores = scores.float().cpu().numpy()
425
426 query_ids = list(
427 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs))
428 )
429 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"])
430 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False])
431 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int)
432 run_df["q0"] = 0
433 run_df["system"] = pl_module.model.__class__.__name__
434 run_df = run_df[RUN_HEADER]
435
436 self.run_dfs.append(run_df)
437
438 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
439 self._write_run_dfs(trainer, pl_module, dataloader_idx)
440 self.run_dfs = []
441
442
[docs]
443class SearchCallback(RankCallback, _IndexDirMixin):
[docs]
444 def __init__(
445 self,
446 search_config: SearchConfig,
447 index_dir: Path | str | None = None,
448 index_name: str | None = None,
449 save_dir: Path | str | None = None,
450 run_name: str | None = None,
451 overwrite: bool = False,
452 use_gpu: bool = True,
453 ) -> None:
454 """Callback to which uses index to retrieve documents efficiently.
455
456 :param search_config: Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher`
457 :type search_config: SearchConfig
458 :param index_dir: Directory where indexes are stored, defaults to None
459 :type index_dir: Path | str | None, optional
460 :param index_name: Name of the index. If None, the dataset's dataset_id or file name will be used,
461 defaults to None
462 :type index_name: str | None, optional
463 :param save_dir: Directory to save run files to. If None, run files are saved in the model's directory,
464 defaults to None
465 :type save_dir: Path | str | None, optional
466 :param run_name: Name of the run file. If None, the dataset's dataset_id or file name will be used,
467 defaults to None
468 :type run_name: str | None, optional
469 :param overwrite: Whether to skip or overwrite already existing run files, defaults to False
470 :type overwrite: bool, optional
471 :param use_gpu: Toggle to use gpu for retrieval, defaults to True
472 :type use_gpu: bool, optional
473 """
474 super().__init__(save_dir=save_dir, run_name=run_name, overwrite=overwrite)
475 self.search_config = search_config
476 self.index_dir = index_dir
477 self.index_name = index_name
478 self.overwrite = overwrite
479 self.use_gpu = use_gpu
480 self.searcher: Searcher
481
482 def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Searcher:
483 dataloaders = trainer.test_dataloaders
484 if dataloaders is None:
485 raise ValueError("No test_dataloaders found")
486 dataset = dataloaders[dataset_idx].dataset
487
488 index_dir = self._get_index_dir(pl_module, dataset)
489 if hasattr(self, "searcher"):
490 if self.searcher.index_dir == index_dir:
491 return self.searcher
492 # free up memory
493 del self.searcher
494 gc.collect()
495 torch.cuda.empty_cache()
496
497 searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu)
498 return searcher
499
500 def _rank(
501 self, batch: SearchBatch | RankBatch, output: LightningIROutput
502 ) -> Tuple[torch.Tensor, List[str], List[int]]:
503 if batch.doc_ids is None:
504 raise ValueError("BiEncoderModule did not return doc_ids when searching")
505 dummy_docs = [[""] * len(ids) for ids in batch.doc_ids]
506 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, batch.doc_ids, batch.qrels)
507 return super()._rank(batch, output)
508
[docs]
509 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
510 """Hook to validate datasets
511
512 :param trainer: PyTorch Lightning Trainer
513 :type trainer: Trainer
514 :param pl_module: LightningIR BiEncoderModule
515 :type pl_module: BiEncoderModule
516 :raises ValueError: If no test_dataloaders are found
517 :raises ValueError: If not all datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset`
518 """
519 dataloaders = trainer.test_dataloaders
520 if dataloaders is None:
521 raise ValueError("No test_dataloaders found")
522 datasets = [dataloader.dataset for dataloader in dataloaders]
523 if not all(isinstance(dataset, (QueryDataset, _DummyIterableDataset)) for dataset in datasets):
524 raise ValueError("Expected QueryDatasets for indexing")
525
[docs]
526 def on_test_batch_start(
527 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
528 ) -> None:
529 """Hook to initialize searcher for new datasets.
530
531 :param trainer: PyTorch Lightning Trainer
532 :type trainer: Trainer
533 :param pl_module: LightningIR BiEncoderModule
534 :type pl_module: BiEncoderModule
535 :param batch: Batch of input data
536 :type batch: Any
537 :param batch_idx: Index of batch in dataset
538 :type batch_idx: int
539 :param dataloader_idx: Index of the dataloader, defaults to 0
540 :type dataloader_idx: int, optional
541 """
542 if batch_idx == 0:
543 self.searcher = self._get_searcher(trainer, pl_module, dataloader_idx)
544 pl_module.searcher = self.searcher
545 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
546
547
[docs]
548class ReRankCallback(RankCallback):
549 pass
550
551
[docs]
552class RegisterLocalDatasetCallback(Callback):
553
[docs]
554 def __init__(
555 self,
556 dataset_id: str,
557 docs: str | None = None,
558 queries: str | None = None,
559 qrels: str | None = None,
560 docpairs: str | None = None,
561 scoreddocs: str | None = None,
562 qrels_defs: Dict[int, str] | None = None,
563 ):
564 """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using
565 ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported:
566
567 - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries
568 - ``.tsv`` or ``.qrels`` for qrels
569 - ``.tsv`` for training n-tuples
570 - ``.tsv`` or ``.run`` for scored documents / run files
571
572 :param dataset_id: Dataset id
573 :type dataset_id: str
574 :param docs: Path to documents file or valid ir_datasets id from which documents should be taken,
575 defaults to None
576 :type docs: str | None, optional
577 :param queries: Path to queries file or valid ir_datastes id from which queries should be taken,
578 defaults to None
579 :type queries: str | None, optional
580 :param qrels: Path to qrels file or valid ir_datasets id from which qrels will be taken, defaults to None
581 :type qrels: str | None, optional
582 :param docpairs: Path to training n-tuple file or valid ir_datasets id from which training tuples will be taken,
583 defaults to None
584 :type docpairs: str | None, optional
585 :param scoreddocs: Path to run file or valid ir_datasets id from which scored documents will be taken,
586 defaults to None
587 :type scoreddocs: str | None, optional
588 :param qrels_defs: Optional dictionary describing the relevance levels of the qrels, defaults to None
589 :type qrels_defs: Dict[int, str] | None, optional
590 """
591 super().__init__()
592 self.dataset_id = dataset_id
593 self.docs = docs
594 self.queries = queries
595 self.qrels = qrels
596 self.docpairs = docpairs
597 self.scoreddocs = scoreddocs
598 self.qrels_defs = qrels_defs
599
[docs]
600 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
601 """Hook that registers dataset.
602
603 :param trainer: PyTorch Lightning Trainer
604 :type trainer: Trainer
605 :param pl_module: Lightning IR module
606 :type pl_module: LightningIRModule
607 :param stage: Stage of the trainer
608 :type stage: str
609 """
610 register_new_dataset(
611 self.dataset_id,
612 docs=self.docs,
613 queries=self.queries,
614 qrels=self.qrels,
615 docpairs=self.docpairs,
616 scoreddocs=self.scoreddocs,
617 qrels_defs=self.qrels_defs,
618 )