Source code for lightning_ir.retrieve.plaid.plaid_indexer

  1import warnings
  2from array import array
  3from pathlib import Path
  4
  5import torch
  6
  7from ...bi_encoder import BiEncoderModule, BiEncoderOutput
  8from ...data import IndexBatch
  9from ...models import ColConfig
 10from ..base import IndexConfig, Indexer
 11from .residual_codec import ResidualCodec
 12
 13
[docs] 14class PlaidIndexer(Indexer): 15
[docs] 16 def __init__( 17 self, 18 index_dir: Path, 19 index_config: "PlaidIndexConfig", 20 module: BiEncoderModule, 21 verbose: bool = False, 22 ) -> None: 23 super().__init__(index_dir, index_config, module, verbose) 24 25 self.index_config: PlaidIndexConfig 26 27 self._train_embeddings: torch.Tensor | None = torch.full( 28 (self.index_config.num_train_embeddings, self.module.config.embedding_dim), 29 torch.nan, 30 dtype=torch.float32, 31 ) 32 self.residual_codec: ResidualCodec | None = None 33 self.codes = array("l") 34 self.residuals = array("B")
35 36 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 37 doc_embeddings = output.doc_embeddings 38 if doc_embeddings is None: 39 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 40 41 if doc_embeddings.scoring_mask is None: 42 doc_lengths = torch.ones( 43 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 44 ) 45 embeddings = doc_embeddings.embeddings[:, 0] 46 else: 47 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 48 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 49 doc_ids = index_batch.doc_ids 50 embeddings = self.process_embeddings(embeddings) 51 52 if embeddings.shape[0]: 53 if self.residual_codec is None: 54 raise ValueError("Residual codec not trained") 55 codes, residuals = self.residual_codec.compress(embeddings) 56 self.codes.extend(codes.numpy(force=True)) 57 self.residuals.extend(residuals.view(-1).numpy(force=True)) 58 59 self.num_embeddings += embeddings.shape[0] 60 self.num_docs += len(doc_ids) 61 62 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 63 self.doc_ids.extend(doc_ids) 64 65 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 66 embeddings = self._grab_train_embeddings(embeddings) 67 self._train() 68 return embeddings 69 70 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 71 if self._train_embeddings is not None: 72 # save training embeddings until num_train_embeddings is reached 73 # if num_train_embeddings overflows, save the remaining embeddings 74 start = self.num_embeddings 75 end = min(self.index_config.num_train_embeddings, start + embeddings.shape[0]) 76 length = end - start 77 self._train_embeddings[start:end] = embeddings[:length] 78 self.num_embeddings += length 79 embeddings = embeddings[length:] 80 return embeddings 81 82 def _train(self, force: bool = False) -> None: 83 if self._train_embeddings is None: 84 return 85 if not force and self.num_embeddings < self.index_config.num_train_embeddings: 86 return 87 88 if torch.isnan(self._train_embeddings).any(): 89 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.") 90 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 91 92 self.residual_codec = ResidualCodec.train(self.index_config, self._train_embeddings, self.verbose) 93 codes, residuals = self.residual_codec.compress(self._train_embeddings) 94 self.codes.extend(codes.numpy(force=True)) 95 self.residuals.extend(residuals.view(-1).numpy(force=True)) 96 97 self._train_embeddings = None 98 99 def save(self) -> None: 100 if self.residual_codec is None: 101 self._train(force=True) 102 if self.residual_codec is None: 103 raise ValueError("No residual codec to save") 104 super().save() 105 106 codes = torch.frombuffer(self.codes, dtype=torch.long) 107 residuals = torch.frombuffer(self.residuals, dtype=torch.uint8) 108 torch.save(codes, self.index_dir / "codes.pt") 109 torch.save(residuals, self.index_dir / "residuals.pt") 110 self.residual_codec.save(self.index_dir)
111 112
[docs] 113class PlaidIndexConfig(IndexConfig): 114 indexer_class = PlaidIndexer 115 SUPPORTED_MODELS = {ColConfig.model_type} 116
[docs] 117 def __init__( 118 self, 119 num_centroids: int, 120 num_train_embeddings: int | None = None, 121 k_means_iters: int = 4, 122 n_bits: int = 2, 123 seed: int = 42, 124 ) -> None: 125 super().__init__() 126 max_points_per_centroid = 256 127 self.num_centroids = num_centroids 128 self.num_train_embeddings = num_train_embeddings or num_centroids * max_points_per_centroid 129 self.k_means_iters = k_means_iters 130 self.n_bits = n_bits 131 self.seed = seed