Source code for lightning_ir.retrieve.plaid.plaid_indexer
1import warnings
2from array import array
3from pathlib import Path
4
5import torch
6
7from ...bi_encoder import BiEncoderModule, BiEncoderOutput
8from ...data import IndexBatch
9from ...models import ColConfig
10from ..base import IndexConfig, Indexer
11from .residual_codec import ResidualCodec
12
13
[docs]
14class PlaidIndexer(Indexer):
15
[docs]
16 def __init__(
17 self,
18 index_dir: Path,
19 index_config: "PlaidIndexConfig",
20 module: BiEncoderModule,
21 verbose: bool = False,
22 ) -> None:
23 super().__init__(index_dir, index_config, module, verbose)
24
25 self.index_config: PlaidIndexConfig
26
27 self._train_embeddings: torch.Tensor | None = torch.full(
28 (self.index_config.num_train_embeddings, self.module.config.embedding_dim),
29 torch.nan,
30 dtype=torch.float32,
31 )
32 self.residual_codec: ResidualCodec | None = None
33 self.codes = array("l")
34 self.residuals = array("B")
35
36 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
37 doc_embeddings = output.doc_embeddings
38 if doc_embeddings is None:
39 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
40
41 if doc_embeddings.scoring_mask is None:
42 doc_lengths = torch.ones(
43 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
44 )
45 embeddings = doc_embeddings.embeddings[:, 0]
46 else:
47 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
48 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
49 doc_ids = index_batch.doc_ids
50 embeddings = self.process_embeddings(embeddings)
51
52 if embeddings.shape[0]:
53 if self.residual_codec is None:
54 raise ValueError("Residual codec not trained")
55 codes, residuals = self.residual_codec.compress(embeddings)
56 self.codes.extend(codes.numpy(force=True))
57 self.residuals.extend(residuals.view(-1).numpy(force=True))
58
59 self.num_embeddings += embeddings.shape[0]
60 self.num_docs += len(doc_ids)
61
62 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
63 self.doc_ids.extend(doc_ids)
64
65 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
66 embeddings = self._grab_train_embeddings(embeddings)
67 self._train()
68 return embeddings
69
70 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
71 if self._train_embeddings is not None:
72 # save training embeddings until num_train_embeddings is reached
73 # if num_train_embeddings overflows, save the remaining embeddings
74 start = self.num_embeddings
75 end = min(self.index_config.num_train_embeddings, start + embeddings.shape[0])
76 length = end - start
77 self._train_embeddings[start:end] = embeddings[:length]
78 self.num_embeddings += length
79 embeddings = embeddings[length:]
80 return embeddings
81
82 def _train(self, force: bool = False) -> None:
83 if self._train_embeddings is None:
84 return
85 if not force and self.num_embeddings < self.index_config.num_train_embeddings:
86 return
87
88 if torch.isnan(self._train_embeddings).any():
89 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.")
90 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
91
92 self.residual_codec = ResidualCodec.train(self.index_config, self._train_embeddings, self.verbose)
93 codes, residuals = self.residual_codec.compress(self._train_embeddings)
94 self.codes.extend(codes.numpy(force=True))
95 self.residuals.extend(residuals.view(-1).numpy(force=True))
96
97 self._train_embeddings = None
98
99 def save(self) -> None:
100 if self.residual_codec is None:
101 self._train(force=True)
102 if self.residual_codec is None:
103 raise ValueError("No residual codec to save")
104 super().save()
105
106 codes = torch.frombuffer(self.codes, dtype=torch.long)
107 residuals = torch.frombuffer(self.residuals, dtype=torch.uint8)
108 torch.save(codes, self.index_dir / "codes.pt")
109 torch.save(residuals, self.index_dir / "residuals.pt")
110 self.residual_codec.save(self.index_dir)
111
112
[docs]
113class PlaidIndexConfig(IndexConfig):
114 indexer_class = PlaidIndexer
115 SUPPORTED_MODELS = {ColConfig.model_type}
116
[docs]
117 def __init__(
118 self,
119 num_centroids: int,
120 num_train_embeddings: int | None = None,
121 k_means_iters: int = 4,
122 n_bits: int = 2,
123 seed: int = 42,
124 ) -> None:
125 super().__init__()
126 max_points_per_centroid = 256
127 self.num_centroids = num_centroids
128 self.num_train_embeddings = num_train_embeddings or num_centroids * max_points_per_centroid
129 self.k_means_iters = k_means_iters
130 self.n_bits = n_bits
131 self.seed = seed