Source code for lightning_ir.data.external_datasets.sbert

 1import codecs
 2import json
 3from functools import partial
 4
 5from ir_datasets.util import GzipExtract
 6
 7from lightning_ir.data.external_datasets.ir_datasets_utils import ScoredDocTuple, ScoredDocTuples, register_new_dataset
 8
 9
[docs] 10class SBERTScoredDocTuples(ScoredDocTuples):
[docs] 11 def __init__(self, docpairs_dlc, name): 12 super().__init__(docpairs_dlc) 13 self.name = name
14 15 def docpairs_iter(self): 16 with self._docpairs_dlc.stream() as f: 17 f = codecs.getreader("utf8")(f) 18 for line in f: 19 data = json.loads(line) 20 qid = data["qid"] 21 pids = [] 22 scores = [] 23 for doc in data["pos"]: 24 pids.append(str(doc["pid"])) 25 scores.append(float(doc["ce-score"])) 26 for doc in data["neg"][self.name]: 27 pids.append(str(doc["pid"])) 28 scores.append(float(doc["ce-score"])) 29 yield ScoredDocTuple(str(qid), tuple(pids), tuple(scores), len(pids))
30 31
[docs] 32def register_sbert_docpairs(): 33 dlc_contents = { 34 "url": "https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz", 35 "expected_md5": "ecf8cafb10197fd7adf4f68aabd15d84", 36 "cache_path": "msmarco-passage/train/sbert-docpairs.jsonl", 37 "extractors": [GzipExtract], 38 } 39 for name in [ 40 "bm25", 41 "msmarco-distilbert-base-tas-b", 42 "msmarco-distilbert-base-v3", 43 "msmarco-MiniLM-L-6-v3", 44 ]: 45 register_new_dataset( 46 f"msmarco-passage/train/sbert-{name}-docpairs", 47 docs="msmarco-passage", 48 queries="msmarco-passage/train", 49 qrels="msmarco-passage/train", 50 docpairs=dlc_contents, 51 DocpairsType=partial(SBERTScoredDocTuples, name=name), 52 )