Source code for lightning_ir.retrieve.seismic.seismic_searcher

  1"""Seismic Searcher for Lightning IR Framework"""
  2
  3from __future__ import annotations
  4
  5from pathlib import Path
  6from typing import TYPE_CHECKING, Literal, Tuple
  7
  8import numpy as np
  9import torch
 10
 11try:
 12    _seismic_available = True
 13    import seismic
 14    from seismic import SeismicIndex
 15
 16    STRING_TYPE = seismic.get_seismic_string()
 17except ImportError:
 18    STRING_TYPE = None
 19    _seismic_available = False
 20    SeismicIndex = None
 21
 22from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
 23from ...models import SpladeConfig
 24from ..base.packed_tensor import PackedTensor
 25from ..base.searcher import ApproximateSearchConfig, ApproximateSearcher
 26
 27if TYPE_CHECKING:
 28    from ...bi_encoder import BiEncoderModule
 29
 30
[docs] 31class SeismicSearcher(ApproximateSearcher): 32 """Seismic Searcher for efficient retrieval using Seismic indexing.""" 33
[docs] 34 def __init__( 35 self, 36 index_dir: Path | str, 37 search_config: "SeismicSearchConfig", 38 module: BiEncoderModule, 39 use_gpu: bool = False, 40 ) -> None: 41 """Initialize the SeismicSearcher. 42 43 Args: 44 index_dir (Path | str): Directory where the Seismic index is stored. 45 search_config (SeismicSearchConfig): Configuration for the Seismic searcher. 46 module (BiEncoderModule): The BiEncoder module used for searching. 47 use_gpu (bool): Whether to use GPU for searching. Defaults to False. 48 Raises: 49 ImportError: If the seismic package is not available. 50 """ 51 super().__init__(index_dir, search_config, module, use_gpu) 52 if not _seismic_available: 53 raise ImportError( 54 "Please install the seismic package to use the SeismicIndexer. " 55 "Instructions can be found at " 56 "https://github.com/TusKANNy/seismic?tab=readme-ov-file#using-the-python-interface" 57 ) 58 assert SeismicIndex is not None 59 self.index = SeismicIndex.load(str(self.index_dir / ".index.seismic")) 60 self.inverse_doc_ids = {doc_id: idx for idx, doc_id in enumerate(self.doc_ids)} 61 62 self.search_config: SeismicSearchConfig
63 64 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]: 65 if query_embeddings.scoring_mask is None: 66 embeddings = query_embeddings.embeddings[:, 0] 67 else: 68 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 69 70 query_components = [] 71 query_values = [] 72 73 for idx in range(embeddings.shape[0]): 74 non_zero = embeddings[idx].nonzero().view(-1) 75 values = embeddings[idx][non_zero].float().numpy(force=True) 76 tokens = np.array(self.module.tokenizer.convert_ids_to_tokens(non_zero), dtype=STRING_TYPE) 77 query_components.append(tokens) 78 query_values.append(values) 79 80 results = self.index.batch_search( 81 queries_ids=np.array(range(len(query_components)), dtype=STRING_TYPE), 82 query_components=query_components, 83 query_values=query_values, 84 k=self.search_config.k, 85 query_cut=self.search_config.query_cut, 86 heap_factor=self.search_config.heap_factor, 87 num_threads=self.search_config.num_threads, 88 ) 89 90 scores_list = [] 91 candidate_idcs_list = [] 92 num_docs = [] 93 for result in results: 94 for _, score, doc_id in result: 95 doc_idx = self.inverse_doc_ids[doc_id] 96 scores_list.append(score) 97 candidate_idcs_list.append(doc_idx) 98 num_docs.append(len(result)) 99 100 scores = torch.tensor(scores_list) 101 candidate_idcs = torch.tensor(candidate_idcs_list, device=query_embeddings.device) 102 103 return PackedTensor(scores, lengths=num_docs), PackedTensor(candidate_idcs, lengths=num_docs) 104 105 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor: 106 raise NotImplementedError("Gathering doc embeddings is not supported for SeismicSearcher")
107 108
[docs] 109class SeismicSearchConfig(ApproximateSearchConfig): 110 """Configuration for SeismicSearcher.""" 111 112 search_class = SeismicSearcher 113 SUPPORTED_MODELS = {SpladeConfig.model_type} 114
[docs] 115 def __init__( 116 self, 117 k: int = 10, 118 candidate_k: int = 100, 119 imputation_strategy: Literal["min", "gather", "zero"] = "min", 120 query_cut: int = 10, 121 heap_factor: float = 0.7, 122 num_threads: int = 1, 123 ) -> None: 124 """Initialize the SeismicSearchConfig. 125 126 Args: 127 k (int): Number of top candidates to retrieve. Defaults to 10. 128 candidate_k (int): Number of candidates to consider for each query. Defaults to 100. 129 imputation_strategy (Literal["min", "gather", "zero"]): Strategy for handling missing values. 130 Defaults to "min". 131 query_cut (int): Maximum number of components per query. Defaults to 10. 132 heap_factor (float): Factor to control the size of the heap used in the search. Defaults to 0.7. 133 num_threads (int): Number of threads to use for parallel processing. Defaults to 1. 134 Raises: 135 ValueError: If imputation_strategy is "gather", as it is not supported for SeismicSearcher. 136 """ 137 if imputation_strategy == "gather": 138 raise ValueError("Imputation strategy 'gather' is not supported for SeismicSearcher") 139 super().__init__(k=k, candidate_k=candidate_k, imputation_strategy=imputation_strategy) 140 self.query_cut = query_cut 141 self.heap_factor = heap_factor 142 self.num_threads = num_threads