Source code for lightning_ir.retrieve.base.indexer

  1"""Base indexer class and configuration for retrieval tasks."""
  2
  3from __future__ import annotations
  4
  5import array
  6import json
  7from abc import ABC, abstractmethod
  8from pathlib import Path
  9from typing import TYPE_CHECKING, List, Set, Type
 10
 11import torch
 12
 13if TYPE_CHECKING:
 14    from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 15    from ...data import IndexBatch
 16
 17
[docs] 18class Indexer(ABC): 19 """Base class for indexers that create and manage indices for retrieval tasks.""" 20
[docs] 21 def __init__( 22 self, 23 index_dir: Path, 24 index_config: IndexConfig, 25 module: BiEncoderModule, 26 verbose: bool = False, 27 ) -> None: 28 """Initialize the Indexer. 29 30 Args: 31 index_dir (Path): Directory where the index will be stored. 32 index_config (IndexConfig): Configuration for the index. 33 module (BiEncoderModule): The bi-encoder module used for encoding documents. 34 verbose (bool): Whether to print verbose output. Defaults to False. 35 """ 36 self.index_dir = index_dir 37 self.index_config = index_config 38 self.module = module 39 self.doc_ids: List[str] = [] 40 self.doc_lengths = array.array("I") 41 self.num_embeddings = 0 42 self.num_docs = 0 43 self.verbose = verbose
44
[docs] 45 @abstractmethod 46 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 47 """Add a batch of documents to the index. 48 49 Args: 50 index_batch (IndexBatch): The batch of documents to add. 51 output (BiEncoderOutput): The output from the bi-encoder module containing document embeddings. 52 """ 53 ...
54
[docs] 55 def save(self) -> None: 56 """Save the index configuration and document IDs to the index directory.""" 57 self.index_config.save(self.index_dir) 58 (self.index_dir / "doc_ids.txt").write_text("\n".join(self.doc_ids)) 59 doc_lengths = torch.frombuffer(self.doc_lengths, dtype=torch.int32) 60 torch.save(doc_lengths, self.index_dir / "doc_lengths.pt")
61 62
[docs] 63class IndexConfig: 64 """Configuration class for indexers that defines the index type and other parameters.""" 65 66 indexer_class: Type[Indexer] 67 SUPPORTED_MODELS: Set[str] 68
[docs] 69 @classmethod 70 def from_pretrained(cls, index_dir: Path | str) -> "IndexConfig": 71 """Load the index configuration from a directory. 72 73 Args: 74 index_dir (Path | str): Path to the directory containing the index configuration. 75 Returns: 76 IndexConfig: An instance of the index configuration class. 77 Raises: 78 ValueError: If the index type in the configuration does not match the expected class name. 79 """ 80 index_dir = Path(index_dir) 81 with open(index_dir / "config.json", "r") as f: 82 data = json.load(f) 83 if data["index_type"] != cls.__name__: 84 raise ValueError(f"Expected index_type {cls.__name__}, got {data['index_type']}") 85 data.pop("index_type", None) 86 data.pop("index_dir", None) 87 return cls(**data)
88
[docs] 89 def save(self, index_dir: Path) -> None: 90 """Save the index configuration to a directory. 91 92 Args: 93 index_dir (Path): The directory to save the index configuration. 94 """ 95 index_dir.mkdir(parents=True, exist_ok=True) 96 with open(index_dir / "config.json", "w") as f: 97 data = self.__dict__.copy() 98 data["index_dir"] = str(index_dir) 99 data["index_type"] = self.__class__.__name__ 100 json.dump(data, f)
101
[docs] 102 def to_dict(self) -> dict: 103 """Convert the index configuration to a dictionary. 104 105 Returns: 106 dict: A dictionary representation of the index configuration. 107 """ 108 return self.__dict__.copy()