Source code for lightning_ir.retrieve.base.indexer
1"""Base indexer class and configuration for retrieval tasks."""
2
3from __future__ import annotations
4
5import array
6import json
7from abc import ABC, abstractmethod
8from pathlib import Path
9from typing import TYPE_CHECKING, List, Set, Type
10
11import torch
12
13if TYPE_CHECKING:
14 from ...bi_encoder import BiEncoderModule, BiEncoderOutput
15 from ...data import IndexBatch
16
17
[docs]
18class Indexer(ABC):
19 """Base class for indexers that create and manage indices for retrieval tasks."""
20
[docs]
21 def __init__(
22 self,
23 index_dir: Path,
24 index_config: IndexConfig,
25 module: BiEncoderModule,
26 verbose: bool = False,
27 ) -> None:
28 """Initialize the Indexer.
29
30 Args:
31 index_dir (Path): Directory where the index will be stored.
32 index_config (IndexConfig): Configuration for the index.
33 module (BiEncoderModule): The bi-encoder module used for encoding documents.
34 verbose (bool): Whether to print verbose output. Defaults to False.
35 """
36 self.index_dir = index_dir
37 self.index_config = index_config
38 self.module = module
39 self.doc_ids: List[str] = []
40 self.doc_lengths = array.array("I")
41 self.num_embeddings = 0
42 self.num_docs = 0
43 self.verbose = verbose
44
[docs]
45 @abstractmethod
46 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
47 """Add a batch of documents to the index.
48
49 Args:
50 index_batch (IndexBatch): The batch of documents to add.
51 output (BiEncoderOutput): The output from the bi-encoder module containing document embeddings.
52 """
53 ...
54
[docs]
55 def save(self) -> None:
56 """Save the index configuration and document IDs to the index directory."""
57 self.index_config.save(self.index_dir)
58 (self.index_dir / "doc_ids.txt").write_text("\n".join(self.doc_ids))
59 doc_lengths = torch.frombuffer(self.doc_lengths, dtype=torch.int32)
60 torch.save(doc_lengths, self.index_dir / "doc_lengths.pt")
61
62
[docs]
63class IndexConfig:
64 """Configuration class for indexers that defines the index type and other parameters."""
65
66 indexer_class: Type[Indexer]
67 SUPPORTED_MODELS: Set[str]
68
[docs]
69 @classmethod
70 def from_pretrained(cls, index_dir: Path | str) -> "IndexConfig":
71 """Load the index configuration from a directory.
72
73 Args:
74 index_dir (Path | str): Path to the directory containing the index configuration.
75 Returns:
76 IndexConfig: An instance of the index configuration class.
77 Raises:
78 ValueError: If the index type in the configuration does not match the expected class name.
79 """
80 index_dir = Path(index_dir)
81 with open(index_dir / "config.json", "r") as f:
82 data = json.load(f)
83 if data["index_type"] != cls.__name__:
84 raise ValueError(f"Expected index_type {cls.__name__}, got {data['index_type']}")
85 data.pop("index_type", None)
86 data.pop("index_dir", None)
87 return cls(**data)
88
[docs]
89 def save(self, index_dir: Path) -> None:
90 """Save the index configuration to a directory.
91
92 Args:
93 index_dir (Path): The directory to save the index configuration.
94 """
95 index_dir.mkdir(parents=True, exist_ok=True)
96 with open(index_dir / "config.json", "w") as f:
97 data = self.__dict__.copy()
98 data["index_dir"] = str(index_dir)
99 data["index_type"] = self.__class__.__name__
100 json.dump(data, f)
101
[docs]
102 def to_dict(self) -> dict:
103 """Convert the index configuration to a dictionary.
104
105 Returns:
106 dict: A dictionary representation of the index configuration.
107 """
108 return self.__dict__.copy()