Source code for lightning_ir.retrieve.faiss.faiss_searcher
1from __future__ import annotations
2
3from pathlib import Path
4from typing import TYPE_CHECKING, Literal, Tuple
5
6import torch
7
8from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
9from ...models import ColConfig, DprConfig
10from ..base.packed_tensor import PackedTensor
11from ..base.searcher import ApproximateSearchConfig, ApproximateSearcher
12
13if TYPE_CHECKING:
14 from ...bi_encoder import BiEncoderModule
15
16
[docs]
17class FaissSearcher(ApproximateSearcher):
[docs]
18 def __init__(
19 self,
20 index_dir: Path | str,
21 search_config: FaissSearchConfig,
22 module: BiEncoderModule,
23 use_gpu: bool = False,
24 ) -> None:
25 import faiss
26
27 self.search_config: FaissSearchConfig
28 self.index = faiss.read_index(str(Path(index_dir) / "index.faiss"))
29 if use_gpu and hasattr(faiss, "index_cpu_to_all_gpus"):
30 self.index = faiss.index_cpu_to_all_gpus(self.index)
31 ivf_index = None
32 try:
33 ivf_index = faiss.extract_index_ivf(self.index)
34 except RuntimeError:
35 pass
36 if ivf_index is not None:
37 ivf_index.nprobe = search_config.n_probe
38 quantizer = getattr(ivf_index, "quantizer", None)
39 if quantizer is not None:
40 downcasted_quantizer = faiss.downcast_index(quantizer)
41 hnsw = getattr(downcasted_quantizer, "hnsw", None)
42 if hnsw is not None:
43 hnsw.efSearch = search_config.ef_search
44 super().__init__(index_dir, search_config, module, use_gpu)
45
46 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]:
47 if query_embeddings.scoring_mask is None:
48 embeddings = query_embeddings.embeddings[:, 0]
49 else:
50 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
51 candidate_scores, candidate_idcs = self.index.search(embeddings.float().cpu(), self.search_config.candidate_k)
52 candidate_scores = torch.from_numpy(candidate_scores).view(-1)
53 candidate_idcs = torch.from_numpy(candidate_idcs).view(-1)
54 num_candidates_per_query_vector = [self.search_config.candidate_k] * embeddings.shape[0]
55 packed_candidate_scores = PackedTensor(candidate_scores, lengths=num_candidates_per_query_vector)
56 packed_candidate_idcs = PackedTensor(candidate_idcs, lengths=num_candidates_per_query_vector)
57 return packed_candidate_scores, packed_candidate_idcs
58
59 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor:
60 return torch.from_numpy(self.index.reconstruct_batch(idcs))
61
62
[docs]
63class FaissSearchConfig(ApproximateSearchConfig):
64 search_class = FaissSearcher
65 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}
66
[docs]
67 def __init__(
68 self,
69 k: int = 10,
70 candidate_k: int = 100,
71 imputation_strategy: Literal["min", "gather", "zero"] = "gather",
72 n_probe: int = 1,
73 ef_search: int = 16,
74 ) -> None:
75 super().__init__(k, candidate_k, imputation_strategy)
76 self.n_probe = n_probe
77 self.ef_search = ef_search