Source code for lightning_ir.retrieve.plaid.plaid_searcher
1"""Plaid Searcher using fast-plaid library for Lightning IR Framework"""
2
3from pathlib import Path
4from typing import List, Tuple
5
6import torch
7
8from ...bi_encoder import BiEncoderModule, BiEncoderOutput
9from ...models import ColConfig
10from ..base.packed_tensor import PackedTensor
11from ..base.searcher import SearchConfig, Searcher
12
13
[docs]
14class PlaidSearcher(Searcher):
15 """Searcher for Plaid using fast-plaid library."""
16
[docs]
17 def __init__(
18 self,
19 index_dir: Path,
20 search_config: "PlaidSearchConfig",
21 module: BiEncoderModule,
22 use_gpu: bool = False,
23 ) -> None:
24 """Initialize the PlaidSearcher.
25
26 Args:
27 index_dir (Path | str): Directory where the Plaid index is stored.
28 search_config (PlaidSearchConfig): Configuration for the Plaid searcher.
29 module (BiEncoderModule): The BiEncoder module used for searching.
30 use_gpu (bool): Whether to use GPU for searching. Defaults to False.
31 """
32 from fast_plaid import search
33
34 super().__init__(index_dir, search_config, module, use_gpu)
35 self.search_config: PlaidSearchConfig
36
37 self.index = search.FastPlaid(index=str(self.index_dir), device=self.device.type, preload_index=True)
38
[docs]
39 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]:
40 """Search for relevant documents using the Plaid index.
41
42 Args:
43 output (BiEncoderOutput): The output from the BiEncoder module containing query embeddings.
44 Returns:
45 Tuple[PackedTensor, List[List[str]]]: A tuple containing the scores and the corresponding document IDs.
46 Raises:
47 ValueError: If the output does not contain query embeddings.
48 ValueError: If the index is not loaded. Call load() before searching.
49 """
50 query_embeddings = output.query_embeddings
51 if query_embeddings is None:
52 raise ValueError("Expected query_embeddings in BiEncoderOutput")
53
54 scores = self.index.search(
55 queries_embeddings=query_embeddings.embeddings,
56 top_k=self.search_config.k,
57 )
58 all_doc_ids = []
59 all_scores = []
60 for result in scores:
61 doc_indices, doc_scores = zip(*result)
62 doc_ids = [self.doc_ids[idx] for idx in doc_indices]
63 all_doc_ids.append(doc_ids)
64 all_scores.append(list(doc_scores))
65
66 lengths = [len(doc_scores) for doc_scores in all_scores]
67 flat_scores = [score for sublist in all_scores for score in sublist]
68 return PackedTensor(torch.tensor(flat_scores), lengths=lengths), all_doc_ids
69
70
[docs]
71class PlaidSearchConfig(SearchConfig):
72 """Configuration class for Plaid searchers in the Lightning IR framework."""
73
74 search_class = PlaidSearcher
75 SUPPORTED_MODELS = {ColConfig.model_type}
76
[docs]
77 def __init__(
78 self,
79 k: int,
80 candidate_k: int = 256,
81 n_cells: int = 1,
82 centroid_score_threshold: float = 0.5,
83 ) -> None:
84 """Initialize the PlaidSearchConfig.
85
86 Args:
87 k (int): Number of top documents to retrieve.
88 candidate_k (int): Number of candidate documents to consider for scoring. Defaults to 256.
89 n_cells (int): Number of cells to use for centroid retrieval. Defaults to 1.
90 centroid_score_threshold (float): Threshold for filtering candidates based on centroid scores.
91 Defaults to 0.5.
92 """
93 super().__init__(k)
94 self.candidate_k = candidate_k
95 self.n_cells = n_cells
96 self.centroid_score_threshold = centroid_score_threshold