Source code for lightning_ir.data.dataset

  1"""
  2Datasets for Lightning IR that data loading and sampling.
  3
  4This module defines several datasets that handle loading and sampling data for training and inference.
  5"""
  6
  7import csv
  8import warnings
  9from itertools import islice
 10from pathlib import Path
 11from typing import Any, Dict, Iterator, Literal, Sequence, Tuple
 12
 13import ir_datasets
 14import numpy as np
 15import pandas as pd
 16import torch
 17from ir_datasets.formats import GenericDoc, GenericDocPair
 18from torch.distributed import get_rank, get_world_size
 19from torch.utils.data import Dataset, IterableDataset, get_worker_info
 20
 21from .data import DocSample, QuerySample, RankSample
 22from .external_datasets.ir_datasets_utils import ScoredDocTuple
 23
 24RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
 25
 26
 27class _DummyIterableDataset(IterableDataset):
 28    """Dummy iterable dataset to use when all inference datasets are skipped."""
 29
 30    def __iter__(self) -> Iterator:
 31        yield from iter([])
 32
 33
[docs] 34class IRDataset: 35 36 _SKIP: bool = False 37 """Set to True to skip the dataset during inference.""" 38
[docs] 39 def __init__(self, dataset: str) -> None: 40 super().__init__() 41 self._dataset = dataset 42 self._queries = None 43 self._docs = None 44 self._qrels = None
45 46 @property 47 def dataset(self) -> str: 48 """Dataset name. 49 50 :return: Dataset name 51 :rtype: str 52 """ 53 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset) 54 55 @property 56 def dataset_id(self) -> str: 57 """Dataset id. 58 59 :return: Dataset id 60 :rtype: str 61 """ 62 if self.ir_dataset is None: 63 return self.dataset 64 return self.ir_dataset.dataset_id() 65 66 @property 67 def docs_dataset_id(self) -> str: 68 """ID of the dataset containing the documents. 69 70 :return: Document dataset id 71 :rtype: str 72 """ 73 return ir_datasets.docs_parent_id(self.dataset_id) 74 75 @property 76 def ir_dataset(self) -> ir_datasets.Dataset | None: 77 """Instance of ir_datasets.Dataset. 78 79 :return: ir_datasets dataset 80 :rtype: ir_datasets.Dataset | None 81 """ 82 try: 83 return ir_datasets.load(self.dataset) 84 except KeyError: 85 return None 86 87 @property 88 def DASHED_DATASET_MAP(self) -> Dict[str, str]: 89 """Map of dataset names with dashes to dataset names with slashes. 90 91 :return: Dataset map 92 :rtype: Dict[str, str] 93 """ 94 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered} 95 96 @property 97 def queries(self) -> pd.Series: 98 """Queries in the dataset. 99 100 :raises ValueError: If no queries are found in the dataset 101 :return: Queries 102 :rtype: pd.Series 103 """ 104 if self._queries is None: 105 if self.ir_dataset is None: 106 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 107 queries_iter = self.ir_dataset.queries_iter() 108 self._queries = pd.Series( 109 {query.query_id: query.default_text() for query in queries_iter}, 110 name="text", 111 ) 112 self._queries.index.name = "query_id" 113 return self._queries 114 115 @property 116 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]: 117 """Documents in the dataset. 118 119 :raises ValueError: If no documents are found in the dataset 120 :return: Documents 121 :rtype: ir_datasets.indices.Docstore | Dict[str, GenericDoc] 122 """ 123 if self._docs is None: 124 if self.ir_dataset is None: 125 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets") 126 self._docs = self.ir_dataset.docs_store() 127 return self._docs 128 129 @property 130 def qrels(self) -> pd.DataFrame | None: 131 """Qrels in the dataset. 132 133 :return: Qrels 134 :rtype: pd.DataFrame | None 135 """ 136 if self._qrels is not None: 137 return self._qrels 138 if self.ir_dataset is None or not self.ir_dataset.has_qrels(): 139 return None 140 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1) 141 if "iteration" not in qrels.columns: 142 qrels["iteration"] = 0 143 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 144 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 145 self._qrels = qrels 146 return self._qrels 147
[docs] 148 def prepare_constituent(self, constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]) -> None: 149 """Downloads the constituent of the dataset using ir_datasets if needed. 150 151 :param constituent: Constituent to download 152 :type constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"] 153 """ 154 if self.ir_dataset is None: 155 return 156 if self.ir_dataset.has(constituent): 157 if constituent == "docs" and hasattr(self.ir_dataset, "docs_store"): 158 # build docs store if not already built 159 docs_store = self.ir_dataset.docs_store() 160 if not docs_store.built(): 161 docs_store.build() 162 else: 163 # get first item to trigger download 164 next(getattr(self.ir_dataset, f"{constituent}_iter")())
165 166 167class _DataParallelIterableDataset(IterableDataset): 168 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734 169 def __init__(self) -> None: 170 super().__init__() 171 # TODO add support for multi-gpu and multi-worker inference; currently 172 # doesn't work 173 worker_info = get_worker_info() 174 num_workers = worker_info.num_workers if worker_info is not None else 1 175 worker_id = worker_info.id if worker_info is not None else 0 176 177 try: 178 world_size = get_world_size() 179 process_rank = get_rank() 180 except (RuntimeError, ValueError): 181 world_size = 1 182 process_rank = 0 183 184 self.num_replicas = num_workers * world_size 185 self.rank = process_rank * num_workers + worker_id 186 187
[docs] 188class QueryDataset(IRDataset, _DataParallelIterableDataset):
[docs] 189 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None: 190 """Dataset containing queries. 191 192 :param query_dataset: Path to file containing queries or valid ir_datasets id 193 :type query_dataset: str 194 :param num_queries: Number of queries in dataset. If None, the number of queries will attempted to be inferred, 195 defaults to None 196 :type num_queries: int | None, optional 197 """ 198 super().__init__(query_dataset) 199 super(IRDataset, self).__init__() 200 self.num_queries = num_queries
201 202 def __len__(self) -> int | None: 203 """Number of queries in the dataset. Returns None if the number of queries cannot be inferred. 204 205 :return: Number of queries 206 :rtype: int | None 207 """ 208 # TODO fix len for multi-gpu and multi-worker inference 209 return self.num_queries or getattr(self.ir_dataset, "queries_count", lambda: None)() or None 210 211 def __iter__(self) -> Iterator[QuerySample]: 212 """Iterate over queries in the dataset. 213 214 :yield: Query sample 215 :rtype: Iterator[QuerySample] 216 """ 217 start = self.rank 218 stop = self.num_queries 219 step = self.num_replicas 220 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step): 221 query_sample = QuerySample.from_ir_dataset_sample(sample) 222 if self.qrels is not None: 223 qrels = ( 224 self.qrels.loc[[query_sample.query_id]] 225 .stack(future_stack=True) 226 .dropna() 227 .astype(int) 228 .reset_index() 229 .to_dict(orient="records") 230 ) 231 query_sample.qrels = qrels 232 yield query_sample 233
[docs] 234 def prepare_data(self) -> None: 235 """Downloads queries using ir_datasets if needed.""" 236 self.prepare_constituent("queries")
237 238
[docs] 239class DocDataset(IRDataset, _DataParallelIterableDataset):
[docs] 240 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None: 241 """Dataset containing documents. 242 243 :param doc_dataset: Path to file containing documents or valid ir_datasets id 244 :type doc_dataset: str 245 :param num_docs: Number of documents in dataset. If None, the number of documents will attempted to be inferred, 246 defaults to None 247 :type num_docs: int | None, optional 248 :param text_fields: Fields to parse the document text from, defaults to None 249 :type text_fields: Sequence[str] | None, optional 250 """ 251 super().__init__(doc_dataset) 252 super(IRDataset, self).__init__() 253 self.num_docs = num_docs 254 self.text_fields = text_fields
255 256 def __len__(self) -> int | None: 257 """Number of documents in the dataset. Returns None if the number of documents cannot be inferred. 258 259 :return: Number of documents 260 :rtype: int | None 261 """ 262 # TODO fix len for multi-gpu and multi-worker inference 263 return self.num_docs or getattr(self.ir_dataset, "docs_count", lambda: None)() or None 264 265 def __iter__(self) -> Iterator[DocSample]: 266 """Iterate over documents in the dataset. 267 268 :yield: Doc sample 269 :rtype: Iterator[DocSample] 270 """ 271 start = self.rank 272 stop = self.num_docs 273 step = self.num_replicas 274 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step): 275 yield DocSample.from_ir_dataset_sample(sample, self.text_fields) 276
[docs] 277 def prepare_data(self) -> None: 278 """Downloads documents using ir_datasets if needed.""" 279 self.prepare_constituent("docs")
280 281
[docs] 282class Sampler: 283 """Helper class for sampling subsets of documents from a ranked list.""" 284
[docs] 285 @staticmethod 286 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 287 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 288 are non-relevant. 289 290 :param documents: Ranked list of documents 291 :type documents: pd.DataFrame 292 :param sample_size: Number of documents to sample 293 :type sample_size: int 294 :return: Sampled documents 295 :rtype: pd.DataFrame 296 """ 297 relevance = documents.filter(like="relevance").max(axis=1).fillna(0) 298 relevant = documents.loc[relevance.gt(0)].sample(1) 299 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna() 300 num_non_relevant = non_relevant_bool.sum() 301 sample_non_relevant = min(sample_size - 1, num_non_relevant) 302 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant) 303 return pd.concat([relevant, non_relevant])
304
[docs] 305 @staticmethod 306 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 307 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 308 are non-relevant. 309 310 :param documents: Ranked list of documents 311 :type documents: pd.DataFrame 312 :param sample_size: Number of documents to sample 313 :type sample_size: int 314 :return: Sampled documents 315 :rtype: pd.DataFrame 316 """ 317 return documents.head(sample_size)
318
[docs] 319 @staticmethod 320 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 321 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the 322 other half randomly. 323 324 :param documents: Ranked list of documents 325 :type documents: pd.DataFrame 326 :param sample_size: Number of documents to sample 327 :type sample_size: int 328 :return: Sampled documents 329 :rtype: pd.DataFrame 330 """ 331 top_size = sample_size // 2 332 random_size = sample_size - top_size 333 top = documents.head(top_size) 334 random = documents.iloc[top_size:].sample(random_size) 335 return pd.concat([top, random])
336
[docs] 337 @staticmethod 338 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 339 """Sampling strategy to randomly sample ``sample_size`` documents. 340 341 :param documents: Ranked list of documents 342 :type documents: pd.DataFrame 343 :param sample_size: Number of documents to sample 344 :type sample_size: int 345 :return: Sampled documents 346 :rtype: pd.DataFrame 347 """ 348 return documents.sample(sample_size)
349
[docs] 350 @staticmethod 351 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 352 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of 353 the ranking. 354 355 :param documents: Ranked list of documents 356 :type documents: pd.DataFrame 357 :param sample_size: Number of documents to sample 358 :type sample_size: int 359 :return: Sampled documents 360 :rtype: pd.DataFrame 361 """ 362 weights = 1 / np.log1p(documents["rank"]) 363 weights[weights.isna()] = weights.min() 364 return documents.sample(sample_size, weights=weights)
365
[docs] 366 @staticmethod 367 def sample( 368 df: pd.DataFrame, 369 sample_size: int, 370 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"], 371 ) -> pd.DataFrame: 372 """ 373 Samples a subset of documents from a ranked list given a sampling_strategy. 374 375 :param documents: Ranked list of documents 376 :type documents: pd.DataFrame 377 :param sample_size: Number of documents to sample 378 :type sample_size: int 379 :return: Sampled documents 380 :rtype: pd.DataFrame 381 """ 382 if sample_size == -1: 383 return df 384 if hasattr(Sampler, sampling_strategy): 385 return getattr(Sampler, sampling_strategy)(df, sample_size) 386 raise ValueError("Invalid sampling strategy.")
387 388
[docs] 389class RunDataset(IRDataset, Dataset):
[docs] 390 def __init__( 391 self, 392 run_path_or_id: Path | str, 393 depth: int = -1, 394 sample_size: int = -1, 395 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top", 396 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None, 397 normalize_targets: bool = False, 398 add_docs_not_in_ranking: bool = False, 399 ) -> None: 400 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list 401 can be sampled using different sampling strategies. 402 403 :param run_path_or_id: Path to a run file or valid ir_datasets id 404 :type run_path_or_id: Path | str 405 :param depth: Depth at which to cut off the ranking. If -1 the full ranking is kept, defaults to -1 406 :type depth: int, optional 407 :param sample_size: The number of documents to sample per query, defaults to -1 408 :type sample_size: int, optional 409 :param sampling_strategy: The sample strategy to use to sample documents, defaults to "top" 410 :type sampling_strategy: Literal['single_relevant', 'top', 'random', 'log_random', 'top_and_random'], optional 411 :param targets: The data type to use as targets for a model during fine-tuning. If relevance the relevance 412 judgements are parsed from qrels, defaults to None 413 :type targets: Literal['relevance', 'subtopic_relevance', 'rank', 'score'] | None, optional 414 :param normalize_targets: Whether to normalize the targets between 0 and 1, defaults to False 415 :type normalize_targets: bool, optional 416 :param add_docs_not_in_ranking: Whether to add relevant to a sample that are in the qrels but not in the 417 ranking, defaults to False 418 :type add_docs_not_in_ranking: bool, optional 419 """ 420 self.run_path = None 421 if Path(run_path_or_id).is_file(): 422 self.run_path = Path(run_path_or_id) 423 dataset = self.run_path.name.split(".")[0].split("__")[-1] 424 else: 425 dataset = str(run_path_or_id) 426 super().__init__(dataset) 427 self.depth = depth 428 self.sample_size = sample_size 429 self.sampling_strategy = sampling_strategy 430 self.targets = targets 431 self.normalize_targets = normalize_targets 432 self.add_docs_not_in_ranking = add_docs_not_in_ranking 433 434 if self.sampling_strategy == "top" and self.sample_size > self.depth: 435 warnings.warn( 436 "Sample size is greater than depth and top sampling strategy is used. " 437 "This can cause documents to be sampled that are not contained " 438 "in the run file, but that are present in the qrels." 439 ) 440 441 self.run: pd.DataFrame | None = None
442
[docs] 443 def prepare_data(self) -> None: 444 """Downloads docs, queries, scoreddocs, and qrels using ir_datasets if needed and available.""" 445 self.prepare_constituent("docs") 446 self.prepare_constituent("queries") 447 if self.run_path is None: 448 self.prepare_constituent("scoreddocs") 449 self.prepare_constituent("qrels")
450 451 def _setup(self): 452 if self.run is not None: 453 return 454 self.run = self._load_run() 455 self.run = self.run.drop_duplicates(["query_id", "doc_id"]) 456 457 if self.qrels is not None: 458 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates()) 459 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique() 460 query_ids = run_query_ids.intersection(qrels_query_ids) 461 if len(run_query_ids.difference(qrels_query_ids)): 462 self.run = self.run[self.run["query_id"].isin(query_ids)] 463 # outer join if docs are from ir_datasets else only keep docs in run 464 how = "left" 465 if self._docs is None and self.add_docs_not_in_ranking: 466 how = "outer" 467 self.run = self.run.merge( 468 self.qrels.loc[pd.IndexSlice[query_ids, :]].droplevel(0, axis=1).add_prefix("relevance_", axis=1), 469 on=["query_id", "doc_id"], 470 how=how, 471 ) 472 473 if self.sample_size != -1: 474 num_docs_per_query = self.run.groupby("query_id").transform("size") 475 self.run = self.run[num_docs_per_query >= self.sample_size] 476 477 self.run = self.run.sort_values(["query_id", "rank"]) 478 self.run_groups = self.run.groupby("query_id") 479 self.query_ids = list(self.run_groups.groups.keys()) 480 481 if self.depth != -1 and self.run["rank"].max() < self.depth: 482 warnings.warn("Depth is greater than the maximum rank in the run file.") 483 484 @staticmethod 485 def _load_csv(path: Path) -> pd.DataFrame: 486 return pd.read_csv( 487 path, 488 sep=r"\s+", 489 header=None, 490 names=RUN_HEADER, 491 usecols=[0, 1, 2, 3, 4], 492 dtype={"query_id": str, "doc_id": str}, 493 quoting=csv.QUOTE_NONE, 494 na_filter=False, 495 ) 496 497 @staticmethod 498 def _load_parquet(path: Path) -> pd.DataFrame: 499 return pd.read_parquet(path).rename( 500 { 501 "qid": "query_id", 502 "docid": "doc_id", 503 "docno": "doc_id", 504 }, 505 axis=1, 506 ) 507 508 @staticmethod 509 def _load_json(path: Path) -> pd.DataFrame: 510 kwargs: Dict[str, Any] = {} 511 if ".jsonl" in path.suffixes: 512 kwargs["lines"] = True 513 kwargs["orient"] = "records" 514 run = pd.read_json( 515 path, 516 **kwargs, 517 dtype={ 518 "query_id": str, 519 "qid": str, 520 "doc_id": str, 521 "docid": str, 522 "docno": str, 523 }, 524 ) 525 return run 526 527 def _get_run_path(self) -> Path | None: 528 run_path = self.run_path 529 if run_path is None: 530 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs(): 531 raise ValueError(f"Run file or dataset with scoreddocs required. Got {self._dataset}") 532 try: 533 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path() 534 except NotImplementedError: 535 pass 536 return run_path 537 538 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame: 539 run = run.rename( 540 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id", "Q0": "iteration", "q0": "iteration"}, 541 axis=1, 542 ) 543 if "query" in run.columns: 544 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text") 545 run = run.drop("query", axis=1) 546 if "text" in run.columns: 547 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict() 548 run = run.drop("text", axis=1) 549 if self.depth != -1: 550 run = run[run["rank"] <= self.depth] 551 dtypes = {"rank": np.int32} 552 if "score" in run.columns: 553 dtypes["score"] = np.float32 554 run = run.astype(dtypes) 555 return run 556 557 def _load_run(self) -> pd.DataFrame: 558 559 suffix_load_map = { 560 ".tsv": self._load_csv, 561 ".run": self._load_csv, 562 ".csv": self._load_csv, 563 ".parquet": self._load_parquet, 564 ".json": self._load_json, 565 ".jsonl": self._load_json, 566 } 567 run = None 568 569 # try loading run from file 570 run_path = self._get_run_path() 571 if run_path is not None: 572 load_func = suffix_load_map.get(run_path.suffixes[0], None) 573 if load_func is not None: 574 try: 575 run = load_func(run_path) 576 except Exception: 577 pass 578 579 # try loading run from ir_datasets 580 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs(): 581 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter()) 582 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False) 583 run = run.sort_values(["query_id", "rank"]) 584 585 if run is None: 586 raise ValueError("Invalid run file format.") 587 588 run = self._clean_run(run) 589 return run 590 591 @property 592 def qrels(self) -> pd.DataFrame | None: 593 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None. 594 595 :return: Qrels 596 :rtype: pd.DataFrame | None 597 """ 598 if self._qrels is not None: 599 return self._qrels 600 if self.run is not None and "relevance" in self.run: 601 qrels = self.run[["query_id", "doc_id", "relevance"]].copy() 602 if "iteration" in self.run: 603 qrels["iteration"] = self.run["iteration"] 604 else: 605 qrels["iteration"] = "0" 606 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore") 607 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 608 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 609 self._qrels = qrels 610 return self._qrels 611 return super().qrels 612 613 def __len__(self) -> int: 614 """Number of queries in the dataset. 615 616 :return: Number of queries 617 :rtype: int 618 """ 619 self._setup() 620 return len(self.query_ids) 621 622 def __getitem__(self, idx: int) -> RankSample: 623 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according 624 to the sampling strategy and sample size. 625 626 :param idx: Index of the query 627 :type idx: int 628 :raises ValueError: If the targets are not found in the run file 629 :return: Sampled query and documents 630 :rtype: RankSample 631 """ 632 self._setup() 633 query_id = str(self.query_ids[idx]) 634 group = self.run_groups.get_group(query_id).copy() 635 query = self.queries[query_id] 636 group = Sampler.sample(group, self.sample_size, self.sampling_strategy) 637 638 doc_ids = tuple(group["doc_id"]) 639 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 640 641 targets = None 642 if self.targets is not None: 643 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0) 644 if filtered.empty: 645 raise ValueError(f"targets `{self.targets}` not found in run file") 646 targets = torch.from_numpy(filtered.values) 647 if self.targets == "rank": 648 # invert ranks to be higher is better (necessary for loss functions) 649 targets = self.depth - targets + 1 650 if self.normalize_targets: 651 targets_min = targets.min() 652 targets_max = targets.max() 653 targets = (targets - targets_min) / (targets_max - targets_min) 654 qrels = None 655 if self.qrels is not None: 656 qrels = ( 657 self.qrels.loc[[query_id]] 658 .stack(future_stack=True) 659 .dropna() 660 .astype(int) 661 .reset_index() 662 .to_dict(orient="records") 663 ) 664 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
665 666
[docs] 667class TupleDataset(IRDataset, IterableDataset):
[docs] 668 def __init__( 669 self, 670 tuples_dataset: str, 671 targets: Literal["order", "score"] = "order", 672 num_docs: int | None = None, 673 ) -> None: 674 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks. 675 676 :param tuples_dataset: Path to file containing tuples or valid ir_datasets id 677 :type tuples_dataset: str 678 :param targets: The data type to use as targets for a model during fine-tuning, defaults to "order" 679 :type targets: Literal["order", "score"], optional 680 :param num_docs: Maximum number of documents per query, defaults to None 681 :type num_docs: int | None, optional 682 """ 683 super().__init__(tuples_dataset) 684 super(IRDataset, self).__init__() 685 self.targets = targets 686 self.num_docs = num_docs
687 688 def _parse_sample( 689 self, sample: ScoredDocTuple | GenericDocPair 690 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]: 691 if isinstance(sample, GenericDocPair): 692 if self.targets == "score": 693 raise ValueError("ScoredDocTuple required for score targets.") 694 targets = (1.0, 0.0) 695 doc_ids = (sample.doc_id_a, sample.doc_id_b) 696 elif isinstance(sample, ScoredDocTuple): 697 doc_ids = sample.doc_ids[: self.num_docs] 698 if self.targets == "score": 699 if sample.scores is None: 700 raise ValueError("tuples dataset does not contain scores") 701 targets = sample.scores 702 elif self.targets == "order": 703 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1)) 704 else: 705 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)") 706 targets = targets[: self.num_docs] 707 else: 708 raise ValueError("Invalid sample type.") 709 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 710 return doc_ids, docs, targets 711 712 def __iter__(self) -> Iterator[RankSample]: 713 """Iterates over tuples in the dataset. 714 715 :yield: A single tuple sample 716 :rtype: Iterator[RankSample] 717 """ 718 for sample in self.ir_dataset.docpairs_iter(): 719 query_id = sample.query_id 720 query = self.queries.loc[query_id] 721 doc_ids, docs, targets = self._parse_sample(sample) 722 if targets is not None: 723 targets = torch.tensor(targets) 724 yield RankSample(query_id, query, doc_ids, docs, targets) 725
[docs] 726 def prepare_data(self) -> None: 727 """Downloads tuples using ir_datasets if needed.""" 728 self.prepare_constituent("docs") 729 self.prepare_constituent("queries") 730 self.prepare_constituent("docpairs")