1import codecs
2import json
3from pathlib import Path
4from typing import Any, Dict, Literal, NamedTuple, Tuple, Type
5
6import ir_datasets
7from ir_datasets.datasets.base import Dataset
8from ir_datasets.datasets.nano_beir import parquet_iter
9from ir_datasets.formats import (
10 BaseDocPairs,
11 BaseDocs,
12 BaseQrels,
13 BaseQueries,
14 BaseScoredDocs,
15 GenericScoredDoc,
16 jsonl,
17 trec,
18 tsv,
19)
20from ir_datasets.util import Cache, DownloadConfig
21
22CONSTITUENT_TYPE_MAP: Dict[str, Dict[str, Type]] = {
23 "docs": {
24 ".json": jsonl.JsonlDocs,
25 ".jsonl": jsonl.JsonlDocs,
26 ".tsv": tsv.TsvDocs,
27 },
28 "queries": {
29 ".json": jsonl.JsonlQueries,
30 ".jsonl": jsonl.JsonlQueries,
31 ".tsv": tsv.TsvQueries,
32 },
33 "qrels": {".tsv": trec.TrecQrels, ".qrels": trec.TrecQrels},
34 "scoreddocs": {".run": trec.TrecScoredDocs, ".tsv": trec.TrecScoredDocs},
35 "docpairs": {".tsv": tsv.TsvDocPairs},
36}
37
38
39def _load_constituent(
40 dataset_id: str,
41 constituent: Path | str | Dict[str, Any] | None,
42 constituent_type: Literal["docs", "queries", "qrels", "scoreddocs", "docpairs"] | Type,
43 **kwargs,
44) -> Any:
45 if constituent is None:
46 return None
47 if isinstance(constituent, dict):
48 constituent_path = Path(constituent["cache_path"])
49 cache = _register_and_get_cache(dataset_id, constituent)
50 elif constituent in ir_datasets.registry._registered:
51 return getattr(ir_datasets.load(constituent), f"{constituent_type}_handler")()
52 else:
53 constituent_path = Path(constituent)
54 cache = Cache(None, constituent_path)
55 if not constituent_path.exists():
56 raise ValueError(f"unable to load {constituent}, expected an `ir_datasets` id or valid path")
57 if isinstance(constituent_type, str):
58 suffix = constituent_path.suffixes[0]
59 constituent_types = CONSTITUENT_TYPE_MAP[constituent_type]
60 if suffix not in constituent_types:
61 raise ValueError(f"Unknown file type: {suffix}, expected one of {constituent_types.keys()}")
62 ConstituentType = constituent_types[suffix]
63 else:
64 ConstituentType = constituent_type
65 return ConstituentType(cache, **kwargs)
66
67
68def _register_and_get_cache(dataset_id: str, dlc_contents: Dict[str, Any]) -> Cache:
69 extractors = dlc_contents.pop("extractors", [])
70 base_id = dataset_id.split("/")[0]
71 new_id = dataset_id.removeprefix(base_id + "/")
72 base_path = ir_datasets.util.home_path()
73 dlc = DownloadConfig.context(base_id, base_path / base_id)
74 dlc.contents()[new_id] = dlc_contents
75 dataset_dlc = dlc[new_id]
76 file_path = Path(dlc_contents["cache_path"])
77 for extractor in extractors:
78 dataset_dlc = extractor(dataset_dlc)
79 return Cache(dataset_dlc, base_path / file_path)
80
81
[docs]
82def register_new_dataset(
83 dataset_id: str,
84 docs: Path | str | Dict[str, str] | None = None,
85 DocsType: Type[BaseDocs] | None = None,
86 queries: Path | str | Dict[str, str] | None = None,
87 QueriesType: Type[BaseQueries] | None = None,
88 qrels: Path | str | Dict[str, str] | None = None,
89 QrelsType: Type[BaseQrels] | None = None,
90 docpairs: Path | str | Dict[str, str] | None = None,
91 DocpairsType: Type[BaseDocPairs] | None = None,
92 scoreddocs: Path | str | Dict[str, str] | None = None,
93 ScoreddocsType: Type[BaseScoredDocs] | None = None,
94 qrels_defs: Dict[int, str] | None = None,
95):
96 if dataset_id in ir_datasets.registry._registered:
97 return
98
99 docs = _load_constituent(dataset_id, docs, DocsType or "docs")
100 queries = _load_constituent(dataset_id, queries, QueriesType or "queries")
101 qrels = _load_constituent(
102 dataset_id, qrels, QrelsType or "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {}
103 )
104 docpairs = _load_constituent(dataset_id, docpairs, DocpairsType or "docpairs")
105 scoreddocs = _load_constituent(dataset_id, scoreddocs, ScoreddocsType or "scoreddocs")
106
107 ir_datasets.registry.register(dataset_id, Dataset(docs, queries, qrels, docpairs, scoreddocs))
108
109
[docs]
110class ScoredDocTuple(NamedTuple):
111 query_id: str
112 doc_ids: Tuple[str, ...]
113 scores: Tuple[float, ...] | None
114 num_docs: int
115
116
[docs]
117class ScoredDocTuples(BaseDocPairs):
[docs]
118 def __init__(self, docpairs_dlc):
119 self._docpairs_dlc = docpairs_dlc
120
121 def docpairs_path(self):
122 return self._docpairs_dlc.path()
123
124 def docpairs_iter(self):
125 file_type = None
126 if self._docpairs_dlc.path().suffix == ".json":
127 file_type = "json"
128 elif self._docpairs_dlc.path().suffix in (".tsv", ".run"):
129 file_type = "tsv"
130 else:
131 raise ValueError(f"Unknown file type: {self._docpairs_dlc.path().suffix}")
132 with self._docpairs_dlc.stream() as f:
133 f = codecs.getreader("utf8")(f)
134 for line in f:
135 if file_type == "json":
136 data = json.loads(line)
137 qid, *doc_data = data
138 pids, scores = zip(*doc_data)
139 pids = tuple(str(pid) for pid in pids)
140 else:
141 cols = line.rstrip().split()
142 pos_score, neg_score, qid, pid1, pid2 = cols
143 pids = (pid1, pid2)
144 scores = (float(pos_score), float(neg_score))
145 yield ScoredDocTuple(str(qid), pids, scores, len(pids))
146
147 def docpairs_cls(self):
148 return ScoredDocTuple
149
150
[docs]
151class ParquetScoredDocs(BaseScoredDocs):
[docs]
152 def __init__(self, scoreddocs_dlc, negate_score=False):
153 self._scoreddocs_dlc = scoreddocs_dlc
154
155 def scoreddocs_path(self):
156 return self._scoreddocs_dlc.path()
157
158 def scoreddocs_iter(self):
159 for d in parquet_iter(self._scoreddocs_dlc.path()):
160 yield GenericScoredDoc(d["query-id"], d["corpus-id"], 1)