Source code for lightning_ir.retrieve.base.indexer

 1from __future__ import annotations
 2
 3import array
 4import json
 5from abc import ABC, abstractmethod
 6from pathlib import Path
 7from typing import TYPE_CHECKING, List, Set, Type
 8
 9import torch
10
11if TYPE_CHECKING:
12    from ...bi_encoder import BiEncoderModule, BiEncoderOutput
13    from ...data import IndexBatch
14
15
[docs] 16class Indexer(ABC):
[docs] 17 def __init__( 18 self, 19 index_dir: Path, 20 index_config: IndexConfig, 21 module: BiEncoderModule, 22 verbose: bool = False, 23 ) -> None: 24 self.index_dir = index_dir 25 self.index_config = index_config 26 self.module = module 27 self.doc_ids: List[str] = [] 28 self.doc_lengths = array.array("I") 29 self.num_embeddings = 0 30 self.num_docs = 0 31 self.verbose = verbose
32 33 @abstractmethod 34 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: ... 35 36 def save(self) -> None: 37 self.index_config.save(self.index_dir) 38 (self.index_dir / "doc_ids.txt").write_text("\n".join(self.doc_ids)) 39 doc_lengths = torch.frombuffer(self.doc_lengths, dtype=torch.int32) 40 torch.save(doc_lengths, self.index_dir / "doc_lengths.pt")
41 42
[docs] 43class IndexConfig: 44 indexer_class: Type[Indexer] 45 SUPPORTED_MODELS: Set[str] 46 47 @classmethod 48 def from_pretrained(cls, index_dir: Path | str) -> "IndexConfig": 49 index_dir = Path(index_dir) 50 with open(index_dir / "config.json", "r") as f: 51 data = json.load(f) 52 if data["index_type"] != cls.__name__: 53 raise ValueError(f"Expected index_type {cls.__name__}, got {data['index_type']}") 54 data.pop("index_type", None) 55 data.pop("index_dir", None) 56 return cls(**data) 57 58 def save(self, index_dir: Path) -> None: 59 index_dir.mkdir(parents=True, exist_ok=True) 60 with open(index_dir / "config.json", "w") as f: 61 data = self.__dict__.copy() 62 data["index_dir"] = str(index_dir) 63 data["index_type"] = self.__class__.__name__ 64 json.dump(data, f) 65 66 def to_dict(self) -> dict: 67 return self.__dict__.copy()