Source code for lightning_ir.retrieve.seismic.seismic_searcher
1"""Seismic Searcher for Lightning IR Framework"""
2
3from __future__ import annotations
4
5from pathlib import Path
6from typing import TYPE_CHECKING, Literal, Tuple
7
8import numpy as np
9import torch
10
11try:
12 _seismic_available = True
13 import seismic
14 from seismic import SeismicIndex
15
16 STRING_TYPE = seismic.get_seismic_string()
17except ImportError:
18 STRING_TYPE = None
19 _seismic_available = False
20 SeismicIndex = None
21
22from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
23from ...models import SpladeConfig
24from ..base.packed_tensor import PackedTensor
25from ..base.searcher import ApproximateSearchConfig, ApproximateSearcher
26
27if TYPE_CHECKING:
28 from ...bi_encoder import BiEncoderModule
29
30
[docs]
31class SeismicSearcher(ApproximateSearcher):
32 """Seismic Searcher for efficient retrieval using Seismic indexing."""
33
[docs]
34 def __init__(
35 self,
36 index_dir: Path | str,
37 search_config: "SeismicSearchConfig",
38 module: BiEncoderModule,
39 use_gpu: bool = False,
40 ) -> None:
41 """Initialize the SeismicSearcher.
42
43 Args:
44 index_dir (Path | str): Directory where the Seismic index is stored.
45 search_config (SeismicSearchConfig): Configuration for the Seismic searcher.
46 module (BiEncoderModule): The BiEncoder module used for searching.
47 use_gpu (bool): Whether to use GPU for searching. Defaults to False.
48 Raises:
49 ImportError: If the seismic package is not available.
50 """
51 super().__init__(index_dir, search_config, module, use_gpu)
52 if not _seismic_available:
53 raise ImportError(
54 "Please install the seismic package to use the SeismicIndexer. "
55 "Instructions can be found at "
56 "https://github.com/TusKANNy/seismic?tab=readme-ov-file#using-the-python-interface"
57 )
58 assert SeismicIndex is not None
59 self.index = SeismicIndex.load(str(self.index_dir / ".index.seismic"))
60 self.inverse_doc_ids = {doc_id: idx for idx, doc_id in enumerate(self.doc_ids)}
61
62 self.search_config: SeismicSearchConfig
63
64 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]:
65 if query_embeddings.scoring_mask is None:
66 embeddings = query_embeddings.embeddings[:, 0]
67 else:
68 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
69
70 query_components = []
71 query_values = []
72
73 for idx in range(embeddings.shape[0]):
74 non_zero = embeddings[idx].nonzero().view(-1)
75 values = embeddings[idx][non_zero].float().numpy(force=True)
76 tokens = np.array(self.module.tokenizer.convert_ids_to_tokens(non_zero), dtype=STRING_TYPE)
77 query_components.append(tokens)
78 query_values.append(values)
79
80 results = self.index.batch_search(
81 queries_ids=np.array(range(len(query_components)), dtype=STRING_TYPE),
82 query_components=query_components,
83 query_values=query_values,
84 k=self.search_config.k,
85 query_cut=self.search_config.query_cut,
86 heap_factor=self.search_config.heap_factor,
87 num_threads=self.search_config.num_threads,
88 )
89
90 scores_list = []
91 candidate_idcs_list = []
92 num_docs = []
93 for result in results:
94 for _, score, doc_id in result:
95 doc_idx = self.inverse_doc_ids[doc_id]
96 scores_list.append(score)
97 candidate_idcs_list.append(doc_idx)
98 num_docs.append(len(result))
99
100 scores = torch.tensor(scores_list)
101 candidate_idcs = torch.tensor(candidate_idcs_list, device=query_embeddings.device)
102
103 return PackedTensor(scores, lengths=num_docs), PackedTensor(candidate_idcs, lengths=num_docs)
104
105 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor:
106 raise NotImplementedError("Gathering doc embeddings is not supported for SeismicSearcher")
107
108
[docs]
109class SeismicSearchConfig(ApproximateSearchConfig):
110 """Configuration for SeismicSearcher."""
111
112 search_class = SeismicSearcher
113 SUPPORTED_MODELS = {SpladeConfig.model_type}
114
[docs]
115 def __init__(
116 self,
117 k: int = 10,
118 candidate_k: int = 100,
119 imputation_strategy: Literal["min", "gather", "zero"] = "min",
120 query_cut: int = 10,
121 heap_factor: float = 0.7,
122 num_threads: int = 1,
123 ) -> None:
124 """Initialize the SeismicSearchConfig.
125
126 Args:
127 k (int): Number of top candidates to retrieve. Defaults to 10.
128 candidate_k (int): Number of candidates to consider for each query. Defaults to 100.
129 imputation_strategy (Literal["min", "gather", "zero"]): Strategy for handling missing values.
130 Defaults to "min".
131 query_cut (int): Maximum number of components per query. Defaults to 10.
132 heap_factor (float): Factor to control the size of the heap used in the search. Defaults to 0.7.
133 num_threads (int): Number of threads to use for parallel processing. Defaults to 1.
134 Raises:
135 ValueError: If imputation_strategy is "gather", as it is not supported for SeismicSearcher.
136 """
137 if imputation_strategy == "gather":
138 raise ValueError("Imputation strategy 'gather' is not supported for SeismicSearcher")
139 super().__init__(k=k, candidate_k=candidate_k, imputation_strategy=imputation_strategy)
140 self.query_cut = query_cut
141 self.heap_factor = heap_factor
142 self.num_threads = num_threads