Source code for lightning_ir.retrieve.plaid.residual_codec

  1from __future__ import annotations
  2
  3import pathlib
  4from itertools import product
  5from pathlib import Path
  6from typing import TYPE_CHECKING, Tuple
  7
  8import numpy as np
  9import torch
 10from torch.utils.cpp_extension import load
 11
 12from ..base.packed_tensor import PackedTensor
 13
 14if TYPE_CHECKING:
 15    from .plaid_indexer import PlaidIndexConfig
 16
 17
[docs] 18class ResidualCodec: 19
[docs] 20 def __init__( 21 self, 22 index_config: PlaidIndexConfig, 23 centroids: torch.Tensor, 24 bucket_cutoffs: torch.Tensor, 25 bucket_weights: torch.Tensor, 26 verbose: bool = False, 27 ) -> None: 28 self.index_config = index_config 29 self.verbose = verbose 30 31 self.centroids = centroids 32 self.bucket_cutoffs = bucket_cutoffs 33 self.bucket_weights = bucket_weights 34 35 self.arange_bits = torch.arange(0, self.index_config.n_bits, dtype=torch.uint8, device=self.centroids.device) 36 self.reversed_bit_map = self._compute_reverse_bit_map() 37 keys_per_byte = 8 // self.index_config.n_bits 38 self.decompression_lookup_table = torch.tensor( 39 list(product(list(range(len(self.bucket_weights))), repeat=keys_per_byte)), 40 device=self.centroids.device, 41 dtype=torch.uint8, 42 ) 43 44 self.residual_dim = max(1, centroids.shape[-1] // 8 * index_config.n_bits) 45 46 self._packbits_cpp = None
47 48 def __repr__(self) -> str: 49 return f"{self.__class__.__name__}(dim={self.dim}, num_centroids={self.num_centroids})" 50 51 def __str__(self) -> str: 52 return self.__repr__() 53 54 @property 55 def dim(self) -> int: 56 return self.centroids.shape[-1] 57 58 @property 59 def num_centroids(self) -> int: 60 return self.centroids.shape[0] 61 62 @classmethod 63 def train( 64 cls, index_config: PlaidIndexConfig, train_embeddings: torch.Tensor, verbose: bool = False 65 ) -> "ResidualCodec": 66 train_embeddings = train_embeddings[torch.randperm(train_embeddings.shape[0])] 67 num_hold_out_embeddings = int(min(0.05 * train_embeddings.shape[0], 2**15)) 68 train_embeddings, holdout_embeddings = train_embeddings.split( 69 [train_embeddings.shape[0] - num_hold_out_embeddings, num_hold_out_embeddings] 70 ) 71 72 centroids = cls._train_kmeans(train_embeddings, index_config, verbose) 73 bucket_cutoffs, bucket_weights = cls._compute_buckets(centroids, holdout_embeddings, index_config) 74 75 return cls(index_config, centroids, bucket_cutoffs, bucket_weights, verbose) 76 77 @staticmethod 78 def _train_kmeans(embeddings: torch.Tensor, index_config: PlaidIndexConfig, verbose: bool = False) -> torch.Tensor: 79 import faiss 80 81 kmeans = faiss.Kmeans( 82 embeddings.shape[-1], 83 index_config.num_centroids, 84 niter=index_config.k_means_iters, 85 gpu=torch.cuda.is_available(), 86 verbose=verbose, 87 seed=index_config.seed, 88 ) 89 # TODO why normalize? 90 kmeans.train(embeddings.numpy()) 91 return torch.nn.functional.normalize(torch.from_numpy(kmeans.centroids), dim=-1) 92 93 def _packbits(self, residuals: torch.Tensor) -> torch.Tensor: 94 if residuals.device == torch.device("cuda"): 95 raise NotImplementedError("CUDA not supported for packbits") 96 residuals_packed = torch.from_numpy(np.packbits(np.asarray(residuals.contiguous().flatten()))) 97 return residuals_packed 98 99 @staticmethod 100 def _compute_buckets( 101 centroids: torch.Tensor, holdout_embeddings: torch.Tensor, index_config: PlaidIndexConfig 102 ) -> Tuple[torch.Tensor, torch.Tensor]: 103 holdout_embeddings_codes = ResidualCodec._compress_into_codes(centroids, holdout_embeddings) 104 holdout_embeddings_centroids = centroids[holdout_embeddings_codes] 105 106 holdout_residual = holdout_embeddings - holdout_embeddings_centroids 107 avg_residual = holdout_residual.abs().mean(dim=0) 108 109 num_options = 2**index_config.n_bits 110 quantiles = torch.arange(0, num_options, device=avg_residual.device) * (1 / num_options) 111 bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options) 112 113 bucket_cutoffs = holdout_residual.float().quantile(bucket_cutoffs_quantiles) 114 bucket_weights = holdout_residual.float().quantile(bucket_weights_quantiles) 115 return bucket_cutoffs, bucket_weights 116 117 def _compute_reverse_bit_map(self) -> torch.Tensor: 118 # We reverse the residual bits because arange_bits as 119 # currently constructed produces results with the reverse 120 # of the expected endianness 121 122 reversed_bit_map = [] 123 mask = (1 << self.index_config.n_bits) - 1 124 for i in range(256): 125 # The reversed byte 126 z = 0 127 for j in range(8, 0, -self.index_config.n_bits): 128 # Extract a subsequence of length n bits 129 x = (i >> (j - self.index_config.n_bits)) & mask 130 131 # Reverse the endianness of each bit subsequence (e.g. 10 -> 01) 132 y = 0 133 for k in range(self.index_config.n_bits - 1, -1, -1): 134 y += ((x >> (self.index_config.n_bits - k - 1)) & 1) * (2**k) 135 136 # Set the corresponding bits in the output byte 137 z |= y 138 if j > self.index_config.n_bits: 139 z <<= self.index_config.n_bits 140 reversed_bit_map.append(z) 141 return torch.tensor(reversed_bit_map, dtype=torch.uint8, device=self.centroids.device) 142 143 @classmethod 144 def try_load_torch_extensions(cls, use_gpu): 145 if hasattr(cls, "loaded_extensions") or not use_gpu: 146 return 147 148 decompress_residuals_cpp = load( 149 name="decompress_residuals_cpp", 150 sources=[ 151 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cpp"), 152 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cu"), 153 ], 154 ) 155 cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp 156 157 cls.loaded_extensions = True 158 159 @classmethod 160 def from_pretrained( 161 cls, index_config: PlaidIndexConfig, index_dir: Path, device: torch.device | None = None 162 ) -> "ResidualCodec": 163 centroids_path = index_dir / "centroids.pt" 164 buckets_path = index_dir / "buckets.pt" 165 166 centroids = torch.load( 167 centroids_path, map_location=str(device) if device is not None else "cpu", weights_only=True 168 ) 169 bucket_cutoffs, bucket_weights = torch.load( 170 buckets_path, map_location=str(device) if device is not None else "cpu", weights_only=True 171 ) 172 173 return cls( 174 index_config=index_config, 175 centroids=centroids, 176 bucket_cutoffs=bucket_cutoffs, 177 bucket_weights=bucket_weights, 178 ) 179 180 def save(self, index_dir: Path): 181 index_dir.mkdir(parents=True, exist_ok=True) 182 centroids_path = index_dir / "centroids.pt" 183 buckets_path = index_dir / "buckets.pt" 184 185 torch.save(self.centroids.half(), centroids_path) 186 torch.save((self.bucket_cutoffs, self.bucket_weights), buckets_path) 187 188 @staticmethod 189 def _compress_into_codes(centroids: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: 190 codes = [] 191 batch_size = 2**29 // centroids.shape[0] 192 for batch in embeddings.split(batch_size): 193 indices = (centroids @ batch.transpose(-1, -2)).argmax(dim=0) 194 codes.append(indices) 195 return torch.cat(codes) 196 197 def compress_into_codes(self, embeddings: torch.Tensor) -> torch.Tensor: 198 return self._compress_into_codes(self.centroids, embeddings) 199 200 def compress(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 201 embeddings = embeddings.to(self.centroids.device) 202 codes = self.compress_into_codes(embeddings) 203 centroids = self.centroids[codes] 204 residuals = self.binarize(embeddings - centroids) 205 return codes, residuals 206 207 def binarize(self, residuals: torch.Tensor) -> torch.Tensor: 208 buckets = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8) 209 buckets_expanded = buckets.unsqueeze(-1).expand(*buckets.size(), self.index_config.n_bits) 210 bucket_bits = buckets_expanded >> self.arange_bits # divide by 2^bit for each bit position 211 bucket_binary = bucket_bits & 1 # apply mod 2 to binarize 212 213 residuals_packed = self._packbits(bucket_binary) 214 residuals_packed = residuals_packed.reshape(residuals.size(0), max(1, self.dim // 8 * self.index_config.n_bits)) 215 216 return residuals_packed 217 218 def decompress(self, codes: PackedTensor, compressed_residuals: PackedTensor) -> PackedTensor: 219 centroids = self.centroids[codes] 220 residuals = self.reversed_bit_map[compressed_residuals.long().view(-1)].view_as(compressed_residuals) 221 residuals = self.decompression_lookup_table[residuals.long()] 222 residuals = residuals.view(residuals.shape[0], -1) 223 residuals = self.bucket_weights[residuals.long()] 224 embeddings = centroids + residuals 225 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) 226 return PackedTensor(embeddings, lengths=codes.lengths)