Source code for lightning_ir.retrieve.seismic.seismic_indexer
1import os
2from pathlib import Path
3
4import numpy as np
5import torch
6
7try:
8 _seismic_available = True
9 import seismic
10 from seismic import SeismicDataset, SeismicIndex
11
12 STRING_TYPE = seismic.get_seismic_string()
13except ImportError:
14 STRING_TYPE = None
15 _seismic_available = False
16 SeismicIndex = SeismicDataset = None
17
18
19from ...bi_encoder import BiEncoderModule, BiEncoderOutput
20from ...data import IndexBatch
21from ...models import SpladeConfig
22from ..base import IndexConfig, Indexer
23
24
[docs]
25class SeismicIndexer(Indexer):
[docs]
26 def __init__(
27 self,
28 index_dir: Path,
29 index_config: "SeismicIndexConfig",
30 module: BiEncoderModule,
31 verbose: bool = False,
32 ) -> None:
33 super().__init__(index_dir, index_config, module, verbose)
34 if _seismic_available is False:
35 raise ImportError(
36 "Please install the seismic package to use the SeismicIndexer. "
37 "Instructions can be found at "
38 "https://github.com/TusKANNy/seismic?tab=readme-ov-file#using-the-python-interface"
39 )
40 self.index_config: SeismicIndexConfig
41 assert SeismicDataset is not None
42 self.seismic_dataset = SeismicDataset()
43
44 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
45 doc_embeddings = output.doc_embeddings
46 if doc_embeddings is None:
47 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
48
49 if doc_embeddings.scoring_mask is None:
50 doc_lengths = torch.ones(
51 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
52 )
53 embeddings = doc_embeddings.embeddings[:, 0]
54 else:
55 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
56 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
57
58 num_docs = len(index_batch.doc_ids)
59 self.doc_ids.extend(index_batch.doc_ids)
60 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
61 self.num_embeddings += embeddings.shape[0]
62 self.num_docs += num_docs
63
64 for idx, doc_id in enumerate(index_batch.doc_ids):
65 non_zero = embeddings[idx].nonzero().view(-1)
66 values = embeddings[idx][non_zero].float().numpy(force=True)
67 tokens = np.array(self.module.tokenizer.convert_ids_to_tokens(non_zero), dtype="U30")
68 self.seismic_dataset.add_document(doc_id, tokens, values)
69
70 def save(self) -> None:
71 super().save()
72
73 assert SeismicIndex is not None
74 index = SeismicIndex.build_from_dataset(
75 self.seismic_dataset,
76 n_postings=self.index_config.num_postings,
77 centroid_fraction=self.index_config.centroid_fraction,
78 min_cluster_size=self.index_config.min_cluster_size,
79 summary_energy=self.index_config.summary_energy,
80 nknn=self.index_config.num_k_nearest_neighbors,
81 batched_indexing=self.index_config.batch_size,
82 num_threads=self.index_config.num_threads,
83 )
84 index.save(str(self.index_dir) + os.path.sep)
85
86
[docs]
87class SeismicIndexConfig(IndexConfig):
88 indexer_class = SeismicIndexer
89 SUPPORTED_MODELS = {SpladeConfig.model_type}
90
[docs]
91 def __init__(
92 self,
93 num_postings: int = 3_500,
94 centroid_fraction: float = 0.1,
95 min_cluster_size: int = 2,
96 summary_energy: float = 0.4,
97 num_k_nearest_neighbors: int = 0,
98 batch_size: int | None = None,
99 num_threads: int = 0,
100 ) -> None:
101 super().__init__()
102 self.num_postings = num_postings
103 self.centroid_fraction = centroid_fraction
104 self.summary_energy = summary_energy
105 self.min_cluster_size = min_cluster_size
106 self.num_k_nearest_neighbors = num_k_nearest_neighbors
107 self.batch_size = batch_size
108 self.num_threads = num_threads