Source code for lightning_ir.retrieve.plaid.plaid_searcher

 1"""Plaid Searcher using fast-plaid library for Lightning IR Framework"""
 2
 3from pathlib import Path
 4from typing import List, Tuple
 5
 6import torch
 7
 8from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 9from ...models import ColConfig
10from ..base.packed_tensor import PackedTensor
11from ..base.searcher import SearchConfig, Searcher
12
13
[docs] 14class PlaidSearcher(Searcher): 15 """Searcher for Plaid using fast-plaid library.""" 16
[docs] 17 def __init__( 18 self, 19 index_dir: Path, 20 search_config: "PlaidSearchConfig", 21 module: BiEncoderModule, 22 use_gpu: bool = False, 23 ) -> None: 24 """Initialize the PlaidSearcher. 25 26 Args: 27 index_dir (Path | str): Directory where the Plaid index is stored. 28 search_config (PlaidSearchConfig): Configuration for the Plaid searcher. 29 module (BiEncoderModule): The BiEncoder module used for searching. 30 use_gpu (bool): Whether to use GPU for searching. Defaults to False. 31 """ 32 from fast_plaid import search 33 34 super().__init__(index_dir, search_config, module, use_gpu) 35 self.search_config: PlaidSearchConfig 36 37 self.index = search.FastPlaid(index=str(self.index_dir), device=self.device.type, preload_index=True)
38
[docs] 39 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: 40 """Search for relevant documents using the Plaid index. 41 42 Args: 43 output (BiEncoderOutput): The output from the BiEncoder module containing query embeddings. 44 Returns: 45 Tuple[PackedTensor, List[List[str]]]: A tuple containing the scores and the corresponding document IDs. 46 Raises: 47 ValueError: If the output does not contain query embeddings. 48 ValueError: If the index is not loaded. Call load() before searching. 49 """ 50 query_embeddings = output.query_embeddings 51 if query_embeddings is None: 52 raise ValueError("Expected query_embeddings in BiEncoderOutput") 53 54 scores = self.index.search( 55 queries_embeddings=query_embeddings.embeddings, 56 top_k=self.search_config.k, 57 ) 58 all_doc_ids = [] 59 all_scores = [] 60 for result in scores: 61 doc_indices, doc_scores = zip(*result) 62 doc_ids = [self.doc_ids[idx] for idx in doc_indices] 63 all_doc_ids.append(doc_ids) 64 all_scores.append(list(doc_scores)) 65 66 lengths = [len(doc_scores) for doc_scores in all_scores] 67 flat_scores = [score for sublist in all_scores for score in sublist] 68 return PackedTensor(torch.tensor(flat_scores), lengths=lengths), all_doc_ids
69 70
[docs] 71class PlaidSearchConfig(SearchConfig): 72 """Configuration class for Plaid searchers in the Lightning IR framework.""" 73 74 search_class = PlaidSearcher 75 SUPPORTED_MODELS = {ColConfig.model_type} 76
[docs] 77 def __init__( 78 self, 79 k: int, 80 candidate_k: int = 256, 81 n_cells: int = 1, 82 centroid_score_threshold: float = 0.5, 83 ) -> None: 84 """Initialize the PlaidSearchConfig. 85 86 Args: 87 k (int): Number of top documents to retrieve. 88 candidate_k (int): Number of candidate documents to consider for scoring. Defaults to 256. 89 n_cells (int): Number of cells to use for centroid retrieval. Defaults to 1. 90 centroid_score_threshold (float): Threshold for filtering candidates based on centroid scores. 91 Defaults to 0.5. 92 """ 93 super().__init__(k) 94 self.candidate_k = candidate_k 95 self.n_cells = n_cells 96 self.centroid_score_threshold = centroid_score_threshold