Source code for lightning_ir.retrieve.plaid.plaid_indexer

  1"""Plaid Indexer for Lightning IR Framework"""
  2
  3import warnings
  4from array import array
  5from pathlib import Path
  6
  7import torch
  8
  9from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 10from ...data import IndexBatch
 11from ...models import ColConfig
 12from ..base import IndexConfig, Indexer
 13from .residual_codec import ResidualCodec
 14
 15
[docs] 16class PlaidIndexer(Indexer): 17 """Indexer for Plaid, a residual-based indexing method for efficient retrieval.""" 18
[docs] 19 def __init__( 20 self, 21 index_dir: Path, 22 index_config: "PlaidIndexConfig", 23 module: BiEncoderModule, 24 verbose: bool = False, 25 ) -> None: 26 """Initialize the PlaidIndexer. 27 28 Args: 29 index_dir (Path): Directory where the index will be stored. 30 index_config (PlaidIndexConfig): Configuration for the Plaid indexer. 31 module (BiEncoderModule): The BiEncoder module used for indexing. 32 verbose (bool): Whether to print verbose output during indexing. Defaults to False. 33 """ 34 super().__init__(index_dir, index_config, module, verbose) 35 36 self.index_config: PlaidIndexConfig 37 38 self._train_embeddings: torch.Tensor | None = torch.full( 39 (self.index_config.num_train_embeddings, self.module.config.embedding_dim), 40 torch.nan, 41 dtype=torch.float32, 42 ) 43 self.residual_codec: ResidualCodec | None = None 44 self.codes = array("l") 45 self.residuals = array("B")
46
[docs] 47 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 48 """Add embeddings from the index batch to the Plaid index. 49 50 Args: 51 index_batch (IndexBatch): Batch of data containing embeddings to be indexed. 52 output (BiEncoderOutput): Output from the BiEncoder module containing embeddings. 53 Raises: 54 ValueError: If the output does not contain document embeddings. 55 ValueError: If the residual codec is not trained. 56 """ 57 doc_embeddings = output.doc_embeddings 58 if doc_embeddings is None: 59 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 60 61 if doc_embeddings.scoring_mask is None: 62 doc_lengths = torch.ones( 63 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 64 ) 65 embeddings = doc_embeddings.embeddings[:, 0] 66 else: 67 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 68 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 69 doc_ids = index_batch.doc_ids 70 embeddings = self.process_embeddings(embeddings) 71 72 if embeddings.shape[0]: 73 if self.residual_codec is None: 74 raise ValueError("Residual codec not trained") 75 codes, residuals = self.residual_codec.compress(embeddings) 76 self.codes.extend(codes.numpy(force=True)) 77 self.residuals.extend(residuals.view(-1).numpy(force=True)) 78 79 self.num_embeddings += embeddings.shape[0] 80 self.num_docs += len(doc_ids) 81 82 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 83 self.doc_ids.extend(doc_ids)
84
[docs] 85 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 86 """Process embeddings before indexing. 87 88 Args: 89 embeddings (torch.Tensor): The embeddings to be processed. 90 Returns: 91 torch.Tensor: The processed embeddings. 92 """ 93 embeddings = self._grab_train_embeddings(embeddings) 94 self._train() 95 return embeddings
96 97 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 98 if self._train_embeddings is not None: 99 # save training embeddings until num_train_embeddings is reached 100 # if num_train_embeddings overflows, save the remaining embeddings 101 start = self.num_embeddings 102 end = min(self.index_config.num_train_embeddings, start + embeddings.shape[0]) 103 length = end - start 104 self._train_embeddings[start:end] = embeddings[:length] 105 self.num_embeddings += length 106 embeddings = embeddings[length:] 107 return embeddings 108 109 def _train(self, force: bool = False) -> None: 110 if self._train_embeddings is None: 111 return 112 if not force and self.num_embeddings < self.index_config.num_train_embeddings: 113 return 114 115 if torch.isnan(self._train_embeddings).any(): 116 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.") 117 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 118 119 self.residual_codec = ResidualCodec.train(self.index_config, self._train_embeddings, self.verbose) 120 codes, residuals = self.residual_codec.compress(self._train_embeddings) 121 self.codes.extend(codes.numpy(force=True)) 122 self.residuals.extend(residuals.view(-1).numpy(force=True)) 123 124 self._train_embeddings = None 125
[docs] 126 def save(self) -> None: 127 """Save the Plaid index to the specified directory. 128 129 Raises: 130 ValueError: If residual_codec is None. 131 """ 132 if self.residual_codec is None: 133 self._train(force=True) 134 if self.residual_codec is None: 135 raise ValueError("No residual codec to save") 136 super().save() 137 138 codes = torch.frombuffer(self.codes, dtype=torch.long) 139 residuals = torch.frombuffer(self.residuals, dtype=torch.uint8) 140 torch.save(codes, self.index_dir / "codes.pt") 141 torch.save(residuals, self.index_dir / "residuals.pt") 142 self.residual_codec.save(self.index_dir)
143 144
[docs] 145class PlaidIndexConfig(IndexConfig): 146 """Configuration class for Plaid indexers in the Lightning IR framework.""" 147 148 indexer_class = PlaidIndexer 149 SUPPORTED_MODELS = {ColConfig.model_type} 150
[docs] 151 def __init__( 152 self, 153 num_centroids: int, 154 num_train_embeddings: int | None = None, 155 k_means_iters: int = 4, 156 n_bits: int = 2, 157 seed: int = 42, 158 ) -> None: 159 """Initialize the PlaidIndexConfig. 160 161 Args: 162 num_centroids (int): Number of centroids for the Plaid index. 163 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 164 be set later. Defaults to None. 165 k_means_iters (int): Number of iterations for k-means clustering. Defaults to 4. 166 n_bits (int): Number of bits for the residual codec. Defaults to 2. 167 seed (int): Random seed for reproducibility. Defaults to 42. 168 """ 169 super().__init__() 170 max_points_per_centroid = 256 171 self.num_centroids = num_centroids 172 self.num_train_embeddings = num_train_embeddings or num_centroids * max_points_per_centroid 173 self.k_means_iters = k_means_iters 174 self.n_bits = n_bits 175 self.seed = seed