Source code for lightning_ir.retrieve.plaid.plaid_searcher
1"""Plaid Searcher using fast-plaid library for Lightning IR Framework"""
2
3from pathlib import Path
4
5import torch
6
7from ...bi_encoder import BiEncoderModule, BiEncoderOutput
8from ...models import ColConfig
9from ..base.packed_tensor import PackedTensor
10from ..base.searcher import SearchConfig, Searcher
11
12
[docs]
13class PlaidSearcher(Searcher):
14 """Searcher for Plaid using fast-plaid library."""
15
[docs]
16 def __init__(
17 self,
18 index_dir: Path,
19 search_config: "PlaidSearchConfig",
20 module: BiEncoderModule,
21 use_gpu: bool = False,
22 ) -> None:
23 """Initialize the PlaidSearcher.
24
25 Args:
26 index_dir (Path | str): Directory where the Plaid index is stored.
27 search_config (PlaidSearchConfig): Configuration for the Plaid searcher.
28 module (BiEncoderModule): The BiEncoder module used for searching.
29 use_gpu (bool): Whether to use GPU for searching. Defaults to False.
30 """
31 from fast_plaid import search
32
33 super().__init__(index_dir, search_config, module, use_gpu)
34 self.search_config: PlaidSearchConfig
35
36 self.index = search.FastPlaid(index=str(self.index_dir), device=self.device.type, preload_index=True)
37
[docs]
38 def search(self, output: BiEncoderOutput) -> tuple[PackedTensor, list[list[str]]]:
39 """Search for relevant documents using the Plaid index.
40
41 Args:
42 output (BiEncoderOutput): The output from the BiEncoder module containing query embeddings.
43 Returns:
44 tuple[PackedTensor, list[list[str]]]: A tuple containing the scores and the corresponding document IDs.
45 Raises:
46 ValueError: If the output does not contain query embeddings.
47 ValueError: If the index is not loaded. Call load() before searching.
48 """
49 query_embeddings = output.query_embeddings
50 if query_embeddings is None:
51 raise ValueError("Expected query_embeddings in BiEncoderOutput")
52
53 scores = self.index.search(
54 queries_embeddings=query_embeddings.embeddings,
55 top_k=self.search_config.k,
56 )
57 all_doc_ids = []
58 all_scores = []
59 for result in scores:
60 doc_indices, doc_scores = zip(*result)
61 doc_ids = [self.doc_ids[idx] for idx in doc_indices]
62 all_doc_ids.append(doc_ids)
63 all_scores.append(list(doc_scores))
64
65 lengths = [len(doc_scores) for doc_scores in all_scores]
66 flat_scores = [score for sublist in all_scores for score in sublist]
67 return PackedTensor(torch.tensor(flat_scores), lengths=lengths), all_doc_ids
68
69
[docs]
70class PlaidSearchConfig(SearchConfig):
71 """Configuration class for Plaid searchers in the Lightning IR framework."""
72
73 search_class = PlaidSearcher
74 SUPPORTED_MODELS = {ColConfig.model_type}
75
[docs]
76 def __init__(
77 self,
78 k: int,
79 candidate_k: int = 256,
80 n_cells: int = 1,
81 centroid_score_threshold: float = 0.5,
82 ) -> None:
83 """Initialize the PlaidSearchConfig.
84
85 Args:
86 k (int): Number of top documents to retrieve.
87 candidate_k (int): Number of candidate documents to consider for scoring. Defaults to 256.
88 n_cells (int): Number of cells to use for centroid retrieval. Defaults to 1.
89 centroid_score_threshold (float): Threshold for filtering candidates based on centroid scores.
90 Defaults to 0.5.
91 """
92 super().__init__(k)
93 self.candidate_k = candidate_k
94 self.n_cells = n_cells
95 self.centroid_score_threshold = centroid_score_threshold