Source code for lightning_ir.retrieve.plaid.plaid_indexer
1"""Plaid Indexer using fast-plaid library for Lightning IR Framework"""
2
3from pathlib import Path
4
5import torch
6
7from ...bi_encoder import BiEncoderModule, BiEncoderOutput
8from ...data import IndexBatch
9from ...models import ColConfig
10from ..base import IndexConfig, Indexer
11
12
[docs]
13class PlaidIndexer(Indexer):
14 """Indexer for Plaid using fast-plaid library."""
15
[docs]
16 def __init__(
17 self,
18 index_dir: Path,
19 index_config: "PlaidIndexConfig",
20 module: BiEncoderModule,
21 verbose: bool = False,
22 ) -> None:
23 """Initialize the PlaidIndexer.
24
25 Args:
26 index_dir (Path): Directory where the index will be stored.
27 index_config (PlaidIndexConfig): Configuration for the Plaid indexer.
28 module (BiEncoderModule): The BiEncoder module used for indexing.
29 verbose (bool): Whether to print verbose output during indexing. Defaults to False.
30 """
31 super().__init__(index_dir, index_config, module, verbose)
32 self.index_config: PlaidIndexConfig
33 self.index = None
34 self.embeddings = []
35
[docs]
36 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
37 """Add embeddings from the index batch to the Plaid index.
38
39 Args:
40 index_batch (IndexBatch): Batch of data containing embeddings to be indexed.
41 output (BiEncoderOutput): Output from the BiEncoder module containing embeddings.
42 Raises:
43 ValueError: If the output does not contain document embeddings.
44 """
45 doc_embeddings = output.doc_embeddings
46 if doc_embeddings is None:
47 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
48
49 if doc_embeddings.scoring_mask is None:
50 doc_lengths = torch.ones(
51 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
52 )
53 embeddings = doc_embeddings.embeddings[:, 0]
54 else:
55 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
56 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
57 doc_ids = index_batch.doc_ids
58
59 self.num_embeddings += embeddings.shape[0]
60 self.num_docs += len(doc_ids)
61
62 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
63 self.doc_ids.extend(doc_ids)
64
65 self.embeddings.extend(embeddings.cpu().split(doc_lengths.int().cpu().tolist()))
66
[docs]
67 def save(self) -> None:
68 """Save the index configuration and document IDs to the index directory."""
69 from fast_plaid import search
70
71 index = search.FastPlaid(index=str(self.index_dir))
72
73 index.create(
74 documents_embeddings=self.embeddings,
75 kmeans_niters=self.index_config.k_means_iters,
76 nbits=self.index_config.n_bits,
77 seed=self.index_config.seed,
78 )
79
80 super().save()
81
82
[docs]
83class PlaidIndexConfig(IndexConfig):
84 """Configuration class for Plaid indexers in the Lightning IR framework."""
85
86 indexer_class = PlaidIndexer
87 SUPPORTED_MODELS = {ColConfig.model_type}
88
[docs]
89 def __init__(
90 self,
91 num_centroids: int,
92 k_means_iters: int = 4,
93 n_bits: int = 2,
94 seed: int = 42,
95 ) -> None:
96 """Initialize the PlaidIndexConfig.
97
98 Args:
99 num_centroids (int): Number of centroids for the Plaid index.
100 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
101 be set later. Defaults to None.
102 k_means_iters (int): Number of iterations for k-means clustering. Defaults to 4.
103 n_bits (int): Number of bits for the residual codec. Defaults to 2.
104 seed (int): Random seed for reproducibility. Defaults to 42.
105 """
106 super().__init__()
107 self.num_centroids = num_centroids
108 self.k_means_iters = k_means_iters
109 self.n_bits = n_bits
110 self.seed = seed