1"""Residual Codec for Plaid Indexing and Retrieval"""
2
3from __future__ import annotations
4
5import pathlib
6from itertools import product
7from pathlib import Path
8from typing import TYPE_CHECKING, Tuple
9
10import numpy as np
11import torch
12from torch.utils.cpp_extension import load
13
14from ..base.packed_tensor import PackedTensor
15
16if TYPE_CHECKING:
17 from .plaid_indexer import PlaidIndexConfig
18
19
[docs]
20class ResidualCodec:
21 """Residual Codec for Plaid, a residual-based search method for efficient retrieval."""
22
[docs]
23 def __init__(
24 self,
25 index_config: PlaidIndexConfig,
26 centroids: torch.Tensor,
27 bucket_cutoffs: torch.Tensor,
28 bucket_weights: torch.Tensor,
29 verbose: bool = False,
30 ) -> None:
31 """Initialize the ResidualCodec.
32
33 Args:
34 index_config (PlaidIndexConfig): Configuration for the Plaid indexer.
35 centroids (torch.Tensor): The centroids used for indexing.
36 bucket_cutoffs (torch.Tensor): The cutoffs for the residual buckets.
37 bucket_weights (torch.Tensor): The weights for the residual buckets.
38 verbose (bool): Whether to print verbose output. Defaults to False.
39 """
40 self.index_config = index_config
41 self.verbose = verbose
42
43 self.centroids = centroids
44 self.bucket_cutoffs = bucket_cutoffs
45 self.bucket_weights = bucket_weights
46
47 self.arange_bits = torch.arange(0, self.index_config.n_bits, dtype=torch.uint8, device=self.centroids.device)
48 self.reversed_bit_map = self._compute_reverse_bit_map()
49 keys_per_byte = 8 // self.index_config.n_bits
50 self.decompression_lookup_table = torch.tensor(
51 list(product(list(range(len(self.bucket_weights))), repeat=keys_per_byte)),
52 device=self.centroids.device,
53 dtype=torch.uint8,
54 )
55
56 self.residual_dim = max(1, centroids.shape[-1] // 8 * index_config.n_bits)
57
58 self._packbits_cpp = None
59
60 def __repr__(self) -> str:
61 return f"{self.__class__.__name__}(dim={self.dim}, num_centroids={self.num_centroids})"
62
63 def __str__(self) -> str:
64 return self.__repr__()
65
66 @property
67 def dim(self) -> int:
68 """Get the dimensionality of the centroids."""
69 return self.centroids.shape[-1]
70
71 @property
72 def num_centroids(self) -> int:
73 """Get the number of centroids."""
74 return self.centroids.shape[0]
75
[docs]
76 @classmethod
77 def train(
78 cls, index_config: PlaidIndexConfig, train_embeddings: torch.Tensor, verbose: bool = False
79 ) -> "ResidualCodec":
80 """Train the ResidualCodec using the provided training embeddings.
81
82 Args:
83 index_config (PlaidIndexConfig): Configuration for the Plaid indexer.
84 train_embeddings (torch.Tensor): The embeddings to use for training the codec.
85 verbose (bool): Whether to print verbose output. Defaults to False.
86 Returns:
87 ResidualCodec: An instance of the ResidualCodec trained on the provided embeddings.
88 """
89 train_embeddings = train_embeddings[torch.randperm(train_embeddings.shape[0])]
90 num_hold_out_embeddings = int(min(0.05 * train_embeddings.shape[0], 2**15))
91 train_embeddings, holdout_embeddings = train_embeddings.split(
92 [train_embeddings.shape[0] - num_hold_out_embeddings, num_hold_out_embeddings]
93 )
94
95 centroids = cls._train_kmeans(train_embeddings, index_config, verbose)
96 bucket_cutoffs, bucket_weights = cls._compute_buckets(centroids, holdout_embeddings, index_config)
97
98 return cls(index_config, centroids, bucket_cutoffs, bucket_weights, verbose)
99
100 @staticmethod
101 def _train_kmeans(embeddings: torch.Tensor, index_config: PlaidIndexConfig, verbose: bool = False) -> torch.Tensor:
102 import faiss
103
104 kmeans = faiss.Kmeans(
105 embeddings.shape[-1],
106 index_config.num_centroids,
107 niter=index_config.k_means_iters,
108 gpu=torch.cuda.is_available(),
109 verbose=verbose,
110 seed=index_config.seed,
111 )
112 # TODO why normalize?
113 kmeans.train(embeddings.numpy())
114 return torch.nn.functional.normalize(torch.from_numpy(kmeans.centroids), dim=-1)
115
116 def _packbits(self, residuals: torch.Tensor) -> torch.Tensor:
117 if residuals.device == torch.device("cuda"):
118 raise NotImplementedError("CUDA not supported for packbits")
119 residuals_packed = torch.from_numpy(np.packbits(np.asarray(residuals.contiguous().flatten())))
120 return residuals_packed
121
122 @staticmethod
123 def _compute_buckets(
124 centroids: torch.Tensor, holdout_embeddings: torch.Tensor, index_config: PlaidIndexConfig
125 ) -> Tuple[torch.Tensor, torch.Tensor]:
126 holdout_embeddings_codes = ResidualCodec._compress_into_codes(centroids, holdout_embeddings)
127 holdout_embeddings_centroids = centroids[holdout_embeddings_codes]
128
129 holdout_residual = holdout_embeddings - holdout_embeddings_centroids
130 avg_residual = holdout_residual.abs().mean(dim=0)
131
132 num_options = 2**index_config.n_bits
133 quantiles = torch.arange(0, num_options, device=avg_residual.device) * (1 / num_options)
134 bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options)
135
136 bucket_cutoffs = holdout_residual.float().quantile(bucket_cutoffs_quantiles)
137 bucket_weights = holdout_residual.float().quantile(bucket_weights_quantiles)
138 return bucket_cutoffs, bucket_weights
139
140 def _compute_reverse_bit_map(self) -> torch.Tensor:
141 # We reverse the residual bits because arange_bits as
142 # currently constructed produces results with the reverse
143 # of the expected endianness
144
145 reversed_bit_map = []
146 mask = (1 << self.index_config.n_bits) - 1
147 for i in range(256):
148 # The reversed byte
149 z = 0
150 for j in range(8, 0, -self.index_config.n_bits):
151 # Extract a subsequence of length n bits
152 x = (i >> (j - self.index_config.n_bits)) & mask
153
154 # Reverse the endianness of each bit subsequence (e.g. 10 -> 01)
155 y = 0
156 for k in range(self.index_config.n_bits - 1, -1, -1):
157 y += ((x >> (self.index_config.n_bits - k - 1)) & 1) * (2**k)
158
159 # Set the corresponding bits in the output byte
160 z |= y
161 if j > self.index_config.n_bits:
162 z <<= self.index_config.n_bits
163 reversed_bit_map.append(z)
164 return torch.tensor(reversed_bit_map, dtype=torch.uint8, device=self.centroids.device)
165
[docs]
166 @classmethod
167 def try_load_torch_extensions(cls, use_gpu):
168 """Load the necessary C++ extensions for the ResidualCodec.
169
170 Args:
171 cls: The class to load the extensions for.
172 use_gpu (bool): Whether to use GPU for the extensions.
173 """
174 if hasattr(cls, "loaded_extensions") or not use_gpu:
175 return
176
177 decompress_residuals_cpp = load(
178 name="decompress_residuals_cpp",
179 sources=[
180 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cpp"),
181 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cu"),
182 ],
183 )
184 cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp
185
186 cls.loaded_extensions = True
187
[docs]
188 @classmethod
189 def from_pretrained(
190 cls, index_config: PlaidIndexConfig, index_dir: Path, device: torch.device | None = None
191 ) -> "ResidualCodec":
192 """Load a ResidualCodec from the specified directory.
193
194 Args:
195 index_config (PlaidIndexConfig): Configuration for the Plaid indexer.
196 index_dir (Path): Directory containing the saved codec files.
197 device (torch.device | None): Device to load the codec onto. Defaults to None, which uses the CPU.
198 Returns:
199 ResidualCodec: An instance of the ResidualCodec loaded from the specified directory.
200 """
201 centroids_path = index_dir / "centroids.pt"
202 buckets_path = index_dir / "buckets.pt"
203
204 centroids = torch.load(
205 centroids_path, map_location=str(device) if device is not None else "cpu", weights_only=True
206 )
207 bucket_cutoffs, bucket_weights = torch.load(
208 buckets_path, map_location=str(device) if device is not None else "cpu", weights_only=True
209 )
210
211 return cls(
212 index_config=index_config,
213 centroids=centroids,
214 bucket_cutoffs=bucket_cutoffs,
215 bucket_weights=bucket_weights,
216 )
217
[docs]
218 def save(self, index_dir: Path):
219 """Save the ResidualCodec to the specified directory.
220
221 Args:
222 index_dir (Path): Directory to save the codec files.
223 Raises:
224 ValueError: If residual_codec is None.
225 """
226 index_dir.mkdir(parents=True, exist_ok=True)
227 centroids_path = index_dir / "centroids.pt"
228 buckets_path = index_dir / "buckets.pt"
229
230 torch.save(self.centroids.half(), centroids_path)
231 torch.save((self.bucket_cutoffs, self.bucket_weights), buckets_path)
232
233 @staticmethod
234 def _compress_into_codes(centroids: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor:
235 codes = []
236 batch_size = 2**29 // centroids.shape[0]
237 for batch in embeddings.split(batch_size):
238 indices = (centroids @ batch.transpose(-1, -2)).argmax(dim=0)
239 codes.append(indices)
240 return torch.cat(codes)
241
[docs]
242 def compress_into_codes(self, embeddings: torch.Tensor) -> torch.Tensor:
243 """Compress embeddings into codes using the centroids.
244
245 Args:
246 embeddings (torch.Tensor): The embeddings to compress.
247 Returns:
248 torch.Tensor: The compressed codes.
249 """
250 return self._compress_into_codes(self.centroids, embeddings)
251
[docs]
252 def compress(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
253 """Compress embeddings into codes and residuals.
254
255 Args:
256 embeddings (torch.Tensor): The embeddings to compress.
257 Returns:
258 Tuple[torch.Tensor, torch.Tensor]: A tuple containing the compressed codes and residuals.
259 """
260 embeddings = embeddings.to(self.centroids.device)
261 codes = self.compress_into_codes(embeddings)
262 centroids = self.centroids[codes]
263 residuals = self.binarize(embeddings - centroids)
264 return codes, residuals
265
[docs]
266 def binarize(self, residuals: torch.Tensor) -> torch.Tensor:
267 """Binarize the residuals using the bucket cutoffs and weights.
268
269 Args:
270 residuals (torch.Tensor): The residuals to binarize.
271 Returns:
272 torch.Tensor: The binarized residuals.
273 """
274 buckets = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8)
275 buckets_expanded = buckets.unsqueeze(-1).expand(*buckets.size(), self.index_config.n_bits)
276 bucket_bits = buckets_expanded >> self.arange_bits # divide by 2^bit for each bit position
277 bucket_binary = bucket_bits & 1 # apply mod 2 to binarize
278
279 residuals_packed = self._packbits(bucket_binary)
280 residuals_packed = residuals_packed.reshape(residuals.size(0), max(1, self.dim // 8 * self.index_config.n_bits))
281
282 return residuals_packed
283
[docs]
284 def decompress(self, codes: PackedTensor, compressed_residuals: PackedTensor) -> PackedTensor:
285 """Decompress the codes and residuals into embeddings.
286
287 Args:
288 codes (PackedTensor): The packed tensor containing the codes.
289 compressed_residuals (PackedTensor): The packed tensor containing the compressed residuals.
290 Returns:
291 PackedTensor: The decompressed embeddings.
292 """
293 centroids = self.centroids[codes]
294 residuals = self.reversed_bit_map[compressed_residuals.long().view(-1)].view_as(compressed_residuals)
295 residuals = self.decompression_lookup_table[residuals.long()]
296 residuals = residuals.view(residuals.shape[0], -1)
297 residuals = self.bucket_weights[residuals.long()]
298 embeddings = centroids + residuals
299 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
300 return PackedTensor(embeddings, lengths=codes.lengths)