Source code for lightning_ir.retrieve.plaid.plaid_searcher

  1"""Plaid Searcher for Lightning IR Framework"""
  2
  3from __future__ import annotations
  4
  5from pathlib import Path
  6from typing import TYPE_CHECKING, List, Tuple
  7
  8import torch
  9
 10from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
 11from ...models import ColConfig
 12from ..base.packed_tensor import PackedTensor
 13from ..base.searcher import SearchConfig, Searcher
 14from .plaid_indexer import PlaidIndexConfig
 15from .residual_codec import ResidualCodec
 16
 17if TYPE_CHECKING:
 18    from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 19
 20
[docs] 21class PlaidSearcher(Searcher): 22 """Searcher for Plaid, a residual-based search method for efficient retrieval.""" 23
[docs] 24 def __init__( 25 self, index_dir: Path | str, search_config: PlaidSearchConfig, module: BiEncoderModule, use_gpu: bool = False 26 ) -> None: 27 """Initialize the PlaidSearcher. 28 29 Args: 30 index_dir (Path | str): Directory where the Plaid index is stored. 31 search_config (PlaidSearchConfig): Configuration for the Plaid searcher. 32 module (BiEncoderModule): The BiEncoder module used for searching. 33 use_gpu (bool): Whether to use GPU for searching. Defaults to False. 34 """ 35 super().__init__(index_dir, search_config, module, use_gpu) 36 self.residual_codec = ResidualCodec.from_pretrained( 37 PlaidIndexConfig.from_pretrained(self.index_dir), self.index_dir, device=self.device 38 ) 39 40 self.codes = torch.load(self.index_dir / "codes.pt", weights_only=True).to(self.device) 41 self.residuals = ( 42 torch.load(self.index_dir / "residuals.pt", weights_only=True).view(self.codes.shape[0], -1).to(self.device) 43 ) 44 self.packed_codes = PackedTensor(self.codes, lengths=self.doc_lengths.tolist()).to(self.device) 45 self.packed_residuals = PackedTensor(self.residuals, lengths=self.doc_lengths.tolist()).to(self.device) 46 47 # code_idx to embedding_idcs mapping 48 sorted_codes, embedding_idcs = self.codes.sort() 49 num_embeddings_per_code = torch.bincount(sorted_codes, minlength=self.residual_codec.num_centroids).tolist() 50 51 # code_idx to doc_idcs mapping 52 embedding_idx_to_doc_idx = torch.arange(self.num_docs, device=self.device).repeat_interleave(self.doc_lengths) 53 full_doc_ivf = embedding_idx_to_doc_idx[embedding_idcs] 54 doc_ivf_lengths = [] 55 unique_doc_idcs = [] 56 for doc_idcs in full_doc_ivf.split(num_embeddings_per_code): 57 unique_doc_idcs.append(doc_idcs.unique()) 58 doc_ivf_lengths.append(unique_doc_idcs[-1].shape[0]) 59 self.code_to_doc_ivf = PackedTensor(torch.cat(unique_doc_idcs), lengths=doc_ivf_lengths) 60 61 # doc_idx to code_idcs mapping 62 sorted_doc_idcs, doc_idx_to_code_idx = torch.sort(self.code_to_doc_ivf) 63 code_idcs = torch.arange(self.residual_codec.num_centroids, device=self.device).repeat_interleave( 64 torch.tensor(self.code_to_doc_ivf.lengths, device=self.device) 65 )[doc_idx_to_code_idx] 66 num_codes_per_doc = torch.bincount(sorted_doc_idcs, minlength=self.num_docs) 67 self.doc_to_code_ivf = PackedTensor(code_idcs, lengths=num_codes_per_doc.tolist()) 68 69 self.search_config: PlaidSearchConfig
70 71 def _centroid_candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]: 72 # grab top `n_cells` neighbor cells for all query embeddings 73 # `num_queries x query_length x num_centroids` 74 centroid_scores = ( 75 query_embeddings.embeddings.to(self.residual_codec.centroids) 76 @ self.residual_codec.centroids.transpose(0, 1)[None] 77 ).to(self.device) 78 query_scoring_mask = query_embeddings.scoring_mask 79 centroid_scores = centroid_scores.masked_fill(~query_scoring_mask[..., None], 0) 80 _, codes = torch.topk(centroid_scores, self.search_config.n_cells, dim=-1, sorted=False) 81 packed_codes = codes[query_embeddings.scoring_mask].view(-1) 82 code_lengths = (query_embeddings.scoring_mask.sum(-1) * self.search_config.n_cells).tolist() 83 84 # grab document idcs for all cells 85 packed_doc_idcs = self.code_to_doc_ivf.lookup(packed_codes, code_lengths, unique=True) 86 87 # NOTE no idea why we do two filter steps (the first with a threshold, the second without) 88 # filter step 1 89 _, filtered_doc_idcs = self._filter_candidates( 90 centroid_scores=centroid_scores, 91 doc_idcs=packed_doc_idcs, 92 threshold=self.search_config.centroid_score_threshold, 93 k=self.search_config.candidate_k, 94 query_scoring_mask=query_scoring_mask, 95 ) 96 # filter step 2 97 filtered_scores, filtered_doc_idcs = self._filter_candidates( 98 centroid_scores=centroid_scores, 99 doc_idcs=filtered_doc_idcs, 100 threshold=None, 101 k=self.search_config.candidate_k // 4, 102 query_scoring_mask=query_scoring_mask, 103 ) 104 return filtered_scores, filtered_doc_idcs 105 106 def _filter_candidates( 107 self, 108 centroid_scores: torch.Tensor, 109 doc_idcs: PackedTensor, 110 threshold: float | None, 111 k: int, 112 query_scoring_mask: torch.Tensor, 113 ) -> Tuple[PackedTensor, PackedTensor]: 114 num_query_vecs = centroid_scores.shape[1] 115 num_centroids = centroid_scores.shape[-1] 116 117 # repeat query centroid scores for each document 118 # `num_docs x num_query_vecs x num_centroids + 1` 119 # NOTE we pad values such that the codes with -1 padding index 0 values 120 expanded_centroid_scores = torch.nn.functional.pad( 121 centroid_scores.repeat_interleave(torch.tensor(doc_idcs.lengths, device=self.device), dim=0), (0, 1) 122 ) 123 124 # grab codes for each document 125 code_idcs = self.doc_to_code_ivf.lookup(doc_idcs, 1) 126 # `num_docs x max_num_codes_per_doc` 127 padded_codes = code_idcs.to_padded_tensor(pad_value=num_centroids) 128 mask = padded_codes != num_centroids 129 # `num_docs x max_num_query_vecs x max_num_codes_per_doc` 130 padded_codes = padded_codes[:, None].expand(-1, num_query_vecs, -1) 131 132 # apply pruning threshold 133 if threshold is not None and threshold: 134 expanded_centroid_scores = expanded_centroid_scores.masked_fill( 135 expanded_centroid_scores.amax(1, keepdim=True) < threshold, 0 136 ) 137 138 # NOTE this is colbert scoring, but instead of using the doc embeddings we use the centroid scores 139 # expanded_centroid_scores: `num_docs x max_num_query_vecs x num_centroids + 1 ` 140 # padded_codes: `num_docs x max_num_query_vecs x max_num_codes_per_doc` 141 # approx_similarity: `num_docs x max_num_query_vecs x max_num_codes_per_doc` 142 approx_similarity = torch.gather(input=expanded_centroid_scores, dim=-1, index=padded_codes) 143 approx_scores = self.module.model.aggregate_similarity( 144 approx_similarity, 145 query_scoring_mask=query_scoring_mask, 146 doc_scoring_mask=mask[:, None], 147 num_docs=doc_idcs.lengths, 148 ) 149 packed_approx_scores = PackedTensor(approx_scores, lengths=doc_idcs.lengths) 150 filtered_scores, filtered_doc_idcs = self._filter_and_sort(packed_approx_scores, doc_idcs, k) 151 return filtered_scores, filtered_doc_idcs 152 153 def _reconstruct_doc_embeddings(self, candidate_doc_idcs: PackedTensor) -> BiEncoderEmbedding: 154 doc_embedding_codes = self.packed_codes.lookup(candidate_doc_idcs, 1) 155 doc_embedding_residuals = self.packed_residuals.lookup(candidate_doc_idcs, 1) 156 doc_embeddings = self.residual_codec.decompress(doc_embedding_codes, doc_embedding_residuals) 157 padded_doc_embeddings = doc_embeddings.to_padded_tensor() 158 doc_scoring_mask = padded_doc_embeddings[..., 0] != 0 159 return BiEncoderEmbedding(padded_doc_embeddings, doc_scoring_mask, None) 160
[docs] 161 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: 162 """Search for relevant documents using the Plaid index. 163 164 Args: 165 output (BiEncoderOutput): The output from the BiEncoder module containing query embeddings. 166 Returns: 167 Tuple[PackedTensor, List[List[str]]]: A tuple containing the scores and the corresponding document IDs. 168 Raises: 169 ValueError: If the output does not contain query embeddings. 170 """ 171 query_embeddings = output.query_embeddings 172 if query_embeddings is None: 173 raise ValueError("Expected query_embeddings in BiEncoderOutput") 174 query_embeddings = query_embeddings.to(self.device) 175 176 _, candidate_idcs = self._centroid_candidate_retrieval(query_embeddings) 177 num_docs = candidate_idcs.lengths 178 179 # compute scores 180 doc_embeddings = self._reconstruct_doc_embeddings(candidate_idcs) 181 output.doc_embeddings = doc_embeddings 182 output = self.module.model.score(output, num_docs) 183 scores = output.scores 184 if scores is None: 185 raise ValueError("Expected scores in BiEncoderOutput") 186 187 scores, doc_idcs = self._filter_and_sort(PackedTensor(scores, lengths=candidate_idcs.lengths), candidate_idcs) 188 doc_ids = [ 189 [self.doc_ids[doc_idx] for doc_idx in _doc_ids.tolist()] for _doc_ids in doc_idcs.split(doc_idcs.lengths) 190 ] 191 return scores, doc_ids
192 193
[docs] 194class PlaidSearchConfig(SearchConfig): 195 """Configuration class for Plaid searchers in the Lightning IR framework.""" 196 197 search_class = PlaidSearcher 198 SUPPORTED_MODELS = {ColConfig.model_type} 199
[docs] 200 def __init__( 201 self, 202 k: int, 203 candidate_k: int = 256, 204 n_cells: int = 1, 205 centroid_score_threshold: float = 0.5, 206 ) -> None: 207 """Initialize the PlaidSearchConfig. 208 209 Args: 210 k (int): Number of top documents to retrieve. 211 candidate_k (int): Number of candidate documents to consider for scoring. Defaults to 256. 212 n_cells (int): Number of cells to use for centroid retrieval. Defaults to 1. 213 centroid_score_threshold (float): Threshold for filtering candidates based on centroid scores. 214 Defaults to 0.5. 215 """ 216 super().__init__(k) 217 self.candidate_k = candidate_k 218 self.n_cells = n_cells 219 self.centroid_score_threshold = centroid_score_threshold