Source code for lightning_ir.retrieve.seismic.seismic_indexer
1"""SeismicIndexer class for indexing documents using the Seismic library."""
2
3import os
4from pathlib import Path
5
6import numpy as np
7import torch
8
9try:
10 _seismic_available = True
11 import seismic
12 from seismic import SeismicDataset, SeismicIndex
13
14 STRING_TYPE = seismic.get_seismic_string()
15except ImportError:
16 STRING_TYPE = None
17 _seismic_available = False
18 SeismicIndex = SeismicDataset = None
19
20
21from ...bi_encoder import BiEncoderModule, BiEncoderOutput
22from ...data import IndexBatch
23from ...models import SpladeConfig
24from ..base import IndexConfig, Indexer
25
26
[docs]
27class SeismicIndexer(Indexer):
28 """Indexer for Seismic, a residual-based indexing method for efficient retrieval."""
29
[docs]
30 def __init__(
31 self,
32 index_dir: Path,
33 index_config: "SeismicIndexConfig",
34 module: BiEncoderModule,
35 verbose: bool = False,
36 ) -> None:
37 """Initialize the SeismicIndexer.
38
39 Args:
40 index_dir (Path): Directory where the index will be stored.
41 index_config (SeismicIndexConfig): Configuration for the Seismic indexer.
42 module (BiEncoderModule): The BiEncoder module used for indexing.
43 verbose (bool): Whether to print verbose output during indexing. Defaults to False.
44 Raises:
45 ImportError: If the seismic package is not available.
46 """
47 super().__init__(index_dir, index_config, module, verbose)
48 if _seismic_available is False:
49 raise ImportError(
50 "Please install the seismic package to use the SeismicIndexer. "
51 "Instructions can be found at "
52 "https://github.com/TusKANNy/seismic?tab=readme-ov-file#using-the-python-interface"
53 )
54 self.index_config: SeismicIndexConfig
55 assert SeismicDataset is not None
56 self.seismic_dataset = SeismicDataset()
57
[docs]
58 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
59 """Add embeddings from the index batch to the Seismic index.
60
61 Args:
62 index_batch (IndexBatch): Batch of data containing embeddings to be indexed.
63 output (BiEncoderOutput): Output from the BiEncoder module containing embeddings.
64 Raises:
65 ValueError: If the output does not contain document embeddings.
66 """
67 doc_embeddings = output.doc_embeddings
68 if doc_embeddings is None:
69 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
70
71 if doc_embeddings.scoring_mask is None:
72 doc_lengths = torch.ones(
73 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
74 )
75 embeddings = doc_embeddings.embeddings[:, 0]
76 else:
77 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
78 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
79
80 num_docs = len(index_batch.doc_ids)
81 self.doc_ids.extend(index_batch.doc_ids)
82 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
83 self.num_embeddings += embeddings.shape[0]
84 self.num_docs += num_docs
85
86 for idx, doc_id in enumerate(index_batch.doc_ids):
87 non_zero = embeddings[idx].nonzero().view(-1)
88 values = embeddings[idx][non_zero].float().numpy(force=True)
89 tokens = np.array(self.module.tokenizer.convert_ids_to_tokens(non_zero), dtype="U30")
90 self.seismic_dataset.add_document(doc_id, tokens, values)
91
[docs]
92 def save(self) -> None:
93 """Save the Seismic index to disk."""
94 super().save()
95
96 assert SeismicIndex is not None
97 index = SeismicIndex.build_from_dataset(
98 self.seismic_dataset,
99 n_postings=self.index_config.num_postings,
100 centroid_fraction=self.index_config.centroid_fraction,
101 min_cluster_size=self.index_config.min_cluster_size,
102 summary_energy=self.index_config.summary_energy,
103 nknn=self.index_config.num_k_nearest_neighbors,
104 batched_indexing=self.index_config.batch_size,
105 num_threads=self.index_config.num_threads,
106 )
107 index.save(str(self.index_dir) + os.path.sep)
108
109
[docs]
110class SeismicIndexConfig(IndexConfig):
111 """Configuration for the Seismic indexer."""
112
113 indexer_class = SeismicIndexer
114 SUPPORTED_MODELS = {SpladeConfig.model_type}
115
[docs]
116 def __init__(
117 self,
118 num_postings: int = 3_500,
119 centroid_fraction: float = 0.1,
120 min_cluster_size: int = 2,
121 summary_energy: float = 0.4,
122 num_k_nearest_neighbors: int = 0,
123 batch_size: int | None = None,
124 num_threads: int = 0,
125 ) -> None:
126 """Initialize the SeismicIndexConfig.
127
128 Args:
129 num_postings (int): Number of postings to keep in the index. Defaults to 3500.
130 centroid_fraction (float): Fraction of centroids to keep. Defaults to 0.1.
131 min_cluster_size (int): Minimum size of clusters. Defaults to 2.
132 summary_energy (float): Energy threshold for summaries. Defaults to 0.4.
133 num_k_nearest_neighbors (int): Number of nearest neighbors to consider. Defaults to 0.
134 batch_size (int | None): Batch size for indexing. Defaults to None.
135 num_threads (int): Number of threads to use for indexing. Defaults to 0.
136 """
137 super().__init__()
138 self.num_postings = num_postings
139 self.centroid_fraction = centroid_fraction
140 self.summary_energy = summary_energy
141 self.min_cluster_size = min_cluster_size
142 self.num_k_nearest_neighbors = num_k_nearest_neighbors
143 self.batch_size = batch_size
144 self.num_threads = num_threads