1"""
2Datasets for Lightning IR that data loading and sampling.
3
4This module defines several datasets that handle loading and sampling data for training and inference.
5"""
6
7import csv
8import warnings
9from itertools import islice
10from pathlib import Path
11from typing import Any, Dict, Iterator, Literal, Sequence, Tuple
12
13import ir_datasets
14import numpy as np
15import pandas as pd
16import torch
17from ir_datasets.formats import GenericDoc, GenericDocPair
18from torch.distributed import get_rank, get_world_size
19from torch.utils.data import Dataset, IterableDataset, get_worker_info
20
21from .data import DocSample, QuerySample, RankSample
22from .external_datasets.ir_datasets_utils import ScoredDocTuple
23
24RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
25
26
27class _DummyIterableDataset(IterableDataset):
28 """Dummy iterable dataset to use when all inference datasets are skipped."""
29
30 def __iter__(self) -> Iterator:
31 yield from iter([])
32
33
[docs]
34class IRDataset:
35
36 _SKIP: bool = False
37 """Set to True to skip the dataset during inference."""
38
[docs]
39 def __init__(self, dataset: str) -> None:
40 """Initializes a new IRDataset.
41
42 Args:
43 dataset (str): Dataset name.
44 """
45 super().__init__()
46 self._dataset = dataset
47 self._queries = None
48 self._docs = None
49 self._qrels = None
50
51 @property
52 def dataset(self) -> str:
53 """Dataset name.
54
55 Returns:
56 str: Dataset name.
57 """
58 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset)
59
60 @property
61 def dataset_id(self) -> str:
62 """Dataset id.
63
64 Returns:
65 str: Dataset id.
66 """
67 if self.ir_dataset is None:
68 return self.dataset
69 return self.ir_dataset.dataset_id()
70
71 @property
72 def docs_dataset_id(self) -> str:
73 """ID of the dataset containing the documents.
74
75 Returns:
76 str: Document dataset id.
77 """
78 return ir_datasets.docs_parent_id(self.dataset_id)
79
80 @property
81 def dashed_docs_dataset_id(self) -> str:
82 """Dataset id with dashes instead of slashes for the documents dataset.
83
84 Returns:
85 str: Document dataset id with dashes.
86 """
87 return self.docs_dataset_id.replace("/", "-")
88
89 @property
90 def ir_dataset(self) -> ir_datasets.Dataset | None:
91 """Instance of ir_datasets.Dataset.
92
93 Returns:
94 ir_datasets.Dataset | None: Instance of ir_datasets.Dataset or None if the dataset is not found.
95 """
96 try:
97 return ir_datasets.load(self.dataset)
98 except KeyError:
99 return None
100
101 @property
102 def DASHED_DATASET_MAP(self) -> Dict[str, str]:
103 """Map of dataset names with dashes to dataset names with slashes.
104
105 Returns:
106 Dict[str, str]: Dataset map.
107 """
108 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}
109
110 @property
111 def queries(self) -> pd.Series:
112 """Queries in the dataset.
113
114 Returns:
115 pd.Series: Queries.
116 Raises:
117 ValueError: If no queries are found in the dataset.
118 """
119 if self._queries is None:
120 if self.ir_dataset is None:
121 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
122 queries_iter = self.ir_dataset.queries_iter()
123 self._queries = pd.Series(
124 {query.query_id: query.default_text() for query in queries_iter},
125 name="text",
126 )
127 self._queries.index.name = "query_id"
128 return self._queries
129
130 @property
131 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]:
132 """Documents in the dataset.
133
134 Returns:
135 ir_datasets.indices.Docstore | Dict[str, GenericDoc]: Documents.
136 Raises:
137 ValueError: If no documents are found in the dataset.
138 """
139 if self._docs is None:
140 if self.ir_dataset is None:
141 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
142 self._docs = self.ir_dataset.docs_store()
143 return self._docs
144
145 @property
146 def qrels(self) -> pd.DataFrame | None:
147 """Qrels in the dataset.
148
149 Returns:
150 pd.DataFrame | None: Qrels.
151 """
152 if self._qrels is not None:
153 return self._qrels
154 if self.ir_dataset is None or not self.ir_dataset.has_qrels():
155 return None
156 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1)
157 if "iteration" not in qrels.columns:
158 qrels["iteration"] = 0
159 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
160 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
161 self._qrels = qrels
162 return self._qrels
163
[docs]
164 def prepare_constituent(self, constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]) -> None:
165 """Downloads the constituent of the dataset using ir_datasets if needed.
166
167 Args:
168 constituent (Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]): Constituent to download.
169 """
170 if self.ir_dataset is None:
171 return
172 if self.ir_dataset.has(constituent):
173 if constituent == "docs" and hasattr(self.ir_dataset, "docs_store"):
174 # build docs store if not already built
175 docs_store = self.ir_dataset.docs_store()
176 if not docs_store.built():
177 docs_store.build()
178 else:
179 # get first item to trigger download
180 next(getattr(self.ir_dataset, f"{constituent}_iter")())
181
182
183class _DataParallelIterableDataset(IterableDataset):
184 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734
185 def __init__(self) -> None:
186 super().__init__()
187 # TODO add support for multi-gpu and multi-worker inference; currently
188 # doesn't work
189 worker_info = get_worker_info()
190 num_workers = worker_info.num_workers if worker_info is not None else 1
191 worker_id = worker_info.id if worker_info is not None else 0
192
193 try:
194 world_size = get_world_size()
195 process_rank = get_rank()
196 except (RuntimeError, ValueError):
197 world_size = 1
198 process_rank = 0
199
200 self.num_replicas = num_workers * world_size
201 self.rank = process_rank * num_workers + worker_id
202
203
[docs]
204class QueryDataset(IRDataset, _DataParallelIterableDataset):
[docs]
205 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None:
206 """Dataset containing queries.
207
208 Args:
209 query_dataset (str): Path to file containing queries or valid ir_datasets id.
210 num_queries (int | None, optional): Number of queries in dataset. If None, the number of queries will
211 attempted to be inferred. Defaults to None.
212 """
213 super().__init__(query_dataset)
214 super(IRDataset, self).__init__()
215 self.num_queries = num_queries
216
217 def __len__(self) -> int | None:
218 """Number of queries in the dataset. Returns None if the number of queries cannot be inferred.
219
220 Returns:
221 int | None: Number of queries.
222 """
223 # TODO fix len for multi-gpu and multi-worker inference
224 return self.num_queries or getattr(self.ir_dataset, "queries_count", lambda: None)() or None
225
226 def __iter__(self) -> Iterator[QuerySample]:
227 """Iterate over queries in the dataset.
228
229 Yields:
230 QuerySample: Query sample.
231 """
232 start = self.rank
233 stop = self.num_queries
234 step = self.num_replicas
235 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step):
236 query_sample = QuerySample.from_ir_dataset_sample(sample)
237 if self.qrels is not None:
238 qrels = (
239 self.qrels.loc[[query_sample.query_id]]
240 .stack(future_stack=True)
241 .dropna()
242 .astype(int)
243 .reset_index()
244 .to_dict(orient="records")
245 )
246 query_sample.qrels = qrels
247 yield query_sample
248
[docs]
249 def prepare_data(self) -> None:
250 """Downloads queries using ir_datasets if needed."""
251 self.prepare_constituent("queries")
252
253
[docs]
254class DocDataset(IRDataset, _DataParallelIterableDataset):
[docs]
255 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None:
256 """Dataset containing documents.
257
258 Args:
259 doc_dataset (str): Path to file containing documents or valid ir_datasets id.
260 num_docs (int | None, optional): Number of documents in dataset. If None, the number of documents will
261 attempted to be inferred. Defaults to None.
262 text_fields (Sequence[str] | None, optional): Fields to parse the document text from. Defaults to None.
263 """
264 super().__init__(doc_dataset)
265 super(IRDataset, self).__init__()
266 self.num_docs = num_docs
267 self.text_fields = text_fields
268
269 def __len__(self) -> int | None:
270 """Number of documents in the dataset. Returns None if the number of documents cannot be inferred.
271
272 Returns:
273 int | None: Number of documents.
274 """
275 # TODO fix len for multi-gpu and multi-worker inference
276 return self.num_docs or getattr(self.ir_dataset, "docs_count", lambda: None)() or None
277
278 def __iter__(self) -> Iterator[DocSample]:
279 """Iterate over documents in the dataset.
280
281 Yields:
282 DocSample: Document sample.
283 """
284 start = self.rank
285 stop = self.num_docs
286 step = self.num_replicas
287 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step):
288 yield DocSample.from_ir_dataset_sample(sample, self.text_fields)
289
[docs]
290 def prepare_data(self) -> None:
291 """Downloads documents using ir_datasets if needed."""
292 self.prepare_constituent("docs")
293
294
[docs]
295class Sampler:
296 """Helper class for sampling subsets of documents from a ranked list."""
297
[docs]
298 @staticmethod
299 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
300 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
301 are non-relevant.
302
303 Args:
304 documents (pd.DataFrame): Ranked list of documents.
305 sample_size (int): Number of documents to sample.
306 Returns:
307 pd.DataFrame: Sampled documents.
308 """
309 relevance = documents.filter(like="relevance").max(axis=1).fillna(0)
310 relevant = documents.loc[relevance.gt(0)].sample(1)
311 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna()
312 num_non_relevant = non_relevant_bool.sum()
313 sample_non_relevant = min(sample_size - 1, num_non_relevant)
314 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant)
315 return pd.concat([relevant, non_relevant])
316
[docs]
317 @staticmethod
318 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
319 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
320 are non-relevant.
321
322 Args:
323 documents (pd.DataFrame): Ranked list of documents.
324 sample_size (int): Number of documents to sample.
325 Returns:
326 pd.DataFrame: Sampled documents.
327 """
328 return documents.head(sample_size)
329
[docs]
330 @staticmethod
331 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
332 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the
333 other half randomly.
334
335 Args:
336 documents (pd.DataFrame): Ranked list of documents.
337 sample_size (int): Number of documents to sample.
338 Returns:
339 pd.DataFrame: Sampled documents.
340 """
341 top_size = sample_size // 2
342 random_size = sample_size - top_size
343 top = documents.head(top_size)
344 random = documents.iloc[top_size:].sample(random_size)
345 return pd.concat([top, random])
346
[docs]
347 @staticmethod
348 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
349 """Sampling strategy to randomly sample ``sample_size`` documents.
350
351 Args:
352 documents (pd.DataFrame): Ranked list of documents.
353 sample_size (int): Number of documents to sample.
354 Returns:
355 pd.DataFrame: Sampled documents.
356 """
357 return documents.sample(sample_size)
358
[docs]
359 @staticmethod
360 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
361 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of
362 the ranking.
363
364 Args:
365 documents (pd.DataFrame): Ranked list of documents.
366 sample_size (int): Number of documents to sample.
367 Returns:
368 pd.DataFrame: Sampled documents.
369 """
370 weights = 1 / np.log1p(documents["rank"])
371 weights[weights.isna()] = weights.min()
372 return documents.sample(sample_size, weights=weights)
373
[docs]
374 @staticmethod
375 def sample(
376 df: pd.DataFrame,
377 sample_size: int,
378 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"],
379 ) -> pd.DataFrame:
380 """
381 Samples a subset of documents from a ranked list given a sampling_strategy.
382
383 Args:
384 documents (pd.DataFrame): Ranked list of documents.
385 sample_size (int): Number of documents to sample.
386 Returns:
387 pd.DataFrame: Sampled documents.
388 """
389 if sample_size == -1:
390 return df
391 if hasattr(Sampler, sampling_strategy):
392 return getattr(Sampler, sampling_strategy)(df, sample_size)
393 raise ValueError("Invalid sampling strategy.")
394
395
[docs]
396class RunDataset(IRDataset, Dataset):
[docs]
397 def __init__(
398 self,
399 run_path_or_id: Path | str,
400 depth: int = -1,
401 sample_size: int = -1,
402 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top",
403 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
404 normalize_targets: bool = False,
405 add_docs_not_in_ranking: bool = False,
406 ) -> None:
407 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list
408 can be sampled using different sampling strategies.
409
410 Args:
411 run_path_or_id (Path | str): Path to a run file or valid ir_datasets id.
412 depth (int): Depth at which to cut off the ranking. If -1 the full ranking is kept.
413 Defaults to -1.
414 sample_size (int): The number of documents to sample per query. Defaults to -1.
415 sampling_strategy (Literal["single_relevant", "top", "random", "log_random", "top_and_random"]):
416 The sample strategy to use to sample documents. Defaults to "top".
417 targets (Literal["relevance", "subtopic_relevance", "rank", "score"] | None):
418 The data type to use as targets for a model during fine-tuning. If "relevance" the relevance
419 judgements are parsed from qrels. Defaults to None.
420 normalize_targets (bool): Whether to normalize the targets between 0 and 1. Defaults to False.
421 add_docs_not_in_ranking (bool): Whether to add relevant documents to a sample that are in the qrels but not
422 in the ranking. Defaults to False.
423 """
424 self.run_path = None
425 if Path(run_path_or_id).is_file():
426 self.run_path = Path(run_path_or_id)
427 dataset = self.run_path.name.split(".")[0].split("__")[-1]
428 else:
429 dataset = str(run_path_or_id)
430 super().__init__(dataset)
431 self.depth = depth
432 self.sample_size = sample_size
433 self.sampling_strategy = sampling_strategy
434 self.targets = targets
435 self.normalize_targets = normalize_targets
436 self.add_docs_not_in_ranking = add_docs_not_in_ranking
437
438 if self.sampling_strategy == "top" and self.sample_size > self.depth:
439 warnings.warn(
440 "Sample size is greater than depth and top sampling strategy is used. "
441 "This can cause documents to be sampled that are not contained "
442 "in the run file, but that are present in the qrels."
443 )
444
445 self.run: pd.DataFrame | None = None
446
[docs]
447 def prepare_data(self) -> None:
448 """Downloads docs, queries, scoreddocs, and qrels using ir_datasets if needed and available."""
449 self.prepare_constituent("docs")
450 self.prepare_constituent("queries")
451 if self.run_path is None:
452 self.prepare_constituent("scoreddocs")
453 self.prepare_constituent("qrels")
454
455 def _setup(self):
456 if self.run is not None:
457 return
458 self.run = self._load_run()
459 self.run = self.run.drop_duplicates(["query_id", "doc_id"])
460
461 if self.qrels is not None:
462 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates())
463 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique()
464 query_ids = run_query_ids.intersection(qrels_query_ids)
465 if len(run_query_ids.difference(qrels_query_ids)):
466 self.run = self.run[self.run["query_id"].isin(query_ids)]
467 # outer join if docs are from ir_datasets else only keep docs in run
468 how = "left"
469 if self._docs is None and self.add_docs_not_in_ranking:
470 how = "outer"
471 self.run = self.run.merge(
472 self.qrels.loc[pd.IndexSlice[query_ids, :]].droplevel(0, axis=1).add_prefix("relevance_", axis=1),
473 on=["query_id", "doc_id"],
474 how=how,
475 )
476
477 if self.sample_size != -1:
478 num_docs_per_query = self.run.groupby("query_id").transform("size")
479 self.run = self.run[num_docs_per_query >= self.sample_size]
480
481 self.run = self.run.sort_values(["query_id", "rank"])
482 self.run_groups = self.run.groupby("query_id")
483 self.query_ids = list(self.run_groups.groups.keys())
484
485 if self.depth != -1 and self.run["rank"].max() < self.depth:
486 warnings.warn("Depth is greater than the maximum rank in the run file.")
487
488 @staticmethod
489 def _load_csv(path: Path) -> pd.DataFrame:
490 return pd.read_csv(
491 path,
492 sep=r"\s+",
493 header=None,
494 names=RUN_HEADER,
495 usecols=[0, 1, 2, 3, 4],
496 dtype={"query_id": str, "doc_id": str},
497 quoting=csv.QUOTE_NONE,
498 na_filter=False,
499 )
500
501 @staticmethod
502 def _load_parquet(path: Path) -> pd.DataFrame:
503 return pd.read_parquet(path)
504
505 @staticmethod
506 def _load_json(path: Path) -> pd.DataFrame:
507 kwargs: Dict[str, Any] = {}
508 if ".jsonl" in path.suffixes:
509 kwargs["lines"] = True
510 kwargs["orient"] = "records"
511 run = pd.read_json(path, **kwargs)
512 return run
513
514 def _get_run_path(self) -> Path | None:
515 run_path = self.run_path
516 if run_path is None:
517 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs():
518 raise ValueError(f"Run file or dataset with scoreddocs required. Got {self._dataset}")
519 try:
520 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path()
521 except NotImplementedError:
522 pass
523 return run_path
524
525 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame:
526 run = run.rename(
527 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id", "Q0": "iteration", "q0": "iteration"},
528 axis=1,
529 )
530 dtypes = {"rank": np.int32, "query_id": str, "doc_id": str}
531 if "score" in run.columns:
532 dtypes["score"] = np.float32
533 run = run.astype(dtypes)
534 if "query" in run.columns:
535 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text")
536 run = run.drop("query", axis=1)
537 if "text" in run.columns:
538 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict()
539 run = run.drop("text", axis=1)
540 if self.depth != -1:
541 run = run[run["rank"] <= self.depth]
542 return run
543
544 def _load_run(self) -> pd.DataFrame:
545
546 suffix_load_map = {
547 ".tsv": self._load_csv,
548 ".run": self._load_csv,
549 ".csv": self._load_csv,
550 ".parquet": self._load_parquet,
551 ".json": self._load_json,
552 ".jsonl": self._load_json,
553 }
554 run = None
555
556 # try loading run from file
557 run_path = self._get_run_path()
558 if run_path is not None:
559 load_func = suffix_load_map.get(run_path.suffixes[0], None)
560 if load_func is not None:
561 try:
562 run = load_func(run_path)
563 except Exception:
564 pass
565
566 # try loading run from ir_datasets
567 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
568 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
569 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
570 run = run.sort_values(["query_id", "rank"])
571
572 if run is None:
573 raise ValueError("Invalid run file format.")
574
575 run = self._clean_run(run)
576 return run
577
578 @property
579 def qrels(self) -> pd.DataFrame | None:
580 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None.
581
582 Returns:
583 pd.DataFrame | None: Qrels.
584 """
585 if self._qrels is not None:
586 return self._qrels
587 if self.run is not None and "relevance" in self.run:
588 qrels = self.run[["query_id", "doc_id", "relevance"]].copy()
589 if "iteration" in self.run:
590 qrels["iteration"] = self.run["iteration"]
591 else:
592 qrels["iteration"] = "0"
593 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore")
594 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
595 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
596 self._qrels = qrels
597 return self._qrels
598 return super().qrels
599
600 def __len__(self) -> int:
601 """Number of queries in the dataset.
602
603 Returns:
604 int: Number of queries.
605 """
606 self._setup()
607 return len(self.query_ids)
608
609 def __getitem__(self, idx: int) -> RankSample:
610 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according
611 to the sampling strategy and sample size.
612
613 Args:
614 idx (int): Index of the query.
615 Returns:
616 RankSample: Sampled query and documents.
617 Raises:
618 ValueError: If the targets are not found in the run file.
619 """
620 self._setup()
621 query_id = str(self.query_ids[idx])
622 group = self.run_groups.get_group(query_id).copy()
623 query = self.queries[query_id]
624 group = Sampler.sample(group, self.sample_size, self.sampling_strategy)
625
626 doc_ids = tuple(group["doc_id"])
627 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
628
629 targets = None
630 if self.targets is not None:
631 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0)
632 if filtered.empty:
633 raise ValueError(f"targets `{self.targets}` not found in run file")
634 targets = torch.from_numpy(filtered.values)
635 if self.targets == "rank":
636 # invert ranks to be higher is better (necessary for loss functions)
637 targets = self.depth - targets + 1
638 if self.normalize_targets:
639 targets_min = targets.min()
640 targets_max = targets.max()
641 targets = (targets - targets_min) / (targets_max - targets_min)
642 qrels = None
643 if self.qrels is not None:
644 qrels = (
645 self.qrels.loc[[query_id]]
646 .stack(future_stack=True)
647 .dropna()
648 .astype(int)
649 .reset_index()
650 .to_dict(orient="records")
651 )
652 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
653
654
[docs]
655class TupleDataset(IRDataset, IterableDataset):
[docs]
656 def __init__(
657 self,
658 tuples_dataset: str,
659 targets: Literal["order", "score"] = "order",
660 num_docs: int | None = None,
661 ) -> None:
662 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks.
663
664 Args:
665 tuples_dataset (str): Path to file containing tuples or valid ir_datasets id.
666 targets (Literal["order", "score"], optional): Data type to use as targets for a model during fine-tuning.
667 Defaults to "order".
668 num_docs (int | None, optional): Maximum number of documents per query. Defaults to None.
669 """
670 super().__init__(tuples_dataset)
671 super(IRDataset, self).__init__()
672 self.targets = targets
673 self.num_docs = num_docs
674
675 def _parse_sample(
676 self, sample: ScoredDocTuple | GenericDocPair
677 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]:
678 if isinstance(sample, GenericDocPair):
679 if self.targets == "score":
680 raise ValueError("ScoredDocTuple required for score targets.")
681 targets = (1.0, 0.0)
682 doc_ids = (sample.doc_id_a, sample.doc_id_b)
683 elif isinstance(sample, ScoredDocTuple):
684 doc_ids = sample.doc_ids[: self.num_docs]
685 if self.targets == "score":
686 if sample.scores is None:
687 raise ValueError("tuples dataset does not contain scores")
688 targets = sample.scores
689 elif self.targets == "order":
690 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1))
691 else:
692 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)")
693 targets = targets[: self.num_docs]
694 else:
695 raise ValueError("Invalid sample type.")
696 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
697 return doc_ids, docs, targets
698
699 def __iter__(self) -> Iterator[RankSample]:
700 """Iterates over tuples in the dataset.
701
702 Yields:
703 RankSample: Sampled query and documents with targets.
704 """
705 for sample in self.ir_dataset.docpairs_iter():
706 query_id = sample.query_id
707 query = self.queries.loc[query_id]
708 doc_ids, docs, targets = self._parse_sample(sample)
709 if targets is not None:
710 targets = torch.tensor(targets)
711 yield RankSample(query_id, query, doc_ids, docs, targets)
712
[docs]
713 def prepare_data(self) -> None:
714 """Downloads tuples using ir_datasets if needed."""
715 self.prepare_constituent("docs")
716 self.prepare_constituent("queries")
717 self.prepare_constituent("docpairs")