Source code for lightning_ir.data.dataset

  1"""
  2Datasets for Lightning IR that data loading and sampling.
  3
  4This module defines several datasets that handle loading and sampling data for training and inference.
  5"""
  6
  7import csv
  8import warnings
  9from collections.abc import Iterator, Sequence
 10from itertools import islice
 11from pathlib import Path
 12from typing import Any, Literal, Self
 13
 14import ir_datasets
 15import numpy as np
 16import pandas as pd
 17import torch
 18from ir_datasets.formats import GenericDoc, GenericDocPair
 19from torch.distributed import get_rank, get_world_size
 20from torch.utils.data import Dataset, IterableDataset, get_worker_info
 21
 22from .data import DocSample, QuerySample, RankSample
 23from .external_datasets.ir_datasets_utils import ScoredDocTuple
 24from .external_datasets.lsr_benchmark import register_lsr_dataset_to_ir_datasets
 25
 26RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
 27
 28
 29class _DummyIterableDataset(IterableDataset):
 30    """Dummy iterable dataset to use when all inference datasets are skipped."""
 31
 32    def __iter__(self) -> Iterator:
 33        yield from iter([])
 34
 35
[docs] 36class IRDataset: 37 _SKIP: bool = False 38 """set to True to skip the dataset during inference.""" 39
[docs] 40 def __init__(self, dataset: str) -> None: 41 """Initializes a new IRDataset. 42 43 Args: 44 dataset (str): Dataset name. 45 """ 46 super().__init__() 47 self._dataset = dataset 48 self._queries = None 49 self._docs = None 50 self._qrels = None
51 52 @property 53 def dataset(self) -> str: 54 """Dataset name. 55 56 Returns: 57 str: Dataset name. 58 """ 59 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset) 60 61 @property 62 def dataset_id(self) -> str: 63 """Dataset id. 64 65 Returns: 66 str: Dataset id. 67 """ 68 if self.ir_dataset is None: 69 return self.dataset 70 return self.ir_dataset.dataset_id() 71 72 @property 73 def docs_dataset_id(self) -> str: 74 """ID of the dataset containing the documents. 75 76 Returns: 77 str: Document dataset id. 78 """ 79 return ir_datasets.docs_parent_id(self.dataset_id) 80 81 @property 82 def dashed_docs_dataset_id(self) -> str: 83 """Dataset id with dashes instead of slashes for the documents dataset. 84 85 Returns: 86 str: Document dataset id with dashes. 87 """ 88 return self.docs_dataset_id.replace("/", "-") 89 90 @property 91 def ir_dataset(self) -> ir_datasets.Dataset | None: 92 """Instance of ir_datasets.Dataset. 93 94 Returns: 95 ir_datasets.Dataset | None: Instance of ir_datasets.Dataset or None if the dataset is not found. 96 """ 97 try: 98 return ir_datasets.load(self.dataset) 99 except KeyError: 100 if self.dataset and self.dataset.startswith("lsr-benchmark/"): 101 register_lsr_dataset_to_ir_datasets(self.dataset) 102 try: 103 return ir_datasets.load(self.dataset) 104 except KeyError: 105 return None 106 return None 107 108 @property 109 def DASHED_DATASET_MAP(self) -> dict[str, str]: 110 """Map of dataset names with dashes to dataset names with slashes. 111 112 Returns: 113 dict[str, str]: Dataset map. 114 """ 115 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered} 116 117 @property 118 def queries(self) -> pd.Series: 119 """Queries in the dataset. 120 121 Returns: 122 pd.Series: Queries. 123 Raises: 124 ValueError: If no queries are found in the dataset. 125 """ 126 if self._queries is None: 127 if self.ir_dataset is None: 128 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 129 queries_iter = self.ir_dataset.queries_iter() 130 self._queries = pd.Series( 131 {query.query_id: query.default_text() for query in queries_iter}, 132 name="text", 133 ) 134 self._queries.index.name = "query_id" 135 return self._queries 136 137 @property 138 def docs(self) -> ir_datasets.indices.Docstore | dict[str, GenericDoc]: 139 """Documents in the dataset. 140 141 Returns: 142 ir_datasets.indices.Docstore | dict[str, GenericDoc]: Documents. 143 Raises: 144 ValueError: If no documents are found in the dataset. 145 """ 146 if self._docs is None: 147 if self.ir_dataset is None: 148 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 149 self._docs = self.ir_dataset.docs_store() 150 return self._docs 151 152 @property 153 def qrels(self) -> pd.DataFrame | None: 154 """Qrels in the dataset. 155 156 Returns: 157 pd.DataFrame | None: Qrels. 158 """ 159 if self._qrels is not None: 160 return self._qrels 161 if self.ir_dataset is None or not self.ir_dataset.has_qrels(): 162 return None 163 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1) 164 if "iteration" not in qrels.columns: 165 qrels["iteration"] = 0 166 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 167 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 168 self._qrels = qrels 169 return self._qrels 170
[docs] 171 def prepare_constituent(self, constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]) -> None: 172 """Downloads the constituent of the dataset using ir_datasets if needed. 173 174 Args: 175 constituent (Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]): Constituent to download. 176 """ 177 if self.ir_dataset is None: 178 return 179 if self.ir_dataset.has(constituent): 180 if constituent == "docs" and hasattr(self.ir_dataset, "docs_store"): 181 # build docs store if not already built 182 docs_store = self.ir_dataset.docs_store() 183 if not docs_store.built(): 184 docs_store.build() 185 else: 186 # get first item to trigger download 187 next(getattr(self.ir_dataset, f"{constituent}_iter")())
188 189 190class _DataParallelIterableDataset(IterableDataset): 191 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734 192 def __init__(self) -> None: 193 super().__init__() 194 # TODO add support for multi-gpu and multi-worker inference; currently 195 # doesn't work 196 worker_info = get_worker_info() 197 num_workers = worker_info.num_workers if worker_info is not None else 1 198 worker_id = worker_info.id if worker_info is not None else 0 199 200 try: 201 world_size = get_world_size() 202 process_rank = get_rank() 203 except (RuntimeError, ValueError): 204 world_size = 1 205 process_rank = 0 206 207 self.num_replicas = num_workers * world_size 208 self.rank = process_rank * num_workers + worker_id 209 210
[docs] 211class QueryDataset(IRDataset, _DataParallelIterableDataset):
[docs] 212 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None: 213 """Dataset containing queries. 214 215 Args: 216 query_dataset (str): Path to file containing queries or valid ir_datasets id. 217 num_queries (int | None, optional): Number of queries in dataset. If None, the number of queries will 218 attempted to be inferred. Defaults to None. 219 """ 220 super().__init__(query_dataset) 221 super(IRDataset, self).__init__() 222 self.num_queries = num_queries
223 224 def __len__(self) -> int | None: 225 """Number of queries in the dataset. Returns None if the number of queries cannot be inferred. 226 227 Returns: 228 int | None: Number of queries. 229 """ 230 # TODO fix len for multi-gpu and multi-worker inference 231 return self.num_queries or getattr(self.ir_dataset, "queries_count", lambda: None)() or None 232 233 def __iter__(self) -> Iterator[QuerySample]: 234 """Iterate over queries in the dataset. 235 236 Yields: 237 QuerySample: Query sample. 238 """ 239 start = self.rank 240 stop = self.num_queries 241 step = self.num_replicas 242 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step): 243 query_sample = QuerySample.from_ir_dataset_sample(sample) 244 if self.qrels is not None: 245 qrels = ( 246 self.qrels.loc[[query_sample.query_id]] 247 .stack(future_stack=True) 248 .dropna() 249 .astype(int) 250 .reset_index() 251 .to_dict(orient="records") 252 ) 253 query_sample.qrels = qrels 254 yield query_sample 255
[docs] 256 def prepare_data(self) -> None: 257 """Downloads queries using ir_datasets if needed.""" 258 self.prepare_constituent("queries")
259 260
[docs] 261class DocDataset(IRDataset, _DataParallelIterableDataset):
[docs] 262 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None: 263 """Dataset containing documents. 264 265 Args: 266 doc_dataset (str): Path to file containing documents or valid ir_datasets id. 267 num_docs (int | None, optional): Number of documents in dataset. If None, the number of documents will 268 attempted to be inferred. Defaults to None. 269 text_fields (Sequence[str] | None, optional): Fields to parse the document text from. Defaults to None. 270 """ 271 super().__init__(doc_dataset) 272 super(IRDataset, self).__init__() 273 self.num_docs = num_docs 274 self.text_fields = text_fields
275 276 def __len__(self) -> int | None: 277 """Number of documents in the dataset. Returns None if the number of documents cannot be inferred. 278 279 Returns: 280 int | None: Number of documents. 281 """ 282 # TODO fix len for multi-gpu and multi-worker inference 283 return self.num_docs or getattr(self.ir_dataset, "docs_count", lambda: None)() or None 284 285 def __iter__(self) -> Iterator[DocSample]: 286 """Iterate over documents in the dataset. 287 288 Yields: 289 DocSample: Document sample. 290 """ 291 start = self.rank 292 stop = self.num_docs 293 step = self.num_replicas 294 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step): 295 yield DocSample.from_ir_dataset_sample(sample, self.text_fields) 296
[docs] 297 def prepare_data(self) -> None: 298 """Downloads documents using ir_datasets if needed.""" 299 self.prepare_constituent("docs")
300 301
[docs] 302class Sampler: 303 """Helper class for sampling subsets of documents from a ranked list.""" 304
[docs] 305 @staticmethod 306 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 307 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 308 are non-relevant. 309 310 Args: 311 documents (pd.DataFrame): Ranked list of documents. 312 sample_size (int): Number of documents to sample. 313 Returns: 314 pd.DataFrame: Sampled documents. 315 """ 316 relevance = documents.filter(like="relevance").max(axis=1).fillna(0) 317 relevant = documents.loc[relevance.gt(0)].sample(1) 318 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna() 319 num_non_relevant = non_relevant_bool.sum() 320 sample_non_relevant = min(sample_size - 1, num_non_relevant) 321 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant) 322 return pd.concat([relevant, non_relevant])
323
[docs] 324 @staticmethod 325 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 326 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 327 are non-relevant. 328 329 Args: 330 documents (pd.DataFrame): Ranked list of documents. 331 sample_size (int): Number of documents to sample. 332 Returns: 333 pd.DataFrame: Sampled documents. 334 """ 335 return documents.head(sample_size)
336
[docs] 337 @staticmethod 338 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 339 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the 340 other half randomly. 341 342 Args: 343 documents (pd.DataFrame): Ranked list of documents. 344 sample_size (int): Number of documents to sample. 345 Returns: 346 pd.DataFrame: Sampled documents. 347 """ 348 top_size = sample_size // 2 349 random_size = sample_size - top_size 350 top = documents.head(top_size) 351 random = documents.iloc[top_size:].sample(random_size) 352 return pd.concat([top, random])
353
[docs] 354 @staticmethod 355 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 356 """Sampling strategy to randomly sample ``sample_size`` documents. 357 358 Args: 359 documents (pd.DataFrame): Ranked list of documents. 360 sample_size (int): Number of documents to sample. 361 Returns: 362 pd.DataFrame: Sampled documents. 363 """ 364 return documents.sample(sample_size)
365
[docs] 366 @staticmethod 367 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 368 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of 369 the ranking. 370 371 Args: 372 documents (pd.DataFrame): Ranked list of documents. 373 sample_size (int): Number of documents to sample. 374 Returns: 375 pd.DataFrame: Sampled documents. 376 """ 377 weights = 1 / np.log1p(documents["rank"]) 378 weights[weights.isna()] = weights.min() 379 return documents.sample(sample_size, weights=weights)
380
[docs] 381 @staticmethod 382 def sample( 383 df: pd.DataFrame, 384 sample_size: int, 385 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"], 386 ) -> pd.DataFrame: 387 """ 388 Samples a subset of documents from a ranked list given a sampling_strategy. 389 390 Args: 391 documents (pd.DataFrame): Ranked list of documents. 392 sample_size (int): Number of documents to sample. 393 Returns: 394 pd.DataFrame: Sampled documents. 395 """ 396 if sample_size == -1: 397 return df 398 if hasattr(Sampler, sampling_strategy): 399 return getattr(Sampler, sampling_strategy)(df, sample_size) 400 raise ValueError("Invalid sampling strategy.")
401 402
[docs] 403class RunDataset(IRDataset, Dataset):
[docs] 404 def __init__( 405 self, 406 run_path_or_id: Path | str, 407 depth: int = -1, 408 sample_size: int = -1, 409 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top", 410 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None, 411 normalize_targets: bool = False, 412 add_docs_not_in_ranking: bool = False, 413 ) -> None: 414 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list 415 can be sampled using different sampling strategies. 416 417 Args: 418 run_path_or_id (Path | str): Path to a run file or valid ir_datasets id. 419 depth (int): Depth at which to cut off the ranking. If -1 the full ranking is kept. 420 Defaults to -1. 421 sample_size (int): The number of documents to sample per query. Defaults to -1. 422 sampling_strategy (Literal["single_relevant", "top", "random", "log_random", "top_and_random"]): 423 The sample strategy to use to sample documents. Defaults to "top". 424 targets (Literal["relevance", "subtopic_relevance", "rank", "score"] | None): 425 The data type to use as targets for a model during fine-tuning. If "relevance" the relevance 426 judgements are parsed from qrels. Defaults to None. 427 normalize_targets (bool): Whether to normalize the targets between 0 and 1. Defaults to False. 428 add_docs_not_in_ranking (bool): Whether to add relevant documents to a sample that are in the qrels but not 429 in the ranking. Defaults to False. 430 """ 431 self.run_path = None 432 if Path(run_path_or_id).is_file(): 433 self.run_path = Path(run_path_or_id) 434 dataset = self.run_path.name.split(".")[0].split("__")[-1] 435 else: 436 dataset = str(run_path_or_id) 437 super().__init__(dataset) 438 self.depth = depth 439 self.sample_size = sample_size 440 self.sampling_strategy = sampling_strategy 441 self.targets = targets 442 self.normalize_targets = normalize_targets 443 self.add_docs_not_in_ranking = add_docs_not_in_ranking 444 445 if self.sampling_strategy == "top" and self.sample_size > self.depth: 446 warnings.warn( 447 "Sample size is greater than depth and top sampling strategy is used. " 448 "This can cause documents to be sampled that are not contained " 449 "in the run file, but that are present in the qrels.", 450 stacklevel=2, 451 ) 452 453 self.run: pd.DataFrame | None = None
454
[docs] 455 def prepare_data(self) -> Self: 456 """Downloads docs, queries, scoreddocs, and qrels using ir_datasets if needed and available.""" 457 self.prepare_constituent("docs") 458 self.prepare_constituent("queries") 459 if self.run_path is None: 460 self.prepare_constituent("scoreddocs") 461 self.prepare_constituent("docpairs") 462 self.prepare_constituent("qrels") 463 return self
464 465 def _setup(self): 466 if self.run is not None: 467 return 468 self.run = self._load_run() 469 self.run = self.run.drop_duplicates(["query_id", "doc_id"]) 470 471 if self.qrels is not None: 472 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates()) 473 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique() 474 query_ids = run_query_ids.intersection(qrels_query_ids) 475 if len(run_query_ids.difference(qrels_query_ids)): 476 self.run = self.run[self.run["query_id"].isin(query_ids)] 477 # outer join if docs are from ir_datasets else only keep docs in run 478 how = "left" 479 if self._docs is None and self.add_docs_not_in_ranking: 480 how = "outer" 481 self.run = self.run.merge( 482 self.qrels.loc[pd.IndexSlice[query_ids, :]].droplevel(0, axis=1).add_prefix("relevance_", axis=1), 483 on=["query_id", "doc_id"], 484 how=how, 485 ) 486 487 if self.sample_size != -1: 488 num_docs_per_query = self.run.groupby("query_id").transform("size") 489 self.run = self.run[num_docs_per_query >= self.sample_size] 490 491 self.run = self.run.sort_values(["query_id", "rank"]) 492 self.run_groups = self.run.groupby("query_id") 493 self.query_ids = list(self.run_groups.groups.keys()) 494 495 if self.depth != -1 and self.run["rank"].max() < self.depth: 496 warnings.warn("Depth is greater than the maximum rank in the run file.", stacklevel=2) 497 498 @staticmethod 499 def _load_csv(path: Path) -> pd.DataFrame: 500 return pd.read_csv( 501 path, 502 sep=r"\s+", 503 header=None, 504 names=RUN_HEADER, 505 usecols=[0, 1, 2, 3, 4], 506 dtype={"query_id": str, "doc_id": str}, 507 quoting=csv.QUOTE_NONE, 508 na_filter=False, 509 ) 510 511 @staticmethod 512 def _load_parquet(path: Path) -> pd.DataFrame: 513 return pd.read_parquet(path) 514 515 @staticmethod 516 def _load_json(path: Path) -> pd.DataFrame: 517 kwargs: dict[str, Any] = {} 518 if ".jsonl" in path.suffixes: 519 kwargs["lines"] = True 520 kwargs["orient"] = "records" 521 run = pd.read_json(path, **kwargs) 522 return run 523 524 def _get_run_path(self) -> Path | None: 525 run_path = self.run_path 526 if run_path is None: 527 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs(): 528 raise ValueError(f"Run file or dataset with scoreddocs required. Got {self._dataset}") 529 try: 530 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path() 531 except NotImplementedError: 532 pass 533 return run_path 534 535 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame: 536 run = run.rename( 537 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id", "Q0": "iteration", "q0": "iteration"}, 538 axis=1, 539 ) 540 dtypes = {"rank": np.int32, "query_id": str, "doc_id": str} 541 if "score" in run.columns: 542 dtypes["score"] = np.float32 543 run = run.astype(dtypes) 544 if "query" in run.columns: 545 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text") 546 run = run.drop("query", axis=1) 547 if "text" in run.columns: 548 self._docs = ( 549 run.drop_duplicates("doc_id").set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict() 550 ) 551 run = run.drop("text", axis=1) 552 if self.depth != -1: 553 run = run[run["rank"] <= self.depth] 554 return run 555 556 def _load_run(self) -> pd.DataFrame: 557 suffix_load_map = { 558 ".tsv": self._load_csv, 559 ".run": self._load_csv, 560 ".csv": self._load_csv, 561 ".parquet": self._load_parquet, 562 ".json": self._load_json, 563 ".jsonl": self._load_json, 564 } 565 run = None 566 567 # try loading run from file 568 run_path = self._get_run_path() 569 if run_path is not None: 570 load_func = suffix_load_map.get(run_path.suffixes[0], None) 571 if load_func is not None: 572 try: 573 run = load_func(run_path) 574 except Exception: 575 pass 576 577 # try loading run from ir_datasets 578 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs(): 579 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter()) 580 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False) 581 run = run.sort_values(["query_id", "rank"]) 582 583 if run is None: 584 raise ValueError("Invalid run file format.") 585 586 run = self._clean_run(run) 587 return run 588 589 @property 590 def qrels(self) -> pd.DataFrame | None: 591 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None. 592 593 Returns: 594 pd.DataFrame | None: Qrels. 595 """ 596 if self._qrels is not None: 597 return self._qrels 598 if self.run is not None and "relevance" in self.run: 599 qrels = self.run[["query_id", "doc_id", "relevance"]].copy() 600 if "iteration" in self.run: 601 qrels["iteration"] = self.run["iteration"] 602 else: 603 qrels["iteration"] = "0" 604 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore") 605 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 606 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 607 self._qrels = qrels 608 return self._qrels 609 return super().qrels 610 611 def __len__(self) -> int: 612 """Number of queries in the dataset. 613 614 Returns: 615 int: Number of queries. 616 """ 617 self._setup() 618 return len(self.query_ids) 619 620 def __getitem__(self, idx: int) -> RankSample: 621 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according 622 to the sampling strategy and sample size. 623 624 Args: 625 idx (int): Index of the query. 626 Returns: 627 RankSample: Sampled query and documents. 628 Raises: 629 ValueError: If the targets are not found in the run file. 630 """ 631 self._setup() 632 query_id = str(self.query_ids[idx]) 633 group = self.run_groups.get_group(query_id).copy() 634 query = self.queries[query_id] 635 group = Sampler.sample(group, self.sample_size, self.sampling_strategy) 636 637 doc_ids = tuple(group["doc_id"]) 638 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 639 640 targets = None 641 if self.targets is not None: 642 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0) 643 if filtered.empty: 644 raise ValueError(f"targets `{self.targets}` not found in run file") 645 targets = torch.from_numpy(filtered.values) 646 if self.targets == "rank": 647 # invert ranks to be higher is better (necessary for loss functions) 648 targets = self.depth - targets + 1 649 if self.normalize_targets: 650 targets_min = targets.min() 651 targets_max = targets.max() 652 targets = (targets - targets_min) / (targets_max - targets_min) 653 qrels = None 654 if self.qrels is not None: 655 qrels = ( 656 self.qrels.loc[[query_id]] 657 .stack(future_stack=True) 658 .dropna() 659 .astype(int) 660 .reset_index() 661 .to_dict(orient="records") 662 ) 663 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
664 665
[docs] 666class TupleDataset(IRDataset, IterableDataset):
[docs] 667 def __init__( 668 self, 669 tuples_dataset: str, 670 targets: Literal["order", "score"] = "order", 671 num_docs: int | None = None, 672 ) -> None: 673 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks. 674 675 Args: 676 tuples_dataset (str): Path to file containing tuples or valid ir_datasets id. 677 targets (Literal["order", "score"], optional): Data type to use as targets for a model during fine-tuning. 678 Defaults to "order". 679 num_docs (int | None, optional): Maximum number of documents per query. Defaults to None. 680 """ 681 super().__init__(tuples_dataset) 682 super(IRDataset, self).__init__() 683 self.targets = targets 684 self.num_docs = num_docs
685 686 def _parse_sample( 687 self, sample: ScoredDocTuple | GenericDocPair 688 ) -> tuple[tuple[str, ...], tuple[str, ...], tuple[float, ...] | None]: 689 if isinstance(sample, GenericDocPair): 690 if self.targets == "score": 691 raise ValueError("ScoredDocTuple required for score targets.") 692 targets = (1.0, 0.0) 693 doc_ids = (sample.doc_id_a, sample.doc_id_b) 694 elif isinstance(sample, ScoredDocTuple): 695 doc_ids = sample.doc_ids[: self.num_docs] 696 if self.targets == "score": 697 if sample.scores is None: 698 raise ValueError("tuples dataset does not contain scores") 699 targets = sample.scores 700 elif self.targets == "order": 701 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1)) 702 else: 703 raise ValueError(f"invalid value for targets, got {self.targets}, expected one of (order, score)") 704 targets = targets[: self.num_docs] 705 else: 706 raise ValueError("Invalid sample type.") 707 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 708 return doc_ids, docs, targets 709 710 def __iter__(self) -> Iterator[RankSample]: 711 """Iterates over tuples in the dataset. 712 713 Yields: 714 RankSample: Sampled query and documents with targets. 715 """ 716 for sample in self.ir_dataset.docpairs_iter(): 717 query_id = sample.query_id 718 query = self.queries.loc[query_id] 719 doc_ids, docs, targets = self._parse_sample(sample) 720 if targets is not None: 721 targets = torch.tensor(targets) 722 yield RankSample(query_id, query, doc_ids, docs, targets) 723
[docs] 724 def prepare_data(self) -> None: 725 """Downloads tuples using ir_datasets if needed.""" 726 self.prepare_constituent("docs") 727 self.prepare_constituent("queries") 728 self.prepare_constituent("docpairs")