1from __future__ import annotations
2
3from pathlib import Path
4from typing import TYPE_CHECKING, List, Tuple
5
6import torch
7
8from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding
9from ...models import ColConfig
10from ..base.packed_tensor import PackedTensor
11from ..base.searcher import SearchConfig, Searcher
12from .plaid_indexer import PlaidIndexConfig
13from .residual_codec import ResidualCodec
14
15if TYPE_CHECKING:
16 from ...bi_encoder import BiEncoderModule, BiEncoderOutput
17
18
[docs]
19class PlaidSearcher(Searcher):
[docs]
20 def __init__(
21 self, index_dir: Path | str, search_config: PlaidSearchConfig, module: BiEncoderModule, use_gpu: bool = False
22 ) -> None:
23 super().__init__(index_dir, search_config, module, use_gpu)
24 self.residual_codec = ResidualCodec.from_pretrained(
25 PlaidIndexConfig.from_pretrained(self.index_dir), self.index_dir, device=self.device
26 )
27
28 self.codes = torch.load(self.index_dir / "codes.pt", weights_only=True).to(self.device)
29 self.residuals = (
30 torch.load(self.index_dir / "residuals.pt", weights_only=True).view(self.codes.shape[0], -1).to(self.device)
31 )
32 self.packed_codes = PackedTensor(self.codes, lengths=self.doc_lengths.tolist()).to(self.device)
33 self.packed_residuals = PackedTensor(self.residuals, lengths=self.doc_lengths.tolist()).to(self.device)
34
35 # code_idx to embedding_idcs mapping
36 sorted_codes, embedding_idcs = self.codes.sort()
37 num_embeddings_per_code = torch.bincount(sorted_codes, minlength=self.residual_codec.num_centroids).tolist()
38
39 # code_idx to doc_idcs mapping
40 embedding_idx_to_doc_idx = torch.arange(self.num_docs, device=self.device).repeat_interleave(self.doc_lengths)
41 full_doc_ivf = embedding_idx_to_doc_idx[embedding_idcs]
42 doc_ivf_lengths = []
43 unique_doc_idcs = []
44 for doc_idcs in full_doc_ivf.split(num_embeddings_per_code):
45 unique_doc_idcs.append(doc_idcs.unique())
46 doc_ivf_lengths.append(unique_doc_idcs[-1].shape[0])
47 self.code_to_doc_ivf = PackedTensor(torch.cat(unique_doc_idcs), lengths=doc_ivf_lengths)
48
49 # doc_idx to code_idcs mapping
50 sorted_doc_idcs, doc_idx_to_code_idx = torch.sort(self.code_to_doc_ivf)
51 code_idcs = torch.arange(self.residual_codec.num_centroids, device=self.device).repeat_interleave(
52 torch.tensor(self.code_to_doc_ivf.lengths, device=self.device)
53 )[doc_idx_to_code_idx]
54 num_codes_per_doc = torch.bincount(sorted_doc_idcs, minlength=self.num_docs)
55 self.doc_to_code_ivf = PackedTensor(code_idcs, lengths=num_codes_per_doc.tolist())
56
57 self.search_config: PlaidSearchConfig
58
59 def _centroid_candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]:
60 # grab top `n_cells` neighbor cells for all query embeddings
61 # `num_queries x query_length x num_centroids`
62 centroid_scores = (
63 query_embeddings.embeddings.to(self.residual_codec.centroids)
64 @ self.residual_codec.centroids.transpose(0, 1)[None]
65 ).to(self.device)
66 query_scoring_mask = query_embeddings.scoring_mask
67 centroid_scores = centroid_scores.masked_fill(~query_scoring_mask[..., None], 0)
68 _, codes = torch.topk(centroid_scores, self.search_config.n_cells, dim=-1, sorted=False)
69 packed_codes = codes[query_embeddings.scoring_mask].view(-1)
70 code_lengths = (query_embeddings.scoring_mask.sum(-1) * self.search_config.n_cells).tolist()
71
72 # grab document idcs for all cells
73 packed_doc_idcs = self.code_to_doc_ivf.lookup(packed_codes, code_lengths, unique=True)
74
75 # NOTE no idea why we do two filter steps (the first with a threshold, the second without)
76 # filter step 1
77 _, filtered_doc_idcs = self._filter_candidates(
78 centroid_scores=centroid_scores,
79 doc_idcs=packed_doc_idcs,
80 threshold=self.search_config.centroid_score_threshold,
81 k=self.search_config.candidate_k,
82 query_scoring_mask=query_scoring_mask,
83 )
84 # filter step 2
85 filtered_scores, filtered_doc_idcs = self._filter_candidates(
86 centroid_scores=centroid_scores,
87 doc_idcs=filtered_doc_idcs,
88 threshold=None,
89 k=self.search_config.candidate_k // 4,
90 query_scoring_mask=query_scoring_mask,
91 )
92 return filtered_scores, filtered_doc_idcs
93
94 def _filter_candidates(
95 self,
96 centroid_scores: torch.Tensor,
97 doc_idcs: PackedTensor,
98 threshold: float | None,
99 k: int,
100 query_scoring_mask: torch.Tensor,
101 ) -> Tuple[PackedTensor, PackedTensor]:
102 num_query_vecs = centroid_scores.shape[1]
103 num_centroids = centroid_scores.shape[-1]
104
105 # repeat query centroid scores for each document
106 # `num_docs x num_query_vecs x num_centroids + 1`
107 # NOTE we pad values such that the codes with -1 padding index 0 values
108 expanded_centroid_scores = torch.nn.functional.pad(
109 centroid_scores.repeat_interleave(torch.tensor(doc_idcs.lengths, device=self.device), dim=0), (0, 1)
110 )
111
112 # grab codes for each document
113 code_idcs = self.doc_to_code_ivf.lookup(doc_idcs, 1)
114 # `num_docs x max_num_codes_per_doc`
115 padded_codes = code_idcs.to_padded_tensor(pad_value=num_centroids)
116 mask = padded_codes != num_centroids
117 # `num_docs x max_num_query_vecs x max_num_codes_per_doc`
118 padded_codes = padded_codes[:, None].expand(-1, num_query_vecs, -1)
119
120 # apply pruning threshold
121 if threshold is not None and threshold:
122 expanded_centroid_scores = expanded_centroid_scores.masked_fill(
123 expanded_centroid_scores.amax(1, keepdim=True) < threshold, 0
124 )
125
126 # NOTE this is colbert scoring, but instead of using the doc embeddings we use the centroid scores
127 # expanded_centroid_scores: `num_docs x max_num_query_vecs x num_centroids + 1 `
128 # padded_codes: `num_docs x max_num_query_vecs x max_num_codes_per_doc`
129 # approx_similarity: `num_docs x max_num_query_vecs x max_num_codes_per_doc`
130 approx_similarity = torch.gather(input=expanded_centroid_scores, dim=-1, index=padded_codes)
131 approx_scores = self.module.model.aggregate_similarity(
132 approx_similarity,
133 query_scoring_mask=query_scoring_mask,
134 doc_scoring_mask=mask[:, None],
135 num_docs=doc_idcs.lengths,
136 )
137 packed_approx_scores = PackedTensor(approx_scores, lengths=doc_idcs.lengths)
138 filtered_scores, filtered_doc_idcs = self._filter_and_sort(packed_approx_scores, doc_idcs, k)
139 return filtered_scores, filtered_doc_idcs
140
141 def _reconstruct_doc_embeddings(self, candidate_doc_idcs: PackedTensor) -> BiEncoderEmbedding:
142 doc_embedding_codes = self.packed_codes.lookup(candidate_doc_idcs, 1)
143 doc_embedding_residuals = self.packed_residuals.lookup(candidate_doc_idcs, 1)
144 doc_embeddings = self.residual_codec.decompress(doc_embedding_codes, doc_embedding_residuals)
145 padded_doc_embeddings = doc_embeddings.to_padded_tensor()
146 doc_scoring_mask = padded_doc_embeddings[..., 0] != 0
147 return BiEncoderEmbedding(padded_doc_embeddings, doc_scoring_mask, None)
148
149 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]:
150 query_embeddings = output.query_embeddings
151 if query_embeddings is None:
152 raise ValueError("Expected query_embeddings in BiEncoderOutput")
153 query_embeddings = query_embeddings.to(self.device)
154
155 _, candidate_idcs = self._centroid_candidate_retrieval(query_embeddings)
156 num_docs = candidate_idcs.lengths
157
158 # compute scores
159 doc_embeddings = self._reconstruct_doc_embeddings(candidate_idcs)
160 scores = self.module.model.score(query_embeddings, doc_embeddings, num_docs)
161
162 scores, doc_idcs = self._filter_and_sort(PackedTensor(scores, lengths=candidate_idcs.lengths), candidate_idcs)
163 doc_ids = [
164 [self.doc_ids[doc_idx] for doc_idx in _doc_ids.tolist()] for _doc_ids in doc_idcs.split(doc_idcs.lengths)
165 ]
166 return scores, doc_ids
167
168
[docs]
169class PlaidSearchConfig(SearchConfig):
170
171 search_class = PlaidSearcher
172 SUPPORTED_MODELS = {ColConfig.model_type}
173
[docs]
174 def __init__(
175 self,
176 k: int,
177 candidate_k: int = 256,
178 n_cells: int = 1,
179 centroid_score_threshold: float = 0.5,
180 ) -> None:
181 super().__init__(k)
182 self.candidate_k = candidate_k
183 self.n_cells = n_cells
184 self.centroid_score_threshold = centroid_score_threshold