Source code for lightning_ir.retrieve.plaid.plaid_indexer
1"""Plaid Indexer for Lightning IR Framework"""
2
3import warnings
4from array import array
5from pathlib import Path
6
7import torch
8
9from ...bi_encoder import BiEncoderModule, BiEncoderOutput
10from ...data import IndexBatch
11from ...models import ColConfig
12from ..base import IndexConfig, Indexer
13from .residual_codec import ResidualCodec
14
15
[docs]
16class PlaidIndexer(Indexer):
17 """Indexer for Plaid, a residual-based indexing method for efficient retrieval."""
18
[docs]
19 def __init__(
20 self,
21 index_dir: Path,
22 index_config: "PlaidIndexConfig",
23 module: BiEncoderModule,
24 verbose: bool = False,
25 ) -> None:
26 """Initialize the PlaidIndexer.
27
28 Args:
29 index_dir (Path): Directory where the index will be stored.
30 index_config (PlaidIndexConfig): Configuration for the Plaid indexer.
31 module (BiEncoderModule): The BiEncoder module used for indexing.
32 verbose (bool): Whether to print verbose output during indexing. Defaults to False.
33 """
34 super().__init__(index_dir, index_config, module, verbose)
35
36 self.index_config: PlaidIndexConfig
37
38 self._train_embeddings: torch.Tensor | None = torch.full(
39 (self.index_config.num_train_embeddings, self.module.config.embedding_dim),
40 torch.nan,
41 dtype=torch.float32,
42 )
43 self.residual_codec: ResidualCodec | None = None
44 self.codes = array("l")
45 self.residuals = array("B")
46
[docs]
47 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
48 """Add embeddings from the index batch to the Plaid index.
49
50 Args:
51 index_batch (IndexBatch): Batch of data containing embeddings to be indexed.
52 output (BiEncoderOutput): Output from the BiEncoder module containing embeddings.
53 Raises:
54 ValueError: If the output does not contain document embeddings.
55 ValueError: If the residual codec is not trained.
56 """
57 doc_embeddings = output.doc_embeddings
58 if doc_embeddings is None:
59 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
60
61 if doc_embeddings.scoring_mask is None:
62 doc_lengths = torch.ones(
63 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
64 )
65 embeddings = doc_embeddings.embeddings[:, 0]
66 else:
67 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
68 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
69 doc_ids = index_batch.doc_ids
70 embeddings = self.process_embeddings(embeddings)
71
72 if embeddings.shape[0]:
73 if self.residual_codec is None:
74 raise ValueError("Residual codec not trained")
75 codes, residuals = self.residual_codec.compress(embeddings)
76 self.codes.extend(codes.numpy(force=True))
77 self.residuals.extend(residuals.view(-1).numpy(force=True))
78
79 self.num_embeddings += embeddings.shape[0]
80 self.num_docs += len(doc_ids)
81
82 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
83 self.doc_ids.extend(doc_ids)
84
[docs]
85 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
86 """Process embeddings before indexing.
87
88 Args:
89 embeddings (torch.Tensor): The embeddings to be processed.
90 Returns:
91 torch.Tensor: The processed embeddings.
92 """
93 embeddings = self._grab_train_embeddings(embeddings)
94 self._train()
95 return embeddings
96
97 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
98 if self._train_embeddings is not None:
99 # save training embeddings until num_train_embeddings is reached
100 # if num_train_embeddings overflows, save the remaining embeddings
101 start = self.num_embeddings
102 end = min(self.index_config.num_train_embeddings, start + embeddings.shape[0])
103 length = end - start
104 self._train_embeddings[start:end] = embeddings[:length]
105 self.num_embeddings += length
106 embeddings = embeddings[length:]
107 return embeddings
108
109 def _train(self, force: bool = False) -> None:
110 if self._train_embeddings is None:
111 return
112 if not force and self.num_embeddings < self.index_config.num_train_embeddings:
113 return
114
115 if torch.isnan(self._train_embeddings).any():
116 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.")
117 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
118
119 self.residual_codec = ResidualCodec.train(self.index_config, self._train_embeddings, self.verbose)
120 codes, residuals = self.residual_codec.compress(self._train_embeddings)
121 self.codes.extend(codes.numpy(force=True))
122 self.residuals.extend(residuals.view(-1).numpy(force=True))
123
124 self._train_embeddings = None
125
[docs]
126 def save(self) -> None:
127 """Save the Plaid index to the specified directory.
128
129 Raises:
130 ValueError: If residual_codec is None.
131 """
132 if self.residual_codec is None:
133 self._train(force=True)
134 if self.residual_codec is None:
135 raise ValueError("No residual codec to save")
136 super().save()
137
138 codes = torch.frombuffer(self.codes, dtype=torch.long)
139 residuals = torch.frombuffer(self.residuals, dtype=torch.uint8)
140 torch.save(codes, self.index_dir / "codes.pt")
141 torch.save(residuals, self.index_dir / "residuals.pt")
142 self.residual_codec.save(self.index_dir)
143
144
[docs]
145class PlaidIndexConfig(IndexConfig):
146 """Configuration class for Plaid indexers in the Lightning IR framework."""
147
148 indexer_class = PlaidIndexer
149 SUPPORTED_MODELS = {ColConfig.model_type}
150
[docs]
151 def __init__(
152 self,
153 num_centroids: int,
154 num_train_embeddings: int | None = None,
155 k_means_iters: int = 4,
156 n_bits: int = 2,
157 seed: int = 42,
158 ) -> None:
159 """Initialize the PlaidIndexConfig.
160
161 Args:
162 num_centroids (int): Number of centroids for the Plaid index.
163 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
164 be set later. Defaults to None.
165 k_means_iters (int): Number of iterations for k-means clustering. Defaults to 4.
166 n_bits (int): Number of bits for the residual codec. Defaults to 2.
167 seed (int): Random seed for reproducibility. Defaults to 42.
168 """
169 super().__init__()
170 max_points_per_centroid = 256
171 self.num_centroids = num_centroids
172 self.num_train_embeddings = num_train_embeddings or num_centroids * max_points_per_centroid
173 self.k_means_iters = k_means_iters
174 self.n_bits = n_bits
175 self.seed = seed