Source code for lightning_ir.retrieve.faiss.faiss_searcher

  1"""FAISS-based searcher for approximate nearest neighbor retrieval in the Lightning IR framework."""
  2
  3from __future__ import annotations
  4
  5from pathlib import Path
  6from typing import TYPE_CHECKING, Literal, Tuple
  7
  8import torch
  9
 10from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
 11from ...models import ColConfig, DprConfig
 12from ..base.packed_tensor import PackedTensor
 13from ..base.searcher import ApproximateSearchConfig, ApproximateSearcher
 14
 15if TYPE_CHECKING:
 16    from ...bi_encoder import BiEncoderModule
 17
 18
[docs] 19class FaissSearcher(ApproximateSearcher): 20 """FAISS-based searcher for approximate nearest neighbor retrieval in the Lightning IR framework.""" 21
[docs] 22 def __init__( 23 self, 24 index_dir: Path | str, 25 search_config: FaissSearchConfig, 26 module: BiEncoderModule, 27 use_gpu: bool = False, 28 ) -> None: 29 """Initialize the FaissSearcher. 30 31 Args: 32 index_dir (Path | str): Directory containing the FAISS index files. 33 search_config (FaissSearchConfig): Configuration for the FAISS searcher. 34 module (BiEncoderModule): The bi-encoder module used for embeddings. 35 use_gpu (bool): Whether to use GPU for FAISS operations. Defaults to False. 36 """ 37 import faiss 38 39 self.search_config: FaissSearchConfig 40 self.index = faiss.read_index(str(Path(index_dir) / "index.faiss")) 41 if use_gpu and hasattr(faiss, "index_cpu_to_all_gpus"): 42 self.index = faiss.index_cpu_to_all_gpus(self.index) 43 ivf_index = None 44 try: 45 ivf_index = faiss.extract_index_ivf(self.index) 46 except RuntimeError: 47 pass 48 if ivf_index is not None: 49 ivf_index.nprobe = search_config.n_probe 50 quantizer = getattr(ivf_index, "quantizer", None) 51 if quantizer is not None: 52 downcasted_quantizer = faiss.downcast_index(quantizer) 53 hnsw = getattr(downcasted_quantizer, "hnsw", None) 54 if hnsw is not None: 55 hnsw.efSearch = search_config.ef_search 56 super().__init__(index_dir, search_config, module, use_gpu)
57 58 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]: 59 if query_embeddings.scoring_mask is None: 60 embeddings = query_embeddings.embeddings[:, 0] 61 else: 62 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 63 candidate_scores, candidate_idcs = self.index.search(embeddings.float().cpu(), self.search_config.candidate_k) 64 candidate_scores = torch.from_numpy(candidate_scores).view(-1) 65 candidate_idcs = torch.from_numpy(candidate_idcs).view(-1) 66 num_candidates_per_query_vector = [self.search_config.candidate_k] * embeddings.shape[0] 67 packed_candidate_scores = PackedTensor(candidate_scores, lengths=num_candidates_per_query_vector) 68 packed_candidate_idcs = PackedTensor(candidate_idcs, lengths=num_candidates_per_query_vector) 69 return packed_candidate_scores, packed_candidate_idcs 70 71 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor: 72 return torch.from_numpy(self.index.reconstruct_batch(idcs))
73 74
[docs] 75class FaissSearchConfig(ApproximateSearchConfig): 76 """Configuration class for FAISS-based searchers in the Lightning IR framework.""" 77 78 search_class = FaissSearcher 79 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type} 80
[docs] 81 def __init__( 82 self, 83 k: int = 10, 84 candidate_k: int = 100, 85 imputation_strategy: Literal["min", "gather", "zero"] = "gather", 86 n_probe: int = 1, 87 ef_search: int = 16, 88 ) -> None: 89 """Initialize the FaissSearchConfig. 90 91 Args: 92 k (int): Number of top results to return. Defaults to 10. 93 candidate_k (int): Number of candidates to retrieve before ranking. Defaults to 100. 94 imputation_strategy (Literal["min", "gather", "zero"]): Strategy for handling missing scores. 95 Defaults to "gather". 96 n_probe (int): Number of probes for the IVF index. Defaults to 1. 97 ef_search (int): Size of the dynamic list used during search. Defaults to 16. 98 """ 99 super().__init__(k, candidate_k, imputation_strategy) 100 self.n_probe = n_probe 101 self.ef_search = ef_search