Source code for lightning_ir.retrieve.faiss.faiss_searcher

 1from __future__ import annotations
 2
 3from pathlib import Path
 4from typing import TYPE_CHECKING, Literal, Tuple
 5
 6import torch
 7
 8from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
 9from ...models import ColConfig, DprConfig
10from ..base.packed_tensor import PackedTensor
11from ..base.searcher import ApproximateSearchConfig, ApproximateSearcher
12
13if TYPE_CHECKING:
14    from ...bi_encoder import BiEncoderModule
15
16
[docs] 17class FaissSearcher(ApproximateSearcher):
[docs] 18 def __init__( 19 self, 20 index_dir: Path | str, 21 search_config: FaissSearchConfig, 22 module: BiEncoderModule, 23 use_gpu: bool = False, 24 ) -> None: 25 import faiss 26 27 self.search_config: FaissSearchConfig 28 self.index = faiss.read_index(str(Path(index_dir) / "index.faiss")) 29 if use_gpu and hasattr(faiss, "index_cpu_to_all_gpus"): 30 self.index = faiss.index_cpu_to_all_gpus(self.index) 31 ivf_index = None 32 try: 33 ivf_index = faiss.extract_index_ivf(self.index) 34 except RuntimeError: 35 pass 36 if ivf_index is not None: 37 ivf_index.nprobe = search_config.n_probe 38 quantizer = getattr(ivf_index, "quantizer", None) 39 if quantizer is not None: 40 downcasted_quantizer = faiss.downcast_index(quantizer) 41 hnsw = getattr(downcasted_quantizer, "hnsw", None) 42 if hnsw is not None: 43 hnsw.efSearch = search_config.ef_search 44 super().__init__(index_dir, search_config, module, use_gpu)
45 46 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]: 47 if query_embeddings.scoring_mask is None: 48 embeddings = query_embeddings.embeddings[:, 0] 49 else: 50 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 51 candidate_scores, candidate_idcs = self.index.search(embeddings.float().cpu(), self.search_config.candidate_k) 52 candidate_scores = torch.from_numpy(candidate_scores).view(-1) 53 candidate_idcs = torch.from_numpy(candidate_idcs).view(-1) 54 num_candidates_per_query_vector = [self.search_config.candidate_k] * embeddings.shape[0] 55 packed_candidate_scores = PackedTensor(candidate_scores, lengths=num_candidates_per_query_vector) 56 packed_candidate_idcs = PackedTensor(candidate_idcs, lengths=num_candidates_per_query_vector) 57 return packed_candidate_scores, packed_candidate_idcs 58 59 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor: 60 return torch.from_numpy(self.index.reconstruct_batch(idcs))
61 62
[docs] 63class FaissSearchConfig(ApproximateSearchConfig): 64 search_class = FaissSearcher 65 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type} 66
[docs] 67 def __init__( 68 self, 69 k: int = 10, 70 candidate_k: int = 100, 71 imputation_strategy: Literal["min", "gather", "zero"] = "gather", 72 n_probe: int = 1, 73 ef_search: int = 16, 74 ) -> None: 75 super().__init__(k, candidate_k, imputation_strategy) 76 self.n_probe = n_probe 77 self.ef_search = ef_search