Source code for lightning_ir.retrieve.seismic.seismic_searcher
1from __future__ import annotations
2
3from pathlib import Path
4from typing import TYPE_CHECKING, Literal, Tuple
5
6import numpy as np
7import torch
8
9try:
10 _seismic_available = True
11 import seismic
12 from seismic import SeismicIndex
13
14 STRING_TYPE = seismic.get_seismic_string()
15except ImportError:
16 STRING_TYPE = None
17 _seismic_available = False
18 SeismicIndex = None
19
20from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
21from ...models import SpladeConfig
22from ..base.packed_tensor import PackedTensor
23from ..base.searcher import ApproximateSearchConfig, ApproximateSearcher
24
25if TYPE_CHECKING:
26 from ...bi_encoder import BiEncoderModule
27
28
[docs]
29class SeismicSearcher(ApproximateSearcher):
[docs]
30 def __init__(
31 self,
32 index_dir: Path | str,
33 search_config: "SeismicSearchConfig",
34 module: BiEncoderModule,
35 use_gpu: bool = False,
36 ) -> None:
37 super().__init__(index_dir, search_config, module, use_gpu)
38 if not _seismic_available:
39 raise ImportError(
40 "Please install the seismic package to use the SeismicIndexer. "
41 "Instructions can be found at "
42 "https://github.com/TusKANNy/seismic?tab=readme-ov-file#using-the-python-interface"
43 )
44 assert SeismicIndex is not None
45 self.index = SeismicIndex.load(str(self.index_dir / ".index.seismic"))
46 self.inverse_doc_ids = {doc_id: idx for idx, doc_id in enumerate(self.doc_ids)}
47
48 self.search_config: SeismicSearchConfig
49
50 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]:
51 if query_embeddings.scoring_mask is None:
52 embeddings = query_embeddings.embeddings[:, 0]
53 else:
54 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
55
56 query_components = []
57 query_values = []
58
59 for idx in range(embeddings.shape[0]):
60 non_zero = embeddings[idx].nonzero().view(-1)
61 values = embeddings[idx][non_zero].float().numpy(force=True)
62 tokens = np.array(self.module.tokenizer.convert_ids_to_tokens(non_zero), dtype=STRING_TYPE)
63 query_components.append(tokens)
64 query_values.append(values)
65
66 results = self.index.batch_search(
67 queries_ids=np.array(range(len(query_components)), dtype=STRING_TYPE),
68 query_components=query_components,
69 query_values=query_values,
70 k=self.search_config.k,
71 query_cut=self.search_config.query_cut,
72 heap_factor=self.search_config.heap_factor,
73 num_threads=self.search_config.num_threads,
74 )
75
76 scores_list = []
77 candidate_idcs_list = []
78 num_docs = []
79 for result in results:
80 for _, score, doc_id in result:
81 doc_idx = self.inverse_doc_ids[doc_id]
82 scores_list.append(score)
83 candidate_idcs_list.append(doc_idx)
84 num_docs.append(len(result))
85
86 scores = torch.tensor(scores_list)
87 candidate_idcs = torch.tensor(candidate_idcs_list, device=query_embeddings.device)
88
89 return PackedTensor(scores, lengths=num_docs), PackedTensor(candidate_idcs, lengths=num_docs)
90
91 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor:
92 raise NotImplementedError("Gathering doc embeddings is not supported for SeismicSearcher")
93
94
[docs]
95class SeismicSearchConfig(ApproximateSearchConfig):
96
97 search_class = SeismicSearcher
98 SUPPORTED_MODELS = {SpladeConfig.model_type}
99
[docs]
100 def __init__(
101 self,
102 k: int = 10,
103 candidate_k: int = 100,
104 imputation_strategy: Literal["min", "gather", "zero"] = "min",
105 query_cut: int = 10,
106 heap_factor: float = 0.7,
107 num_threads: int = 1,
108 ) -> None:
109 if imputation_strategy == "gather":
110 raise ValueError("Imputation strategy 'gather' is not supported for SeismicSearcher")
111 super().__init__(k=k, candidate_k=candidate_k, imputation_strategy=imputation_strategy)
112 self.query_cut = query_cut
113 self.heap_factor = heap_factor
114 self.num_threads = num_threads