1"""
2Datasets for Lightning IR that data loading and sampling.
3
4This module defines several datasets that handle loading and sampling data for training and inference.
5"""
6
7import csv
8import warnings
9from collections.abc import Iterator, Sequence
10from itertools import islice
11from pathlib import Path
12from typing import Any, Literal, Self
13
14import ir_datasets
15import numpy as np
16import pandas as pd
17import torch
18from ir_datasets.formats import GenericDoc, GenericDocPair
19from torch.distributed import get_rank, get_world_size
20from torch.utils.data import Dataset, IterableDataset, get_worker_info
21
22from .data import DocSample, QuerySample, RankSample
23from .external_datasets.ir_datasets_utils import ScoredDocTuple
24from .external_datasets.lsr_benchmark import register_lsr_dataset_to_ir_datasets
25
26RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
27
28
29class _DummyIterableDataset(IterableDataset):
30 """Dummy iterable dataset to use when all inference datasets are skipped."""
31
32 def __iter__(self) -> Iterator:
33 yield from iter([])
34
35
[docs]
36class IRDataset:
37 _SKIP: bool = False
38 """set to True to skip the dataset during inference."""
39
[docs]
40 def __init__(self, dataset: str) -> None:
41 """Initializes a new IRDataset.
42
43 Args:
44 dataset (str): Dataset name.
45 """
46 super().__init__()
47 self._dataset = dataset
48 self._queries = None
49 self._docs = None
50 self._qrels = None
51
52 @property
53 def dataset(self) -> str:
54 """Dataset name.
55
56 Returns:
57 str: Dataset name.
58 """
59 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset)
60
61 @property
62 def dataset_id(self) -> str:
63 """Dataset id.
64
65 Returns:
66 str: Dataset id.
67 """
68 if self.ir_dataset is None:
69 return self.dataset
70 return self.ir_dataset.dataset_id()
71
72 @property
73 def docs_dataset_id(self) -> str:
74 """ID of the dataset containing the documents.
75
76 Returns:
77 str: Document dataset id.
78 """
79 return ir_datasets.docs_parent_id(self.dataset_id)
80
81 @property
82 def dashed_docs_dataset_id(self) -> str:
83 """Dataset id with dashes instead of slashes for the documents dataset.
84
85 Returns:
86 str: Document dataset id with dashes.
87 """
88 return self.docs_dataset_id.replace("/", "-")
89
90 @property
91 def ir_dataset(self) -> ir_datasets.Dataset | None:
92 """Instance of ir_datasets.Dataset.
93
94 Returns:
95 ir_datasets.Dataset | None: Instance of ir_datasets.Dataset or None if the dataset is not found.
96 """
97 try:
98 return ir_datasets.load(self.dataset)
99 except KeyError:
100 if self.dataset and self.dataset.startswith("lsr-benchmark/"):
101 register_lsr_dataset_to_ir_datasets(self.dataset)
102 try:
103 return ir_datasets.load(self.dataset)
104 except KeyError:
105 return None
106 return None
107
108 @property
109 def DASHED_DATASET_MAP(self) -> dict[str, str]:
110 """Map of dataset names with dashes to dataset names with slashes.
111
112 Returns:
113 dict[str, str]: Dataset map.
114 """
115 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}
116
117 @property
118 def queries(self) -> pd.Series:
119 """Queries in the dataset.
120
121 Returns:
122 pd.Series: Queries.
123 Raises:
124 ValueError: If no queries are found in the dataset.
125 """
126 if self._queries is None:
127 if self.ir_dataset is None:
128 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
129 queries_iter = self.ir_dataset.queries_iter()
130 self._queries = pd.Series(
131 {query.query_id: query.default_text() for query in queries_iter},
132 name="text",
133 )
134 self._queries.index.name = "query_id"
135 return self._queries
136
137 @property
138 def docs(self) -> ir_datasets.indices.Docstore | dict[str, GenericDoc]:
139 """Documents in the dataset.
140
141 Returns:
142 ir_datasets.indices.Docstore | dict[str, GenericDoc]: Documents.
143 Raises:
144 ValueError: If no documents are found in the dataset.
145 """
146 if self._docs is None:
147 if self.ir_dataset is None:
148 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
149 self._docs = self.ir_dataset.docs_store()
150 return self._docs
151
152 @property
153 def qrels(self) -> pd.DataFrame | None:
154 """Qrels in the dataset.
155
156 Returns:
157 pd.DataFrame | None: Qrels.
158 """
159 if self._qrels is not None:
160 return self._qrels
161 if self.ir_dataset is None or not self.ir_dataset.has_qrels():
162 return None
163 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1)
164 if "iteration" not in qrels.columns:
165 qrels["iteration"] = 0
166 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
167 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
168 self._qrels = qrels
169 return self._qrels
170
[docs]
171 def prepare_constituent(self, constituent: Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]) -> None:
172 """Downloads the constituent of the dataset using ir_datasets if needed.
173
174 Args:
175 constituent (Literal["qrels", "queries", "docs", "scoreddocs", "docpairs"]): Constituent to download.
176 """
177 if self.ir_dataset is None:
178 return
179 if self.ir_dataset.has(constituent):
180 if constituent == "docs" and hasattr(self.ir_dataset, "docs_store"):
181 # build docs store if not already built
182 docs_store = self.ir_dataset.docs_store()
183 if not docs_store.built():
184 docs_store.build()
185 else:
186 # get first item to trigger download
187 next(getattr(self.ir_dataset, f"{constituent}_iter")())
188
189
190class _DataParallelIterableDataset(IterableDataset):
191 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734
192 def __init__(self) -> None:
193 super().__init__()
194 # TODO add support for multi-gpu and multi-worker inference; currently
195 # doesn't work
196 worker_info = get_worker_info()
197 num_workers = worker_info.num_workers if worker_info is not None else 1
198 worker_id = worker_info.id if worker_info is not None else 0
199
200 try:
201 world_size = get_world_size()
202 process_rank = get_rank()
203 except (RuntimeError, ValueError):
204 world_size = 1
205 process_rank = 0
206
207 self.num_replicas = num_workers * world_size
208 self.rank = process_rank * num_workers + worker_id
209
210
[docs]
211class QueryDataset(IRDataset, _DataParallelIterableDataset):
[docs]
212 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None:
213 """Dataset containing queries.
214
215 Args:
216 query_dataset (str): Path to file containing queries or valid ir_datasets id.
217 num_queries (int | None, optional): Number of queries in dataset. If None, the number of queries will
218 attempted to be inferred. Defaults to None.
219 """
220 super().__init__(query_dataset)
221 super(IRDataset, self).__init__()
222 self.num_queries = num_queries
223
224 def __len__(self) -> int | None:
225 """Number of queries in the dataset. Returns None if the number of queries cannot be inferred.
226
227 Returns:
228 int | None: Number of queries.
229 """
230 # TODO fix len for multi-gpu and multi-worker inference
231 return self.num_queries or getattr(self.ir_dataset, "queries_count", lambda: None)() or None
232
233 def __iter__(self) -> Iterator[QuerySample]:
234 """Iterate over queries in the dataset.
235
236 Yields:
237 QuerySample: Query sample.
238 """
239 start = self.rank
240 stop = self.num_queries
241 step = self.num_replicas
242 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step):
243 query_sample = QuerySample.from_ir_dataset_sample(sample)
244 if self.qrels is not None:
245 qrels = (
246 self.qrels.loc[[query_sample.query_id]]
247 .stack(future_stack=True)
248 .dropna()
249 .astype(int)
250 .reset_index()
251 .to_dict(orient="records")
252 )
253 query_sample.qrels = qrels
254 yield query_sample
255
[docs]
256 def prepare_data(self) -> None:
257 """Downloads queries using ir_datasets if needed."""
258 self.prepare_constituent("queries")
259
260
[docs]
261class DocDataset(IRDataset, _DataParallelIterableDataset):
[docs]
262 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None:
263 """Dataset containing documents.
264
265 Args:
266 doc_dataset (str): Path to file containing documents or valid ir_datasets id.
267 num_docs (int | None, optional): Number of documents in dataset. If None, the number of documents will
268 attempted to be inferred. Defaults to None.
269 text_fields (Sequence[str] | None, optional): Fields to parse the document text from. Defaults to None.
270 """
271 super().__init__(doc_dataset)
272 super(IRDataset, self).__init__()
273 self.num_docs = num_docs
274 self.text_fields = text_fields
275
276 def __len__(self) -> int | None:
277 """Number of documents in the dataset. Returns None if the number of documents cannot be inferred.
278
279 Returns:
280 int | None: Number of documents.
281 """
282 # TODO fix len for multi-gpu and multi-worker inference
283 return self.num_docs or getattr(self.ir_dataset, "docs_count", lambda: None)() or None
284
285 def __iter__(self) -> Iterator[DocSample]:
286 """Iterate over documents in the dataset.
287
288 Yields:
289 DocSample: Document sample.
290 """
291 start = self.rank
292 stop = self.num_docs
293 step = self.num_replicas
294 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step):
295 yield DocSample.from_ir_dataset_sample(sample, self.text_fields)
296
[docs]
297 def prepare_data(self) -> None:
298 """Downloads documents using ir_datasets if needed."""
299 self.prepare_constituent("docs")
300
301
[docs]
302class Sampler:
303 """Helper class for sampling subsets of documents from a ranked list."""
304
[docs]
305 @staticmethod
306 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
307 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
308 are non-relevant.
309
310 Args:
311 documents (pd.DataFrame): Ranked list of documents.
312 sample_size (int): Number of documents to sample.
313 Returns:
314 pd.DataFrame: Sampled documents.
315 """
316 relevance = documents.filter(like="relevance").max(axis=1).fillna(0)
317 relevant = documents.loc[relevance.gt(0)].sample(1)
318 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna()
319 num_non_relevant = non_relevant_bool.sum()
320 sample_non_relevant = min(sample_size - 1, num_non_relevant)
321 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant)
322 return pd.concat([relevant, non_relevant])
323
[docs]
324 @staticmethod
325 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
326 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
327 are non-relevant.
328
329 Args:
330 documents (pd.DataFrame): Ranked list of documents.
331 sample_size (int): Number of documents to sample.
332 Returns:
333 pd.DataFrame: Sampled documents.
334 """
335 return documents.head(sample_size)
336
[docs]
337 @staticmethod
338 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
339 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the
340 other half randomly.
341
342 Args:
343 documents (pd.DataFrame): Ranked list of documents.
344 sample_size (int): Number of documents to sample.
345 Returns:
346 pd.DataFrame: Sampled documents.
347 """
348 top_size = sample_size // 2
349 random_size = sample_size - top_size
350 top = documents.head(top_size)
351 random = documents.iloc[top_size:].sample(random_size)
352 return pd.concat([top, random])
353
[docs]
354 @staticmethod
355 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
356 """Sampling strategy to randomly sample ``sample_size`` documents.
357
358 Args:
359 documents (pd.DataFrame): Ranked list of documents.
360 sample_size (int): Number of documents to sample.
361 Returns:
362 pd.DataFrame: Sampled documents.
363 """
364 return documents.sample(sample_size)
365
[docs]
366 @staticmethod
367 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
368 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of
369 the ranking.
370
371 Args:
372 documents (pd.DataFrame): Ranked list of documents.
373 sample_size (int): Number of documents to sample.
374 Returns:
375 pd.DataFrame: Sampled documents.
376 """
377 weights = 1 / np.log1p(documents["rank"])
378 weights[weights.isna()] = weights.min()
379 return documents.sample(sample_size, weights=weights)
380
[docs]
381 @staticmethod
382 def sample(
383 df: pd.DataFrame,
384 sample_size: int,
385 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"],
386 ) -> pd.DataFrame:
387 """
388 Samples a subset of documents from a ranked list given a sampling_strategy.
389
390 Args:
391 documents (pd.DataFrame): Ranked list of documents.
392 sample_size (int): Number of documents to sample.
393 Returns:
394 pd.DataFrame: Sampled documents.
395 """
396 if sample_size == -1:
397 return df
398 if hasattr(Sampler, sampling_strategy):
399 return getattr(Sampler, sampling_strategy)(df, sample_size)
400 raise ValueError("Invalid sampling strategy.")
401
402
[docs]
403class RunDataset(IRDataset, Dataset):
[docs]
404 def __init__(
405 self,
406 run_path_or_id: Path | str,
407 depth: int = -1,
408 sample_size: int = -1,
409 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top",
410 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
411 normalize_targets: bool = False,
412 add_docs_not_in_ranking: bool = False,
413 ) -> None:
414 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list
415 can be sampled using different sampling strategies.
416
417 Args:
418 run_path_or_id (Path | str): Path to a run file or valid ir_datasets id.
419 depth (int): Depth at which to cut off the ranking. If -1 the full ranking is kept.
420 Defaults to -1.
421 sample_size (int): The number of documents to sample per query. Defaults to -1.
422 sampling_strategy (Literal["single_relevant", "top", "random", "log_random", "top_and_random"]):
423 The sample strategy to use to sample documents. Defaults to "top".
424 targets (Literal["relevance", "subtopic_relevance", "rank", "score"] | None):
425 The data type to use as targets for a model during fine-tuning. If "relevance" the relevance
426 judgements are parsed from qrels. Defaults to None.
427 normalize_targets (bool): Whether to normalize the targets between 0 and 1. Defaults to False.
428 add_docs_not_in_ranking (bool): Whether to add relevant documents to a sample that are in the qrels but not
429 in the ranking. Defaults to False.
430 """
431 self.run_path = None
432 if Path(run_path_or_id).is_file():
433 self.run_path = Path(run_path_or_id)
434 dataset = self.run_path.name.split(".")[0].split("__")[-1]
435 else:
436 dataset = str(run_path_or_id)
437 super().__init__(dataset)
438 self.depth = depth
439 self.sample_size = sample_size
440 self.sampling_strategy = sampling_strategy
441 self.targets = targets
442 self.normalize_targets = normalize_targets
443 self.add_docs_not_in_ranking = add_docs_not_in_ranking
444
445 if self.sampling_strategy == "top" and self.sample_size > self.depth:
446 warnings.warn(
447 "Sample size is greater than depth and top sampling strategy is used. "
448 "This can cause documents to be sampled that are not contained "
449 "in the run file, but that are present in the qrels.",
450 stacklevel=2,
451 )
452
453 self.run: pd.DataFrame | None = None
454
[docs]
455 def prepare_data(self) -> Self:
456 """Downloads docs, queries, scoreddocs, and qrels using ir_datasets if needed and available."""
457 self.prepare_constituent("docs")
458 self.prepare_constituent("queries")
459 if self.run_path is None:
460 self.prepare_constituent("scoreddocs")
461 self.prepare_constituent("docpairs")
462 self.prepare_constituent("qrels")
463 return self
464
465 def _setup(self):
466 if self.run is not None:
467 return
468 self.run = self._load_run()
469 self.run = self.run.drop_duplicates(["query_id", "doc_id"])
470
471 if self.qrels is not None:
472 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates())
473 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique()
474 query_ids = run_query_ids.intersection(qrels_query_ids)
475 if len(run_query_ids.difference(qrels_query_ids)):
476 self.run = self.run[self.run["query_id"].isin(query_ids)]
477 # outer join if docs are from ir_datasets else only keep docs in run
478 how = "left"
479 if self._docs is None and self.add_docs_not_in_ranking:
480 how = "outer"
481 self.run = self.run.merge(
482 self.qrels.loc[pd.IndexSlice[query_ids, :]].droplevel(0, axis=1).add_prefix("relevance_", axis=1),
483 on=["query_id", "doc_id"],
484 how=how,
485 )
486
487 if self.sample_size != -1:
488 num_docs_per_query = self.run.groupby("query_id").transform("size")
489 self.run = self.run[num_docs_per_query >= self.sample_size]
490
491 self.run = self.run.sort_values(["query_id", "rank"])
492 self.run_groups = self.run.groupby("query_id")
493 self.query_ids = list(self.run_groups.groups.keys())
494
495 if self.depth != -1 and self.run["rank"].max() < self.depth:
496 warnings.warn("Depth is greater than the maximum rank in the run file.", stacklevel=2)
497
498 @staticmethod
499 def _load_csv(path: Path) -> pd.DataFrame:
500 return pd.read_csv(
501 path,
502 sep=r"\s+",
503 header=None,
504 names=RUN_HEADER,
505 usecols=[0, 1, 2, 3, 4],
506 dtype={"query_id": str, "doc_id": str},
507 quoting=csv.QUOTE_NONE,
508 na_filter=False,
509 )
510
511 @staticmethod
512 def _load_parquet(path: Path) -> pd.DataFrame:
513 return pd.read_parquet(path)
514
515 @staticmethod
516 def _load_json(path: Path) -> pd.DataFrame:
517 kwargs: dict[str, Any] = {}
518 if ".jsonl" in path.suffixes:
519 kwargs["lines"] = True
520 kwargs["orient"] = "records"
521 run = pd.read_json(path, **kwargs)
522 return run
523
524 def _get_run_path(self) -> Path | None:
525 run_path = self.run_path
526 if run_path is None:
527 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs():
528 raise ValueError(f"Run file or dataset with scoreddocs required. Got {self._dataset}")
529 try:
530 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path()
531 except NotImplementedError:
532 pass
533 return run_path
534
535 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame:
536 run = run.rename(
537 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id", "Q0": "iteration", "q0": "iteration"},
538 axis=1,
539 )
540 dtypes = {"rank": np.int32, "query_id": str, "doc_id": str}
541 if "score" in run.columns:
542 dtypes["score"] = np.float32
543 run = run.astype(dtypes)
544 if "query" in run.columns:
545 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text")
546 run = run.drop("query", axis=1)
547 if "text" in run.columns:
548 self._docs = (
549 run.drop_duplicates("doc_id").set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict()
550 )
551 run = run.drop("text", axis=1)
552 if self.depth != -1:
553 run = run[run["rank"] <= self.depth]
554 return run
555
556 def _load_run(self) -> pd.DataFrame:
557 suffix_load_map = {
558 ".tsv": self._load_csv,
559 ".run": self._load_csv,
560 ".csv": self._load_csv,
561 ".parquet": self._load_parquet,
562 ".json": self._load_json,
563 ".jsonl": self._load_json,
564 }
565 run = None
566
567 # try loading run from file
568 run_path = self._get_run_path()
569 if run_path is not None:
570 load_func = suffix_load_map.get(run_path.suffixes[0], None)
571 if load_func is not None:
572 try:
573 run = load_func(run_path)
574 except Exception:
575 pass
576
577 # try loading run from ir_datasets
578 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
579 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
580 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
581 run = run.sort_values(["query_id", "rank"])
582
583 if run is None:
584 raise ValueError("Invalid run file format.")
585
586 run = self._clean_run(run)
587 return run
588
589 @property
590 def qrels(self) -> pd.DataFrame | None:
591 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None.
592
593 Returns:
594 pd.DataFrame | None: Qrels.
595 """
596 if self._qrels is not None:
597 return self._qrels
598 if self.run is not None and "relevance" in self.run:
599 qrels = self.run[["query_id", "doc_id", "relevance"]].copy()
600 if "iteration" in self.run:
601 qrels["iteration"] = self.run["iteration"]
602 else:
603 qrels["iteration"] = "0"
604 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore")
605 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
606 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
607 self._qrels = qrels
608 return self._qrels
609 return super().qrels
610
611 def __len__(self) -> int:
612 """Number of queries in the dataset.
613
614 Returns:
615 int: Number of queries.
616 """
617 self._setup()
618 return len(self.query_ids)
619
620 def __getitem__(self, idx: int) -> RankSample:
621 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according
622 to the sampling strategy and sample size.
623
624 Args:
625 idx (int): Index of the query.
626 Returns:
627 RankSample: Sampled query and documents.
628 Raises:
629 ValueError: If the targets are not found in the run file.
630 """
631 self._setup()
632 query_id = str(self.query_ids[idx])
633 group = self.run_groups.get_group(query_id).copy()
634 query = self.queries[query_id]
635 group = Sampler.sample(group, self.sample_size, self.sampling_strategy)
636
637 doc_ids = tuple(group["doc_id"])
638 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
639
640 targets = None
641 if self.targets is not None:
642 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0)
643 if filtered.empty:
644 raise ValueError(f"targets `{self.targets}` not found in run file")
645 targets = torch.from_numpy(filtered.values)
646 if self.targets == "rank":
647 # invert ranks to be higher is better (necessary for loss functions)
648 targets = self.depth - targets + 1
649 if self.normalize_targets:
650 targets_min = targets.min()
651 targets_max = targets.max()
652 targets = (targets - targets_min) / (targets_max - targets_min)
653 qrels = None
654 if self.qrels is not None:
655 qrels = (
656 self.qrels.loc[[query_id]]
657 .stack(future_stack=True)
658 .dropna()
659 .astype(int)
660 .reset_index()
661 .to_dict(orient="records")
662 )
663 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
664
665
[docs]
666class TupleDataset(IRDataset, IterableDataset):
[docs]
667 def __init__(
668 self,
669 tuples_dataset: str,
670 targets: Literal["order", "score"] = "order",
671 num_docs: int | None = None,
672 ) -> None:
673 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks.
674
675 Args:
676 tuples_dataset (str): Path to file containing tuples or valid ir_datasets id.
677 targets (Literal["order", "score"], optional): Data type to use as targets for a model during fine-tuning.
678 Defaults to "order".
679 num_docs (int | None, optional): Maximum number of documents per query. Defaults to None.
680 """
681 super().__init__(tuples_dataset)
682 super(IRDataset, self).__init__()
683 self.targets = targets
684 self.num_docs = num_docs
685
686 def _parse_sample(
687 self, sample: ScoredDocTuple | GenericDocPair
688 ) -> tuple[tuple[str, ...], tuple[str, ...], tuple[float, ...] | None]:
689 if isinstance(sample, GenericDocPair):
690 if self.targets == "score":
691 raise ValueError("ScoredDocTuple required for score targets.")
692 targets = (1.0, 0.0)
693 doc_ids = (sample.doc_id_a, sample.doc_id_b)
694 elif isinstance(sample, ScoredDocTuple):
695 doc_ids = sample.doc_ids[: self.num_docs]
696 if self.targets == "score":
697 if sample.scores is None:
698 raise ValueError("tuples dataset does not contain scores")
699 targets = sample.scores
700 elif self.targets == "order":
701 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1))
702 else:
703 raise ValueError(f"invalid value for targets, got {self.targets}, expected one of (order, score)")
704 targets = targets[: self.num_docs]
705 else:
706 raise ValueError("Invalid sample type.")
707 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
708 return doc_ids, docs, targets
709
710 def __iter__(self) -> Iterator[RankSample]:
711 """Iterates over tuples in the dataset.
712
713 Yields:
714 RankSample: Sampled query and documents with targets.
715 """
716 for sample in self.ir_dataset.docpairs_iter():
717 query_id = sample.query_id
718 query = self.queries.loc[query_id]
719 doc_ids, docs, targets = self._parse_sample(sample)
720 if targets is not None:
721 targets = torch.tensor(targets)
722 yield RankSample(query_id, query, doc_ids, docs, targets)
723
[docs]
724 def prepare_data(self) -> None:
725 """Downloads tuples using ir_datasets if needed."""
726 self.prepare_constituent("docs")
727 self.prepare_constituent("queries")
728 self.prepare_constituent("docpairs")