Source code for lightning_ir.data.dataset

  1"""
  2Datasets for Lightning IR that data loading and sampling.
  3
  4This module defines several datasets that handle loading and sampling data for training and inference.
  5"""
  6
  7import csv
  8import warnings
  9from itertools import islice
 10from pathlib import Path
 11from typing import Any, Dict, Iterator, Literal, Sequence, Tuple
 12
 13import ir_datasets
 14import numpy as np
 15import pandas as pd
 16import torch
 17from ir_datasets.formats import GenericDoc, GenericDocPair
 18from torch.distributed import get_rank, get_world_size
 19from torch.utils.data import Dataset, IterableDataset, get_worker_info
 20
 21from .data import DocSample, QuerySample, RankSample
 22from .external_datasets.ir_datasets_utils import ScoredDocTuple
 23
 24RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
 25
 26
 27class _DummyIterableDataset(IterableDataset):
 28    """Dummy iterable dataset to use when all inference datasets are skipped."""
 29
 30    def __iter__(self) -> Iterator:
 31        yield from iter([])
 32
 33
[docs] 34class IRDataset: 35 36 _SKIP: bool = False 37 """Set to True to skip the dataset during inference.""" 38
[docs] 39 def __init__(self, dataset: str) -> None: 40 """Initializes a new IRDataset. 41 42 Args: 43 dataset (str): Dataset name. 44 """ 45 super().__init__() 46 self._dataset = dataset 47 self._queries = None 48 self._docs = None 49 self._qrels = None
50 51 @property 52 def dataset(self) -> str: 53 """Dataset name. 54 55 Returns: 56 str: Dataset name. 57 """ 58 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset) 59 60 @property 61 def dataset_id(self) -> str: 62 """Dataset id. 63 64 Returns: 65 str: Dataset id. 66 """ 67 if self.ir_dataset is None: 68 return self.dataset 69 return self.ir_dataset.dataset_id() 70 71 @property 72 def docs_dataset_id(self) -> str: 73 """ID of the dataset containing the documents. 74 75 Returns: 76 str: Document dataset id. 77 """ 78 return ir_datasets.docs_parent_id(self.dataset_id) 79 80 @property 81 def dashed_docs_dataset_id(self) -> str: 82 """Dataset id with dashes instead of slashes for the documents dataset. 83 84 Returns: 85 str: Document dataset id with dashes. 86 """ 87 return self.docs_dataset_id.replace("/", "-") 88 89 @property 90 def ir_dataset(self) -> ir_datasets.Dataset | None: 91 """Instance of ir_datasets.Dataset. 92 93 Returns: 94 ir_datasets.Dataset | None: Instance of ir_datasets.Dataset or None if the dataset is not found. 95 """ 96 try: 97 return ir_datasets.load(self.dataset) 98 except KeyError: 99 return None 100 101 @property 102 def DASHED_DATASET_MAP(self) -> Dict[str, str]: 103 """Map of dataset names with dashes to dataset names with slashes. 104 105 Returns: 106 Dict[str, str]: Dataset map. 107 """ 108 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered} 109 110 @property 111 def queries(self) -> pd.Series: 112 """Queries in the dataset. 113 114 Returns: 115 pd.Series: Queries. 116 Raises: 117 ValueError: If no queries are found in the dataset. 118 """ 119 if self._queries is None: 120 if self.ir_dataset is None: 121 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 122 queries_iter = self.ir_dataset.queries_iter() 123 self._queries = pd.Series( 124 {query.query_id: query.default_text() for query in queries_iter}, 125 name="text", 126 ) 127 self._queries.index.name = "query_id" 128 return self._queries 129 130 @property 131 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]: 132 """Documents in the dataset. 133 134 Returns: 135 ir_datasets.indices.Docstore | Dict[str, GenericDoc]: Documents. 136 Raises: 137 ValueError: If no documents are found in the dataset. 138 """ 139 if self._docs is None: 140 if self.ir_dataset is None: 141 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 142 self._docs = self.ir_dataset.docs_store() 143 return self._docs 144 145 @property 146 def qrels(self) -> pd.DataFrame | None: 147 """Qrels in the dataset. 148 149 Returns: 150 pd.DataFrame | None: Qrels. 151 """ 152 if self._qrels is not None: 153 return self._qrels 154 if self.ir_dataset is None or not self.ir_dataset.has_qrels(): 155 return None 156 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1) 157 if "iteration" not in qrels.columns: 158 qrels["iteration"] = 0 159 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 160 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 161 self._qrels = qrels 162 return self._qrels 163
[docs] 164 def prepare_constituent(self, constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]) -> None: 165 """Downloads the constituent of the dataset using ir_datasets if needed. 166 167 Args: 168 constituent (Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]): Constituent to download. 169 """ 170 if self.ir_dataset is None: 171 return 172 if self.ir_dataset.has(constituent): 173 if constituent == "docs" and hasattr(self.ir_dataset, "docs_store"): 174 # build docs store if not already built 175 docs_store = self.ir_dataset.docs_store() 176 if not docs_store.built(): 177 docs_store.build() 178 else: 179 # get first item to trigger download 180 next(getattr(self.ir_dataset, f"{constituent}_iter")())
181 182 183class _DataParallelIterableDataset(IterableDataset): 184 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734 185 def __init__(self) -> None: 186 super().__init__() 187 # TODO add support for multi-gpu and multi-worker inference; currently 188 # doesn't work 189 worker_info = get_worker_info() 190 num_workers = worker_info.num_workers if worker_info is not None else 1 191 worker_id = worker_info.id if worker_info is not None else 0 192 193 try: 194 world_size = get_world_size() 195 process_rank = get_rank() 196 except (RuntimeError, ValueError): 197 world_size = 1 198 process_rank = 0 199 200 self.num_replicas = num_workers * world_size 201 self.rank = process_rank * num_workers + worker_id 202 203
[docs] 204class QueryDataset(IRDataset, _DataParallelIterableDataset):
[docs] 205 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None: 206 """Dataset containing queries. 207 208 Args: 209 query_dataset (str): Path to file containing queries or valid ir_datasets id. 210 num_queries (int | None, optional): Number of queries in dataset. If None, the number of queries will 211 attempted to be inferred. Defaults to None. 212 """ 213 super().__init__(query_dataset) 214 super(IRDataset, self).__init__() 215 self.num_queries = num_queries
216 217 def __len__(self) -> int | None: 218 """Number of queries in the dataset. Returns None if the number of queries cannot be inferred. 219 220 Returns: 221 int | None: Number of queries. 222 """ 223 # TODO fix len for multi-gpu and multi-worker inference 224 return self.num_queries or getattr(self.ir_dataset, "queries_count", lambda: None)() or None 225 226 def __iter__(self) -> Iterator[QuerySample]: 227 """Iterate over queries in the dataset. 228 229 Yields: 230 QuerySample: Query sample. 231 """ 232 start = self.rank 233 stop = self.num_queries 234 step = self.num_replicas 235 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step): 236 query_sample = QuerySample.from_ir_dataset_sample(sample) 237 if self.qrels is not None: 238 qrels = ( 239 self.qrels.loc[[query_sample.query_id]] 240 .stack(future_stack=True) 241 .dropna() 242 .astype(int) 243 .reset_index() 244 .to_dict(orient="records") 245 ) 246 query_sample.qrels = qrels 247 yield query_sample 248
[docs] 249 def prepare_data(self) -> None: 250 """Downloads queries using ir_datasets if needed.""" 251 self.prepare_constituent("queries")
252 253
[docs] 254class DocDataset(IRDataset, _DataParallelIterableDataset):
[docs] 255 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None: 256 """Dataset containing documents. 257 258 Args: 259 doc_dataset (str): Path to file containing documents or valid ir_datasets id. 260 num_docs (int | None, optional): Number of documents in dataset. If None, the number of documents will 261 attempted to be inferred. Defaults to None. 262 text_fields (Sequence[str] | None, optional): Fields to parse the document text from. Defaults to None. 263 """ 264 super().__init__(doc_dataset) 265 super(IRDataset, self).__init__() 266 self.num_docs = num_docs 267 self.text_fields = text_fields
268 269 def __len__(self) -> int | None: 270 """Number of documents in the dataset. Returns None if the number of documents cannot be inferred. 271 272 Returns: 273 int | None: Number of documents. 274 """ 275 # TODO fix len for multi-gpu and multi-worker inference 276 return self.num_docs or getattr(self.ir_dataset, "docs_count", lambda: None)() or None 277 278 def __iter__(self) -> Iterator[DocSample]: 279 """Iterate over documents in the dataset. 280 281 Yields: 282 DocSample: Document sample. 283 """ 284 start = self.rank 285 stop = self.num_docs 286 step = self.num_replicas 287 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step): 288 yield DocSample.from_ir_dataset_sample(sample, self.text_fields) 289
[docs] 290 def prepare_data(self) -> None: 291 """Downloads documents using ir_datasets if needed.""" 292 self.prepare_constituent("docs")
293 294
[docs] 295class Sampler: 296 """Helper class for sampling subsets of documents from a ranked list.""" 297
[docs] 298 @staticmethod 299 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 300 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 301 are non-relevant. 302 303 Args: 304 documents (pd.DataFrame): Ranked list of documents. 305 sample_size (int): Number of documents to sample. 306 Returns: 307 pd.DataFrame: Sampled documents. 308 """ 309 relevance = documents.filter(like="relevance").max(axis=1).fillna(0) 310 relevant = documents.loc[relevance.gt(0)].sample(1) 311 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna() 312 num_non_relevant = non_relevant_bool.sum() 313 sample_non_relevant = min(sample_size - 1, num_non_relevant) 314 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant) 315 return pd.concat([relevant, non_relevant])
316
[docs] 317 @staticmethod 318 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 319 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 320 are non-relevant. 321 322 Args: 323 documents (pd.DataFrame): Ranked list of documents. 324 sample_size (int): Number of documents to sample. 325 Returns: 326 pd.DataFrame: Sampled documents. 327 """ 328 return documents.head(sample_size)
329
[docs] 330 @staticmethod 331 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 332 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the 333 other half randomly. 334 335 Args: 336 documents (pd.DataFrame): Ranked list of documents. 337 sample_size (int): Number of documents to sample. 338 Returns: 339 pd.DataFrame: Sampled documents. 340 """ 341 top_size = sample_size // 2 342 random_size = sample_size - top_size 343 top = documents.head(top_size) 344 random = documents.iloc[top_size:].sample(random_size) 345 return pd.concat([top, random])
346
[docs] 347 @staticmethod 348 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 349 """Sampling strategy to randomly sample ``sample_size`` documents. 350 351 Args: 352 documents (pd.DataFrame): Ranked list of documents. 353 sample_size (int): Number of documents to sample. 354 Returns: 355 pd.DataFrame: Sampled documents. 356 """ 357 return documents.sample(sample_size)
358
[docs] 359 @staticmethod 360 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 361 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of 362 the ranking. 363 364 Args: 365 documents (pd.DataFrame): Ranked list of documents. 366 sample_size (int): Number of documents to sample. 367 Returns: 368 pd.DataFrame: Sampled documents. 369 """ 370 weights = 1 / np.log1p(documents["rank"]) 371 weights[weights.isna()] = weights.min() 372 return documents.sample(sample_size, weights=weights)
373
[docs] 374 @staticmethod 375 def sample( 376 df: pd.DataFrame, 377 sample_size: int, 378 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"], 379 ) -> pd.DataFrame: 380 """ 381 Samples a subset of documents from a ranked list given a sampling_strategy. 382 383 Args: 384 documents (pd.DataFrame): Ranked list of documents. 385 sample_size (int): Number of documents to sample. 386 Returns: 387 pd.DataFrame: Sampled documents. 388 """ 389 if sample_size == -1: 390 return df 391 if hasattr(Sampler, sampling_strategy): 392 return getattr(Sampler, sampling_strategy)(df, sample_size) 393 raise ValueError("Invalid sampling strategy.")
394 395
[docs] 396class RunDataset(IRDataset, Dataset):
[docs] 397 def __init__( 398 self, 399 run_path_or_id: Path | str, 400 depth: int = -1, 401 sample_size: int = -1, 402 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top", 403 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None, 404 normalize_targets: bool = False, 405 add_docs_not_in_ranking: bool = False, 406 ) -> None: 407 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list 408 can be sampled using different sampling strategies. 409 410 Args: 411 run_path_or_id (Path | str): Path to a run file or valid ir_datasets id. 412 depth (int): Depth at which to cut off the ranking. If -1 the full ranking is kept. 413 Defaults to -1. 414 sample_size (int): The number of documents to sample per query. Defaults to -1. 415 sampling_strategy (Literal["single_relevant", "top", "random", "log_random", "top_and_random"]): 416 The sample strategy to use to sample documents. Defaults to "top". 417 targets (Literal["relevance", "subtopic_relevance", "rank", "score"] | None): 418 The data type to use as targets for a model during fine-tuning. If "relevance" the relevance 419 judgements are parsed from qrels. Defaults to None. 420 normalize_targets (bool): Whether to normalize the targets between 0 and 1. Defaults to False. 421 add_docs_not_in_ranking (bool): Whether to add relevant documents to a sample that are in the qrels but not 422 in the ranking. Defaults to False. 423 """ 424 self.run_path = None 425 if Path(run_path_or_id).is_file(): 426 self.run_path = Path(run_path_or_id) 427 dataset = self.run_path.name.split(".")[0].split("__")[-1] 428 else: 429 dataset = str(run_path_or_id) 430 super().__init__(dataset) 431 self.depth = depth 432 self.sample_size = sample_size 433 self.sampling_strategy = sampling_strategy 434 self.targets = targets 435 self.normalize_targets = normalize_targets 436 self.add_docs_not_in_ranking = add_docs_not_in_ranking 437 438 if self.sampling_strategy == "top" and self.sample_size > self.depth: 439 warnings.warn( 440 "Sample size is greater than depth and top sampling strategy is used. " 441 "This can cause documents to be sampled that are not contained " 442 "in the run file, but that are present in the qrels." 443 ) 444 445 self.run: pd.DataFrame | None = None
446
[docs] 447 def prepare_data(self) -> None: 448 """Downloads docs, queries, scoreddocs, and qrels using ir_datasets if needed and available.""" 449 self.prepare_constituent("docs") 450 self.prepare_constituent("queries") 451 if self.run_path is None: 452 self.prepare_constituent("scoreddocs") 453 self.prepare_constituent("qrels")
454 455 def _setup(self): 456 if self.run is not None: 457 return 458 self.run = self._load_run() 459 self.run = self.run.drop_duplicates(["query_id", "doc_id"]) 460 461 if self.qrels is not None: 462 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates()) 463 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique() 464 query_ids = run_query_ids.intersection(qrels_query_ids) 465 if len(run_query_ids.difference(qrels_query_ids)): 466 self.run = self.run[self.run["query_id"].isin(query_ids)] 467 # outer join if docs are from ir_datasets else only keep docs in run 468 how = "left" 469 if self._docs is None and self.add_docs_not_in_ranking: 470 how = "outer" 471 self.run = self.run.merge( 472 self.qrels.loc[pd.IndexSlice[query_ids, :]].droplevel(0, axis=1).add_prefix("relevance_", axis=1), 473 on=["query_id", "doc_id"], 474 how=how, 475 ) 476 477 if self.sample_size != -1: 478 num_docs_per_query = self.run.groupby("query_id").transform("size") 479 self.run = self.run[num_docs_per_query >= self.sample_size] 480 481 self.run = self.run.sort_values(["query_id", "rank"]) 482 self.run_groups = self.run.groupby("query_id") 483 self.query_ids = list(self.run_groups.groups.keys()) 484 485 if self.depth != -1 and self.run["rank"].max() < self.depth: 486 warnings.warn("Depth is greater than the maximum rank in the run file.") 487 488 @staticmethod 489 def _load_csv(path: Path) -> pd.DataFrame: 490 return pd.read_csv( 491 path, 492 sep=r"\s+", 493 header=None, 494 names=RUN_HEADER, 495 usecols=[0, 1, 2, 3, 4], 496 dtype={"query_id": str, "doc_id": str}, 497 quoting=csv.QUOTE_NONE, 498 na_filter=False, 499 ) 500 501 @staticmethod 502 def _load_parquet(path: Path) -> pd.DataFrame: 503 return pd.read_parquet(path) 504 505 @staticmethod 506 def _load_json(path: Path) -> pd.DataFrame: 507 kwargs: Dict[str, Any] = {} 508 if ".jsonl" in path.suffixes: 509 kwargs["lines"] = True 510 kwargs["orient"] = "records" 511 run = pd.read_json(path, **kwargs) 512 return run 513 514 def _get_run_path(self) -> Path | None: 515 run_path = self.run_path 516 if run_path is None: 517 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs(): 518 raise ValueError(f"Run file or dataset with scoreddocs required. Got {self._dataset}") 519 try: 520 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path() 521 except NotImplementedError: 522 pass 523 return run_path 524 525 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame: 526 run = run.rename( 527 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id", "Q0": "iteration", "q0": "iteration"}, 528 axis=1, 529 ) 530 dtypes = {"rank": np.int32, "query_id": str, "doc_id": str} 531 if "score" in run.columns: 532 dtypes["score"] = np.float32 533 run = run.astype(dtypes) 534 if "query" in run.columns: 535 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text") 536 run = run.drop("query", axis=1) 537 if "text" in run.columns: 538 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict() 539 run = run.drop("text", axis=1) 540 if self.depth != -1: 541 run = run[run["rank"] <= self.depth] 542 return run 543 544 def _load_run(self) -> pd.DataFrame: 545 546 suffix_load_map = { 547 ".tsv": self._load_csv, 548 ".run": self._load_csv, 549 ".csv": self._load_csv, 550 ".parquet": self._load_parquet, 551 ".json": self._load_json, 552 ".jsonl": self._load_json, 553 } 554 run = None 555 556 # try loading run from file 557 run_path = self._get_run_path() 558 if run_path is not None: 559 load_func = suffix_load_map.get(run_path.suffixes[0], None) 560 if load_func is not None: 561 try: 562 run = load_func(run_path) 563 except Exception: 564 pass 565 566 # try loading run from ir_datasets 567 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs(): 568 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter()) 569 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False) 570 run = run.sort_values(["query_id", "rank"]) 571 572 if run is None: 573 raise ValueError("Invalid run file format.") 574 575 run = self._clean_run(run) 576 return run 577 578 @property 579 def qrels(self) -> pd.DataFrame | None: 580 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None. 581 582 Returns: 583 pd.DataFrame | None: Qrels. 584 """ 585 if self._qrels is not None: 586 return self._qrels 587 if self.run is not None and "relevance" in self.run: 588 qrels = self.run[["query_id", "doc_id", "relevance"]].copy() 589 if "iteration" in self.run: 590 qrels["iteration"] = self.run["iteration"] 591 else: 592 qrels["iteration"] = "0" 593 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore") 594 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 595 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 596 self._qrels = qrels 597 return self._qrels 598 return super().qrels 599 600 def __len__(self) -> int: 601 """Number of queries in the dataset. 602 603 Returns: 604 int: Number of queries. 605 """ 606 self._setup() 607 return len(self.query_ids) 608 609 def __getitem__(self, idx: int) -> RankSample: 610 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according 611 to the sampling strategy and sample size. 612 613 Args: 614 idx (int): Index of the query. 615 Returns: 616 RankSample: Sampled query and documents. 617 Raises: 618 ValueError: If the targets are not found in the run file. 619 """ 620 self._setup() 621 query_id = str(self.query_ids[idx]) 622 group = self.run_groups.get_group(query_id).copy() 623 query = self.queries[query_id] 624 group = Sampler.sample(group, self.sample_size, self.sampling_strategy) 625 626 doc_ids = tuple(group["doc_id"]) 627 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 628 629 targets = None 630 if self.targets is not None: 631 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0) 632 if filtered.empty: 633 raise ValueError(f"targets `{self.targets}` not found in run file") 634 targets = torch.from_numpy(filtered.values) 635 if self.targets == "rank": 636 # invert ranks to be higher is better (necessary for loss functions) 637 targets = self.depth - targets + 1 638 if self.normalize_targets: 639 targets_min = targets.min() 640 targets_max = targets.max() 641 targets = (targets - targets_min) / (targets_max - targets_min) 642 qrels = None 643 if self.qrels is not None: 644 qrels = ( 645 self.qrels.loc[[query_id]] 646 .stack(future_stack=True) 647 .dropna() 648 .astype(int) 649 .reset_index() 650 .to_dict(orient="records") 651 ) 652 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
653 654
[docs] 655class TupleDataset(IRDataset, IterableDataset):
[docs] 656 def __init__( 657 self, 658 tuples_dataset: str, 659 targets: Literal["order", "score"] = "order", 660 num_docs: int | None = None, 661 ) -> None: 662 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks. 663 664 Args: 665 tuples_dataset (str): Path to file containing tuples or valid ir_datasets id. 666 targets (Literal["order", "score"], optional): Data type to use as targets for a model during fine-tuning. 667 Defaults to "order". 668 num_docs (int | None, optional): Maximum number of documents per query. Defaults to None. 669 """ 670 super().__init__(tuples_dataset) 671 super(IRDataset, self).__init__() 672 self.targets = targets 673 self.num_docs = num_docs
674 675 def _parse_sample( 676 self, sample: ScoredDocTuple | GenericDocPair 677 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]: 678 if isinstance(sample, GenericDocPair): 679 if self.targets == "score": 680 raise ValueError("ScoredDocTuple required for score targets.") 681 targets = (1.0, 0.0) 682 doc_ids = (sample.doc_id_a, sample.doc_id_b) 683 elif isinstance(sample, ScoredDocTuple): 684 doc_ids = sample.doc_ids[: self.num_docs] 685 if self.targets == "score": 686 if sample.scores is None: 687 raise ValueError("tuples dataset does not contain scores") 688 targets = sample.scores 689 elif self.targets == "order": 690 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1)) 691 else: 692 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)") 693 targets = targets[: self.num_docs] 694 else: 695 raise ValueError("Invalid sample type.") 696 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 697 return doc_ids, docs, targets 698 699 def __iter__(self) -> Iterator[RankSample]: 700 """Iterates over tuples in the dataset. 701 702 Yields: 703 RankSample: Sampled query and documents with targets. 704 """ 705 for sample in self.ir_dataset.docpairs_iter(): 706 query_id = sample.query_id 707 query = self.queries.loc[query_id] 708 doc_ids, docs, targets = self._parse_sample(sample) 709 if targets is not None: 710 targets = torch.tensor(targets) 711 yield RankSample(query_id, query, doc_ids, docs, targets) 712
[docs] 713 def prepare_data(self) -> None: 714 """Downloads tuples using ir_datasets if needed.""" 715 self.prepare_constituent("docs") 716 self.prepare_constituent("queries") 717 self.prepare_constituent("docpairs")