Source code for lightning_ir.retrieve.plaid.plaid_searcher

 1"""Plaid Searcher using fast-plaid library for Lightning IR Framework"""
 2
 3from pathlib import Path
 4
 5import torch
 6
 7from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 8from ...models import ColConfig
 9from ..base.packed_tensor import PackedTensor
10from ..base.searcher import SearchConfig, Searcher
11
12
[docs] 13class PlaidSearcher(Searcher): 14 """Searcher for Plaid using fast-plaid library.""" 15
[docs] 16 def __init__( 17 self, 18 index_dir: Path, 19 search_config: "PlaidSearchConfig", 20 module: BiEncoderModule, 21 use_gpu: bool = False, 22 ) -> None: 23 """Initialize the PlaidSearcher. 24 25 Args: 26 index_dir (Path | str): Directory where the Plaid index is stored. 27 search_config (PlaidSearchConfig): Configuration for the Plaid searcher. 28 module (BiEncoderModule): The BiEncoder module used for searching. 29 use_gpu (bool): Whether to use GPU for searching. Defaults to False. 30 """ 31 from fast_plaid import search 32 33 super().__init__(index_dir, search_config, module, use_gpu) 34 self.search_config: PlaidSearchConfig 35 36 self.index = search.FastPlaid(index=str(self.index_dir), device=self.device.type, preload_index=True)
37
[docs] 38 def search(self, output: BiEncoderOutput) -> tuple[PackedTensor, list[list[str]]]: 39 """Search for relevant documents using the Plaid index. 40 41 Args: 42 output (BiEncoderOutput): The output from the BiEncoder module containing query embeddings. 43 Returns: 44 tuple[PackedTensor, list[list[str]]]: A tuple containing the scores and the corresponding document IDs. 45 Raises: 46 ValueError: If the output does not contain query embeddings. 47 ValueError: If the index is not loaded. Call load() before searching. 48 """ 49 query_embeddings = output.query_embeddings 50 if query_embeddings is None: 51 raise ValueError("Expected query_embeddings in BiEncoderOutput") 52 53 scores = self.index.search( 54 queries_embeddings=query_embeddings.embeddings, 55 top_k=self.search_config.k, 56 ) 57 all_doc_ids = [] 58 all_scores = [] 59 for result in scores: 60 doc_indices, doc_scores = zip(*result) 61 doc_ids = [self.doc_ids[idx] for idx in doc_indices] 62 all_doc_ids.append(doc_ids) 63 all_scores.append(list(doc_scores)) 64 65 lengths = [len(doc_scores) for doc_scores in all_scores] 66 flat_scores = [score for sublist in all_scores for score in sublist] 67 return PackedTensor(torch.tensor(flat_scores), lengths=lengths), all_doc_ids
68 69
[docs] 70class PlaidSearchConfig(SearchConfig): 71 """Configuration class for Plaid searchers in the Lightning IR framework.""" 72 73 search_class = PlaidSearcher 74 SUPPORTED_MODELS = {ColConfig.model_type} 75
[docs] 76 def __init__( 77 self, 78 k: int, 79 candidate_k: int = 256, 80 n_cells: int = 1, 81 centroid_score_threshold: float = 0.5, 82 ) -> None: 83 """Initialize the PlaidSearchConfig. 84 85 Args: 86 k (int): Number of top documents to retrieve. 87 candidate_k (int): Number of candidate documents to consider for scoring. Defaults to 256. 88 n_cells (int): Number of cells to use for centroid retrieval. Defaults to 1. 89 centroid_score_threshold (float): Threshold for filtering candidates based on centroid scores. 90 Defaults to 0.5. 91 """ 92 super().__init__(k) 93 self.candidate_k = candidate_k 94 self.n_cells = n_cells 95 self.centroid_score_threshold = centroid_score_threshold