Source code for lightning_ir.retrieve.seismic.seismic_searcher

  1from __future__ import annotations
  2
  3from pathlib import Path
  4from typing import TYPE_CHECKING, Literal, Tuple
  5
  6import numpy as np
  7import torch
  8
  9try:
 10    _seismic_available = True
 11    import seismic
 12    from seismic import SeismicIndex
 13
 14    STRING_TYPE = seismic.get_seismic_string()
 15except ImportError:
 16    STRING_TYPE = None
 17    _seismic_available = False
 18    SeismicIndex = None
 19
 20from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
 21from ...models import SpladeConfig
 22from ..base.packed_tensor import PackedTensor
 23from ..base.searcher import ApproximateSearchConfig, ApproximateSearcher
 24
 25if TYPE_CHECKING:
 26    from ...bi_encoder import BiEncoderModule
 27
 28
[docs] 29class SeismicSearcher(ApproximateSearcher):
[docs] 30 def __init__( 31 self, 32 index_dir: Path | str, 33 search_config: "SeismicSearchConfig", 34 module: BiEncoderModule, 35 use_gpu: bool = False, 36 ) -> None: 37 super().__init__(index_dir, search_config, module, use_gpu) 38 if not _seismic_available: 39 raise ImportError( 40 "Please install the seismic package to use the SeismicIndexer. " 41 "Instructions can be found at " 42 "https://github.com/TusKANNy/seismic?tab=readme-ov-file#using-the-python-interface" 43 ) 44 assert SeismicIndex is not None 45 self.index = SeismicIndex.load(str(self.index_dir / ".index.seismic")) 46 self.inverse_doc_ids = {doc_id: idx for idx, doc_id in enumerate(self.doc_ids)} 47 48 self.search_config: SeismicSearchConfig
49 50 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]: 51 if query_embeddings.scoring_mask is None: 52 embeddings = query_embeddings.embeddings[:, 0] 53 else: 54 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 55 56 query_components = [] 57 query_values = [] 58 59 for idx in range(embeddings.shape[0]): 60 non_zero = embeddings[idx].nonzero().view(-1) 61 values = embeddings[idx][non_zero].float().numpy(force=True) 62 tokens = np.array(self.module.tokenizer.convert_ids_to_tokens(non_zero), dtype=STRING_TYPE) 63 query_components.append(tokens) 64 query_values.append(values) 65 66 results = self.index.batch_search( 67 queries_ids=np.array(range(len(query_components)), dtype=STRING_TYPE), 68 query_components=query_components, 69 query_values=query_values, 70 k=self.search_config.k, 71 query_cut=self.search_config.query_cut, 72 heap_factor=self.search_config.heap_factor, 73 num_threads=self.search_config.num_threads, 74 ) 75 76 scores_list = [] 77 candidate_idcs_list = [] 78 num_docs = [] 79 for result in results: 80 for _, score, doc_id in result: 81 doc_idx = self.inverse_doc_ids[doc_id] 82 scores_list.append(score) 83 candidate_idcs_list.append(doc_idx) 84 num_docs.append(len(result)) 85 86 scores = torch.tensor(scores_list) 87 candidate_idcs = torch.tensor(candidate_idcs_list, device=query_embeddings.device) 88 89 return PackedTensor(scores, lengths=num_docs), PackedTensor(candidate_idcs, lengths=num_docs) 90 91 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor: 92 raise NotImplementedError("Gathering doc embeddings is not supported for SeismicSearcher")
93 94
[docs] 95class SeismicSearchConfig(ApproximateSearchConfig): 96 97 search_class = SeismicSearcher 98 SUPPORTED_MODELS = {SpladeConfig.model_type} 99
[docs] 100 def __init__( 101 self, 102 k: int = 10, 103 candidate_k: int = 100, 104 imputation_strategy: Literal["min", "gather", "zero"] = "min", 105 query_cut: int = 10, 106 heap_factor: float = 0.7, 107 num_threads: int = 1, 108 ) -> None: 109 if imputation_strategy == "gather": 110 raise ValueError("Imputation strategy 'gather' is not supported for SeismicSearcher") 111 super().__init__(k=k, candidate_k=candidate_k, imputation_strategy=imputation_strategy) 112 self.query_cut = query_cut 113 self.heap_factor = heap_factor 114 self.num_threads = num_threads