Source code for lightning_ir.callbacks.callbacks

  1"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""
  2
  3from __future__ import annotations
  4
  5import csv
  6import gc
  7import itertools
  8from dataclasses import is_dataclass
  9from pathlib import Path
 10from typing import TYPE_CHECKING, Any, Callable, TypeVar
 11
 12import pandas as pd
 13import torch
 14from lightning import Trainer
 15from lightning.pytorch.callbacks import Callback, TQDMProgressBar
 16
 17from ..base.validation_utils import evaluate_run
 18from ..data import LightningIRDataModule, RankBatch, SearchBatch
 19from ..data.dataset import RUN_HEADER, DocDataset, IRDataset, QueryDataset, RunDataset, _DummyIterableDataset
 20from ..data.external_datasets.ir_datasets_utils import register_new_dataset
 21from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
 22
 23if TYPE_CHECKING:
 24    from ..base import LightningIRModule, LightningIROutput
 25    from ..bi_encoder import BiEncoderModule, BiEncoderOutput
 26
 27T = TypeVar("T")
 28
 29
 30def _format_large_number(number: float) -> str:
 31    suffixes = ["", "K", "M", "B", "T"]
 32    suffix_index = 0
 33
 34    while number >= 1000 and suffix_index < len(suffixes) - 1:
 35        number /= 1000.0
 36        suffix_index += 1
 37
 38    formatted_number = f"{number:.2f}"
 39
 40    suffix = suffixes[suffix_index]
 41    if suffix:
 42        formatted_number += f" {suffix}"
 43    return formatted_number
 44
 45
 46class _GatherMixin:
 47    """Mixin to gather dataclasses across all processes"""
 48
 49    def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
 50        if is_dataclass(dataclass):
 51            return dataclass.__class__(
 52                **{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
 53            )
 54        return pl_module.all_gather(dataclass)
 55
 56
 57class _IndexDirMixin:
 58    """Mixin to get index_dir"""
 59
 60    index_dir: Path | str | None
 61    index_name: str | None
 62
 63    def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
 64        index_dir = self.index_dir
 65        if index_dir is None:
 66            default_index_dir = Path(pl_module.config.name_or_path)
 67            if default_index_dir.exists():
 68                index_dir = default_index_dir / "indexes"
 69            else:
 70                raise ValueError("No index_dir provided and model_name_or_path is not a path")
 71        index_dir = Path(index_dir)
 72        if self.index_name is None:
 73            index_dir = index_dir / dataset.dashed_docs_dataset_id
 74        else:
 75            index_dir = index_dir / self.index_name
 76        return index_dir
 77
 78
 79class _OverwriteMixin:
 80    """Mixin to skip datasets (for indexing or searching) if they already exist"""
 81
 82    _get_save_path: Callable[[LightningIRModule, IRDataset], Path]
 83
 84    def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
 85        overwrite = getattr(self, "overwrite", False)
 86        if not overwrite:
 87            datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
 88            if datamodule is None:
 89                raise ValueError("No datamodule found")
 90            if datamodule.inference_datasets is None:
 91                return
 92            inference_datasets = list(datamodule.inference_datasets)
 93            for dataset in inference_datasets:
 94                save_path = self._get_save_path(pl_module, dataset)
 95                if save_path.exists():
 96                    dataset._SKIP = True
 97                    trainer.print(f"`{save_path}` already exists. set overwrite=True to overwrite")
 98                    if (
 99                        save_path.name.endswith(".run")
100                        and dataset.qrels is not None
101                        and pl_module.evaluation_metrics is not None
102                    ):
103                        run = RunDataset._load_csv(save_path)
104                        qrels = dataset.qrels.stack(future_stack=True).dropna().astype(int).reset_index()
105                        if isinstance(dataset, RunDataset) and dataset.run_path is not None:
106                            dataset_id = dataset.run_path.name
107                        else:
108                            dataset_id = dataset.dataset_id
109                        for key, value in evaluate_run(run, qrels, pl_module.evaluation_metrics).items():
110                            key = f"{dataset_id}/{key}"
111                            pl_module._additional_log_metrics[key] = value
112
113    def _cleanup(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
114        # reset skip flat and additional log metrics
115        datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
116        if datamodule is not None and datamodule.inference_datasets is not None:
117            for dataset in datamodule.inference_datasets:
118                dataset._SKIP = False
119        pl_module._additional_log_metrics = {}
120
121
[docs] 122class IndexCallback(Callback, _GatherMixin, _IndexDirMixin, _OverwriteMixin):
[docs] 123 def __init__( 124 self, 125 index_config: IndexConfig, 126 index_dir: Path | str | None = None, 127 index_name: str | None = None, 128 overwrite: bool = False, 129 verbose: bool = False, 130 ) -> None: 131 """Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`. 132 133 Args: 134 index_config (IndexConfig): Configuration for the indexer. 135 index_dir (Path | str | None): Directory to save index(es) to. If None, indexes will be stored in the 136 model's directory. Defaults to None. 137 index_name (str | None): Name of the index. If None, the dataset's dataset_id or file name will be used. 138 Defaults to None. 139 overwrite (bool): Whether to skip or overwrite already existing indexes. Defaults to False. 140 verbose (bool): Toggle verbose output. Defaults to False. 141 """ 142 super().__init__() 143 self.index_config = index_config 144 self.index_dir = index_dir 145 self.index_name = index_name 146 self.overwrite = overwrite 147 self.verbose = verbose 148 self.indexer: Indexer
149
[docs] 150 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 151 """Hook to setup the callback. 152 153 Args: 154 trainer (Trainer): PyTorch Lightning Trainer. 155 pl_module (BiEncoderModule): LightningIR bi-encoder module used for indexing. 156 stage (str): Stage of the trainer, must be "test". 157 Raises: 158 ValueError: If the stage is not "test". 159 """ 160 if stage != "test": 161 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 162 self._remove_overwrite_datasets(trainer, pl_module)
163
[docs] 164 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 165 """Hook to cleanup the callback. 166 167 Args: 168 trainer (Trainer): PyTorch Lightning Trainer. 169 pl_module (BiEncoderModule): LightningIR bi-encoder module used for indexing. 170 stage (str): Stage of the trainer. 171 """ 172 self._cleanup(trainer, pl_module)
173 174 def _get_save_path(self, pl_module: BiEncoderModule, dataset: IRDataset) -> Path: 175 if not isinstance(dataset, DocDataset): 176 raise ValueError("Expected DocDataset for indexing") 177 return self._get_index_dir(pl_module, dataset) 178 179 def _get_indexer(self, pl_module: BiEncoderModule, dataloader_idx: int) -> Indexer: 180 dataset = pl_module.get_dataset(dataloader_idx) 181 if dataset is None: 182 raise ValueError("No dataset found to index") 183 if not isinstance(dataset, DocDataset): 184 raise ValueError("Expected DocDataset for indexing") 185 index_dir = self._get_save_path(pl_module, dataset) 186 187 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module, self.verbose) 188 return indexer 189 190 def _log_to_pg(self, info: dict[str, Any], trainer: Trainer): 191 pg_callback = trainer.progress_bar_callback 192 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar): 193 return 194 pg = pg_callback.test_progress_bar 195 info = {k: _format_large_number(v) for k, v in info.items()} 196 if pg is not None: 197 pg.set_postfix(info) 198
[docs] 199 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 200 """Hook to test datasets are configured correctly. 201 202 Args: 203 trainer (Trainer): PyTorch Lightning Trainer. 204 pl_module (BiEncoderModule): LightningIR bi-encoder module. 205 Raises: 206 ValueError: If no test_dataloaders are found. 207 ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset`. 208 """ 209 dataloaders = trainer.test_dataloaders 210 if dataloaders is None: 211 raise ValueError("No test_dataloaders found") 212 datasets = [dataloader.dataset for dataloader in dataloaders] 213 if not all(isinstance(dataset, (DocDataset, _DummyIterableDataset)) for dataset in datasets): 214 raise ValueError("Expected DocDatasets for indexing")
215
[docs] 216 def on_test_batch_start( 217 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 218 ) -> None: 219 """Hook to setup the indexer between datasets. 220 221 Args: 222 trainer (Trainer): PyTorch Lightning Trainer. 223 pl_module (BiEncoderModule): LightningIR bi-encoder module. 224 batch (Any): Batch of input data. 225 batch_idx (int): Index of batch in the current dataset. 226 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 227 """ 228 if not trainer.is_global_zero: 229 return 230 if batch_idx == 0: 231 if hasattr(self, "indexer"): 232 self.indexer.save() 233 self.indexer = self._get_indexer(pl_module, dataloader_idx) 234 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
235
[docs] 236 def on_test_batch_end( 237 self, 238 trainer: Trainer, 239 pl_module: BiEncoderModule, 240 outputs: BiEncoderOutput, 241 batch: Any, 242 batch_idx: int, 243 dataloader_idx: int = 0, 244 ) -> None: 245 """Hook to pass encoded documents to the indexer 246 247 Args: 248 trainer (Trainer): PyTorch Lightning Trainer. 249 pl_module (BiEncoderModule): LightningIR bi-encoder module. 250 outputs (BiEncoderOutput): Encoded documents. 251 batch (Any): Batch of input data. 252 batch_idx (int): Index of batch in the current dataset. 253 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 254 """ 255 batch = self._gather(pl_module, batch) 256 outputs = self._gather(pl_module, outputs) 257 258 if not trainer.is_global_zero: 259 return 260 261 self.indexer.add(batch, outputs) 262 self._log_to_pg( 263 { 264 "num_docs": self.indexer.num_docs, 265 "num_embeddings": self.indexer.num_embeddings, 266 }, 267 trainer, 268 ) 269 # TODO if dataset length cannot be inferred, num_test_batches is inf and no index is saved 270 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1: 271 assert hasattr(self, "indexer")
272
[docs] 273 def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningIRModule) -> None: 274 """Hook to save the index after indexing is done. 275 276 Args: 277 trainer (Trainer): PyTorch Lightning Trainer. 278 pl_module (LightningIRModule): LightningIR module. 279 """ 280 if not trainer.is_global_zero: 281 return 282 if hasattr(self, "indexer"): 283 self.indexer.save()
284 285
[docs] 286class RankCallback(Callback, _GatherMixin, _OverwriteMixin):
[docs] 287 def __init__( 288 self, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False 289 ) -> None: 290 """Callback to write run file of ranked documents to disk. 291 292 Args: 293 save_dir (Path | str | None): Directory to save run files to. If None, run files will be saved in the 294 models' directory. Defaults to None. 295 run_name (str | None): Name of the run file. If None, the dataset's dataset_id or file name will be used. 296 Defaults to None. 297 overwrite (bool): Whether to skip or overwrite already existing run files. Defaults to False. 298 """ 299 super().__init__() 300 self.save_dir = Path(save_dir) if save_dir is not None else None 301 self.run_name = run_name 302 self.overwrite = overwrite 303 self.run_dfs: list[pd.DataFrame] = []
304
[docs] 305 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None: 306 """Hook to setup the callback. 307 308 Args: 309 trainer (Trainer): PyTorch Lightning Trainer. 310 pl_module (LightningIRModule): LightningIR module. 311 stage (str): Stage of the trainer, must be "test". 312 Raises: 313 ValueError: If the stage is not "test". 314 ValueError: If no save_dir is provided and model_name_or_path is not a path (the model is not local). 315 """ 316 if stage != "test": 317 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 318 if self.save_dir is None: 319 default_save_dir = Path(pl_module.config.name_or_path) 320 if default_save_dir.exists(): 321 self.save_dir = default_save_dir / "runs" 322 print(f"Using default save_dir `{self.save_dir}` to save runs") 323 else: 324 raise ValueError("No save_dir provided and model_name_or_path is not a path") 325 self._remove_overwrite_datasets(trainer, pl_module)
326
[docs] 327 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 328 """Hook to cleanup the callback. 329 330 Args: 331 trainer (Trainer): PyTorch Lightning Trainer. 332 pl_module (LightningIRModule): LightningIR bi-encoder module used for indexing. 333 stage (str): Stage of the trainer, must be "test". 334 """ 335 self._cleanup(trainer, pl_module)
336 337 def _get_save_path(self, pl_module: LightningIRModule, dataset: IRDataset) -> Path: 338 if self.save_dir is None: 339 raise ValueError("No save_dir found; call setup before using this method") 340 if self.run_name is not None: 341 run_file = self.run_name 342 elif isinstance(dataset, QueryDataset): 343 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 344 elif isinstance(dataset, RunDataset): 345 if dataset.run_path is None: 346 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 347 else: 348 run_file = f"{dataset.run_path.name.split('.')[0]}.run" 349 else: 350 raise ValueError("Expected QueryDataset or RunDataset for ranking") 351 run_file_path = self.save_dir / run_file 352 return run_file_path 353 354 def _rank(self, batch: RankBatch, output: LightningIROutput) -> tuple[torch.Tensor, list[str], list[int]]: 355 scores = output.scores 356 if scores is None: 357 raise ValueError("Expected output to have scores") 358 doc_ids = batch.doc_ids 359 if doc_ids is None: 360 raise ValueError("Expected batch to have doc_ids") 361 scores = scores.view(-1) 362 num_docs = [len(_doc_ids) for _doc_ids in doc_ids] 363 doc_ids = list(itertools.chain.from_iterable(doc_ids)) 364 if scores.shape[0] != len(doc_ids): 365 raise ValueError("scores and doc_ids must have the same length") 366 return scores, doc_ids, num_docs 367 368 def _write_run_dfs(self, trainer: Trainer, pl_module: LightningIRModule, dataloader_idx: int): 369 if not trainer.is_global_zero or not self.run_dfs: 370 return 371 dataloaders = trainer.test_dataloaders 372 if dataloaders is None: 373 raise ValueError("No test_dataloaders found") 374 dataset = pl_module.get_dataset(dataloader_idx) 375 if dataset is None: 376 raise ValueError("No dataset found to write run file") 377 if not isinstance(dataset, (QueryDataset, RunDataset)): 378 raise ValueError("Expected QueryDataset or RunDataset for ranking") 379 run_file_path = self._get_save_path(pl_module, dataset) 380 run_file_path.parent.mkdir(parents=True, exist_ok=True) 381 run_df = pd.concat(self.run_dfs, ignore_index=True) 382 run_df.to_csv(run_file_path, header=False, index=False, sep="\t", quoting=csv.QUOTE_NONE) 383
[docs] 384 def on_test_batch_start( 385 self, trainer: Trainer, pl_module: LightningIRModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 386 ) -> None: 387 """Hook to write run file. 388 389 Args: 390 trainer (Trainer): PyTorch Lightning Trainer. 391 pl_module (LightningIRModule): LightningIR module. 392 batch (Any): Batch of input data. 393 batch_idx (int): Index of batch in the current dataset. 394 dataloader_idx (int, optional): Index of the dataloader. Defaults to 0. 395 """ 396 if not trainer.is_global_zero: 397 return 398 if batch_idx == 0: 399 if self.run_dfs: 400 self._write_run_dfs(trainer, pl_module, dataloader_idx) 401 self.run_dfs = [] 402 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
403
[docs] 404 def on_test_batch_end( 405 self, 406 trainer: Trainer, 407 pl_module: LightningIRModule, 408 outputs: LightningIROutput, 409 batch: Any, 410 batch_idx: int, 411 dataloader_idx: int = 0, 412 ) -> None: 413 """Hook to aggregate and write ranking to file. 414 415 Args: 416 trainer (Trainer): PyTorch Lightning Trainer. 417 pl_module (LightningIRModule): LightningIR module. 418 outputs (LightningIROutput): Scored query documents pairs. 419 batch (Any): Batch of input data. 420 batch_idx (int): Index of batch in the current dataset. 421 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 422 Raises: 423 ValueError: If the batch does not have query_ids. 424 """ 425 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) 426 batch = self._gather(pl_module, batch) 427 outputs = self._gather(pl_module, outputs) 428 if not trainer.is_global_zero: 429 return 430 431 query_ids = batch.query_ids 432 if query_ids is None: 433 raise ValueError("Expected batch to have query_ids") 434 scores, doc_ids, num_docs = self._rank(batch, outputs) 435 scores = scores.float().cpu().numpy() 436 437 query_ids = list( 438 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs)) 439 ) 440 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"]) 441 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False]) 442 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int) 443 run_df["q0"] = 0 444 run_df["system"] = pl_module.model.__class__.__name__ 445 run_df = run_df[RUN_HEADER] 446 447 self.run_dfs.append(run_df)
448
[docs] 449 def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningIRModule) -> None: 450 """Hook to write remaining run files after ranking is done. 451 452 Args: 453 trainer (Trainer): PyTorch Lightning Trainer. 454 pl_module (LightningIRModule): LightningIR module. 455 """ 456 if not trainer.is_global_zero: 457 return 458 if self.run_dfs: 459 self._write_run_dfs(trainer, pl_module, 0)
460 461
[docs] 462class SearchCallback(RankCallback, _IndexDirMixin):
[docs] 463 def __init__( 464 self, 465 search_config: SearchConfig, 466 index_dir: Path | str | None = None, 467 index_name: str | None = None, 468 save_dir: Path | str | None = None, 469 run_name: str | None = None, 470 overwrite: bool = False, 471 use_gpu: bool = True, 472 ) -> None: 473 """Callback to which uses index to retrieve documents efficiently. 474 475 Args: 476 search_config (SearchConfig): Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher` 477 index_dir (Path | str | None): Directory where indexes are stored. Defaults to None. 478 index_name (str | None): Name of the index. If None, the dataset's dataset_id or file name will be used. 479 Defaults to None. 480 save_dir (Path | str | None): Directory to save run files to. If None, run files will be saved in the 481 model's directory. Defaults to None. 482 run_name (str | None): Name of the run file. If None, the dataset's dataset_id or file name will be used. 483 Defaults to None. 484 overwrite (bool): Whether to skip or overwrite already existing run files. Defaults to False. 485 use_gpu (bool): Toggle to use GPU for retrieval. Defaults to True. 486 """ 487 super().__init__(save_dir=save_dir, run_name=run_name, overwrite=overwrite) 488 self.search_config = search_config 489 self.index_dir = index_dir 490 self.index_name = index_name 491 self.overwrite = overwrite 492 self.use_gpu = use_gpu 493 self.searcher: Searcher
494 495 def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Searcher: 496 dataloaders = trainer.test_dataloaders 497 if dataloaders is None: 498 raise ValueError("No test_dataloaders found") 499 dataset = dataloaders[dataset_idx].dataset 500 501 index_dir = self._get_index_dir(pl_module, dataset) 502 if hasattr(self, "searcher"): 503 if self.searcher.index_dir == index_dir: 504 return self.searcher 505 # free up memory 506 del self.searcher 507 gc.collect() 508 torch.cuda.empty_cache() 509 510 searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu) 511 return searcher 512 513 def _rank( 514 self, batch: SearchBatch | RankBatch, output: LightningIROutput 515 ) -> tuple[torch.Tensor, list[str], list[int]]: 516 if batch.doc_ids is None: 517 raise ValueError("BiEncoderModule did not return doc_ids when searching") 518 dummy_docs = [[""] * len(ids) for ids in batch.doc_ids] 519 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, batch.doc_ids, batch.qrels) 520 return super()._rank(batch, output) 521
[docs] 522 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 523 """Hook to validate datasets 524 525 Args: 526 trainer (Trainer): PyTorch Lightning Trainer. 527 pl_module (BiEncoderModule): LightningIR bi-encoder module. 528 Raises: 529 ValueError: If no test_dataloaders are found. 530 ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset`. 531 """ 532 dataloaders = trainer.test_dataloaders 533 if dataloaders is None: 534 raise ValueError("No test_dataloaders found") 535 datasets = [dataloader.dataset for dataloader in dataloaders] 536 if not all(isinstance(dataset, (QueryDataset, _DummyIterableDataset)) for dataset in datasets): 537 raise ValueError("Expected QueryDatasets for indexing")
538
[docs] 539 def on_test_batch_start( 540 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 541 ) -> None: 542 """Hook to initialize searcher for new datasets. 543 544 Args: 545 trainer (Trainer): PyTorch Lightning Trainer. 546 pl_module (BiEncoderModule): LightningIR bi-encoder module. 547 batch (Any): Batch of input data. 548 batch_idx (int): Index of the batch in the dataset. 549 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 550 """ 551 if batch_idx == 0: 552 self.searcher = self._get_searcher(trainer, pl_module, dataloader_idx) 553 pl_module.searcher = self.searcher 554 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
555 556
[docs] 557class ReRankCallback(RankCallback): 558 pass
559 560
[docs] 561class RegisterLocalDatasetCallback(Callback):
[docs] 562 def __init__( 563 self, 564 dataset_id: str, 565 docs: str | None = None, 566 queries: str | None = None, 567 qrels: str | None = None, 568 docpairs: str | None = None, 569 scoreddocs: str | None = None, 570 qrels_defs: dict[int, str] | None = None, 571 ): 572 """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using 573 ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported: 574 575 - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries 576 - ``.tsv`` or ``.qrels`` for qrels 577 - ``.tsv`` for training n-tuples 578 - ``.tsv`` or ``.run`` for scored documents / run files 579 580 Args: 581 dataset_id (str): Dataset id. 582 docs (str | None): Path to documents file or valid ir_datasets id from which documents should be taken. 583 Defaults to None. 584 queries (str | None): Path to queries file or valid ir_datasets id from which queries should be taken. 585 Defaults to None. 586 qrels (str | None): Path to qrels file or valid ir_datasets id from which qrels will be taken. 587 Defaults to None. 588 docpairs (str | None): Path to training n-tuple file or valid ir_datasets id from which training tuples 589 will be taken. Defaults to None. 590 scoreddocs (str | None): Path to run file or valid ir_datasets id from which scored documents will be taken. 591 Defaults to None. 592 qrels_defs (dict[int, str] | None): Optional dictionary describing the relevance levels of the qrels. 593 Defaults to None. 594 """ 595 super().__init__() 596 self.dataset_id = dataset_id 597 self.docs = docs 598 self.queries = queries 599 self.qrels = qrels 600 self.docpairs = docpairs 601 self.scoreddocs = scoreddocs 602 self.qrels_defs = qrels_defs
603
[docs] 604 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None: 605 """Hook that registers dataset. 606 607 Args: 608 trainer (Trainer): PyTorch Lightning Trainer. 609 pl_module (LightningIRModule): Lightning IR module. 610 stage (str): Stage of the trainer. 611 """ 612 register_new_dataset( 613 self.dataset_id, 614 docs=self.docs, 615 queries=self.queries, 616 qrels=self.qrels, 617 docpairs=self.docpairs, 618 scoreddocs=self.scoreddocs, 619 qrels_defs=self.qrels_defs, 620 )