Source code for lightning_ir.retrieve.seismic.seismic_indexer

  1"""SeismicIndexer class for indexing documents using the Seismic library."""
  2
  3import os
  4from pathlib import Path
  5
  6import numpy as np
  7import torch
  8
  9try:
 10    _seismic_available = True
 11    import seismic
 12    from seismic import SeismicDataset, SeismicIndex
 13
 14    STRING_TYPE = seismic.get_seismic_string()
 15except ImportError:
 16    STRING_TYPE = None
 17    _seismic_available = False
 18    SeismicIndex = SeismicDataset = None
 19
 20
 21from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 22from ...data import IndexBatch
 23from ...models import SpladeConfig
 24from ..base import IndexConfig, Indexer
 25
 26
[docs] 27class SeismicIndexer(Indexer): 28 """Indexer for Seismic, a residual-based indexing method for efficient retrieval.""" 29
[docs] 30 def __init__( 31 self, 32 index_dir: Path, 33 index_config: "SeismicIndexConfig", 34 module: BiEncoderModule, 35 verbose: bool = False, 36 ) -> None: 37 """Initialize the SeismicIndexer. 38 39 Args: 40 index_dir (Path): Directory where the index will be stored. 41 index_config (SeismicIndexConfig): Configuration for the Seismic indexer. 42 module (BiEncoderModule): The BiEncoder module used for indexing. 43 verbose (bool): Whether to print verbose output during indexing. Defaults to False. 44 Raises: 45 ImportError: If the seismic package is not available. 46 """ 47 super().__init__(index_dir, index_config, module, verbose) 48 if _seismic_available is False: 49 raise ImportError( 50 "Please install the seismic package to use the SeismicIndexer. " 51 "Instructions can be found at " 52 "https://github.com/TusKANNy/seismic?tab=readme-ov-file#using-the-python-interface" 53 ) 54 self.index_config: SeismicIndexConfig 55 assert SeismicDataset is not None 56 self.seismic_dataset = SeismicDataset()
57
[docs] 58 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 59 """Add embeddings from the index batch to the Seismic index. 60 61 Args: 62 index_batch (IndexBatch): Batch of data containing embeddings to be indexed. 63 output (BiEncoderOutput): Output from the BiEncoder module containing embeddings. 64 Raises: 65 ValueError: If the output does not contain document embeddings. 66 """ 67 doc_embeddings = output.doc_embeddings 68 if doc_embeddings is None: 69 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 70 71 if doc_embeddings.scoring_mask is None: 72 doc_lengths = torch.ones( 73 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 74 ) 75 embeddings = doc_embeddings.embeddings[:, 0] 76 else: 77 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 78 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 79 80 num_docs = len(index_batch.doc_ids) 81 self.doc_ids.extend(index_batch.doc_ids) 82 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 83 self.num_embeddings += embeddings.shape[0] 84 self.num_docs += num_docs 85 86 for idx, doc_id in enumerate(index_batch.doc_ids): 87 non_zero = embeddings[idx].nonzero().view(-1) 88 values = embeddings[idx][non_zero].float().numpy(force=True) 89 tokens = np.array(self.module.tokenizer.convert_ids_to_tokens(non_zero), dtype="U30") 90 self.seismic_dataset.add_document(doc_id, tokens, values)
91
[docs] 92 def save(self) -> None: 93 """Save the Seismic index to disk.""" 94 super().save() 95 96 assert SeismicIndex is not None 97 index = SeismicIndex.build_from_dataset( 98 self.seismic_dataset, 99 n_postings=self.index_config.num_postings, 100 centroid_fraction=self.index_config.centroid_fraction, 101 min_cluster_size=self.index_config.min_cluster_size, 102 summary_energy=self.index_config.summary_energy, 103 nknn=self.index_config.num_k_nearest_neighbors, 104 batched_indexing=self.index_config.batch_size, 105 num_threads=self.index_config.num_threads, 106 ) 107 index.save(str(self.index_dir) + os.path.sep)
108 109
[docs] 110class SeismicIndexConfig(IndexConfig): 111 """Configuration for the Seismic indexer.""" 112 113 indexer_class = SeismicIndexer 114 SUPPORTED_MODELS = {SpladeConfig.model_type} 115
[docs] 116 def __init__( 117 self, 118 num_postings: int = 3_500, 119 centroid_fraction: float = 0.1, 120 min_cluster_size: int = 2, 121 summary_energy: float = 0.4, 122 num_k_nearest_neighbors: int = 0, 123 batch_size: int | None = None, 124 num_threads: int = 0, 125 ) -> None: 126 """Initialize the SeismicIndexConfig. 127 128 Args: 129 num_postings (int): Number of postings to keep in the index. Defaults to 3500. 130 centroid_fraction (float): Fraction of centroids to keep. Defaults to 0.1. 131 min_cluster_size (int): Minimum size of clusters. Defaults to 2. 132 summary_energy (float): Energy threshold for summaries. Defaults to 0.4. 133 num_k_nearest_neighbors (int): Number of nearest neighbors to consider. Defaults to 0. 134 batch_size (int | None): Batch size for indexing. Defaults to None. 135 num_threads (int): Number of threads to use for indexing. Defaults to 0. 136 """ 137 super().__init__() 138 self.num_postings = num_postings 139 self.centroid_fraction = centroid_fraction 140 self.summary_energy = summary_energy 141 self.min_cluster_size = min_cluster_size 142 self.num_k_nearest_neighbors = num_k_nearest_neighbors 143 self.batch_size = batch_size 144 self.num_threads = num_threads