1"""
2Datasets for Lightning IR that data loading and sampling.
3
4This module defines several datasets that handle loading and sampling data for training and inference.
5"""
6
7import csv
8import warnings
9from itertools import islice
10from pathlib import Path
11from typing import Any, Dict, Iterator, Literal, Sequence, Tuple
12
13import ir_datasets
14import numpy as np
15import pandas as pd
16import torch
17from ir_datasets.formats import GenericDoc, GenericDocPair
18from torch.distributed import get_rank, get_world_size
19from torch.utils.data import Dataset, IterableDataset, get_worker_info
20
21from .data import DocSample, QuerySample, RankSample
22from .external_datasets.ir_datasets_utils import ScoredDocTuple
23
24RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
25
26
27class _DummyIterableDataset(IterableDataset):
28 """Dummy iterable dataset to use when all inference datasets are skipped."""
29
30 def __iter__(self) -> Iterator:
31 yield from iter([])
32
33
[docs]
34class IRDataset:
35
36 _SKIP: bool = False
37 """Set to True to skip the dataset during inference."""
38
[docs]
39 def __init__(self, dataset: str) -> None:
40 """Initializes a new IRDataset.
41
42 Args:
43 dataset (str): Dataset name.
44 """
45 super().__init__()
46 self._dataset = dataset
47 self._queries = None
48 self._docs = None
49 self._qrels = None
50
51 @property
52 def dataset(self) -> str:
53 """Dataset name.
54
55 Returns:
56 str: Dataset name.
57 """
58 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset)
59
60 @property
61 def dataset_id(self) -> str:
62 """Dataset id.
63
64 Returns:
65 str: Dataset id.
66 """
67 if self.ir_dataset is None:
68 return self.dataset
69 return self.ir_dataset.dataset_id()
70
71 @property
72 def docs_dataset_id(self) -> str:
73 """ID of the dataset containing the documents.
74
75 Returns:
76 str: Document dataset id.
77 """
78 return ir_datasets.docs_parent_id(self.dataset_id)
79
80 @property
81 def ir_dataset(self) -> ir_datasets.Dataset | None:
82 """Instance of ir_datasets.Dataset.
83
84 Returns:
85 ir_datasets.Dataset | None: Instance of ir_datasets.Dataset or None if the dataset is not found.
86 """
87 try:
88 return ir_datasets.load(self.dataset)
89 except KeyError:
90 return None
91
92 @property
93 def DASHED_DATASET_MAP(self) -> Dict[str, str]:
94 """Map of dataset names with dashes to dataset names with slashes.
95
96 Returns:
97 Dict[str, str]: Dataset map.
98 """
99 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}
100
101 @property
102 def queries(self) -> pd.Series:
103 """Queries in the dataset.
104
105 Returns:
106 pd.Series: Queries.
107 Raises:
108 ValueError: If no queries are found in the dataset.
109 """
110 if self._queries is None:
111 if self.ir_dataset is None:
112 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
113 queries_iter = self.ir_dataset.queries_iter()
114 self._queries = pd.Series(
115 {query.query_id: query.default_text() for query in queries_iter},
116 name="text",
117 )
118 self._queries.index.name = "query_id"
119 return self._queries
120
121 @property
122 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]:
123 """Documents in the dataset.
124
125 Returns:
126 ir_datasets.indices.Docstore | Dict[str, GenericDoc]: Documents.
127 Raises:
128 ValueError: If no documents are found in the dataset.
129 """
130 if self._docs is None:
131 if self.ir_dataset is None:
132 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
133 self._docs = self.ir_dataset.docs_store()
134 return self._docs
135
136 @property
137 def qrels(self) -> pd.DataFrame | None:
138 """Qrels in the dataset.
139
140 Returns:
141 pd.DataFrame | None: Qrels.
142 """
143 if self._qrels is not None:
144 return self._qrels
145 if self.ir_dataset is None or not self.ir_dataset.has_qrels():
146 return None
147 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1)
148 if "iteration" not in qrels.columns:
149 qrels["iteration"] = 0
150 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
151 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
152 self._qrels = qrels
153 return self._qrels
154
[docs]
155 def prepare_constituent(self, constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]) -> None:
156 """Downloads the constituent of the dataset using ir_datasets if needed.
157
158 Args:
159 constituent (Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]): Constituent to download.
160 """
161 if self.ir_dataset is None:
162 return
163 if self.ir_dataset.has(constituent):
164 if constituent == "docs" and hasattr(self.ir_dataset, "docs_store"):
165 # build docs store if not already built
166 docs_store = self.ir_dataset.docs_store()
167 if not docs_store.built():
168 docs_store.build()
169 else:
170 # get first item to trigger download
171 next(getattr(self.ir_dataset, f"{constituent}_iter")())
172
173
174class _DataParallelIterableDataset(IterableDataset):
175 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734
176 def __init__(self) -> None:
177 super().__init__()
178 # TODO add support for multi-gpu and multi-worker inference; currently
179 # doesn't work
180 worker_info = get_worker_info()
181 num_workers = worker_info.num_workers if worker_info is not None else 1
182 worker_id = worker_info.id if worker_info is not None else 0
183
184 try:
185 world_size = get_world_size()
186 process_rank = get_rank()
187 except (RuntimeError, ValueError):
188 world_size = 1
189 process_rank = 0
190
191 self.num_replicas = num_workers * world_size
192 self.rank = process_rank * num_workers + worker_id
193
194
[docs]
195class QueryDataset(IRDataset, _DataParallelIterableDataset):
[docs]
196 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None:
197 """Dataset containing queries.
198
199 Args:
200 query_dataset (str): Path to file containing queries or valid ir_datasets id.
201 num_queries (int | None, optional): Number of queries in dataset. If None, the number of queries will
202 attempted to be inferred. Defaults to None.
203 """
204 super().__init__(query_dataset)
205 super(IRDataset, self).__init__()
206 self.num_queries = num_queries
207
208 def __len__(self) -> int | None:
209 """Number of queries in the dataset. Returns None if the number of queries cannot be inferred.
210
211 Returns:
212 int | None: Number of queries.
213 """
214 # TODO fix len for multi-gpu and multi-worker inference
215 return self.num_queries or getattr(self.ir_dataset, "queries_count", lambda: None)() or None
216
217 def __iter__(self) -> Iterator[QuerySample]:
218 """Iterate over queries in the dataset.
219
220 Yields:
221 QuerySample: Query sample.
222 """
223 start = self.rank
224 stop = self.num_queries
225 step = self.num_replicas
226 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step):
227 query_sample = QuerySample.from_ir_dataset_sample(sample)
228 if self.qrels is not None:
229 qrels = (
230 self.qrels.loc[[query_sample.query_id]]
231 .stack(future_stack=True)
232 .dropna()
233 .astype(int)
234 .reset_index()
235 .to_dict(orient="records")
236 )
237 query_sample.qrels = qrels
238 yield query_sample
239
[docs]
240 def prepare_data(self) -> None:
241 """Downloads queries using ir_datasets if needed."""
242 self.prepare_constituent("queries")
243
244
[docs]
245class DocDataset(IRDataset, _DataParallelIterableDataset):
[docs]
246 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None:
247 """Dataset containing documents.
248
249 Args:
250 doc_dataset (str): Path to file containing documents or valid ir_datasets id.
251 num_docs (int | None, optional): Number of documents in dataset. If None, the number of documents will
252 attempted to be inferred. Defaults to None.
253 text_fields (Sequence[str] | None, optional): Fields to parse the document text from. Defaults to None.
254 """
255 super().__init__(doc_dataset)
256 super(IRDataset, self).__init__()
257 self.num_docs = num_docs
258 self.text_fields = text_fields
259
260 def __len__(self) -> int | None:
261 """Number of documents in the dataset. Returns None if the number of documents cannot be inferred.
262
263 Returns:
264 int | None: Number of documents.
265 """
266 # TODO fix len for multi-gpu and multi-worker inference
267 return self.num_docs or getattr(self.ir_dataset, "docs_count", lambda: None)() or None
268
269 def __iter__(self) -> Iterator[DocSample]:
270 """Iterate over documents in the dataset.
271
272 Yields:
273 DocSample: Document sample.
274 """
275 start = self.rank
276 stop = self.num_docs
277 step = self.num_replicas
278 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step):
279 yield DocSample.from_ir_dataset_sample(sample, self.text_fields)
280
[docs]
281 def prepare_data(self) -> None:
282 """Downloads documents using ir_datasets if needed."""
283 self.prepare_constituent("docs")
284
285
[docs]
286class Sampler:
287 """Helper class for sampling subsets of documents from a ranked list."""
288
[docs]
289 @staticmethod
290 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
291 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
292 are non-relevant.
293
294 Args:
295 documents (pd.DataFrame): Ranked list of documents.
296 sample_size (int): Number of documents to sample.
297 Returns:
298 pd.DataFrame: Sampled documents.
299 """
300 relevance = documents.filter(like="relevance").max(axis=1).fillna(0)
301 relevant = documents.loc[relevance.gt(0)].sample(1)
302 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna()
303 num_non_relevant = non_relevant_bool.sum()
304 sample_non_relevant = min(sample_size - 1, num_non_relevant)
305 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant)
306 return pd.concat([relevant, non_relevant])
307
[docs]
308 @staticmethod
309 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
310 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
311 are non-relevant.
312
313 Args:
314 documents (pd.DataFrame): Ranked list of documents.
315 sample_size (int): Number of documents to sample.
316 Returns:
317 pd.DataFrame: Sampled documents.
318 """
319 return documents.head(sample_size)
320
[docs]
321 @staticmethod
322 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
323 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the
324 other half randomly.
325
326 Args:
327 documents (pd.DataFrame): Ranked list of documents.
328 sample_size (int): Number of documents to sample.
329 Returns:
330 pd.DataFrame: Sampled documents.
331 """
332 top_size = sample_size // 2
333 random_size = sample_size - top_size
334 top = documents.head(top_size)
335 random = documents.iloc[top_size:].sample(random_size)
336 return pd.concat([top, random])
337
[docs]
338 @staticmethod
339 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
340 """Sampling strategy to randomly sample ``sample_size`` documents.
341
342 Args:
343 documents (pd.DataFrame): Ranked list of documents.
344 sample_size (int): Number of documents to sample.
345 Returns:
346 pd.DataFrame: Sampled documents.
347 """
348 return documents.sample(sample_size)
349
[docs]
350 @staticmethod
351 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
352 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of
353 the ranking.
354
355 Args:
356 documents (pd.DataFrame): Ranked list of documents.
357 sample_size (int): Number of documents to sample.
358 Returns:
359 pd.DataFrame: Sampled documents.
360 """
361 weights = 1 / np.log1p(documents["rank"])
362 weights[weights.isna()] = weights.min()
363 return documents.sample(sample_size, weights=weights)
364
[docs]
365 @staticmethod
366 def sample(
367 df: pd.DataFrame,
368 sample_size: int,
369 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"],
370 ) -> pd.DataFrame:
371 """
372 Samples a subset of documents from a ranked list given a sampling_strategy.
373
374 Args:
375 documents (pd.DataFrame): Ranked list of documents.
376 sample_size (int): Number of documents to sample.
377 Returns:
378 pd.DataFrame: Sampled documents.
379 """
380 if sample_size == -1:
381 return df
382 if hasattr(Sampler, sampling_strategy):
383 return getattr(Sampler, sampling_strategy)(df, sample_size)
384 raise ValueError("Invalid sampling strategy.")
385
386
[docs]
387class RunDataset(IRDataset, Dataset):
[docs]
388 def __init__(
389 self,
390 run_path_or_id: Path | str,
391 depth: int = -1,
392 sample_size: int = -1,
393 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top",
394 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
395 normalize_targets: bool = False,
396 add_docs_not_in_ranking: bool = False,
397 ) -> None:
398 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list
399 can be sampled using different sampling strategies.
400
401 Args:
402 run_path_or_id (Path | str): Path to a run file or valid ir_datasets id.
403 depth (int): Depth at which to cut off the ranking. If -1 the full ranking is kept.
404 Defaults to -1.
405 sample_size (int): The number of documents to sample per query. Defaults to -1.
406 sampling_strategy (Literal["single_relevant", "top", "random", "log_random", "top_and_random"]):
407 The sample strategy to use to sample documents. Defaults to "top".
408 targets (Literal["relevance", "subtopic_relevance", "rank", "score"] | None):
409 The data type to use as targets for a model during fine-tuning. If "relevance" the relevance
410 judgements are parsed from qrels. Defaults to None.
411 normalize_targets (bool): Whether to normalize the targets between 0 and 1. Defaults to False.
412 add_docs_not_in_ranking (bool): Whether to add relevant documents to a sample that are in the qrels but not
413 in the ranking. Defaults to False.
414 """
415 self.run_path = None
416 if Path(run_path_or_id).is_file():
417 self.run_path = Path(run_path_or_id)
418 dataset = self.run_path.name.split(".")[0].split("__")[-1]
419 else:
420 dataset = str(run_path_or_id)
421 super().__init__(dataset)
422 self.depth = depth
423 self.sample_size = sample_size
424 self.sampling_strategy = sampling_strategy
425 self.targets = targets
426 self.normalize_targets = normalize_targets
427 self.add_docs_not_in_ranking = add_docs_not_in_ranking
428
429 if self.sampling_strategy == "top" and self.sample_size > self.depth:
430 warnings.warn(
431 "Sample size is greater than depth and top sampling strategy is used. "
432 "This can cause documents to be sampled that are not contained "
433 "in the run file, but that are present in the qrels."
434 )
435
436 self.run: pd.DataFrame | None = None
437
[docs]
438 def prepare_data(self) -> None:
439 """Downloads docs, queries, scoreddocs, and qrels using ir_datasets if needed and available."""
440 self.prepare_constituent("docs")
441 self.prepare_constituent("queries")
442 if self.run_path is None:
443 self.prepare_constituent("scoreddocs")
444 self.prepare_constituent("qrels")
445
446 def _setup(self):
447 if self.run is not None:
448 return
449 self.run = self._load_run()
450 self.run = self.run.drop_duplicates(["query_id", "doc_id"])
451
452 if self.qrels is not None:
453 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates())
454 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique()
455 query_ids = run_query_ids.intersection(qrels_query_ids)
456 if len(run_query_ids.difference(qrels_query_ids)):
457 self.run = self.run[self.run["query_id"].isin(query_ids)]
458 # outer join if docs are from ir_datasets else only keep docs in run
459 how = "left"
460 if self._docs is None and self.add_docs_not_in_ranking:
461 how = "outer"
462 self.run = self.run.merge(
463 self.qrels.loc[pd.IndexSlice[query_ids, :]].droplevel(0, axis=1).add_prefix("relevance_", axis=1),
464 on=["query_id", "doc_id"],
465 how=how,
466 )
467
468 if self.sample_size != -1:
469 num_docs_per_query = self.run.groupby("query_id").transform("size")
470 self.run = self.run[num_docs_per_query >= self.sample_size]
471
472 self.run = self.run.sort_values(["query_id", "rank"])
473 self.run_groups = self.run.groupby("query_id")
474 self.query_ids = list(self.run_groups.groups.keys())
475
476 if self.depth != -1 and self.run["rank"].max() < self.depth:
477 warnings.warn("Depth is greater than the maximum rank in the run file.")
478
479 @staticmethod
480 def _load_csv(path: Path) -> pd.DataFrame:
481 return pd.read_csv(
482 path,
483 sep=r"\s+",
484 header=None,
485 names=RUN_HEADER,
486 usecols=[0, 1, 2, 3, 4],
487 dtype={"query_id": str, "doc_id": str},
488 quoting=csv.QUOTE_NONE,
489 na_filter=False,
490 )
491
492 @staticmethod
493 def _load_parquet(path: Path) -> pd.DataFrame:
494 return pd.read_parquet(path).rename(
495 {
496 "qid": "query_id",
497 "docid": "doc_id",
498 "docno": "doc_id",
499 },
500 axis=1,
501 )
502
503 @staticmethod
504 def _load_json(path: Path) -> pd.DataFrame:
505 kwargs: Dict[str, Any] = {}
506 if ".jsonl" in path.suffixes:
507 kwargs["lines"] = True
508 kwargs["orient"] = "records"
509 run = pd.read_json(
510 path,
511 **kwargs,
512 dtype={
513 "query_id": str,
514 "qid": str,
515 "doc_id": str,
516 "docid": str,
517 "docno": str,
518 },
519 )
520 return run
521
522 def _get_run_path(self) -> Path | None:
523 run_path = self.run_path
524 if run_path is None:
525 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs():
526 raise ValueError(f"Run file or dataset with scoreddocs required. Got {self._dataset}")
527 try:
528 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path()
529 except NotImplementedError:
530 pass
531 return run_path
532
533 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame:
534 run = run.rename(
535 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id", "Q0": "iteration", "q0": "iteration"},
536 axis=1,
537 )
538 if "query" in run.columns:
539 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text")
540 run = run.drop("query", axis=1)
541 if "text" in run.columns:
542 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict()
543 run = run.drop("text", axis=1)
544 if self.depth != -1:
545 run = run[run["rank"] <= self.depth]
546 dtypes = {"rank": np.int32}
547 if "score" in run.columns:
548 dtypes["score"] = np.float32
549 run = run.astype(dtypes)
550 return run
551
552 def _load_run(self) -> pd.DataFrame:
553
554 suffix_load_map = {
555 ".tsv": self._load_csv,
556 ".run": self._load_csv,
557 ".csv": self._load_csv,
558 ".parquet": self._load_parquet,
559 ".json": self._load_json,
560 ".jsonl": self._load_json,
561 }
562 run = None
563
564 # try loading run from file
565 run_path = self._get_run_path()
566 if run_path is not None:
567 load_func = suffix_load_map.get(run_path.suffixes[0], None)
568 if load_func is not None:
569 try:
570 run = load_func(run_path)
571 except Exception:
572 pass
573
574 # try loading run from ir_datasets
575 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
576 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
577 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
578 run = run.sort_values(["query_id", "rank"])
579
580 if run is None:
581 raise ValueError("Invalid run file format.")
582
583 run = self._clean_run(run)
584 return run
585
586 @property
587 def qrels(self) -> pd.DataFrame | None:
588 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None.
589
590 Returns:
591 pd.DataFrame | None: Qrels.
592 """
593 if self._qrels is not None:
594 return self._qrels
595 if self.run is not None and "relevance" in self.run:
596 qrels = self.run[["query_id", "doc_id", "relevance"]].copy()
597 if "iteration" in self.run:
598 qrels["iteration"] = self.run["iteration"]
599 else:
600 qrels["iteration"] = "0"
601 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore")
602 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
603 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
604 self._qrels = qrels
605 return self._qrels
606 return super().qrels
607
608 def __len__(self) -> int:
609 """Number of queries in the dataset.
610
611 Returns:
612 int: Number of queries.
613 """
614 self._setup()
615 return len(self.query_ids)
616
617 def __getitem__(self, idx: int) -> RankSample:
618 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according
619 to the sampling strategy and sample size.
620
621 Args:
622 idx (int): Index of the query.
623 Returns:
624 RankSample: Sampled query and documents.
625 Raises:
626 ValueError: If the targets are not found in the run file.
627 """
628 self._setup()
629 query_id = str(self.query_ids[idx])
630 group = self.run_groups.get_group(query_id).copy()
631 query = self.queries[query_id]
632 group = Sampler.sample(group, self.sample_size, self.sampling_strategy)
633
634 doc_ids = tuple(group["doc_id"])
635 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
636
637 targets = None
638 if self.targets is not None:
639 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0)
640 if filtered.empty:
641 raise ValueError(f"targets `{self.targets}` not found in run file")
642 targets = torch.from_numpy(filtered.values)
643 if self.targets == "rank":
644 # invert ranks to be higher is better (necessary for loss functions)
645 targets = self.depth - targets + 1
646 if self.normalize_targets:
647 targets_min = targets.min()
648 targets_max = targets.max()
649 targets = (targets - targets_min) / (targets_max - targets_min)
650 qrels = None
651 if self.qrels is not None:
652 qrels = (
653 self.qrels.loc[[query_id]]
654 .stack(future_stack=True)
655 .dropna()
656 .astype(int)
657 .reset_index()
658 .to_dict(orient="records")
659 )
660 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
661
662
[docs]
663class TupleDataset(IRDataset, IterableDataset):
[docs]
664 def __init__(
665 self,
666 tuples_dataset: str,
667 targets: Literal["order", "score"] = "order",
668 num_docs: int | None = None,
669 ) -> None:
670 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks.
671
672 Args:
673 tuples_dataset (str): Path to file containing tuples or valid ir_datasets id.
674 targets (Literal["order", "score"], optional): Data type to use as targets for a model during fine-tuning.
675 Defaults to "order".
676 num_docs (int | None, optional): Maximum number of documents per query. Defaults to None.
677 """
678 super().__init__(tuples_dataset)
679 super(IRDataset, self).__init__()
680 self.targets = targets
681 self.num_docs = num_docs
682
683 def _parse_sample(
684 self, sample: ScoredDocTuple | GenericDocPair
685 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]:
686 if isinstance(sample, GenericDocPair):
687 if self.targets == "score":
688 raise ValueError("ScoredDocTuple required for score targets.")
689 targets = (1.0, 0.0)
690 doc_ids = (sample.doc_id_a, sample.doc_id_b)
691 elif isinstance(sample, ScoredDocTuple):
692 doc_ids = sample.doc_ids[: self.num_docs]
693 if self.targets == "score":
694 if sample.scores is None:
695 raise ValueError("tuples dataset does not contain scores")
696 targets = sample.scores
697 elif self.targets == "order":
698 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1))
699 else:
700 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)")
701 targets = targets[: self.num_docs]
702 else:
703 raise ValueError("Invalid sample type.")
704 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
705 return doc_ids, docs, targets
706
707 def __iter__(self) -> Iterator[RankSample]:
708 """Iterates over tuples in the dataset.
709
710 Yields:
711 RankSample: Sampled query and documents with targets.
712 """
713 for sample in self.ir_dataset.docpairs_iter():
714 query_id = sample.query_id
715 query = self.queries.loc[query_id]
716 doc_ids, docs, targets = self._parse_sample(sample)
717 if targets is not None:
718 targets = torch.tensor(targets)
719 yield RankSample(query_id, query, doc_ids, docs, targets)
720
[docs]
721 def prepare_data(self) -> None:
722 """Downloads tuples using ir_datasets if needed."""
723 self.prepare_constituent("docs")
724 self.prepare_constituent("queries")
725 self.prepare_constituent("docpairs")