Source code for lightning_ir.callbacks.callbacks

  1"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""
  2
  3from __future__ import annotations
  4
  5import csv
  6import gc
  7import itertools
  8from dataclasses import is_dataclass
  9from pathlib import Path
 10from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar
 11
 12import pandas as pd
 13import torch
 14from lightning import Trainer
 15from lightning.pytorch.callbacks import Callback, TQDMProgressBar
 16
 17from ..base.validation_utils import evaluate_run
 18from ..data import LightningIRDataModule, RankBatch, SearchBatch
 19from ..data.dataset import RUN_HEADER, DocDataset, IRDataset, QueryDataset, RunDataset, _DummyIterableDataset
 20from ..data.external_datasets.ir_datasets_utils import register_new_dataset
 21from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
 22
 23if TYPE_CHECKING:
 24    from ..base import LightningIRModule, LightningIROutput
 25    from ..bi_encoder import BiEncoderModule, BiEncoderOutput
 26
 27T = TypeVar("T")
 28
 29
 30def _format_large_number(number: float) -> str:
 31    suffixes = ["", "K", "M", "B", "T"]
 32    suffix_index = 0
 33
 34    while number >= 1000 and suffix_index < len(suffixes) - 1:
 35        number /= 1000.0
 36        suffix_index += 1
 37
 38    formatted_number = "{:.2f}".format(number)
 39
 40    suffix = suffixes[suffix_index]
 41    if suffix:
 42        formatted_number += f" {suffix}"
 43    return formatted_number
 44
 45
 46class _GatherMixin:
 47    """Mixin to gather dataclasses across all processes"""
 48
 49    def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
 50        if is_dataclass(dataclass):
 51            return dataclass.__class__(
 52                **{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
 53            )
 54        return pl_module.all_gather(dataclass)
 55
 56
 57class _IndexDirMixin:
 58    """Mixin to get index_dir"""
 59
 60    index_dir: Path | str | None
 61    index_name: str | None
 62
 63    def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
 64        index_dir = self.index_dir
 65        if index_dir is None:
 66            default_index_dir = Path(pl_module.config.name_or_path)
 67            if default_index_dir.exists():
 68                index_dir = default_index_dir / "indexes"
 69            else:
 70                raise ValueError("No index_dir provided and model_name_or_path is not a path")
 71        index_dir = Path(index_dir)
 72        if self.index_name is None:
 73            index_dir = index_dir / dataset.docs_dataset_id
 74        else:
 75            index_dir = index_dir / self.index_name
 76        return index_dir
 77
 78
 79class _OverwriteMixin:
 80    """Mixin to skip datasets (for indexing or searching) if they already exist"""
 81
 82    _get_save_path: Callable[[LightningIRModule, IRDataset], Path]
 83
 84    def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
 85        overwrite = getattr(self, "overwrite", False)
 86        if not overwrite:
 87            datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
 88            if datamodule is None:
 89                raise ValueError("No datamodule found")
 90            if datamodule.inference_datasets is None:
 91                return
 92            inference_datasets = list(datamodule.inference_datasets)
 93            for dataset in inference_datasets:
 94                save_path = self._get_save_path(pl_module, dataset)
 95                if save_path.exists():
 96                    dataset._SKIP = True
 97                    trainer.print(f"`{save_path}` already exists. Set overwrite=True to overwrite")
 98                    if (
 99                        save_path.name.endswith(".run")
100                        and dataset.qrels is not None
101                        and pl_module.evaluation_metrics is not None
102                    ):
103                        run = RunDataset._load_csv(save_path)
104                        qrels = dataset.qrels.stack(future_stack=True).dropna().astype(int).reset_index()
105                        if isinstance(dataset, RunDataset) and dataset.run_path is not None:
106                            dataset_id = dataset.run_path.name
107                        else:
108                            dataset_id = dataset.dataset_id
109                        for key, value in evaluate_run(run, qrels, pl_module.evaluation_metrics).items():
110                            key = f"{dataset_id}/{key}"
111                            pl_module._additional_log_metrics[key] = value
112
113    def _cleanup(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
114        # reset skip flat and additional log metrics
115        datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
116        if datamodule is not None and datamodule.inference_datasets is not None:
117            for dataset in datamodule.inference_datasets:
118                dataset._SKIP = False
119        pl_module._additional_log_metrics = {}
120
121
[docs] 122class IndexCallback(Callback, _GatherMixin, _IndexDirMixin, _OverwriteMixin):
[docs] 123 def __init__( 124 self, 125 index_config: IndexConfig, 126 index_dir: Path | str | None = None, 127 index_name: str | None = None, 128 overwrite: bool = False, 129 verbose: bool = False, 130 ) -> None: 131 """Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`. 132 133 :param index_config: Configuration for the indexer 134 :type index_config: IndexConfig 135 :param index_dir: Directory to save index(es) to. If None, indexes will be stored in the model's directory, 136 defaults to None 137 :type index_dir: Path | str | None, optional 138 :param index_name: Name of the index. If None, the dataset's dataset_id or file name will be used, 139 defaults to None 140 :type index_name: str | None, optional 141 :param overwrite: Whether to skip or overwrite already existing indexes, defaults to False 142 :type overwrite: bool, optional 143 :param verbose: Toggle verbose output, defaults to False 144 :type verbose: bool, optional 145 """ 146 super().__init__() 147 self.index_config = index_config 148 self.index_dir = index_dir 149 self.index_name = index_name 150 self.overwrite = overwrite 151 self.verbose = verbose 152 self.indexer: Indexer
153
[docs] 154 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 155 """Hook to setup the callback. 156 157 :param trainer: PyTorch Lightning Trainer 158 :type trainer: Trainer 159 :param pl_module: LightningIR bi-encoder module used for indexing 160 :type pl_module: BiEncoderModule 161 :param stage: Stage of the trainer, must be "test" 162 :type stage: str 163 :raises ValueError: If the stage is not "test" 164 """ 165 if stage != "test": 166 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 167 self._remove_overwrite_datasets(trainer, pl_module)
168
[docs] 169 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 170 """Hook to cleanup the callback. 171 172 :param trainer: PyTorch Lightning Trainer 173 :type trainer: Trainer 174 :param pl_module: LightningIR bi-encoder module used for indexing 175 :type pl_module: BiEncoderModule 176 :param stage: Stage of the trainer, must be "test" 177 :type stage: str 178 """ 179 self._cleanup(trainer, pl_module)
180 181 def _get_save_path(self, pl_module: BiEncoderModule, dataset: IRDataset) -> Path: 182 if not isinstance(dataset, DocDataset): 183 raise ValueError("Expected DocDataset for indexing") 184 return self._get_index_dir(pl_module, dataset) 185 186 def _get_indexer(self, pl_module: BiEncoderModule, dataloader_idx: int) -> Indexer: 187 dataset = pl_module.get_dataset(dataloader_idx) 188 if dataset is None: 189 raise ValueError("No dataset found to index") 190 if not isinstance(dataset, DocDataset): 191 raise ValueError("Expected DocDataset for indexing") 192 index_dir = self._get_save_path(pl_module, dataset) 193 194 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module, self.verbose) 195 return indexer 196 197 def _log_to_pg(self, info: Dict[str, Any], trainer: Trainer): 198 pg_callback = trainer.progress_bar_callback 199 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar): 200 return 201 pg = pg_callback.test_progress_bar 202 info = {k: _format_large_number(v) for k, v in info.items()} 203 if pg is not None: 204 pg.set_postfix(info) 205
[docs] 206 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 207 """Hook to test datasets are configured correctly. 208 209 :param trainer: PyTorch Lightning Trainer 210 :type trainer: Trainer 211 :param pl_module: LightningIR BiEncoderModule 212 :type pl_module: BiEncoderModule 213 :raises ValueError: If no test_dataloaders are found 214 :raises ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset` 215 """ 216 dataloaders = trainer.test_dataloaders 217 if dataloaders is None: 218 raise ValueError("No test_dataloaders found") 219 datasets = [dataloader.dataset for dataloader in dataloaders] 220 if not all(isinstance(dataset, (DocDataset, _DummyIterableDataset)) for dataset in datasets): 221 raise ValueError("Expected DocDatasets for indexing")
222
[docs] 223 def on_test_batch_start( 224 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 225 ) -> None: 226 """Hook to setup the indexer between datasets. 227 228 :param trainer: PyTorch Lightning Trainer 229 :type trainer: Trainer 230 :param pl_module: LightningIR bi-encoder module 231 :type pl_module: BiEncoderModule 232 :param batch: Batch of input data 233 :type batch: Any 234 :param batch_idx: Index of batch in the current dataset 235 :type batch_idx: int 236 :param dataloader_idx: Index of the dataloader, defaults to 0 237 :type dataloader_idx: int, optional 238 """ 239 if batch_idx == 0: 240 self.indexer = self._get_indexer(pl_module, dataloader_idx) 241 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
242
[docs] 243 def on_test_batch_end( 244 self, 245 trainer: Trainer, 246 pl_module: BiEncoderModule, 247 outputs: BiEncoderOutput, 248 batch: Any, 249 batch_idx: int, 250 dataloader_idx: int = 0, 251 ) -> None: 252 """Hook to pass encoded documents to the indexer 253 254 :param trainer: PyTorch Lightning Trainer 255 :type trainer: Trainer 256 :param pl_module: LightningIR bi-encoder module 257 :type pl_module: BiEncoderModule 258 :param outputs: Encoded documents 259 :type outputs: BiEncoderOutput 260 :param batch: Batch of input data 261 :type batch: Any 262 :param batch_idx: Index of batch in the current dataset 263 :type batch_idx: int 264 :param dataloader_idx: Index of the dataloader, defaults to 0 265 :type dataloader_idx: int, optional 266 """ 267 batch = self._gather(pl_module, batch) 268 outputs = self._gather(pl_module, outputs) 269 270 if not trainer.is_global_zero: 271 return 272 273 self.indexer.add(batch, outputs) 274 self._log_to_pg( 275 { 276 "num_docs": self.indexer.num_docs, 277 "num_embeddings": self.indexer.num_embeddings, 278 }, 279 trainer, 280 ) 281 # TODO if dataset length cannot be inferred, num_test_batches is inf and no index is saved 282 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1: 283 assert hasattr(self, "indexer") 284 self.indexer.save()
285 286
[docs] 287class RankCallback(Callback, _GatherMixin, _OverwriteMixin):
[docs] 288 def __init__( 289 self, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False 290 ) -> None: 291 """Callback to write run file of ranked documents to disk. 292 293 :param save_dir: Directory to save run files to. If None, run files will be saved in the models' directory, 294 defaults to None 295 :type save_dir: Path | str | None, optional 296 :param run_name: Name of the run file. If None, the dataset's dataset_id or file name will be used, 297 defaults to None 298 :type run_name: str | None, optional 299 :param overwrite: Whether to skip or overwrite already existing run files, defaults to False 300 :type overwrite: bool, optional 301 """ 302 super().__init__() 303 self.save_dir = Path(save_dir) if save_dir is not None else None 304 self.run_name = run_name 305 self.overwrite = overwrite 306 self.run_dfs: List[pd.DataFrame] = []
307
[docs] 308 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None: 309 """Hook to setup the callback. 310 311 :param trainer: PyTorch Lightning Trainer 312 :type trainer: Trainer 313 :param pl_module: LightningIR module 314 :type pl_module: LightningIRModule 315 :param stage: Stage of the trainer, must be "test" 316 :type stage: str 317 :raises ValueError: If the stage is not "test" 318 :raises ValueError: If no save_dir is provided and model_name_or_path is not a path (the model is not local) 319 """ 320 if stage != "test": 321 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 322 if self.save_dir is None: 323 default_save_dir = Path(pl_module.config.name_or_path) 324 if default_save_dir.exists(): 325 self.save_dir = default_save_dir / "runs" 326 print(f"Using default save_dir `{self.save_dir}` to save runs") 327 else: 328 raise ValueError("No save_dir provided and model_name_or_path is not a path") 329 self._remove_overwrite_datasets(trainer, pl_module)
330
[docs] 331 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 332 """Hook to cleanup the callback. 333 334 :param trainer: PyTorch Lightning Trainer 335 :type trainer: Trainer 336 :param pl_module: LightningIR bi-encoder module used for indexing 337 :type pl_module: BiEncoderModule 338 :param stage: Stage of the trainer, must be "test" 339 :type stage: str 340 """ 341 self._cleanup(trainer, pl_module)
342 343 def _get_save_path(self, pl_module: LightningIRModule, dataset: IRDataset) -> Path: 344 if self.save_dir is None: 345 raise ValueError("No save_dir found; call setup before using this method") 346 if self.run_name is not None: 347 run_file = self.run_name 348 elif isinstance(dataset, QueryDataset): 349 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 350 elif isinstance(dataset, RunDataset): 351 if dataset.run_path is None: 352 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 353 else: 354 run_file = f"{dataset.run_path.name.split('.')[0]}.run" 355 else: 356 raise ValueError("Expected QueryDataset or RunDataset for ranking") 357 run_file_path = self.save_dir / run_file 358 return run_file_path 359 360 def _rank(self, batch: RankBatch, output: LightningIROutput) -> Tuple[torch.Tensor, List[str], List[int]]: 361 scores = output.scores 362 if scores is None: 363 raise ValueError("Expected output to have scores") 364 doc_ids = batch.doc_ids 365 if doc_ids is None: 366 raise ValueError("Expected batch to have doc_ids") 367 scores = scores.view(-1) 368 num_docs = [len(_doc_ids) for _doc_ids in doc_ids] 369 doc_ids = list(itertools.chain.from_iterable(doc_ids)) 370 if scores.shape[0] != len(doc_ids): 371 raise ValueError("scores and doc_ids must have the same length") 372 return scores, doc_ids, num_docs 373 374 def _write_run_dfs(self, trainer: Trainer, pl_module: LightningIRModule, dataloader_idx: int): 375 if not trainer.is_global_zero or not self.run_dfs: 376 return 377 dataloaders = trainer.test_dataloaders 378 if dataloaders is None: 379 raise ValueError("No test_dataloaders found") 380 dataset = pl_module.get_dataset(dataloader_idx) 381 if dataset is None: 382 raise ValueError("No dataset found to write run file") 383 if not isinstance(dataset, (QueryDataset, RunDataset)): 384 raise ValueError("Expected QueryDataset or RunDataset for ranking") 385 run_file_path = self._get_save_path(pl_module, dataset) 386 run_file_path.parent.mkdir(parents=True, exist_ok=True) 387 run_df = pd.concat(self.run_dfs, ignore_index=True) 388 run_df.to_csv(run_file_path, header=False, index=False, sep="\t", quoting=csv.QUOTE_NONE) 389
[docs] 390 def on_test_batch_end( 391 self, 392 trainer: Trainer, 393 pl_module: LightningIRModule, 394 outputs: LightningIROutput, 395 batch: Any, 396 batch_idx: int, 397 dataloader_idx: int = 0, 398 ) -> None: 399 """Hook to aggregate and write ranking to file. 400 401 :param trainer: PyTorch Lightning Trainer 402 :type trainer: Trainer 403 :param pl_module: LightningIR Module 404 :type pl_module: LightningIRModule 405 :param outputs: Scored query documents pairs 406 :type outputs: LightningIROutput 407 :param batch: Batch of input data 408 :type batch: Any 409 :param batch_idx: Index of batch in the current dataset 410 :type batch_idx: int 411 :param dataloader_idx: Index of the dataloader, defaults to 0 412 :type dataloader_idx: int, optional 413 """ 414 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) 415 batch = self._gather(pl_module, batch) 416 outputs = self._gather(pl_module, outputs) 417 if not trainer.is_global_zero: 418 return 419 420 query_ids = batch.query_ids 421 if query_ids is None: 422 raise ValueError("Expected batch to have query_ids") 423 scores, doc_ids, num_docs = self._rank(batch, outputs) 424 scores = scores.float().cpu().numpy() 425 426 query_ids = list( 427 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs)) 428 ) 429 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"]) 430 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False]) 431 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int) 432 run_df["q0"] = 0 433 run_df["system"] = pl_module.model.__class__.__name__ 434 run_df = run_df[RUN_HEADER] 435 436 self.run_dfs.append(run_df) 437 438 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1: 439 self._write_run_dfs(trainer, pl_module, dataloader_idx) 440 self.run_dfs = []
441 442
[docs] 443class SearchCallback(RankCallback, _IndexDirMixin):
[docs] 444 def __init__( 445 self, 446 search_config: SearchConfig, 447 index_dir: Path | str | None = None, 448 index_name: str | None = None, 449 save_dir: Path | str | None = None, 450 run_name: str | None = None, 451 overwrite: bool = False, 452 use_gpu: bool = True, 453 ) -> None: 454 """Callback to which uses index to retrieve documents efficiently. 455 456 :param search_config: Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher` 457 :type search_config: SearchConfig 458 :param index_dir: Directory where indexes are stored, defaults to None 459 :type index_dir: Path | str | None, optional 460 :param index_name: Name of the index. If None, the dataset's dataset_id or file name will be used, 461 defaults to None 462 :type index_name: str | None, optional 463 :param save_dir: Directory to save run files to. If None, run files are saved in the model's directory, 464 defaults to None 465 :type save_dir: Path | str | None, optional 466 :param run_name: Name of the run file. If None, the dataset's dataset_id or file name will be used, 467 defaults to None 468 :type run_name: str | None, optional 469 :param overwrite: Whether to skip or overwrite already existing run files, defaults to False 470 :type overwrite: bool, optional 471 :param use_gpu: Toggle to use gpu for retrieval, defaults to True 472 :type use_gpu: bool, optional 473 """ 474 super().__init__(save_dir=save_dir, run_name=run_name, overwrite=overwrite) 475 self.search_config = search_config 476 self.index_dir = index_dir 477 self.index_name = index_name 478 self.overwrite = overwrite 479 self.use_gpu = use_gpu 480 self.searcher: Searcher
481 482 def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Searcher: 483 dataloaders = trainer.test_dataloaders 484 if dataloaders is None: 485 raise ValueError("No test_dataloaders found") 486 dataset = dataloaders[dataset_idx].dataset 487 488 index_dir = self._get_index_dir(pl_module, dataset) 489 if hasattr(self, "searcher"): 490 if self.searcher.index_dir == index_dir: 491 return self.searcher 492 # free up memory 493 del self.searcher 494 gc.collect() 495 torch.cuda.empty_cache() 496 497 searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu) 498 return searcher 499 500 def _rank( 501 self, batch: SearchBatch | RankBatch, output: LightningIROutput 502 ) -> Tuple[torch.Tensor, List[str], List[int]]: 503 if batch.doc_ids is None: 504 raise ValueError("BiEncoderModule did not return doc_ids when searching") 505 dummy_docs = [[""] * len(ids) for ids in batch.doc_ids] 506 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, batch.doc_ids, batch.qrels) 507 return super()._rank(batch, output) 508
[docs] 509 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 510 """Hook to validate datasets 511 512 :param trainer: PyTorch Lightning Trainer 513 :type trainer: Trainer 514 :param pl_module: LightningIR BiEncoderModule 515 :type pl_module: BiEncoderModule 516 :raises ValueError: If no test_dataloaders are found 517 :raises ValueError: If not all datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset` 518 """ 519 dataloaders = trainer.test_dataloaders 520 if dataloaders is None: 521 raise ValueError("No test_dataloaders found") 522 datasets = [dataloader.dataset for dataloader in dataloaders] 523 if not all(isinstance(dataset, (QueryDataset, _DummyIterableDataset)) for dataset in datasets): 524 raise ValueError("Expected QueryDatasets for indexing")
525
[docs] 526 def on_test_batch_start( 527 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 528 ) -> None: 529 """Hook to initialize searcher for new datasets. 530 531 :param trainer: PyTorch Lightning Trainer 532 :type trainer: Trainer 533 :param pl_module: LightningIR BiEncoderModule 534 :type pl_module: BiEncoderModule 535 :param batch: Batch of input data 536 :type batch: Any 537 :param batch_idx: Index of batch in dataset 538 :type batch_idx: int 539 :param dataloader_idx: Index of the dataloader, defaults to 0 540 :type dataloader_idx: int, optional 541 """ 542 if batch_idx == 0: 543 self.searcher = self._get_searcher(trainer, pl_module, dataloader_idx) 544 pl_module.searcher = self.searcher 545 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
546 547
[docs] 548class ReRankCallback(RankCallback): 549 pass
550 551
[docs] 552class RegisterLocalDatasetCallback(Callback): 553
[docs] 554 def __init__( 555 self, 556 dataset_id: str, 557 docs: str | None = None, 558 queries: str | None = None, 559 qrels: str | None = None, 560 docpairs: str | None = None, 561 scoreddocs: str | None = None, 562 qrels_defs: Dict[int, str] | None = None, 563 ): 564 """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using 565 ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported: 566 567 - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries 568 - ``.tsv`` or ``.qrels`` for qrels 569 - ``.tsv`` for training n-tuples 570 - ``.tsv`` or ``.run`` for scored documents / run files 571 572 :param dataset_id: Dataset id 573 :type dataset_id: str 574 :param docs: Path to documents file or valid ir_datasets id from which documents should be taken, 575 defaults to None 576 :type docs: str | None, optional 577 :param queries: Path to queries file or valid ir_datastes id from which queries should be taken, 578 defaults to None 579 :type queries: str | None, optional 580 :param qrels: Path to qrels file or valid ir_datasets id from which qrels will be taken, defaults to None 581 :type qrels: str | None, optional 582 :param docpairs: Path to training n-tuple file or valid ir_datasets id from which training tuples will be taken, 583 defaults to None 584 :type docpairs: str | None, optional 585 :param scoreddocs: Path to run file or valid ir_datasets id from which scored documents will be taken, 586 defaults to None 587 :type scoreddocs: str | None, optional 588 :param qrels_defs: Optional dictionary describing the relevance levels of the qrels, defaults to None 589 :type qrels_defs: Dict[int, str] | None, optional 590 """ 591 super().__init__() 592 self.dataset_id = dataset_id 593 self.docs = docs 594 self.queries = queries 595 self.qrels = qrels 596 self.docpairs = docpairs 597 self.scoreddocs = scoreddocs 598 self.qrels_defs = qrels_defs
599
[docs] 600 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None: 601 """Hook that registers dataset. 602 603 :param trainer: PyTorch Lightning Trainer 604 :type trainer: Trainer 605 :param pl_module: Lightning IR module 606 :type pl_module: LightningIRModule 607 :param stage: Stage of the trainer 608 :type stage: str 609 """ 610 register_new_dataset( 611 self.dataset_id, 612 docs=self.docs, 613 queries=self.queries, 614 qrels=self.qrels, 615 docpairs=self.docpairs, 616 scoreddocs=self.scoreddocs, 617 qrels_defs=self.qrels_defs, 618 )