1"""
2Datasets for Lightning IR that data loading and sampling.
3
4This module defines several datasets that handle loading and sampling data for training and inference.
5"""
6
7import csv
8import warnings
9from itertools import islice
10from pathlib import Path
11from typing import Any, Dict, Iterator, Literal, Sequence, Tuple
12
13import ir_datasets
14import numpy as np
15import pandas as pd
16import torch
17from ir_datasets.formats import GenericDoc, GenericDocPair
18from torch.distributed import get_rank, get_world_size
19from torch.utils.data import Dataset, IterableDataset, get_worker_info
20
21from .data import DocSample, QuerySample, RankSample
22from .external_datasets.ir_datasets_utils import ScoredDocTuple
23
24RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
25
26
27class _DummyIterableDataset(IterableDataset):
28 """Dummy iterable dataset to use when all inference datasets are skipped."""
29
30 def __iter__(self) -> Iterator:
31 yield from iter([])
32
33
[docs]
34class IRDataset:
35
36 _SKIP: bool = False
37 """Set to True to skip the dataset during inference."""
38
[docs]
39 def __init__(self, dataset: str) -> None:
40 super().__init__()
41 self._dataset = dataset
42 self._queries = None
43 self._docs = None
44 self._qrels = None
45
46 @property
47 def dataset(self) -> str:
48 """Dataset name.
49
50 :return: Dataset name
51 :rtype: str
52 """
53 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset)
54
55 @property
56 def dataset_id(self) -> str:
57 """Dataset id.
58
59 :return: Dataset id
60 :rtype: str
61 """
62 if self.ir_dataset is None:
63 return self.dataset
64 return self.ir_dataset.dataset_id()
65
66 @property
67 def docs_dataset_id(self) -> str:
68 """ID of the dataset containing the documents.
69
70 :return: Document dataset id
71 :rtype: str
72 """
73 return ir_datasets.docs_parent_id(self.dataset_id)
74
75 @property
76 def ir_dataset(self) -> ir_datasets.Dataset | None:
77 """Instance of ir_datasets.Dataset.
78
79 :return: ir_datasets dataset
80 :rtype: ir_datasets.Dataset | None
81 """
82 try:
83 return ir_datasets.load(self.dataset)
84 except KeyError:
85 return None
86
87 @property
88 def DASHED_DATASET_MAP(self) -> Dict[str, str]:
89 """Map of dataset names with dashes to dataset names with slashes.
90
91 :return: Dataset map
92 :rtype: Dict[str, str]
93 """
94 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}
95
96 @property
97 def queries(self) -> pd.Series:
98 """Queries in the dataset.
99
100 :raises ValueError: If no queries are found in the dataset
101 :return: Queries
102 :rtype: pd.Series
103 """
104 if self._queries is None:
105 if self.ir_dataset is None:
106 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
107 queries_iter = self.ir_dataset.queries_iter()
108 self._queries = pd.Series(
109 {query.query_id: query.default_text() for query in queries_iter},
110 name="text",
111 )
112 self._queries.index.name = "query_id"
113 return self._queries
114
115 @property
116 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]:
117 """Documents in the dataset.
118
119 :raises ValueError: If no documents are found in the dataset
120 :return: Documents
121 :rtype: ir_datasets.indices.Docstore | Dict[str, GenericDoc]
122 """
123 if self._docs is None:
124 if self.ir_dataset is None:
125 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
126 self._docs = self.ir_dataset.docs_store()
127 return self._docs
128
129 @property
130 def qrels(self) -> pd.DataFrame | None:
131 """Qrels in the dataset.
132
133 :return: Qrels
134 :rtype: pd.DataFrame | None
135 """
136 if self._qrels is not None:
137 return self._qrels
138 if self.ir_dataset is None or not self.ir_dataset.has_qrels():
139 return None
140 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1)
141 if "iteration" not in qrels.columns:
142 qrels["iteration"] = 0
143 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
144 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
145 self._qrels = qrels
146 return self._qrels
147
[docs]
148 def prepare_constituent(self, constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]) -> None:
149 """Downloads the constituent of the dataset using ir_datasets if needed.
150
151 :param constituent: Constituent to download
152 :type constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]
153 """
154 if self.ir_dataset is None:
155 return
156 if self.ir_dataset.has(constituent):
157 if constituent == "docs" and hasattr(self.ir_dataset, "docs_store"):
158 # build docs store if not already built
159 docs_store = self.ir_dataset.docs_store()
160 if not docs_store.built():
161 docs_store.build()
162 else:
163 # get first item to trigger download
164 next(getattr(self.ir_dataset, f"{constituent}_iter")())
165
166
167class _DataParallelIterableDataset(IterableDataset):
168 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734
169 def __init__(self) -> None:
170 super().__init__()
171 # TODO add support for multi-gpu and multi-worker inference; currently
172 # doesn't work
173 worker_info = get_worker_info()
174 num_workers = worker_info.num_workers if worker_info is not None else 1
175 worker_id = worker_info.id if worker_info is not None else 0
176
177 try:
178 world_size = get_world_size()
179 process_rank = get_rank()
180 except (RuntimeError, ValueError):
181 world_size = 1
182 process_rank = 0
183
184 self.num_replicas = num_workers * world_size
185 self.rank = process_rank * num_workers + worker_id
186
187
[docs]
188class QueryDataset(IRDataset, _DataParallelIterableDataset):
[docs]
189 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None:
190 """Dataset containing queries.
191
192 :param query_dataset: Path to file containing queries or valid ir_datasets id
193 :type query_dataset: str
194 :param num_queries: Number of queries in dataset. If None, the number of queries will attempted to be inferred,
195 defaults to None
196 :type num_queries: int | None, optional
197 """
198 super().__init__(query_dataset)
199 super(IRDataset, self).__init__()
200 self.num_queries = num_queries
201
202 def __len__(self) -> int | None:
203 """Number of queries in the dataset. Returns None if the number of queries cannot be inferred.
204
205 :return: Number of queries
206 :rtype: int | None
207 """
208 # TODO fix len for multi-gpu and multi-worker inference
209 return self.num_queries or getattr(self.ir_dataset, "queries_count", lambda: None)() or None
210
211 def __iter__(self) -> Iterator[QuerySample]:
212 """Iterate over queries in the dataset.
213
214 :yield: Query sample
215 :rtype: Iterator[QuerySample]
216 """
217 start = self.rank
218 stop = self.num_queries
219 step = self.num_replicas
220 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step):
221 query_sample = QuerySample.from_ir_dataset_sample(sample)
222 if self.qrels is not None:
223 qrels = (
224 self.qrels.loc[[query_sample.query_id]]
225 .stack(future_stack=True)
226 .dropna()
227 .astype(int)
228 .reset_index()
229 .to_dict(orient="records")
230 )
231 query_sample.qrels = qrels
232 yield query_sample
233
[docs]
234 def prepare_data(self) -> None:
235 """Downloads queries using ir_datasets if needed."""
236 self.prepare_constituent("queries")
237
238
[docs]
239class DocDataset(IRDataset, _DataParallelIterableDataset):
[docs]
240 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None:
241 """Dataset containing documents.
242
243 :param doc_dataset: Path to file containing documents or valid ir_datasets id
244 :type doc_dataset: str
245 :param num_docs: Number of documents in dataset. If None, the number of documents will attempted to be inferred,
246 defaults to None
247 :type num_docs: int | None, optional
248 :param text_fields: Fields to parse the document text from, defaults to None
249 :type text_fields: Sequence[str] | None, optional
250 """
251 super().__init__(doc_dataset)
252 super(IRDataset, self).__init__()
253 self.num_docs = num_docs
254 self.text_fields = text_fields
255
256 def __len__(self) -> int | None:
257 """Number of documents in the dataset. Returns None if the number of documents cannot be inferred.
258
259 :return: Number of documents
260 :rtype: int | None
261 """
262 # TODO fix len for multi-gpu and multi-worker inference
263 return self.num_docs or getattr(self.ir_dataset, "docs_count", lambda: None)() or None
264
265 def __iter__(self) -> Iterator[DocSample]:
266 """Iterate over documents in the dataset.
267
268 :yield: Doc sample
269 :rtype: Iterator[DocSample]
270 """
271 start = self.rank
272 stop = self.num_docs
273 step = self.num_replicas
274 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step):
275 yield DocSample.from_ir_dataset_sample(sample, self.text_fields)
276
[docs]
277 def prepare_data(self) -> None:
278 """Downloads documents using ir_datasets if needed."""
279 self.prepare_constituent("docs")
280
281
[docs]
282class Sampler:
283 """Helper class for sampling subsets of documents from a ranked list."""
284
[docs]
285 @staticmethod
286 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
287 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
288 are non-relevant.
289
290 :param documents: Ranked list of documents
291 :type documents: pd.DataFrame
292 :param sample_size: Number of documents to sample
293 :type sample_size: int
294 :return: Sampled documents
295 :rtype: pd.DataFrame
296 """
297 relevance = documents.filter(like="relevance").max(axis=1).fillna(0)
298 relevant = documents.loc[relevance.gt(0)].sample(1)
299 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna()
300 num_non_relevant = non_relevant_bool.sum()
301 sample_non_relevant = min(sample_size - 1, num_non_relevant)
302 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant)
303 return pd.concat([relevant, non_relevant])
304
[docs]
305 @staticmethod
306 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
307 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
308 are non-relevant.
309
310 :param documents: Ranked list of documents
311 :type documents: pd.DataFrame
312 :param sample_size: Number of documents to sample
313 :type sample_size: int
314 :return: Sampled documents
315 :rtype: pd.DataFrame
316 """
317 return documents.head(sample_size)
318
[docs]
319 @staticmethod
320 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
321 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the
322 other half randomly.
323
324 :param documents: Ranked list of documents
325 :type documents: pd.DataFrame
326 :param sample_size: Number of documents to sample
327 :type sample_size: int
328 :return: Sampled documents
329 :rtype: pd.DataFrame
330 """
331 top_size = sample_size // 2
332 random_size = sample_size - top_size
333 top = documents.head(top_size)
334 random = documents.iloc[top_size:].sample(random_size)
335 return pd.concat([top, random])
336
[docs]
337 @staticmethod
338 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
339 """Sampling strategy to randomly sample ``sample_size`` documents.
340
341 :param documents: Ranked list of documents
342 :type documents: pd.DataFrame
343 :param sample_size: Number of documents to sample
344 :type sample_size: int
345 :return: Sampled documents
346 :rtype: pd.DataFrame
347 """
348 return documents.sample(sample_size)
349
[docs]
350 @staticmethod
351 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
352 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of
353 the ranking.
354
355 :param documents: Ranked list of documents
356 :type documents: pd.DataFrame
357 :param sample_size: Number of documents to sample
358 :type sample_size: int
359 :return: Sampled documents
360 :rtype: pd.DataFrame
361 """
362 weights = 1 / np.log1p(documents["rank"])
363 weights[weights.isna()] = weights.min()
364 return documents.sample(sample_size, weights=weights)
365
[docs]
366 @staticmethod
367 def sample(
368 df: pd.DataFrame,
369 sample_size: int,
370 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"],
371 ) -> pd.DataFrame:
372 """
373 Samples a subset of documents from a ranked list given a sampling_strategy.
374
375 :param documents: Ranked list of documents
376 :type documents: pd.DataFrame
377 :param sample_size: Number of documents to sample
378 :type sample_size: int
379 :return: Sampled documents
380 :rtype: pd.DataFrame
381 """
382 if sample_size == -1:
383 return df
384 if hasattr(Sampler, sampling_strategy):
385 return getattr(Sampler, sampling_strategy)(df, sample_size)
386 raise ValueError("Invalid sampling strategy.")
387
388
[docs]
389class RunDataset(IRDataset, Dataset):
[docs]
390 def __init__(
391 self,
392 run_path_or_id: Path | str,
393 depth: int = -1,
394 sample_size: int = -1,
395 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top",
396 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
397 normalize_targets: bool = False,
398 add_docs_not_in_ranking: bool = False,
399 ) -> None:
400 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list
401 can be sampled using different sampling strategies.
402
403 :param run_path_or_id: Path to a run file or valid ir_datasets id
404 :type run_path_or_id: Path | str
405 :param depth: Depth at which to cut off the ranking. If -1 the full ranking is kept, defaults to -1
406 :type depth: int, optional
407 :param sample_size: The number of documents to sample per query, defaults to -1
408 :type sample_size: int, optional
409 :param sampling_strategy: The sample strategy to use to sample documents, defaults to "top"
410 :type sampling_strategy: Literal['single_relevant', 'top', 'random', 'log_random', 'top_and_random'], optional
411 :param targets: The data type to use as targets for a model during fine-tuning. If relevance the relevance
412 judgements are parsed from qrels, defaults to None
413 :type targets: Literal['relevance', 'subtopic_relevance', 'rank', 'score'] | None, optional
414 :param normalize_targets: Whether to normalize the targets between 0 and 1, defaults to False
415 :type normalize_targets: bool, optional
416 :param add_docs_not_in_ranking: Whether to add relevant to a sample that are in the qrels but not in the
417 ranking, defaults to False
418 :type add_docs_not_in_ranking: bool, optional
419 """
420 self.run_path = None
421 if Path(run_path_or_id).is_file():
422 self.run_path = Path(run_path_or_id)
423 dataset = self.run_path.name.split(".")[0].split("__")[-1]
424 else:
425 dataset = str(run_path_or_id)
426 super().__init__(dataset)
427 self.depth = depth
428 self.sample_size = sample_size
429 self.sampling_strategy = sampling_strategy
430 self.targets = targets
431 self.normalize_targets = normalize_targets
432 self.add_docs_not_in_ranking = add_docs_not_in_ranking
433
434 if self.sampling_strategy == "top" and self.sample_size > self.depth:
435 warnings.warn(
436 "Sample size is greater than depth and top sampling strategy is used. "
437 "This can cause documents to be sampled that are not contained "
438 "in the run file, but that are present in the qrels."
439 )
440
441 self.run: pd.DataFrame | None = None
442
[docs]
443 def prepare_data(self) -> None:
444 """Downloads docs, queries, scoreddocs, and qrels using ir_datasets if needed and available."""
445 self.prepare_constituent("docs")
446 self.prepare_constituent("queries")
447 if self.run_path is None:
448 self.prepare_constituent("scoreddocs")
449 self.prepare_constituent("qrels")
450
451 def _setup(self):
452 if self.run is not None:
453 return
454 self.run = self._load_run()
455 self.run = self.run.drop_duplicates(["query_id", "doc_id"])
456
457 if self.qrels is not None:
458 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates())
459 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique()
460 query_ids = run_query_ids.intersection(qrels_query_ids)
461 if len(run_query_ids.difference(qrels_query_ids)):
462 self.run = self.run[self.run["query_id"].isin(query_ids)]
463 # outer join if docs are from ir_datasets else only keep docs in run
464 how = "left"
465 if self._docs is None and self.add_docs_not_in_ranking:
466 how = "outer"
467 self.run = self.run.merge(
468 self.qrels.loc[pd.IndexSlice[query_ids, :]].droplevel(0, axis=1).add_prefix("relevance_", axis=1),
469 on=["query_id", "doc_id"],
470 how=how,
471 )
472
473 if self.sample_size != -1:
474 num_docs_per_query = self.run.groupby("query_id").transform("size")
475 self.run = self.run[num_docs_per_query >= self.sample_size]
476
477 self.run = self.run.sort_values(["query_id", "rank"])
478 self.run_groups = self.run.groupby("query_id")
479 self.query_ids = list(self.run_groups.groups.keys())
480
481 if self.depth != -1 and self.run["rank"].max() < self.depth:
482 warnings.warn("Depth is greater than the maximum rank in the run file.")
483
484 @staticmethod
485 def _load_csv(path: Path) -> pd.DataFrame:
486 return pd.read_csv(
487 path,
488 sep=r"\s+",
489 header=None,
490 names=RUN_HEADER,
491 usecols=[0, 1, 2, 3, 4],
492 dtype={"query_id": str, "doc_id": str},
493 quoting=csv.QUOTE_NONE,
494 na_filter=False,
495 )
496
497 @staticmethod
498 def _load_parquet(path: Path) -> pd.DataFrame:
499 return pd.read_parquet(path).rename(
500 {
501 "qid": "query_id",
502 "docid": "doc_id",
503 "docno": "doc_id",
504 },
505 axis=1,
506 )
507
508 @staticmethod
509 def _load_json(path: Path) -> pd.DataFrame:
510 kwargs: Dict[str, Any] = {}
511 if ".jsonl" in path.suffixes:
512 kwargs["lines"] = True
513 kwargs["orient"] = "records"
514 run = pd.read_json(
515 path,
516 **kwargs,
517 dtype={
518 "query_id": str,
519 "qid": str,
520 "doc_id": str,
521 "docid": str,
522 "docno": str,
523 },
524 )
525 return run
526
527 def _get_run_path(self) -> Path | None:
528 run_path = self.run_path
529 if run_path is None:
530 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs():
531 raise ValueError(f"Run file or dataset with scoreddocs required. Got {self._dataset}")
532 try:
533 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path()
534 except NotImplementedError:
535 pass
536 return run_path
537
538 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame:
539 run = run.rename(
540 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id", "Q0": "iteration", "q0": "iteration"},
541 axis=1,
542 )
543 if "query" in run.columns:
544 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text")
545 run = run.drop("query", axis=1)
546 if "text" in run.columns:
547 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict()
548 run = run.drop("text", axis=1)
549 if self.depth != -1:
550 run = run[run["rank"] <= self.depth]
551 dtypes = {"rank": np.int32}
552 if "score" in run.columns:
553 dtypes["score"] = np.float32
554 run = run.astype(dtypes)
555 return run
556
557 def _load_run(self) -> pd.DataFrame:
558
559 suffix_load_map = {
560 ".tsv": self._load_csv,
561 ".run": self._load_csv,
562 ".csv": self._load_csv,
563 ".parquet": self._load_parquet,
564 ".json": self._load_json,
565 ".jsonl": self._load_json,
566 }
567 run = None
568
569 # try loading run from file
570 run_path = self._get_run_path()
571 if run_path is not None:
572 load_func = suffix_load_map.get(run_path.suffixes[0], None)
573 if load_func is not None:
574 try:
575 run = load_func(run_path)
576 except Exception:
577 pass
578
579 # try loading run from ir_datasets
580 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
581 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
582 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
583 run = run.sort_values(["query_id", "rank"])
584
585 if run is None:
586 raise ValueError("Invalid run file format.")
587
588 run = self._clean_run(run)
589 return run
590
591 @property
592 def qrels(self) -> pd.DataFrame | None:
593 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None.
594
595 :return: Qrels
596 :rtype: pd.DataFrame | None
597 """
598 if self._qrels is not None:
599 return self._qrels
600 if self.run is not None and "relevance" in self.run:
601 qrels = self.run[["query_id", "doc_id", "relevance"]].copy()
602 if "iteration" in self.run:
603 qrels["iteration"] = self.run["iteration"]
604 else:
605 qrels["iteration"] = "0"
606 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore")
607 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
608 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
609 self._qrels = qrels
610 return self._qrels
611 return super().qrels
612
613 def __len__(self) -> int:
614 """Number of queries in the dataset.
615
616 :return: Number of queries
617 :rtype: int
618 """
619 self._setup()
620 return len(self.query_ids)
621
622 def __getitem__(self, idx: int) -> RankSample:
623 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according
624 to the sampling strategy and sample size.
625
626 :param idx: Index of the query
627 :type idx: int
628 :raises ValueError: If the targets are not found in the run file
629 :return: Sampled query and documents
630 :rtype: RankSample
631 """
632 self._setup()
633 query_id = str(self.query_ids[idx])
634 group = self.run_groups.get_group(query_id).copy()
635 query = self.queries[query_id]
636 group = Sampler.sample(group, self.sample_size, self.sampling_strategy)
637
638 doc_ids = tuple(group["doc_id"])
639 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
640
641 targets = None
642 if self.targets is not None:
643 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0)
644 if filtered.empty:
645 raise ValueError(f"targets `{self.targets}` not found in run file")
646 targets = torch.from_numpy(filtered.values)
647 if self.targets == "rank":
648 # invert ranks to be higher is better (necessary for loss functions)
649 targets = self.depth - targets + 1
650 if self.normalize_targets:
651 targets_min = targets.min()
652 targets_max = targets.max()
653 targets = (targets - targets_min) / (targets_max - targets_min)
654 qrels = None
655 if self.qrels is not None:
656 qrels = (
657 self.qrels.loc[[query_id]]
658 .stack(future_stack=True)
659 .dropna()
660 .astype(int)
661 .reset_index()
662 .to_dict(orient="records")
663 )
664 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
665
666
[docs]
667class TupleDataset(IRDataset, IterableDataset):
[docs]
668 def __init__(
669 self,
670 tuples_dataset: str,
671 targets: Literal["order", "score"] = "order",
672 num_docs: int | None = None,
673 ) -> None:
674 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks.
675
676 :param tuples_dataset: Path to file containing tuples or valid ir_datasets id
677 :type tuples_dataset: str
678 :param targets: The data type to use as targets for a model during fine-tuning, defaults to "order"
679 :type targets: Literal["order", "score"], optional
680 :param num_docs: Maximum number of documents per query, defaults to None
681 :type num_docs: int | None, optional
682 """
683 super().__init__(tuples_dataset)
684 super(IRDataset, self).__init__()
685 self.targets = targets
686 self.num_docs = num_docs
687
688 def _parse_sample(
689 self, sample: ScoredDocTuple | GenericDocPair
690 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]:
691 if isinstance(sample, GenericDocPair):
692 if self.targets == "score":
693 raise ValueError("ScoredDocTuple required for score targets.")
694 targets = (1.0, 0.0)
695 doc_ids = (sample.doc_id_a, sample.doc_id_b)
696 elif isinstance(sample, ScoredDocTuple):
697 doc_ids = sample.doc_ids[: self.num_docs]
698 if self.targets == "score":
699 if sample.scores is None:
700 raise ValueError("tuples dataset does not contain scores")
701 targets = sample.scores
702 elif self.targets == "order":
703 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1))
704 else:
705 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)")
706 targets = targets[: self.num_docs]
707 else:
708 raise ValueError("Invalid sample type.")
709 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
710 return doc_ids, docs, targets
711
712 def __iter__(self) -> Iterator[RankSample]:
713 """Iterates over tuples in the dataset.
714
715 :yield: A single tuple sample
716 :rtype: Iterator[RankSample]
717 """
718 for sample in self.ir_dataset.docpairs_iter():
719 query_id = sample.query_id
720 query = self.queries.loc[query_id]
721 doc_ids, docs, targets = self._parse_sample(sample)
722 if targets is not None:
723 targets = torch.tensor(targets)
724 yield RankSample(query_id, query, doc_ids, docs, targets)
725
[docs]
726 def prepare_data(self) -> None:
727 """Downloads tuples using ir_datasets if needed."""
728 self.prepare_constituent("docs")
729 self.prepare_constituent("queries")
730 self.prepare_constituent("docpairs")