1"""Plaid Searcher for Lightning IR Framework"""
2
3from __future__ import annotations
4
5from pathlib import Path
6from typing import TYPE_CHECKING, List, Tuple
7
8import torch
9
10from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
11from ...models import ColConfig
12from ..base.packed_tensor import PackedTensor
13from ..base.searcher import SearchConfig, Searcher
14from .plaid_indexer import PlaidIndexConfig
15from .residual_codec import ResidualCodec
16
17if TYPE_CHECKING:
18 from ...bi_encoder import BiEncoderModule, BiEncoderOutput
19
20
[docs]
21class PlaidSearcher(Searcher):
22 """Searcher for Plaid, a residual-based search method for efficient retrieval."""
23
[docs]
24 def __init__(
25 self, index_dir: Path | str, search_config: PlaidSearchConfig, module: BiEncoderModule, use_gpu: bool = False
26 ) -> None:
27 """Initialize the PlaidSearcher.
28
29 Args:
30 index_dir (Path | str): Directory where the Plaid index is stored.
31 search_config (PlaidSearchConfig): Configuration for the Plaid searcher.
32 module (BiEncoderModule): The BiEncoder module used for searching.
33 use_gpu (bool): Whether to use GPU for searching. Defaults to False.
34 """
35 super().__init__(index_dir, search_config, module, use_gpu)
36 self.residual_codec = ResidualCodec.from_pretrained(
37 PlaidIndexConfig.from_pretrained(self.index_dir), self.index_dir, device=self.device
38 )
39
40 self.codes = torch.load(self.index_dir / "codes.pt", weights_only=True).to(self.device)
41 self.residuals = (
42 torch.load(self.index_dir / "residuals.pt", weights_only=True).view(self.codes.shape[0], -1).to(self.device)
43 )
44 self.packed_codes = PackedTensor(self.codes, lengths=self.doc_lengths.tolist()).to(self.device)
45 self.packed_residuals = PackedTensor(self.residuals, lengths=self.doc_lengths.tolist()).to(self.device)
46
47 # code_idx to embedding_idcs mapping
48 sorted_codes, embedding_idcs = self.codes.sort()
49 num_embeddings_per_code = torch.bincount(sorted_codes, minlength=self.residual_codec.num_centroids).tolist()
50
51 # code_idx to doc_idcs mapping
52 embedding_idx_to_doc_idx = torch.arange(self.num_docs, device=self.device).repeat_interleave(self.doc_lengths)
53 full_doc_ivf = embedding_idx_to_doc_idx[embedding_idcs]
54 doc_ivf_lengths = []
55 unique_doc_idcs = []
56 for doc_idcs in full_doc_ivf.split(num_embeddings_per_code):
57 unique_doc_idcs.append(doc_idcs.unique())
58 doc_ivf_lengths.append(unique_doc_idcs[-1].shape[0])
59 self.code_to_doc_ivf = PackedTensor(torch.cat(unique_doc_idcs), lengths=doc_ivf_lengths)
60
61 # doc_idx to code_idcs mapping
62 sorted_doc_idcs, doc_idx_to_code_idx = torch.sort(self.code_to_doc_ivf)
63 code_idcs = torch.arange(self.residual_codec.num_centroids, device=self.device).repeat_interleave(
64 torch.tensor(self.code_to_doc_ivf.lengths, device=self.device)
65 )[doc_idx_to_code_idx]
66 num_codes_per_doc = torch.bincount(sorted_doc_idcs, minlength=self.num_docs)
67 self.doc_to_code_ivf = PackedTensor(code_idcs, lengths=num_codes_per_doc.tolist())
68
69 self.search_config: PlaidSearchConfig
70
71 def _centroid_candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]:
72 # grab top `n_cells` neighbor cells for all query embeddings
73 # `num_queries x query_length x num_centroids`
74 centroid_scores = (
75 query_embeddings.embeddings.to(self.residual_codec.centroids)
76 @ self.residual_codec.centroids.transpose(0, 1)[None]
77 ).to(self.device)
78 query_scoring_mask = query_embeddings.scoring_mask
79 centroid_scores = centroid_scores.masked_fill(~query_scoring_mask[..., None], 0)
80 _, codes = torch.topk(centroid_scores, self.search_config.n_cells, dim=-1, sorted=False)
81 packed_codes = codes[query_embeddings.scoring_mask].view(-1)
82 code_lengths = (query_embeddings.scoring_mask.sum(-1) * self.search_config.n_cells).tolist()
83
84 # grab document idcs for all cells
85 packed_doc_idcs = self.code_to_doc_ivf.lookup(packed_codes, code_lengths, unique=True)
86
87 # NOTE no idea why we do two filter steps (the first with a threshold, the second without)
88 # filter step 1
89 _, filtered_doc_idcs = self._filter_candidates(
90 centroid_scores=centroid_scores,
91 doc_idcs=packed_doc_idcs,
92 threshold=self.search_config.centroid_score_threshold,
93 k=self.search_config.candidate_k,
94 query_scoring_mask=query_scoring_mask,
95 )
96 # filter step 2
97 filtered_scores, filtered_doc_idcs = self._filter_candidates(
98 centroid_scores=centroid_scores,
99 doc_idcs=filtered_doc_idcs,
100 threshold=None,
101 k=self.search_config.candidate_k // 4,
102 query_scoring_mask=query_scoring_mask,
103 )
104 return filtered_scores, filtered_doc_idcs
105
106 def _filter_candidates(
107 self,
108 centroid_scores: torch.Tensor,
109 doc_idcs: PackedTensor,
110 threshold: float | None,
111 k: int,
112 query_scoring_mask: torch.Tensor,
113 ) -> Tuple[PackedTensor, PackedTensor]:
114 num_query_vecs = centroid_scores.shape[1]
115 num_centroids = centroid_scores.shape[-1]
116
117 # repeat query centroid scores for each document
118 # `num_docs x num_query_vecs x num_centroids + 1`
119 # NOTE we pad values such that the codes with -1 padding index 0 values
120 expanded_centroid_scores = torch.nn.functional.pad(
121 centroid_scores.repeat_interleave(torch.tensor(doc_idcs.lengths, device=self.device), dim=0), (0, 1)
122 )
123
124 # grab codes for each document
125 code_idcs = self.doc_to_code_ivf.lookup(doc_idcs, 1)
126 # `num_docs x max_num_codes_per_doc`
127 padded_codes = code_idcs.to_padded_tensor(pad_value=num_centroids)
128 mask = padded_codes != num_centroids
129 # `num_docs x max_num_query_vecs x max_num_codes_per_doc`
130 padded_codes = padded_codes[:, None].expand(-1, num_query_vecs, -1)
131
132 # apply pruning threshold
133 if threshold is not None and threshold:
134 expanded_centroid_scores = expanded_centroid_scores.masked_fill(
135 expanded_centroid_scores.amax(1, keepdim=True) < threshold, 0
136 )
137
138 # NOTE this is colbert scoring, but instead of using the doc embeddings we use the centroid scores
139 # expanded_centroid_scores: `num_docs x max_num_query_vecs x num_centroids + 1 `
140 # padded_codes: `num_docs x max_num_query_vecs x max_num_codes_per_doc`
141 # approx_similarity: `num_docs x max_num_query_vecs x max_num_codes_per_doc`
142 approx_similarity = torch.gather(input=expanded_centroid_scores, dim=-1, index=padded_codes)
143 approx_scores = self.module.model.aggregate_similarity(
144 approx_similarity,
145 query_scoring_mask=query_scoring_mask,
146 doc_scoring_mask=mask[:, None],
147 num_docs=doc_idcs.lengths,
148 )
149 packed_approx_scores = PackedTensor(approx_scores, lengths=doc_idcs.lengths)
150 filtered_scores, filtered_doc_idcs = self._filter_and_sort(packed_approx_scores, doc_idcs, k)
151 return filtered_scores, filtered_doc_idcs
152
153 def _reconstruct_doc_embeddings(self, candidate_doc_idcs: PackedTensor) -> BiEncoderEmbedding:
154 doc_embedding_codes = self.packed_codes.lookup(candidate_doc_idcs, 1)
155 doc_embedding_residuals = self.packed_residuals.lookup(candidate_doc_idcs, 1)
156 doc_embeddings = self.residual_codec.decompress(doc_embedding_codes, doc_embedding_residuals)
157 padded_doc_embeddings = doc_embeddings.to_padded_tensor()
158 doc_scoring_mask = padded_doc_embeddings[..., 0] != 0
159 return BiEncoderEmbedding(padded_doc_embeddings, doc_scoring_mask, None)
160
[docs]
161 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]:
162 """Search for relevant documents using the Plaid index.
163
164 Args:
165 output (BiEncoderOutput): The output from the BiEncoder module containing query embeddings.
166 Returns:
167 Tuple[PackedTensor, List[List[str]]]: A tuple containing the scores and the corresponding document IDs.
168 Raises:
169 ValueError: If the output does not contain query embeddings.
170 """
171 query_embeddings = output.query_embeddings
172 if query_embeddings is None:
173 raise ValueError("Expected query_embeddings in BiEncoderOutput")
174 query_embeddings = query_embeddings.to(self.device)
175
176 _, candidate_idcs = self._centroid_candidate_retrieval(query_embeddings)
177 num_docs = candidate_idcs.lengths
178
179 # compute scores
180 doc_embeddings = self._reconstruct_doc_embeddings(candidate_idcs)
181 output.doc_embeddings = doc_embeddings
182 output = self.module.model.score(output, num_docs)
183 scores = output.scores
184 if scores is None:
185 raise ValueError("Expected scores in BiEncoderOutput")
186
187 scores, doc_idcs = self._filter_and_sort(PackedTensor(scores, lengths=candidate_idcs.lengths), candidate_idcs)
188 doc_ids = [
189 [self.doc_ids[doc_idx] for doc_idx in _doc_ids.tolist()] for _doc_ids in doc_idcs.split(doc_idcs.lengths)
190 ]
191 return scores, doc_ids
192
193
[docs]
194class PlaidSearchConfig(SearchConfig):
195 """Configuration class for Plaid searchers in the Lightning IR framework."""
196
197 search_class = PlaidSearcher
198 SUPPORTED_MODELS = {ColConfig.model_type}
199
[docs]
200 def __init__(
201 self,
202 k: int,
203 candidate_k: int = 256,
204 n_cells: int = 1,
205 centroid_score_threshold: float = 0.5,
206 ) -> None:
207 """Initialize the PlaidSearchConfig.
208
209 Args:
210 k (int): Number of top documents to retrieve.
211 candidate_k (int): Number of candidate documents to consider for scoring. Defaults to 256.
212 n_cells (int): Number of cells to use for centroid retrieval. Defaults to 1.
213 centroid_score_threshold (float): Threshold for filtering candidates based on centroid scores.
214 Defaults to 0.5.
215 """
216 super().__init__(k)
217 self.candidate_k = candidate_k
218 self.n_cells = n_cells
219 self.centroid_score_threshold = centroid_score_threshold