Source code for lightning_ir.data.external_datasets.ir_datasets_utils

  1import codecs
  2import json
  3from pathlib import Path
  4from typing import Any, Dict, Literal, NamedTuple, Tuple, Type
  5
  6import ir_datasets
  7from ir_datasets.datasets.base import Dataset
  8from ir_datasets.datasets.nano_beir import parquet_iter
  9from ir_datasets.formats import (
 10    BaseDocPairs,
 11    BaseDocs,
 12    BaseQrels,
 13    BaseQueries,
 14    BaseScoredDocs,
 15    GenericScoredDoc,
 16    jsonl,
 17    trec,
 18    tsv,
 19)
 20from ir_datasets.util import Cache, DownloadConfig
 21
 22CONSTITUENT_TYPE_MAP: Dict[str, Dict[str, Type]] = {
 23    "docs": {
 24        ".json": jsonl.JsonlDocs,
 25        ".jsonl": jsonl.JsonlDocs,
 26        ".tsv": tsv.TsvDocs,
 27    },
 28    "queries": {
 29        ".json": jsonl.JsonlQueries,
 30        ".jsonl": jsonl.JsonlQueries,
 31        ".tsv": tsv.TsvQueries,
 32    },
 33    "qrels": {".tsv": trec.TrecQrels, ".qrels": trec.TrecQrels},
 34    "scoreddocs": {".run": trec.TrecScoredDocs, ".tsv": trec.TrecScoredDocs},
 35    "docpairs": {".tsv": tsv.TsvDocPairs},
 36}
 37
 38
 39def _load_constituent(
 40    dataset_id: str,
 41    constituent: Path | str | Dict[str, Any] | None,
 42    constituent_type: Literal["docs", "queries", "qrels", "scoreddocs", "docpairs"] | Type,
 43    **kwargs,
 44) -> Any:
 45    if constituent is None:
 46        return None
 47    if isinstance(constituent, dict):
 48        constituent_path = Path(constituent["cache_path"])
 49        cache = _register_and_get_cache(dataset_id, constituent)
 50    elif constituent in ir_datasets.registry._registered:
 51        return getattr(ir_datasets.load(constituent), f"{constituent_type}_handler")()
 52    else:
 53        constituent_path = Path(constituent)
 54        cache = Cache(None, constituent_path)
 55        if not constituent_path.exists():
 56            raise ValueError(f"unable to load {constituent}, expected an `ir_datasets` id or valid path")
 57    if isinstance(constituent_type, str):
 58        suffix = constituent_path.suffixes[0]
 59        constituent_types = CONSTITUENT_TYPE_MAP[constituent_type]
 60        if suffix not in constituent_types:
 61            raise ValueError(f"Unknown file type: {suffix}, expected one of {constituent_types.keys()}")
 62        ConstituentType = constituent_types[suffix]
 63    else:
 64        ConstituentType = constituent_type
 65    return ConstituentType(cache, **kwargs)
 66
 67
 68def _register_and_get_cache(dataset_id: str, dlc_contents: Dict[str, Any]) -> Cache:
 69    extractors = dlc_contents.pop("extractors", [])
 70    base_id = dataset_id.split("/")[0]
 71    new_id = dataset_id.removeprefix(base_id + "/")
 72    base_path = ir_datasets.util.home_path()
 73    dlc = DownloadConfig.context(base_id, base_path / base_id)
 74    dlc.contents()[new_id] = dlc_contents
 75    dataset_dlc = dlc[new_id]
 76    file_path = Path(dlc_contents["cache_path"])
 77    for extractor in extractors:
 78        dataset_dlc = extractor(dataset_dlc)
 79    return Cache(dataset_dlc, base_path / file_path)
 80
 81
[docs] 82def register_new_dataset( 83 dataset_id: str, 84 docs: Path | str | Dict[str, str] | None = None, 85 DocsType: Type[BaseDocs] | None = None, 86 queries: Path | str | Dict[str, str] | None = None, 87 QueriesType: Type[BaseQueries] | None = None, 88 qrels: Path | str | Dict[str, str] | None = None, 89 QrelsType: Type[BaseQrels] | None = None, 90 docpairs: Path | str | Dict[str, str] | None = None, 91 DocpairsType: Type[BaseDocPairs] | None = None, 92 scoreddocs: Path | str | Dict[str, str] | None = None, 93 ScoreddocsType: Type[BaseScoredDocs] | None = None, 94 qrels_defs: Dict[int, str] | None = None, 95): 96 if dataset_id in ir_datasets.registry._registered: 97 return 98 99 docs = _load_constituent(dataset_id, docs, DocsType or "docs") 100 queries = _load_constituent(dataset_id, queries, QueriesType or "queries") 101 qrels = _load_constituent( 102 dataset_id, qrels, QrelsType or "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {} 103 ) 104 docpairs = _load_constituent(dataset_id, docpairs, DocpairsType or "docpairs") 105 scoreddocs = _load_constituent(dataset_id, scoreddocs, ScoreddocsType or "scoreddocs") 106 107 ir_datasets.registry.register(dataset_id, Dataset(docs, queries, qrels, docpairs, scoreddocs))
108 109
[docs] 110class ScoredDocTuple(NamedTuple): 111 query_id: str 112 doc_ids: Tuple[str, ...] 113 scores: Tuple[float, ...] | None 114 num_docs: int
115 116
[docs] 117class ScoredDocTuples(BaseDocPairs):
[docs] 118 def __init__(self, docpairs_dlc): 119 self._docpairs_dlc = docpairs_dlc
120 121 def docpairs_path(self): 122 return self._docpairs_dlc.path() 123 124 def docpairs_iter(self): 125 file_type = None 126 if self._docpairs_dlc.path().suffix == ".json": 127 file_type = "json" 128 elif self._docpairs_dlc.path().suffix in (".tsv", ".run"): 129 file_type = "tsv" 130 else: 131 raise ValueError(f"Unknown file type: {self._docpairs_dlc.path().suffix}") 132 with self._docpairs_dlc.stream() as f: 133 f = codecs.getreader("utf8")(f) 134 for line in f: 135 if file_type == "json": 136 data = json.loads(line) 137 qid, *doc_data = data 138 pids, scores = zip(*doc_data) 139 pids = tuple(str(pid) for pid in pids) 140 else: 141 cols = line.rstrip().split() 142 pos_score, neg_score, qid, pid1, pid2 = cols 143 pids = (pid1, pid2) 144 scores = (float(pos_score), float(neg_score)) 145 yield ScoredDocTuple(str(qid), pids, scores, len(pids)) 146 147 def docpairs_cls(self): 148 return ScoredDocTuple
149 150
[docs] 151class ParquetScoredDocs(BaseScoredDocs):
[docs] 152 def __init__(self, scoreddocs_dlc, negate_score=False): 153 self._scoreddocs_dlc = scoreddocs_dlc
154 155 def scoreddocs_path(self): 156 return self._scoreddocs_dlc.path() 157 158 def scoreddocs_iter(self): 159 for d in parquet_iter(self._scoreddocs_dlc.path()): 160 yield GenericScoredDoc(d["query-id"], d["corpus-id"], 1)