Source code for lightning_ir.retrieve.plaid.plaid_searcher

  1from __future__ import annotations
  2
  3from pathlib import Path
  4from typing import TYPE_CHECKING, List, Tuple
  5
  6import torch
  7
  8from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
  9from ...models import ColConfig
 10from ..base.packed_tensor import PackedTensor
 11from ..base.searcher import SearchConfig, Searcher
 12from .plaid_indexer import PlaidIndexConfig
 13from .residual_codec import ResidualCodec
 14
 15if TYPE_CHECKING:
 16    from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 17
 18
[docs] 19class PlaidSearcher(Searcher):
[docs] 20 def __init__( 21 self, index_dir: Path | str, search_config: PlaidSearchConfig, module: BiEncoderModule, use_gpu: bool = False 22 ) -> None: 23 super().__init__(index_dir, search_config, module, use_gpu) 24 self.residual_codec = ResidualCodec.from_pretrained( 25 PlaidIndexConfig.from_pretrained(self.index_dir), self.index_dir, device=self.device 26 ) 27 28 self.codes = torch.load(self.index_dir / "codes.pt", weights_only=True).to(self.device) 29 self.residuals = ( 30 torch.load(self.index_dir / "residuals.pt", weights_only=True).view(self.codes.shape[0], -1).to(self.device) 31 ) 32 self.packed_codes = PackedTensor(self.codes, lengths=self.doc_lengths.tolist()).to(self.device) 33 self.packed_residuals = PackedTensor(self.residuals, lengths=self.doc_lengths.tolist()).to(self.device) 34 35 # code_idx to embedding_idcs mapping 36 sorted_codes, embedding_idcs = self.codes.sort() 37 num_embeddings_per_code = torch.bincount(sorted_codes, minlength=self.residual_codec.num_centroids).tolist() 38 39 # code_idx to doc_idcs mapping 40 embedding_idx_to_doc_idx = torch.arange(self.num_docs, device=self.device).repeat_interleave(self.doc_lengths) 41 full_doc_ivf = embedding_idx_to_doc_idx[embedding_idcs] 42 doc_ivf_lengths = [] 43 unique_doc_idcs = [] 44 for doc_idcs in full_doc_ivf.split(num_embeddings_per_code): 45 unique_doc_idcs.append(doc_idcs.unique()) 46 doc_ivf_lengths.append(unique_doc_idcs[-1].shape[0]) 47 self.code_to_doc_ivf = PackedTensor(torch.cat(unique_doc_idcs), lengths=doc_ivf_lengths) 48 49 # doc_idx to code_idcs mapping 50 sorted_doc_idcs, doc_idx_to_code_idx = torch.sort(self.code_to_doc_ivf) 51 code_idcs = torch.arange(self.residual_codec.num_centroids, device=self.device).repeat_interleave( 52 torch.tensor(self.code_to_doc_ivf.lengths, device=self.device) 53 )[doc_idx_to_code_idx] 54 num_codes_per_doc = torch.bincount(sorted_doc_idcs, minlength=self.num_docs) 55 self.doc_to_code_ivf = PackedTensor(code_idcs, lengths=num_codes_per_doc.tolist()) 56 57 self.search_config: PlaidSearchConfig
58 59 def _centroid_candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]: 60 # grab top `n_cells` neighbor cells for all query embeddings 61 # `num_queries x query_length x num_centroids` 62 centroid_scores = ( 63 query_embeddings.embeddings.to(self.residual_codec.centroids) 64 @ self.residual_codec.centroids.transpose(0, 1)[None] 65 ).to(self.device) 66 query_scoring_mask = query_embeddings.scoring_mask 67 centroid_scores = centroid_scores.masked_fill(~query_scoring_mask[..., None], 0) 68 _, codes = torch.topk(centroid_scores, self.search_config.n_cells, dim=-1, sorted=False) 69 packed_codes = codes[query_embeddings.scoring_mask].view(-1) 70 code_lengths = (query_embeddings.scoring_mask.sum(-1) * self.search_config.n_cells).tolist() 71 72 # grab document idcs for all cells 73 packed_doc_idcs = self.code_to_doc_ivf.lookup(packed_codes, code_lengths, unique=True) 74 75 # NOTE no idea why we do two filter steps (the first with a threshold, the second without) 76 # filter step 1 77 _, filtered_doc_idcs = self._filter_candidates( 78 centroid_scores=centroid_scores, 79 doc_idcs=packed_doc_idcs, 80 threshold=self.search_config.centroid_score_threshold, 81 k=self.search_config.candidate_k, 82 query_scoring_mask=query_scoring_mask, 83 ) 84 # filter step 2 85 filtered_scores, filtered_doc_idcs = self._filter_candidates( 86 centroid_scores=centroid_scores, 87 doc_idcs=filtered_doc_idcs, 88 threshold=None, 89 k=self.search_config.candidate_k // 4, 90 query_scoring_mask=query_scoring_mask, 91 ) 92 return filtered_scores, filtered_doc_idcs 93 94 def _filter_candidates( 95 self, 96 centroid_scores: torch.Tensor, 97 doc_idcs: PackedTensor, 98 threshold: float | None, 99 k: int, 100 query_scoring_mask: torch.Tensor, 101 ) -> Tuple[PackedTensor, PackedTensor]: 102 num_query_vecs = centroid_scores.shape[1] 103 num_centroids = centroid_scores.shape[-1] 104 105 # repeat query centroid scores for each document 106 # `num_docs x num_query_vecs x num_centroids + 1` 107 # NOTE we pad values such that the codes with -1 padding index 0 values 108 expanded_centroid_scores = torch.nn.functional.pad( 109 centroid_scores.repeat_interleave(torch.tensor(doc_idcs.lengths, device=self.device), dim=0), (0, 1) 110 ) 111 112 # grab codes for each document 113 code_idcs = self.doc_to_code_ivf.lookup(doc_idcs, 1) 114 # `num_docs x max_num_codes_per_doc` 115 padded_codes = code_idcs.to_padded_tensor(pad_value=num_centroids) 116 mask = padded_codes != num_centroids 117 # `num_docs x max_num_query_vecs x max_num_codes_per_doc` 118 padded_codes = padded_codes[:, None].expand(-1, num_query_vecs, -1) 119 120 # apply pruning threshold 121 if threshold is not None and threshold: 122 expanded_centroid_scores = expanded_centroid_scores.masked_fill( 123 expanded_centroid_scores.amax(1, keepdim=True) < threshold, 0 124 ) 125 126 # NOTE this is colbert scoring, but instead of using the doc embeddings we use the centroid scores 127 # expanded_centroid_scores: `num_docs x max_num_query_vecs x num_centroids + 1 ` 128 # padded_codes: `num_docs x max_num_query_vecs x max_num_codes_per_doc` 129 # approx_similarity: `num_docs x max_num_query_vecs x max_num_codes_per_doc` 130 approx_similarity = torch.gather(input=expanded_centroid_scores, dim=-1, index=padded_codes) 131 approx_scores = self.module.model.aggregate_similarity( 132 approx_similarity, 133 query_scoring_mask=query_scoring_mask, 134 doc_scoring_mask=mask[:, None], 135 num_docs=doc_idcs.lengths, 136 ) 137 packed_approx_scores = PackedTensor(approx_scores, lengths=doc_idcs.lengths) 138 filtered_scores, filtered_doc_idcs = self._filter_and_sort(packed_approx_scores, doc_idcs, k) 139 return filtered_scores, filtered_doc_idcs 140 141 def _reconstruct_doc_embeddings(self, candidate_doc_idcs: PackedTensor) -> BiEncoderEmbedding: 142 doc_embedding_codes = self.packed_codes.lookup(candidate_doc_idcs, 1) 143 doc_embedding_residuals = self.packed_residuals.lookup(candidate_doc_idcs, 1) 144 doc_embeddings = self.residual_codec.decompress(doc_embedding_codes, doc_embedding_residuals) 145 padded_doc_embeddings = doc_embeddings.to_padded_tensor() 146 doc_scoring_mask = padded_doc_embeddings[..., 0] != 0 147 return BiEncoderEmbedding(padded_doc_embeddings, doc_scoring_mask, None) 148 149 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: 150 query_embeddings = output.query_embeddings 151 if query_embeddings is None: 152 raise ValueError("Expected query_embeddings in BiEncoderOutput") 153 query_embeddings = query_embeddings.to(self.device) 154 155 _, candidate_idcs = self._centroid_candidate_retrieval(query_embeddings) 156 num_docs = candidate_idcs.lengths 157 158 # compute scores 159 doc_embeddings = self._reconstruct_doc_embeddings(candidate_idcs) 160 scores = self.module.model.score(query_embeddings, doc_embeddings, num_docs) 161 162 scores, doc_idcs = self._filter_and_sort(PackedTensor(scores, lengths=candidate_idcs.lengths), candidate_idcs) 163 doc_ids = [ 164 [self.doc_ids[doc_idx] for doc_idx in _doc_ids.tolist()] for _doc_ids in doc_idcs.split(doc_idcs.lengths) 165 ] 166 return scores, doc_ids
167 168
[docs] 169class PlaidSearchConfig(SearchConfig): 170 171 search_class = PlaidSearcher 172 SUPPORTED_MODELS = {ColConfig.model_type} 173
[docs] 174 def __init__( 175 self, 176 k: int, 177 candidate_k: int = 256, 178 n_cells: int = 1, 179 centroid_score_threshold: float = 0.5, 180 ) -> None: 181 super().__init__(k) 182 self.candidate_k = candidate_k 183 self.n_cells = n_cells 184 self.centroid_score_threshold = centroid_score_threshold