Source code for lightning_ir.retrieve.plaid.plaid_indexer

  1"""Plaid Indexer using fast-plaid library for Lightning IR Framework"""
  2
  3from pathlib import Path
  4
  5import torch
  6
  7from ...bi_encoder import BiEncoderModule, BiEncoderOutput
  8from ...data import IndexBatch
  9from ...models import ColConfig
 10from ..base import IndexConfig, Indexer
 11
 12
[docs] 13class PlaidIndexer(Indexer): 14 """Indexer for Plaid using fast-plaid library.""" 15
[docs] 16 def __init__( 17 self, 18 index_dir: Path, 19 index_config: "PlaidIndexConfig", 20 module: BiEncoderModule, 21 verbose: bool = False, 22 ) -> None: 23 """Initialize the PlaidIndexer. 24 25 Args: 26 index_dir (Path): Directory where the index will be stored. 27 index_config (PlaidIndexConfig): Configuration for the Plaid indexer. 28 module (BiEncoderModule): The BiEncoder module used for indexing. 29 verbose (bool): Whether to print verbose output during indexing. Defaults to False. 30 """ 31 super().__init__(index_dir, index_config, module, verbose) 32 self.index_config: PlaidIndexConfig 33 self.index = None 34 self.embeddings = []
35
[docs] 36 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 37 """Add embeddings from the index batch to the Plaid index. 38 39 Args: 40 index_batch (IndexBatch): Batch of data containing embeddings to be indexed. 41 output (BiEncoderOutput): Output from the BiEncoder module containing embeddings. 42 Raises: 43 ValueError: If the output does not contain document embeddings. 44 """ 45 doc_embeddings = output.doc_embeddings 46 if doc_embeddings is None: 47 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 48 49 if doc_embeddings.scoring_mask is None: 50 doc_lengths = torch.ones( 51 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 52 ) 53 embeddings = doc_embeddings.embeddings[:, 0] 54 else: 55 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 56 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 57 doc_ids = index_batch.doc_ids 58 59 self.num_embeddings += embeddings.shape[0] 60 self.num_docs += len(doc_ids) 61 62 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 63 self.doc_ids.extend(doc_ids) 64 65 self.embeddings.extend(embeddings.cpu().split(doc_lengths.int().cpu().tolist()))
66
[docs] 67 def save(self) -> None: 68 """Save the index configuration and document IDs to the index directory.""" 69 from fast_plaid import search 70 71 index = search.FastPlaid(index=str(self.index_dir)) 72 73 index.create( 74 documents_embeddings=self.embeddings, 75 kmeans_niters=self.index_config.k_means_iters, 76 nbits=self.index_config.n_bits, 77 seed=self.index_config.seed, 78 ) 79 80 super().save()
81 82
[docs] 83class PlaidIndexConfig(IndexConfig): 84 """Configuration class for Plaid indexers in the Lightning IR framework.""" 85 86 indexer_class = PlaidIndexer 87 SUPPORTED_MODELS = {ColConfig.model_type} 88
[docs] 89 def __init__( 90 self, 91 num_centroids: int, 92 k_means_iters: int = 4, 93 n_bits: int = 2, 94 seed: int = 42, 95 ) -> None: 96 """Initialize the PlaidIndexConfig. 97 98 Args: 99 num_centroids (int): Number of centroids for the Plaid index. 100 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 101 be set later. Defaults to None. 102 k_means_iters (int): Number of iterations for k-means clustering. Defaults to 4. 103 n_bits (int): Number of bits for the residual codec. Defaults to 2. 104 seed (int): Random seed for reproducibility. Defaults to 42. 105 """ 106 super().__init__() 107 self.num_centroids = num_centroids 108 self.k_means_iters = k_means_iters 109 self.n_bits = n_bits 110 self.seed = seed