1from __future__ import annotations
2
3import pathlib
4from itertools import product
5from pathlib import Path
6from typing import TYPE_CHECKING, Tuple
7
8import numpy as np
9import torch
10from torch.utils.cpp_extension import load
11
12from ..base.packed_tensor import PackedTensor
13
14if TYPE_CHECKING:
15 from .plaid_indexer import PlaidIndexConfig
16
17
[docs]
18class ResidualCodec:
19
[docs]
20 def __init__(
21 self,
22 index_config: PlaidIndexConfig,
23 centroids: torch.Tensor,
24 bucket_cutoffs: torch.Tensor,
25 bucket_weights: torch.Tensor,
26 verbose: bool = False,
27 ) -> None:
28 self.index_config = index_config
29 self.verbose = verbose
30
31 self.centroids = centroids
32 self.bucket_cutoffs = bucket_cutoffs
33 self.bucket_weights = bucket_weights
34
35 self.arange_bits = torch.arange(0, self.index_config.n_bits, dtype=torch.uint8, device=self.centroids.device)
36 self.reversed_bit_map = self._compute_reverse_bit_map()
37 keys_per_byte = 8 // self.index_config.n_bits
38 self.decompression_lookup_table = torch.tensor(
39 list(product(list(range(len(self.bucket_weights))), repeat=keys_per_byte)),
40 device=self.centroids.device,
41 dtype=torch.uint8,
42 )
43
44 self.residual_dim = max(1, centroids.shape[-1] // 8 * index_config.n_bits)
45
46 self._packbits_cpp = None
47
48 def __repr__(self) -> str:
49 return f"{self.__class__.__name__}(dim={self.dim}, num_centroids={self.num_centroids})"
50
51 def __str__(self) -> str:
52 return self.__repr__()
53
54 @property
55 def dim(self) -> int:
56 return self.centroids.shape[-1]
57
58 @property
59 def num_centroids(self) -> int:
60 return self.centroids.shape[0]
61
62 @classmethod
63 def train(
64 cls, index_config: PlaidIndexConfig, train_embeddings: torch.Tensor, verbose: bool = False
65 ) -> "ResidualCodec":
66 train_embeddings = train_embeddings[torch.randperm(train_embeddings.shape[0])]
67 num_hold_out_embeddings = int(min(0.05 * train_embeddings.shape[0], 2**15))
68 train_embeddings, holdout_embeddings = train_embeddings.split(
69 [train_embeddings.shape[0] - num_hold_out_embeddings, num_hold_out_embeddings]
70 )
71
72 centroids = cls._train_kmeans(train_embeddings, index_config, verbose)
73 bucket_cutoffs, bucket_weights = cls._compute_buckets(centroids, holdout_embeddings, index_config)
74
75 return cls(index_config, centroids, bucket_cutoffs, bucket_weights, verbose)
76
77 @staticmethod
78 def _train_kmeans(embeddings: torch.Tensor, index_config: PlaidIndexConfig, verbose: bool = False) -> torch.Tensor:
79 import faiss
80
81 kmeans = faiss.Kmeans(
82 embeddings.shape[-1],
83 index_config.num_centroids,
84 niter=index_config.k_means_iters,
85 gpu=torch.cuda.is_available(),
86 verbose=verbose,
87 seed=index_config.seed,
88 )
89 # TODO why normalize?
90 kmeans.train(embeddings.numpy())
91 return torch.nn.functional.normalize(torch.from_numpy(kmeans.centroids), dim=-1)
92
93 def _packbits(self, residuals: torch.Tensor) -> torch.Tensor:
94 if residuals.device == torch.device("cuda"):
95 raise NotImplementedError("CUDA not supported for packbits")
96 residuals_packed = torch.from_numpy(np.packbits(np.asarray(residuals.contiguous().flatten())))
97 return residuals_packed
98
99 @staticmethod
100 def _compute_buckets(
101 centroids: torch.Tensor, holdout_embeddings: torch.Tensor, index_config: PlaidIndexConfig
102 ) -> Tuple[torch.Tensor, torch.Tensor]:
103 holdout_embeddings_codes = ResidualCodec._compress_into_codes(centroids, holdout_embeddings)
104 holdout_embeddings_centroids = centroids[holdout_embeddings_codes]
105
106 holdout_residual = holdout_embeddings - holdout_embeddings_centroids
107 avg_residual = holdout_residual.abs().mean(dim=0)
108
109 num_options = 2**index_config.n_bits
110 quantiles = torch.arange(0, num_options, device=avg_residual.device) * (1 / num_options)
111 bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options)
112
113 bucket_cutoffs = holdout_residual.float().quantile(bucket_cutoffs_quantiles)
114 bucket_weights = holdout_residual.float().quantile(bucket_weights_quantiles)
115 return bucket_cutoffs, bucket_weights
116
117 def _compute_reverse_bit_map(self) -> torch.Tensor:
118 # We reverse the residual bits because arange_bits as
119 # currently constructed produces results with the reverse
120 # of the expected endianness
121
122 reversed_bit_map = []
123 mask = (1 << self.index_config.n_bits) - 1
124 for i in range(256):
125 # The reversed byte
126 z = 0
127 for j in range(8, 0, -self.index_config.n_bits):
128 # Extract a subsequence of length n bits
129 x = (i >> (j - self.index_config.n_bits)) & mask
130
131 # Reverse the endianness of each bit subsequence (e.g. 10 -> 01)
132 y = 0
133 for k in range(self.index_config.n_bits - 1, -1, -1):
134 y += ((x >> (self.index_config.n_bits - k - 1)) & 1) * (2**k)
135
136 # Set the corresponding bits in the output byte
137 z |= y
138 if j > self.index_config.n_bits:
139 z <<= self.index_config.n_bits
140 reversed_bit_map.append(z)
141 return torch.tensor(reversed_bit_map, dtype=torch.uint8, device=self.centroids.device)
142
143 @classmethod
144 def try_load_torch_extensions(cls, use_gpu):
145 if hasattr(cls, "loaded_extensions") or not use_gpu:
146 return
147
148 decompress_residuals_cpp = load(
149 name="decompress_residuals_cpp",
150 sources=[
151 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cpp"),
152 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cu"),
153 ],
154 )
155 cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp
156
157 cls.loaded_extensions = True
158
159 @classmethod
160 def from_pretrained(
161 cls, index_config: PlaidIndexConfig, index_dir: Path, device: torch.device | None = None
162 ) -> "ResidualCodec":
163 centroids_path = index_dir / "centroids.pt"
164 buckets_path = index_dir / "buckets.pt"
165
166 centroids = torch.load(
167 centroids_path, map_location=str(device) if device is not None else "cpu", weights_only=True
168 )
169 bucket_cutoffs, bucket_weights = torch.load(
170 buckets_path, map_location=str(device) if device is not None else "cpu", weights_only=True
171 )
172
173 return cls(
174 index_config=index_config,
175 centroids=centroids,
176 bucket_cutoffs=bucket_cutoffs,
177 bucket_weights=bucket_weights,
178 )
179
180 def save(self, index_dir: Path):
181 index_dir.mkdir(parents=True, exist_ok=True)
182 centroids_path = index_dir / "centroids.pt"
183 buckets_path = index_dir / "buckets.pt"
184
185 torch.save(self.centroids.half(), centroids_path)
186 torch.save((self.bucket_cutoffs, self.bucket_weights), buckets_path)
187
188 @staticmethod
189 def _compress_into_codes(centroids: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor:
190 codes = []
191 batch_size = 2**29 // centroids.shape[0]
192 for batch in embeddings.split(batch_size):
193 indices = (centroids @ batch.transpose(-1, -2)).argmax(dim=0)
194 codes.append(indices)
195 return torch.cat(codes)
196
197 def compress_into_codes(self, embeddings: torch.Tensor) -> torch.Tensor:
198 return self._compress_into_codes(self.centroids, embeddings)
199
200 def compress(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
201 embeddings = embeddings.to(self.centroids.device)
202 codes = self.compress_into_codes(embeddings)
203 centroids = self.centroids[codes]
204 residuals = self.binarize(embeddings - centroids)
205 return codes, residuals
206
207 def binarize(self, residuals: torch.Tensor) -> torch.Tensor:
208 buckets = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8)
209 buckets_expanded = buckets.unsqueeze(-1).expand(*buckets.size(), self.index_config.n_bits)
210 bucket_bits = buckets_expanded >> self.arange_bits # divide by 2^bit for each bit position
211 bucket_binary = bucket_bits & 1 # apply mod 2 to binarize
212
213 residuals_packed = self._packbits(bucket_binary)
214 residuals_packed = residuals_packed.reshape(residuals.size(0), max(1, self.dim // 8 * self.index_config.n_bits))
215
216 return residuals_packed
217
218 def decompress(self, codes: PackedTensor, compressed_residuals: PackedTensor) -> PackedTensor:
219 centroids = self.centroids[codes]
220 residuals = self.reversed_bit_map[compressed_residuals.long().view(-1)].view_as(compressed_residuals)
221 residuals = self.decompression_lookup_table[residuals.long()]
222 residuals = residuals.view(residuals.shape[0], -1)
223 residuals = self.bucket_weights[residuals.long()]
224 embeddings = centroids + residuals
225 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
226 return PackedTensor(embeddings, lengths=codes.lengths)