Source code for lightning_ir.retrieve.faiss.faiss_searcher
1"""FAISS-based searcher for approximate nearest neighbor retrieval in the Lightning IR framework."""
2
3from __future__ import annotations
4
5from pathlib import Path
6from typing import TYPE_CHECKING, Literal, Tuple
7
8import torch
9
10from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
11from ...models import ColConfig, DprConfig
12from ..base.packed_tensor import PackedTensor
13from ..base.searcher import ApproximateSearchConfig, ApproximateSearcher
14
15if TYPE_CHECKING:
16 from ...bi_encoder import BiEncoderModule
17
18
[docs]
19class FaissSearcher(ApproximateSearcher):
20 """FAISS-based searcher for approximate nearest neighbor retrieval in the Lightning IR framework."""
21
[docs]
22 def __init__(
23 self,
24 index_dir: Path | str,
25 search_config: FaissSearchConfig,
26 module: BiEncoderModule,
27 use_gpu: bool = False,
28 ) -> None:
29 """Initialize the FaissSearcher.
30
31 Args:
32 index_dir (Path | str): Directory containing the FAISS index files.
33 search_config (FaissSearchConfig): Configuration for the FAISS searcher.
34 module (BiEncoderModule): The bi-encoder module used for embeddings.
35 use_gpu (bool): Whether to use GPU for FAISS operations. Defaults to False.
36 """
37 import faiss
38
39 self.search_config: FaissSearchConfig
40 self.index = faiss.read_index(str(Path(index_dir) / "index.faiss"))
41 if use_gpu and hasattr(faiss, "index_cpu_to_all_gpus"):
42 self.index = faiss.index_cpu_to_all_gpus(self.index)
43 ivf_index = None
44 try:
45 ivf_index = faiss.extract_index_ivf(self.index)
46 except RuntimeError:
47 pass
48 if ivf_index is not None:
49 ivf_index.nprobe = search_config.n_probe
50 quantizer = getattr(ivf_index, "quantizer", None)
51 if quantizer is not None:
52 downcasted_quantizer = faiss.downcast_index(quantizer)
53 hnsw = getattr(downcasted_quantizer, "hnsw", None)
54 if hnsw is not None:
55 hnsw.efSearch = search_config.ef_search
56 super().__init__(index_dir, search_config, module, use_gpu)
57
58 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]:
59 if query_embeddings.scoring_mask is None:
60 embeddings = query_embeddings.embeddings[:, 0]
61 else:
62 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
63 candidate_scores, candidate_idcs = self.index.search(embeddings.float().cpu(), self.search_config.candidate_k)
64 candidate_scores = torch.from_numpy(candidate_scores).view(-1)
65 candidate_idcs = torch.from_numpy(candidate_idcs).view(-1)
66 num_candidates_per_query_vector = [self.search_config.candidate_k] * embeddings.shape[0]
67 packed_candidate_scores = PackedTensor(candidate_scores, lengths=num_candidates_per_query_vector)
68 packed_candidate_idcs = PackedTensor(candidate_idcs, lengths=num_candidates_per_query_vector)
69 return packed_candidate_scores, packed_candidate_idcs
70
71 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor:
72 return torch.from_numpy(self.index.reconstruct_batch(idcs))
73
74
[docs]
75class FaissSearchConfig(ApproximateSearchConfig):
76 """Configuration class for FAISS-based searchers in the Lightning IR framework."""
77
78 search_class = FaissSearcher
79 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}
80
[docs]
81 def __init__(
82 self,
83 k: int = 10,
84 candidate_k: int = 100,
85 imputation_strategy: Literal["min", "gather", "zero"] = "gather",
86 n_probe: int = 1,
87 ef_search: int = 16,
88 ) -> None:
89 """Initialize the FaissSearchConfig.
90
91 Args:
92 k (int): Number of top results to return. Defaults to 10.
93 candidate_k (int): Number of candidates to retrieve before ranking. Defaults to 100.
94 imputation_strategy (Literal["min", "gather", "zero"]): Strategy for handling missing scores.
95 Defaults to "gather".
96 n_probe (int): Number of probes for the IVF index. Defaults to 1.
97 ef_search (int): Size of the dynamic list used during search. Defaults to 16.
98 """
99 super().__init__(k, candidate_k, imputation_strategy)
100 self.n_probe = n_probe
101 self.ef_search = ef_search