Source code for lightning_ir.data.dataset

  1"""
  2Datasets for Lightning IR that data loading and sampling.
  3
  4This module defines several datasets that handle loading and sampling data for training and inference.
  5"""
  6
  7import csv
  8import warnings
  9from itertools import islice
 10from pathlib import Path
 11from typing import Any, Dict, Iterator, Literal, Sequence, Tuple
 12
 13import ir_datasets
 14import numpy as np
 15import pandas as pd
 16import torch
 17from ir_datasets.formats import GenericDoc, GenericDocPair
 18from torch.distributed import get_rank, get_world_size
 19from torch.utils.data import Dataset, IterableDataset, get_worker_info
 20
 21from .data import DocSample, QuerySample, RankSample
 22from .external_datasets.ir_datasets_utils import ScoredDocTuple
 23
 24RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
 25
 26
 27class _DummyIterableDataset(IterableDataset):
 28    """Dummy iterable dataset to use when all inference datasets are skipped."""
 29
 30    def __iter__(self) -> Iterator:
 31        yield from iter([])
 32
 33
[docs] 34class IRDataset: 35 36 _SKIP: bool = False 37 """Set to True to skip the dataset during inference.""" 38
[docs] 39 def __init__(self, dataset: str) -> None: 40 """Initializes a new IRDataset. 41 42 Args: 43 dataset (str): Dataset name. 44 """ 45 super().__init__() 46 self._dataset = dataset 47 self._queries = None 48 self._docs = None 49 self._qrels = None
50 51 @property 52 def dataset(self) -> str: 53 """Dataset name. 54 55 Returns: 56 str: Dataset name. 57 """ 58 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset) 59 60 @property 61 def dataset_id(self) -> str: 62 """Dataset id. 63 64 Returns: 65 str: Dataset id. 66 """ 67 if self.ir_dataset is None: 68 return self.dataset 69 return self.ir_dataset.dataset_id() 70 71 @property 72 def docs_dataset_id(self) -> str: 73 """ID of the dataset containing the documents. 74 75 Returns: 76 str: Document dataset id. 77 """ 78 return ir_datasets.docs_parent_id(self.dataset_id) 79 80 @property 81 def ir_dataset(self) -> ir_datasets.Dataset | None: 82 """Instance of ir_datasets.Dataset. 83 84 Returns: 85 ir_datasets.Dataset | None: Instance of ir_datasets.Dataset or None if the dataset is not found. 86 """ 87 try: 88 return ir_datasets.load(self.dataset) 89 except KeyError: 90 return None 91 92 @property 93 def DASHED_DATASET_MAP(self) -> Dict[str, str]: 94 """Map of dataset names with dashes to dataset names with slashes. 95 96 Returns: 97 Dict[str, str]: Dataset map. 98 """ 99 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered} 100 101 @property 102 def queries(self) -> pd.Series: 103 """Queries in the dataset. 104 105 Returns: 106 pd.Series: Queries. 107 Raises: 108 ValueError: If no queries are found in the dataset. 109 """ 110 if self._queries is None: 111 if self.ir_dataset is None: 112 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 113 queries_iter = self.ir_dataset.queries_iter() 114 self._queries = pd.Series( 115 {query.query_id: query.default_text() for query in queries_iter}, 116 name="text", 117 ) 118 self._queries.index.name = "query_id" 119 return self._queries 120 121 @property 122 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]: 123 """Documents in the dataset. 124 125 Returns: 126 ir_datasets.indices.Docstore | Dict[str, GenericDoc]: Documents. 127 Raises: 128 ValueError: If no documents are found in the dataset. 129 """ 130 if self._docs is None: 131 if self.ir_dataset is None: 132 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 133 self._docs = self.ir_dataset.docs_store() 134 return self._docs 135 136 @property 137 def qrels(self) -> pd.DataFrame | None: 138 """Qrels in the dataset. 139 140 Returns: 141 pd.DataFrame | None: Qrels. 142 """ 143 if self._qrels is not None: 144 return self._qrels 145 if self.ir_dataset is None or not self.ir_dataset.has_qrels(): 146 return None 147 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1) 148 if "iteration" not in qrels.columns: 149 qrels["iteration"] = 0 150 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 151 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 152 self._qrels = qrels 153 return self._qrels 154
[docs] 155 def prepare_constituent(self, constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]) -> None: 156 """Downloads the constituent of the dataset using ir_datasets if needed. 157 158 Args: 159 constituent (Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]): Constituent to download. 160 """ 161 if self.ir_dataset is None: 162 return 163 if self.ir_dataset.has(constituent): 164 if constituent == "docs" and hasattr(self.ir_dataset, "docs_store"): 165 # build docs store if not already built 166 docs_store = self.ir_dataset.docs_store() 167 if not docs_store.built(): 168 docs_store.build() 169 else: 170 # get first item to trigger download 171 next(getattr(self.ir_dataset, f"{constituent}_iter")())
172 173 174class _DataParallelIterableDataset(IterableDataset): 175 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734 176 def __init__(self) -> None: 177 super().__init__() 178 # TODO add support for multi-gpu and multi-worker inference; currently 179 # doesn't work 180 worker_info = get_worker_info() 181 num_workers = worker_info.num_workers if worker_info is not None else 1 182 worker_id = worker_info.id if worker_info is not None else 0 183 184 try: 185 world_size = get_world_size() 186 process_rank = get_rank() 187 except (RuntimeError, ValueError): 188 world_size = 1 189 process_rank = 0 190 191 self.num_replicas = num_workers * world_size 192 self.rank = process_rank * num_workers + worker_id 193 194
[docs] 195class QueryDataset(IRDataset, _DataParallelIterableDataset):
[docs] 196 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None: 197 """Dataset containing queries. 198 199 Args: 200 query_dataset (str): Path to file containing queries or valid ir_datasets id. 201 num_queries (int | None, optional): Number of queries in dataset. If None, the number of queries will 202 attempted to be inferred. Defaults to None. 203 """ 204 super().__init__(query_dataset) 205 super(IRDataset, self).__init__() 206 self.num_queries = num_queries
207 208 def __len__(self) -> int | None: 209 """Number of queries in the dataset. Returns None if the number of queries cannot be inferred. 210 211 Returns: 212 int | None: Number of queries. 213 """ 214 # TODO fix len for multi-gpu and multi-worker inference 215 return self.num_queries or getattr(self.ir_dataset, "queries_count", lambda: None)() or None 216 217 def __iter__(self) -> Iterator[QuerySample]: 218 """Iterate over queries in the dataset. 219 220 Yields: 221 QuerySample: Query sample. 222 """ 223 start = self.rank 224 stop = self.num_queries 225 step = self.num_replicas 226 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step): 227 query_sample = QuerySample.from_ir_dataset_sample(sample) 228 if self.qrels is not None: 229 qrels = ( 230 self.qrels.loc[[query_sample.query_id]] 231 .stack(future_stack=True) 232 .dropna() 233 .astype(int) 234 .reset_index() 235 .to_dict(orient="records") 236 ) 237 query_sample.qrels = qrels 238 yield query_sample 239
[docs] 240 def prepare_data(self) -> None: 241 """Downloads queries using ir_datasets if needed.""" 242 self.prepare_constituent("queries")
243 244
[docs] 245class DocDataset(IRDataset, _DataParallelIterableDataset):
[docs] 246 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None: 247 """Dataset containing documents. 248 249 Args: 250 doc_dataset (str): Path to file containing documents or valid ir_datasets id. 251 num_docs (int | None, optional): Number of documents in dataset. If None, the number of documents will 252 attempted to be inferred. Defaults to None. 253 text_fields (Sequence[str] | None, optional): Fields to parse the document text from. Defaults to None. 254 """ 255 super().__init__(doc_dataset) 256 super(IRDataset, self).__init__() 257 self.num_docs = num_docs 258 self.text_fields = text_fields
259 260 def __len__(self) -> int | None: 261 """Number of documents in the dataset. Returns None if the number of documents cannot be inferred. 262 263 Returns: 264 int | None: Number of documents. 265 """ 266 # TODO fix len for multi-gpu and multi-worker inference 267 return self.num_docs or getattr(self.ir_dataset, "docs_count", lambda: None)() or None 268 269 def __iter__(self) -> Iterator[DocSample]: 270 """Iterate over documents in the dataset. 271 272 Yields: 273 DocSample: Document sample. 274 """ 275 start = self.rank 276 stop = self.num_docs 277 step = self.num_replicas 278 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step): 279 yield DocSample.from_ir_dataset_sample(sample, self.text_fields) 280
[docs] 281 def prepare_data(self) -> None: 282 """Downloads documents using ir_datasets if needed.""" 283 self.prepare_constituent("docs")
284 285
[docs] 286class Sampler: 287 """Helper class for sampling subsets of documents from a ranked list.""" 288
[docs] 289 @staticmethod 290 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 291 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 292 are non-relevant. 293 294 Args: 295 documents (pd.DataFrame): Ranked list of documents. 296 sample_size (int): Number of documents to sample. 297 Returns: 298 pd.DataFrame: Sampled documents. 299 """ 300 relevance = documents.filter(like="relevance").max(axis=1).fillna(0) 301 relevant = documents.loc[relevance.gt(0)].sample(1) 302 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna() 303 num_non_relevant = non_relevant_bool.sum() 304 sample_non_relevant = min(sample_size - 1, num_non_relevant) 305 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant) 306 return pd.concat([relevant, non_relevant])
307
[docs] 308 @staticmethod 309 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 310 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 311 are non-relevant. 312 313 Args: 314 documents (pd.DataFrame): Ranked list of documents. 315 sample_size (int): Number of documents to sample. 316 Returns: 317 pd.DataFrame: Sampled documents. 318 """ 319 return documents.head(sample_size)
320
[docs] 321 @staticmethod 322 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 323 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the 324 other half randomly. 325 326 Args: 327 documents (pd.DataFrame): Ranked list of documents. 328 sample_size (int): Number of documents to sample. 329 Returns: 330 pd.DataFrame: Sampled documents. 331 """ 332 top_size = sample_size // 2 333 random_size = sample_size - top_size 334 top = documents.head(top_size) 335 random = documents.iloc[top_size:].sample(random_size) 336 return pd.concat([top, random])
337
[docs] 338 @staticmethod 339 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 340 """Sampling strategy to randomly sample ``sample_size`` documents. 341 342 Args: 343 documents (pd.DataFrame): Ranked list of documents. 344 sample_size (int): Number of documents to sample. 345 Returns: 346 pd.DataFrame: Sampled documents. 347 """ 348 return documents.sample(sample_size)
349
[docs] 350 @staticmethod 351 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 352 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of 353 the ranking. 354 355 Args: 356 documents (pd.DataFrame): Ranked list of documents. 357 sample_size (int): Number of documents to sample. 358 Returns: 359 pd.DataFrame: Sampled documents. 360 """ 361 weights = 1 / np.log1p(documents["rank"]) 362 weights[weights.isna()] = weights.min() 363 return documents.sample(sample_size, weights=weights)
364
[docs] 365 @staticmethod 366 def sample( 367 df: pd.DataFrame, 368 sample_size: int, 369 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"], 370 ) -> pd.DataFrame: 371 """ 372 Samples a subset of documents from a ranked list given a sampling_strategy. 373 374 Args: 375 documents (pd.DataFrame): Ranked list of documents. 376 sample_size (int): Number of documents to sample. 377 Returns: 378 pd.DataFrame: Sampled documents. 379 """ 380 if sample_size == -1: 381 return df 382 if hasattr(Sampler, sampling_strategy): 383 return getattr(Sampler, sampling_strategy)(df, sample_size) 384 raise ValueError("Invalid sampling strategy.")
385 386
[docs] 387class RunDataset(IRDataset, Dataset):
[docs] 388 def __init__( 389 self, 390 run_path_or_id: Path | str, 391 depth: int = -1, 392 sample_size: int = -1, 393 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top", 394 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None, 395 normalize_targets: bool = False, 396 add_docs_not_in_ranking: bool = False, 397 ) -> None: 398 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list 399 can be sampled using different sampling strategies. 400 401 Args: 402 run_path_or_id (Path | str): Path to a run file or valid ir_datasets id. 403 depth (int): Depth at which to cut off the ranking. If -1 the full ranking is kept. 404 Defaults to -1. 405 sample_size (int): The number of documents to sample per query. Defaults to -1. 406 sampling_strategy (Literal["single_relevant", "top", "random", "log_random", "top_and_random"]): 407 The sample strategy to use to sample documents. Defaults to "top". 408 targets (Literal["relevance", "subtopic_relevance", "rank", "score"] | None): 409 The data type to use as targets for a model during fine-tuning. If "relevance" the relevance 410 judgements are parsed from qrels. Defaults to None. 411 normalize_targets (bool): Whether to normalize the targets between 0 and 1. Defaults to False. 412 add_docs_not_in_ranking (bool): Whether to add relevant documents to a sample that are in the qrels but not 413 in the ranking. Defaults to False. 414 """ 415 self.run_path = None 416 if Path(run_path_or_id).is_file(): 417 self.run_path = Path(run_path_or_id) 418 dataset = self.run_path.name.split(".")[0].split("__")[-1] 419 else: 420 dataset = str(run_path_or_id) 421 super().__init__(dataset) 422 self.depth = depth 423 self.sample_size = sample_size 424 self.sampling_strategy = sampling_strategy 425 self.targets = targets 426 self.normalize_targets = normalize_targets 427 self.add_docs_not_in_ranking = add_docs_not_in_ranking 428 429 if self.sampling_strategy == "top" and self.sample_size > self.depth: 430 warnings.warn( 431 "Sample size is greater than depth and top sampling strategy is used. " 432 "This can cause documents to be sampled that are not contained " 433 "in the run file, but that are present in the qrels." 434 ) 435 436 self.run: pd.DataFrame | None = None
437
[docs] 438 def prepare_data(self) -> None: 439 """Downloads docs, queries, scoreddocs, and qrels using ir_datasets if needed and available.""" 440 self.prepare_constituent("docs") 441 self.prepare_constituent("queries") 442 if self.run_path is None: 443 self.prepare_constituent("scoreddocs") 444 self.prepare_constituent("qrels")
445 446 def _setup(self): 447 if self.run is not None: 448 return 449 self.run = self._load_run() 450 self.run = self.run.drop_duplicates(["query_id", "doc_id"]) 451 452 if self.qrels is not None: 453 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates()) 454 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique() 455 query_ids = run_query_ids.intersection(qrels_query_ids) 456 if len(run_query_ids.difference(qrels_query_ids)): 457 self.run = self.run[self.run["query_id"].isin(query_ids)] 458 # outer join if docs are from ir_datasets else only keep docs in run 459 how = "left" 460 if self._docs is None and self.add_docs_not_in_ranking: 461 how = "outer" 462 self.run = self.run.merge( 463 self.qrels.loc[pd.IndexSlice[query_ids, :]].droplevel(0, axis=1).add_prefix("relevance_", axis=1), 464 on=["query_id", "doc_id"], 465 how=how, 466 ) 467 468 if self.sample_size != -1: 469 num_docs_per_query = self.run.groupby("query_id").transform("size") 470 self.run = self.run[num_docs_per_query >= self.sample_size] 471 472 self.run = self.run.sort_values(["query_id", "rank"]) 473 self.run_groups = self.run.groupby("query_id") 474 self.query_ids = list(self.run_groups.groups.keys()) 475 476 if self.depth != -1 and self.run["rank"].max() < self.depth: 477 warnings.warn("Depth is greater than the maximum rank in the run file.") 478 479 @staticmethod 480 def _load_csv(path: Path) -> pd.DataFrame: 481 return pd.read_csv( 482 path, 483 sep=r"\s+", 484 header=None, 485 names=RUN_HEADER, 486 usecols=[0, 1, 2, 3, 4], 487 dtype={"query_id": str, "doc_id": str}, 488 quoting=csv.QUOTE_NONE, 489 na_filter=False, 490 ) 491 492 @staticmethod 493 def _load_parquet(path: Path) -> pd.DataFrame: 494 return pd.read_parquet(path).rename( 495 { 496 "qid": "query_id", 497 "docid": "doc_id", 498 "docno": "doc_id", 499 }, 500 axis=1, 501 ) 502 503 @staticmethod 504 def _load_json(path: Path) -> pd.DataFrame: 505 kwargs: Dict[str, Any] = {} 506 if ".jsonl" in path.suffixes: 507 kwargs["lines"] = True 508 kwargs["orient"] = "records" 509 run = pd.read_json( 510 path, 511 **kwargs, 512 dtype={ 513 "query_id": str, 514 "qid": str, 515 "doc_id": str, 516 "docid": str, 517 "docno": str, 518 }, 519 ) 520 return run 521 522 def _get_run_path(self) -> Path | None: 523 run_path = self.run_path 524 if run_path is None: 525 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs(): 526 raise ValueError(f"Run file or dataset with scoreddocs required. Got {self._dataset}") 527 try: 528 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path() 529 except NotImplementedError: 530 pass 531 return run_path 532 533 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame: 534 run = run.rename( 535 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id", "Q0": "iteration", "q0": "iteration"}, 536 axis=1, 537 ) 538 if "query" in run.columns: 539 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text") 540 run = run.drop("query", axis=1) 541 if "text" in run.columns: 542 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict() 543 run = run.drop("text", axis=1) 544 if self.depth != -1: 545 run = run[run["rank"] <= self.depth] 546 dtypes = {"rank": np.int32} 547 if "score" in run.columns: 548 dtypes["score"] = np.float32 549 run = run.astype(dtypes) 550 return run 551 552 def _load_run(self) -> pd.DataFrame: 553 554 suffix_load_map = { 555 ".tsv": self._load_csv, 556 ".run": self._load_csv, 557 ".csv": self._load_csv, 558 ".parquet": self._load_parquet, 559 ".json": self._load_json, 560 ".jsonl": self._load_json, 561 } 562 run = None 563 564 # try loading run from file 565 run_path = self._get_run_path() 566 if run_path is not None: 567 load_func = suffix_load_map.get(run_path.suffixes[0], None) 568 if load_func is not None: 569 try: 570 run = load_func(run_path) 571 except Exception: 572 pass 573 574 # try loading run from ir_datasets 575 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs(): 576 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter()) 577 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False) 578 run = run.sort_values(["query_id", "rank"]) 579 580 if run is None: 581 raise ValueError("Invalid run file format.") 582 583 run = self._clean_run(run) 584 return run 585 586 @property 587 def qrels(self) -> pd.DataFrame | None: 588 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None. 589 590 Returns: 591 pd.DataFrame | None: Qrels. 592 """ 593 if self._qrels is not None: 594 return self._qrels 595 if self.run is not None and "relevance" in self.run: 596 qrels = self.run[["query_id", "doc_id", "relevance"]].copy() 597 if "iteration" in self.run: 598 qrels["iteration"] = self.run["iteration"] 599 else: 600 qrels["iteration"] = "0" 601 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore") 602 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 603 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 604 self._qrels = qrels 605 return self._qrels 606 return super().qrels 607 608 def __len__(self) -> int: 609 """Number of queries in the dataset. 610 611 Returns: 612 int: Number of queries. 613 """ 614 self._setup() 615 return len(self.query_ids) 616 617 def __getitem__(self, idx: int) -> RankSample: 618 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according 619 to the sampling strategy and sample size. 620 621 Args: 622 idx (int): Index of the query. 623 Returns: 624 RankSample: Sampled query and documents. 625 Raises: 626 ValueError: If the targets are not found in the run file. 627 """ 628 self._setup() 629 query_id = str(self.query_ids[idx]) 630 group = self.run_groups.get_group(query_id).copy() 631 query = self.queries[query_id] 632 group = Sampler.sample(group, self.sample_size, self.sampling_strategy) 633 634 doc_ids = tuple(group["doc_id"]) 635 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 636 637 targets = None 638 if self.targets is not None: 639 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0) 640 if filtered.empty: 641 raise ValueError(f"targets `{self.targets}` not found in run file") 642 targets = torch.from_numpy(filtered.values) 643 if self.targets == "rank": 644 # invert ranks to be higher is better (necessary for loss functions) 645 targets = self.depth - targets + 1 646 if self.normalize_targets: 647 targets_min = targets.min() 648 targets_max = targets.max() 649 targets = (targets - targets_min) / (targets_max - targets_min) 650 qrels = None 651 if self.qrels is not None: 652 qrels = ( 653 self.qrels.loc[[query_id]] 654 .stack(future_stack=True) 655 .dropna() 656 .astype(int) 657 .reset_index() 658 .to_dict(orient="records") 659 ) 660 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
661 662
[docs] 663class TupleDataset(IRDataset, IterableDataset):
[docs] 664 def __init__( 665 self, 666 tuples_dataset: str, 667 targets: Literal["order", "score"] = "order", 668 num_docs: int | None = None, 669 ) -> None: 670 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks. 671 672 Args: 673 tuples_dataset (str): Path to file containing tuples or valid ir_datasets id. 674 targets (Literal["order", "score"], optional): Data type to use as targets for a model during fine-tuning. 675 Defaults to "order". 676 num_docs (int | None, optional): Maximum number of documents per query. Defaults to None. 677 """ 678 super().__init__(tuples_dataset) 679 super(IRDataset, self).__init__() 680 self.targets = targets 681 self.num_docs = num_docs
682 683 def _parse_sample( 684 self, sample: ScoredDocTuple | GenericDocPair 685 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]: 686 if isinstance(sample, GenericDocPair): 687 if self.targets == "score": 688 raise ValueError("ScoredDocTuple required for score targets.") 689 targets = (1.0, 0.0) 690 doc_ids = (sample.doc_id_a, sample.doc_id_b) 691 elif isinstance(sample, ScoredDocTuple): 692 doc_ids = sample.doc_ids[: self.num_docs] 693 if self.targets == "score": 694 if sample.scores is None: 695 raise ValueError("tuples dataset does not contain scores") 696 targets = sample.scores 697 elif self.targets == "order": 698 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1)) 699 else: 700 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)") 701 targets = targets[: self.num_docs] 702 else: 703 raise ValueError("Invalid sample type.") 704 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 705 return doc_ids, docs, targets 706 707 def __iter__(self) -> Iterator[RankSample]: 708 """Iterates over tuples in the dataset. 709 710 Yields: 711 RankSample: Sampled query and documents with targets. 712 """ 713 for sample in self.ir_dataset.docpairs_iter(): 714 query_id = sample.query_id 715 query = self.queries.loc[query_id] 716 doc_ids, docs, targets = self._parse_sample(sample) 717 if targets is not None: 718 targets = torch.tensor(targets) 719 yield RankSample(query_id, query, doc_ids, docs, targets) 720
[docs] 721 def prepare_data(self) -> None: 722 """Downloads tuples using ir_datasets if needed.""" 723 self.prepare_constituent("docs") 724 self.prepare_constituent("queries") 725 self.prepare_constituent("docpairs")