Source code for lightning_ir.data.external_datasets.sbert
1import codecs
2import json
3from functools import partial
4
5from ir_datasets.util import GzipExtract
6
7from lightning_ir.data.external_datasets.ir_datasets_utils import ScoredDocTuple, ScoredDocTuples, register_new_dataset
8
9
[docs]
10class SBERTScoredDocTuples(ScoredDocTuples):
[docs]
11 def __init__(self, docpairs_dlc, name):
12 super().__init__(docpairs_dlc)
13 self.name = name
14
15 def docpairs_iter(self):
16 with self._docpairs_dlc.stream() as f:
17 f = codecs.getreader("utf8")(f)
18 for line in f:
19 data = json.loads(line)
20 qid = data["qid"]
21 pids = []
22 scores = []
23 for doc in data["pos"]:
24 pids.append(str(doc["pid"]))
25 scores.append(float(doc["ce-score"]))
26 for doc in data["neg"][self.name]:
27 pids.append(str(doc["pid"]))
28 scores.append(float(doc["ce-score"]))
29 yield ScoredDocTuple(str(qid), tuple(pids), tuple(scores), len(pids))
30
31
[docs]
32def register_sbert_docpairs():
33 dlc_contents = {
34 "url": "https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz",
35 "expected_md5": "ecf8cafb10197fd7adf4f68aabd15d84",
36 "cache_path": "msmarco-passage/train/sbert-docpairs.jsonl",
37 "extractors": [GzipExtract],
38 }
39 for name in [
40 "bm25",
41 "msmarco-distilbert-base-tas-b",
42 "msmarco-distilbert-base-v3",
43 "msmarco-MiniLM-L-6-v3",
44 ]:
45 register_new_dataset(
46 f"msmarco-passage/train/sbert-{name}-docpairs",
47 docs="msmarco-passage",
48 queries="msmarco-passage/train",
49 qrels="msmarco-passage/train",
50 docpairs=dlc_contents,
51 DocpairsType=partial(SBERTScoredDocTuples, name=name),
52 )