Source code for lightning_ir.callbacks.callbacks

  1"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""
  2
  3from __future__ import annotations
  4
  5import csv
  6import gc
  7import itertools
  8from dataclasses import is_dataclass
  9from pathlib import Path
 10from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar
 11
 12import pandas as pd
 13import torch
 14from lightning import Trainer
 15from lightning.pytorch.callbacks import Callback, TQDMProgressBar
 16
 17from ..base.validation_utils import evaluate_run
 18from ..data import LightningIRDataModule, RankBatch, SearchBatch
 19from ..data.dataset import RUN_HEADER, DocDataset, IRDataset, QueryDataset, RunDataset, _DummyIterableDataset
 20from ..data.external_datasets.ir_datasets_utils import register_new_dataset
 21from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
 22
 23if TYPE_CHECKING:
 24    from ..base import LightningIRModule, LightningIROutput
 25    from ..bi_encoder import BiEncoderModule, BiEncoderOutput
 26
 27T = TypeVar("T")
 28
 29
 30def _format_large_number(number: float) -> str:
 31    suffixes = ["", "K", "M", "B", "T"]
 32    suffix_index = 0
 33
 34    while number >= 1000 and suffix_index < len(suffixes) - 1:
 35        number /= 1000.0
 36        suffix_index += 1
 37
 38    formatted_number = "{:.2f}".format(number)
 39
 40    suffix = suffixes[suffix_index]
 41    if suffix:
 42        formatted_number += f" {suffix}"
 43    return formatted_number
 44
 45
 46class _GatherMixin:
 47    """Mixin to gather dataclasses across all processes"""
 48
 49    def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
 50        if is_dataclass(dataclass):
 51            return dataclass.__class__(
 52                **{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
 53            )
 54        return pl_module.all_gather(dataclass)
 55
 56
 57class _IndexDirMixin:
 58    """Mixin to get index_dir"""
 59
 60    index_dir: Path | str | None
 61    index_name: str | None
 62
 63    def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
 64        index_dir = self.index_dir
 65        if index_dir is None:
 66            default_index_dir = Path(pl_module.config.name_or_path)
 67            if default_index_dir.exists():
 68                index_dir = default_index_dir / "indexes"
 69            else:
 70                raise ValueError("No index_dir provided and model_name_or_path is not a path")
 71        index_dir = Path(index_dir)
 72        if self.index_name is None:
 73            index_dir = index_dir / dataset.dashed_docs_dataset_id
 74        else:
 75            index_dir = index_dir / self.index_name
 76        return index_dir
 77
 78
 79class _OverwriteMixin:
 80    """Mixin to skip datasets (for indexing or searching) if they already exist"""
 81
 82    _get_save_path: Callable[[LightningIRModule, IRDataset], Path]
 83
 84    def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
 85        overwrite = getattr(self, "overwrite", False)
 86        if not overwrite:
 87            datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
 88            if datamodule is None:
 89                raise ValueError("No datamodule found")
 90            if datamodule.inference_datasets is None:
 91                return
 92            inference_datasets = list(datamodule.inference_datasets)
 93            for dataset in inference_datasets:
 94                save_path = self._get_save_path(pl_module, dataset)
 95                if save_path.exists():
 96                    dataset._SKIP = True
 97                    trainer.print(f"`{save_path}` already exists. Set overwrite=True to overwrite")
 98                    if (
 99                        save_path.name.endswith(".run")
100                        and dataset.qrels is not None
101                        and pl_module.evaluation_metrics is not None
102                    ):
103                        run = RunDataset._load_csv(save_path)
104                        qrels = dataset.qrels.stack(future_stack=True).dropna().astype(int).reset_index()
105                        if isinstance(dataset, RunDataset) and dataset.run_path is not None:
106                            dataset_id = dataset.run_path.name
107                        else:
108                            dataset_id = dataset.dataset_id
109                        for key, value in evaluate_run(run, qrels, pl_module.evaluation_metrics).items():
110                            key = f"{dataset_id}/{key}"
111                            pl_module._additional_log_metrics[key] = value
112
113    def _cleanup(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
114        # reset skip flat and additional log metrics
115        datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
116        if datamodule is not None and datamodule.inference_datasets is not None:
117            for dataset in datamodule.inference_datasets:
118                dataset._SKIP = False
119        pl_module._additional_log_metrics = {}
120
121
[docs] 122class IndexCallback(Callback, _GatherMixin, _IndexDirMixin, _OverwriteMixin):
[docs] 123 def __init__( 124 self, 125 index_config: IndexConfig, 126 index_dir: Path | str | None = None, 127 index_name: str | None = None, 128 overwrite: bool = False, 129 verbose: bool = False, 130 ) -> None: 131 """Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`. 132 133 Args: 134 index_config (IndexConfig): Configuration for the indexer. 135 index_dir (Path | str | None): Directory to save index(es) to. If None, indexes will be stored in the 136 model's directory. Defaults to None. 137 index_name (str | None): Name of the index. If None, the dataset's dataset_id or file name will be used. 138 Defaults to None. 139 overwrite (bool): Whether to skip or overwrite already existing indexes. Defaults to False. 140 verbose (bool): Toggle verbose output. Defaults to False. 141 """ 142 super().__init__() 143 self.index_config = index_config 144 self.index_dir = index_dir 145 self.index_name = index_name 146 self.overwrite = overwrite 147 self.verbose = verbose 148 self.indexer: Indexer
149
[docs] 150 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 151 """Hook to setup the callback. 152 153 Args: 154 trainer (Trainer): PyTorch Lightning Trainer. 155 pl_module (BiEncoderModule): LightningIR bi-encoder module used for indexing. 156 stage (str): Stage of the trainer, must be "test". 157 Raises: 158 ValueError: If the stage is not "test". 159 """ 160 if stage != "test": 161 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 162 self._remove_overwrite_datasets(trainer, pl_module)
163
[docs] 164 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 165 """Hook to cleanup the callback. 166 167 Args: 168 trainer (Trainer): PyTorch Lightning Trainer. 169 pl_module (BiEncoderModule): LightningIR bi-encoder module used for indexing. 170 stage (str): Stage of the trainer. 171 """ 172 self._cleanup(trainer, pl_module)
173 174 def _get_save_path(self, pl_module: BiEncoderModule, dataset: IRDataset) -> Path: 175 if not isinstance(dataset, DocDataset): 176 raise ValueError("Expected DocDataset for indexing") 177 return self._get_index_dir(pl_module, dataset) 178 179 def _get_indexer(self, pl_module: BiEncoderModule, dataloader_idx: int) -> Indexer: 180 dataset = pl_module.get_dataset(dataloader_idx) 181 if dataset is None: 182 raise ValueError("No dataset found to index") 183 if not isinstance(dataset, DocDataset): 184 raise ValueError("Expected DocDataset for indexing") 185 index_dir = self._get_save_path(pl_module, dataset) 186 187 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module, self.verbose) 188 return indexer 189 190 def _log_to_pg(self, info: Dict[str, Any], trainer: Trainer): 191 pg_callback = trainer.progress_bar_callback 192 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar): 193 return 194 pg = pg_callback.test_progress_bar 195 info = {k: _format_large_number(v) for k, v in info.items()} 196 if pg is not None: 197 pg.set_postfix(info) 198
[docs] 199 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 200 """Hook to test datasets are configured correctly. 201 202 Args: 203 trainer (Trainer): PyTorch Lightning Trainer. 204 pl_module (BiEncoderModule): LightningIR bi-encoder module. 205 Raises: 206 ValueError: If no test_dataloaders are found. 207 ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset`. 208 """ 209 dataloaders = trainer.test_dataloaders 210 if dataloaders is None: 211 raise ValueError("No test_dataloaders found") 212 datasets = [dataloader.dataset for dataloader in dataloaders] 213 if not all(isinstance(dataset, (DocDataset, _DummyIterableDataset)) for dataset in datasets): 214 raise ValueError("Expected DocDatasets for indexing")
215
[docs] 216 def on_test_batch_start( 217 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 218 ) -> None: 219 """Hook to setup the indexer between datasets. 220 221 Args: 222 trainer (Trainer): PyTorch Lightning Trainer. 223 pl_module (BiEncoderModule): LightningIR bi-encoder module. 224 batch (Any): Batch of input data. 225 batch_idx (int): Index of batch in the current dataset. 226 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 227 """ 228 if batch_idx == 0: 229 self.indexer = self._get_indexer(pl_module, dataloader_idx) 230 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
231
[docs] 232 def on_test_batch_end( 233 self, 234 trainer: Trainer, 235 pl_module: BiEncoderModule, 236 outputs: BiEncoderOutput, 237 batch: Any, 238 batch_idx: int, 239 dataloader_idx: int = 0, 240 ) -> None: 241 """Hook to pass encoded documents to the indexer 242 243 Args: 244 trainer (Trainer): PyTorch Lightning Trainer. 245 pl_module (BiEncoderModule): LightningIR bi-encoder module. 246 outputs (BiEncoderOutput): Encoded documents. 247 batch (Any): Batch of input data. 248 batch_idx (int): Index of batch in the current dataset. 249 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 250 """ 251 batch = self._gather(pl_module, batch) 252 outputs = self._gather(pl_module, outputs) 253 254 if not trainer.is_global_zero: 255 return 256 257 self.indexer.add(batch, outputs) 258 self._log_to_pg( 259 { 260 "num_docs": self.indexer.num_docs, 261 "num_embeddings": self.indexer.num_embeddings, 262 }, 263 trainer, 264 ) 265 # TODO if dataset length cannot be inferred, num_test_batches is inf and no index is saved 266 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1: 267 assert hasattr(self, "indexer") 268 self.indexer.save()
269 270
[docs] 271class RankCallback(Callback, _GatherMixin, _OverwriteMixin):
[docs] 272 def __init__( 273 self, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False 274 ) -> None: 275 """Callback to write run file of ranked documents to disk. 276 277 Args: 278 save_dir (Path | str | None): Directory to save run files to. If None, run files will be saved in the 279 models' directory. Defaults to None. 280 run_name (str | None): Name of the run file. If None, the dataset's dataset_id or file name will be used. 281 Defaults to None. 282 overwrite (bool): Whether to skip or overwrite already existing run files. Defaults to False. 283 """ 284 super().__init__() 285 self.save_dir = Path(save_dir) if save_dir is not None else None 286 self.run_name = run_name 287 self.overwrite = overwrite 288 self.run_dfs: List[pd.DataFrame] = []
289
[docs] 290 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None: 291 """Hook to setup the callback. 292 293 Args: 294 trainer (Trainer): PyTorch Lightning Trainer. 295 pl_module (LightningIRModule): LightningIR module. 296 stage (str): Stage of the trainer, must be "test". 297 Raises: 298 ValueError: If the stage is not "test". 299 ValueError: If no save_dir is provided and model_name_or_path is not a path (the model is not local). 300 """ 301 if stage != "test": 302 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 303 if self.save_dir is None: 304 default_save_dir = Path(pl_module.config.name_or_path) 305 if default_save_dir.exists(): 306 self.save_dir = default_save_dir / "runs" 307 print(f"Using default save_dir `{self.save_dir}` to save runs") 308 else: 309 raise ValueError("No save_dir provided and model_name_or_path is not a path") 310 self._remove_overwrite_datasets(trainer, pl_module)
311
[docs] 312 def teardown(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 313 """Hook to cleanup the callback. 314 315 Args: 316 trainer (Trainer): PyTorch Lightning Trainer. 317 pl_module (LightningIRModule): LightningIR bi-encoder module used for indexing. 318 stage (str): Stage of the trainer, must be "test". 319 """ 320 self._cleanup(trainer, pl_module)
321 322 def _get_save_path(self, pl_module: LightningIRModule, dataset: IRDataset) -> Path: 323 if self.save_dir is None: 324 raise ValueError("No save_dir found; call setup before using this method") 325 if self.run_name is not None: 326 run_file = self.run_name 327 elif isinstance(dataset, QueryDataset): 328 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 329 elif isinstance(dataset, RunDataset): 330 if dataset.run_path is None: 331 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 332 else: 333 run_file = f"{dataset.run_path.name.split('.')[0]}.run" 334 else: 335 raise ValueError("Expected QueryDataset or RunDataset for ranking") 336 run_file_path = self.save_dir / run_file 337 return run_file_path 338 339 def _rank(self, batch: RankBatch, output: LightningIROutput) -> Tuple[torch.Tensor, List[str], List[int]]: 340 scores = output.scores 341 if scores is None: 342 raise ValueError("Expected output to have scores") 343 doc_ids = batch.doc_ids 344 if doc_ids is None: 345 raise ValueError("Expected batch to have doc_ids") 346 scores = scores.view(-1) 347 num_docs = [len(_doc_ids) for _doc_ids in doc_ids] 348 doc_ids = list(itertools.chain.from_iterable(doc_ids)) 349 if scores.shape[0] != len(doc_ids): 350 raise ValueError("scores and doc_ids must have the same length") 351 return scores, doc_ids, num_docs 352 353 def _write_run_dfs(self, trainer: Trainer, pl_module: LightningIRModule, dataloader_idx: int): 354 if not trainer.is_global_zero or not self.run_dfs: 355 return 356 dataloaders = trainer.test_dataloaders 357 if dataloaders is None: 358 raise ValueError("No test_dataloaders found") 359 dataset = pl_module.get_dataset(dataloader_idx) 360 if dataset is None: 361 raise ValueError("No dataset found to write run file") 362 if not isinstance(dataset, (QueryDataset, RunDataset)): 363 raise ValueError("Expected QueryDataset or RunDataset for ranking") 364 run_file_path = self._get_save_path(pl_module, dataset) 365 run_file_path.parent.mkdir(parents=True, exist_ok=True) 366 run_df = pd.concat(self.run_dfs, ignore_index=True) 367 run_df.to_csv(run_file_path, header=False, index=False, sep="\t", quoting=csv.QUOTE_NONE) 368
[docs] 369 def on_test_batch_end( 370 self, 371 trainer: Trainer, 372 pl_module: LightningIRModule, 373 outputs: LightningIROutput, 374 batch: Any, 375 batch_idx: int, 376 dataloader_idx: int = 0, 377 ) -> None: 378 """Hook to aggregate and write ranking to file. 379 380 Args: 381 trainer (Trainer): PyTorch Lightning Trainer. 382 pl_module (LightningIRModule): LightningIR module. 383 outputs (LightningIROutput): Scored query documents pairs. 384 batch (Any): Batch of input data. 385 batch_idx (int): Index of batch in the current dataset. 386 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 387 Raises: 388 ValueError: If the batch does not have query_ids. 389 """ 390 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) 391 batch = self._gather(pl_module, batch) 392 outputs = self._gather(pl_module, outputs) 393 if not trainer.is_global_zero: 394 return 395 396 query_ids = batch.query_ids 397 if query_ids is None: 398 raise ValueError("Expected batch to have query_ids") 399 scores, doc_ids, num_docs = self._rank(batch, outputs) 400 scores = scores.float().cpu().numpy() 401 402 query_ids = list( 403 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs)) 404 ) 405 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"]) 406 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False]) 407 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int) 408 run_df["q0"] = 0 409 run_df["system"] = pl_module.model.__class__.__name__ 410 run_df = run_df[RUN_HEADER] 411 412 self.run_dfs.append(run_df) 413 414 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1: 415 self._write_run_dfs(trainer, pl_module, dataloader_idx) 416 self.run_dfs = []
417 418
[docs] 419class SearchCallback(RankCallback, _IndexDirMixin):
[docs] 420 def __init__( 421 self, 422 search_config: SearchConfig, 423 index_dir: Path | str | None = None, 424 index_name: str | None = None, 425 save_dir: Path | str | None = None, 426 run_name: str | None = None, 427 overwrite: bool = False, 428 use_gpu: bool = True, 429 ) -> None: 430 """Callback to which uses index to retrieve documents efficiently. 431 432 Args: 433 search_config (SearchConfig): Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher` 434 index_dir (Path | str | None): Directory where indexes are stored. Defaults to None. 435 index_name (str | None): Name of the index. If None, the dataset's dataset_id or file name will be used. 436 Defaults to None. 437 save_dir (Path | str | None): Directory to save run files to. If None, run files will be saved in the 438 model's directory. Defaults to None. 439 run_name (str | None): Name of the run file. If None, the dataset's dataset_id or file name will be used. 440 Defaults to None. 441 overwrite (bool): Whether to skip or overwrite already existing run files. Defaults to False. 442 use_gpu (bool): Toggle to use GPU for retrieval. Defaults to True. 443 """ 444 super().__init__(save_dir=save_dir, run_name=run_name, overwrite=overwrite) 445 self.search_config = search_config 446 self.index_dir = index_dir 447 self.index_name = index_name 448 self.overwrite = overwrite 449 self.use_gpu = use_gpu 450 self.searcher: Searcher
451 452 def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Searcher: 453 dataloaders = trainer.test_dataloaders 454 if dataloaders is None: 455 raise ValueError("No test_dataloaders found") 456 dataset = dataloaders[dataset_idx].dataset 457 458 index_dir = self._get_index_dir(pl_module, dataset) 459 if hasattr(self, "searcher"): 460 if self.searcher.index_dir == index_dir: 461 return self.searcher 462 # free up memory 463 del self.searcher 464 gc.collect() 465 torch.cuda.empty_cache() 466 467 searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu) 468 return searcher 469 470 def _rank( 471 self, batch: SearchBatch | RankBatch, output: LightningIROutput 472 ) -> Tuple[torch.Tensor, List[str], List[int]]: 473 if batch.doc_ids is None: 474 raise ValueError("BiEncoderModule did not return doc_ids when searching") 475 dummy_docs = [[""] * len(ids) for ids in batch.doc_ids] 476 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, batch.doc_ids, batch.qrels) 477 return super()._rank(batch, output) 478
[docs] 479 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 480 """Hook to validate datasets 481 482 Args: 483 trainer (Trainer): PyTorch Lightning Trainer. 484 pl_module (BiEncoderModule): LightningIR bi-encoder module. 485 Raises: 486 ValueError: If no test_dataloaders are found. 487 ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset`. 488 """ 489 dataloaders = trainer.test_dataloaders 490 if dataloaders is None: 491 raise ValueError("No test_dataloaders found") 492 datasets = [dataloader.dataset for dataloader in dataloaders] 493 if not all(isinstance(dataset, (QueryDataset, _DummyIterableDataset)) for dataset in datasets): 494 raise ValueError("Expected QueryDatasets for indexing")
495
[docs] 496 def on_test_batch_start( 497 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 498 ) -> None: 499 """Hook to initialize searcher for new datasets. 500 501 Args: 502 trainer (Trainer): PyTorch Lightning Trainer. 503 pl_module (BiEncoderModule): LightningIR bi-encoder module. 504 batch (Any): Batch of input data. 505 batch_idx (int): Index of the batch in the dataset. 506 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 507 """ 508 if batch_idx == 0: 509 self.searcher = self._get_searcher(trainer, pl_module, dataloader_idx) 510 pl_module.searcher = self.searcher 511 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
512 513
[docs] 514class ReRankCallback(RankCallback): 515 pass
516 517
[docs] 518class RegisterLocalDatasetCallback(Callback): 519
[docs] 520 def __init__( 521 self, 522 dataset_id: str, 523 docs: str | None = None, 524 queries: str | None = None, 525 qrels: str | None = None, 526 docpairs: str | None = None, 527 scoreddocs: str | None = None, 528 qrels_defs: Dict[int, str] | None = None, 529 ): 530 """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using 531 ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported: 532 533 - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries 534 - ``.tsv`` or ``.qrels`` for qrels 535 - ``.tsv`` for training n-tuples 536 - ``.tsv`` or ``.run`` for scored documents / run files 537 538 Args: 539 dataset_id (str): Dataset id. 540 docs (str | None): Path to documents file or valid ir_datasets id from which documents should be taken. 541 Defaults to None. 542 queries (str | None): Path to queries file or valid ir_datasets id from which queries should be taken. 543 Defaults to None. 544 qrels (str | None): Path to qrels file or valid ir_datasets id from which qrels will be taken. 545 Defaults to None. 546 docpairs (str | None): Path to training n-tuple file or valid ir_datasets id from which training tuples 547 will be taken. Defaults to None. 548 scoreddocs (str | None): Path to run file or valid ir_datasets id from which scored documents will be taken. 549 Defaults to None. 550 qrels_defs (Dict[int, str] | None): Optional dictionary describing the relevance levels of the qrels. 551 Defaults to None. 552 """ 553 super().__init__() 554 self.dataset_id = dataset_id 555 self.docs = docs 556 self.queries = queries 557 self.qrels = qrels 558 self.docpairs = docpairs 559 self.scoreddocs = scoreddocs 560 self.qrels_defs = qrels_defs
561
[docs] 562 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None: 563 """Hook that registers dataset. 564 565 Args: 566 trainer (Trainer): PyTorch Lightning Trainer. 567 pl_module (LightningIRModule): Lightning IR module. 568 stage (str): Stage of the trainer. 569 """ 570 register_new_dataset( 571 self.dataset_id, 572 docs=self.docs, 573 queries=self.queries, 574 qrels=self.qrels, 575 docpairs=self.docpairs, 576 scoreddocs=self.scoreddocs, 577 qrels_defs=self.qrels_defs, 578 )