Source code for lightning_ir.retrieve.plaid.residual_codec

  1"""Residual Codec for Plaid Indexing and Retrieval"""
  2
  3from __future__ import annotations
  4
  5import pathlib
  6from itertools import product
  7from pathlib import Path
  8from typing import TYPE_CHECKING, Tuple
  9
 10import numpy as np
 11import torch
 12from torch.utils.cpp_extension import load
 13
 14from ..base.packed_tensor import PackedTensor
 15
 16if TYPE_CHECKING:
 17    from .plaid_indexer import PlaidIndexConfig
 18
 19
[docs] 20class ResidualCodec: 21 """Residual Codec for Plaid, a residual-based search method for efficient retrieval.""" 22
[docs] 23 def __init__( 24 self, 25 index_config: PlaidIndexConfig, 26 centroids: torch.Tensor, 27 bucket_cutoffs: torch.Tensor, 28 bucket_weights: torch.Tensor, 29 verbose: bool = False, 30 ) -> None: 31 """Initialize the ResidualCodec. 32 33 Args: 34 index_config (PlaidIndexConfig): Configuration for the Plaid indexer. 35 centroids (torch.Tensor): The centroids used for indexing. 36 bucket_cutoffs (torch.Tensor): The cutoffs for the residual buckets. 37 bucket_weights (torch.Tensor): The weights for the residual buckets. 38 verbose (bool): Whether to print verbose output. Defaults to False. 39 """ 40 self.index_config = index_config 41 self.verbose = verbose 42 43 self.centroids = centroids 44 self.bucket_cutoffs = bucket_cutoffs 45 self.bucket_weights = bucket_weights 46 47 self.arange_bits = torch.arange(0, self.index_config.n_bits, dtype=torch.uint8, device=self.centroids.device) 48 self.reversed_bit_map = self._compute_reverse_bit_map() 49 keys_per_byte = 8 // self.index_config.n_bits 50 self.decompression_lookup_table = torch.tensor( 51 list(product(list(range(len(self.bucket_weights))), repeat=keys_per_byte)), 52 device=self.centroids.device, 53 dtype=torch.uint8, 54 ) 55 56 self.residual_dim = max(1, centroids.shape[-1] // 8 * index_config.n_bits) 57 58 self._packbits_cpp = None
59 60 def __repr__(self) -> str: 61 return f"{self.__class__.__name__}(dim={self.dim}, num_centroids={self.num_centroids})" 62 63 def __str__(self) -> str: 64 return self.__repr__() 65 66 @property 67 def dim(self) -> int: 68 """Get the dimensionality of the centroids.""" 69 return self.centroids.shape[-1] 70 71 @property 72 def num_centroids(self) -> int: 73 """Get the number of centroids.""" 74 return self.centroids.shape[0] 75
[docs] 76 @classmethod 77 def train( 78 cls, index_config: PlaidIndexConfig, train_embeddings: torch.Tensor, verbose: bool = False 79 ) -> "ResidualCodec": 80 """Train the ResidualCodec using the provided training embeddings. 81 82 Args: 83 index_config (PlaidIndexConfig): Configuration for the Plaid indexer. 84 train_embeddings (torch.Tensor): The embeddings to use for training the codec. 85 verbose (bool): Whether to print verbose output. Defaults to False. 86 Returns: 87 ResidualCodec: An instance of the ResidualCodec trained on the provided embeddings. 88 """ 89 train_embeddings = train_embeddings[torch.randperm(train_embeddings.shape[0])] 90 num_hold_out_embeddings = int(min(0.05 * train_embeddings.shape[0], 2**15)) 91 train_embeddings, holdout_embeddings = train_embeddings.split( 92 [train_embeddings.shape[0] - num_hold_out_embeddings, num_hold_out_embeddings] 93 ) 94 95 centroids = cls._train_kmeans(train_embeddings, index_config, verbose) 96 bucket_cutoffs, bucket_weights = cls._compute_buckets(centroids, holdout_embeddings, index_config) 97 98 return cls(index_config, centroids, bucket_cutoffs, bucket_weights, verbose)
99 100 @staticmethod 101 def _train_kmeans(embeddings: torch.Tensor, index_config: PlaidIndexConfig, verbose: bool = False) -> torch.Tensor: 102 import faiss 103 104 kmeans = faiss.Kmeans( 105 embeddings.shape[-1], 106 index_config.num_centroids, 107 niter=index_config.k_means_iters, 108 gpu=torch.cuda.is_available(), 109 verbose=verbose, 110 seed=index_config.seed, 111 ) 112 # TODO why normalize? 113 kmeans.train(embeddings.numpy()) 114 return torch.nn.functional.normalize(torch.from_numpy(kmeans.centroids), dim=-1) 115 116 def _packbits(self, residuals: torch.Tensor) -> torch.Tensor: 117 if residuals.device == torch.device("cuda"): 118 raise NotImplementedError("CUDA not supported for packbits") 119 residuals_packed = torch.from_numpy(np.packbits(np.asarray(residuals.contiguous().flatten()))) 120 return residuals_packed 121 122 @staticmethod 123 def _compute_buckets( 124 centroids: torch.Tensor, holdout_embeddings: torch.Tensor, index_config: PlaidIndexConfig 125 ) -> Tuple[torch.Tensor, torch.Tensor]: 126 holdout_embeddings_codes = ResidualCodec._compress_into_codes(centroids, holdout_embeddings) 127 holdout_embeddings_centroids = centroids[holdout_embeddings_codes] 128 129 holdout_residual = holdout_embeddings - holdout_embeddings_centroids 130 avg_residual = holdout_residual.abs().mean(dim=0) 131 132 num_options = 2**index_config.n_bits 133 quantiles = torch.arange(0, num_options, device=avg_residual.device) * (1 / num_options) 134 bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options) 135 136 bucket_cutoffs = holdout_residual.float().quantile(bucket_cutoffs_quantiles) 137 bucket_weights = holdout_residual.float().quantile(bucket_weights_quantiles) 138 return bucket_cutoffs, bucket_weights 139 140 def _compute_reverse_bit_map(self) -> torch.Tensor: 141 # We reverse the residual bits because arange_bits as 142 # currently constructed produces results with the reverse 143 # of the expected endianness 144 145 reversed_bit_map = [] 146 mask = (1 << self.index_config.n_bits) - 1 147 for i in range(256): 148 # The reversed byte 149 z = 0 150 for j in range(8, 0, -self.index_config.n_bits): 151 # Extract a subsequence of length n bits 152 x = (i >> (j - self.index_config.n_bits)) & mask 153 154 # Reverse the endianness of each bit subsequence (e.g. 10 -> 01) 155 y = 0 156 for k in range(self.index_config.n_bits - 1, -1, -1): 157 y += ((x >> (self.index_config.n_bits - k - 1)) & 1) * (2**k) 158 159 # Set the corresponding bits in the output byte 160 z |= y 161 if j > self.index_config.n_bits: 162 z <<= self.index_config.n_bits 163 reversed_bit_map.append(z) 164 return torch.tensor(reversed_bit_map, dtype=torch.uint8, device=self.centroids.device) 165
[docs] 166 @classmethod 167 def try_load_torch_extensions(cls, use_gpu): 168 """Load the necessary C++ extensions for the ResidualCodec. 169 170 Args: 171 cls: The class to load the extensions for. 172 use_gpu (bool): Whether to use GPU for the extensions. 173 """ 174 if hasattr(cls, "loaded_extensions") or not use_gpu: 175 return 176 177 decompress_residuals_cpp = load( 178 name="decompress_residuals_cpp", 179 sources=[ 180 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cpp"), 181 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cu"), 182 ], 183 ) 184 cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp 185 186 cls.loaded_extensions = True
187
[docs] 188 @classmethod 189 def from_pretrained( 190 cls, index_config: PlaidIndexConfig, index_dir: Path, device: torch.device | None = None 191 ) -> "ResidualCodec": 192 """Load a ResidualCodec from the specified directory. 193 194 Args: 195 index_config (PlaidIndexConfig): Configuration for the Plaid indexer. 196 index_dir (Path): Directory containing the saved codec files. 197 device (torch.device | None): Device to load the codec onto. Defaults to None, which uses the CPU. 198 Returns: 199 ResidualCodec: An instance of the ResidualCodec loaded from the specified directory. 200 """ 201 centroids_path = index_dir / "centroids.pt" 202 buckets_path = index_dir / "buckets.pt" 203 204 centroids = torch.load( 205 centroids_path, map_location=str(device) if device is not None else "cpu", weights_only=True 206 ) 207 bucket_cutoffs, bucket_weights = torch.load( 208 buckets_path, map_location=str(device) if device is not None else "cpu", weights_only=True 209 ) 210 211 return cls( 212 index_config=index_config, 213 centroids=centroids, 214 bucket_cutoffs=bucket_cutoffs, 215 bucket_weights=bucket_weights, 216 )
217
[docs] 218 def save(self, index_dir: Path): 219 """Save the ResidualCodec to the specified directory. 220 221 Args: 222 index_dir (Path): Directory to save the codec files. 223 Raises: 224 ValueError: If residual_codec is None. 225 """ 226 index_dir.mkdir(parents=True, exist_ok=True) 227 centroids_path = index_dir / "centroids.pt" 228 buckets_path = index_dir / "buckets.pt" 229 230 torch.save(self.centroids.half(), centroids_path) 231 torch.save((self.bucket_cutoffs, self.bucket_weights), buckets_path)
232 233 @staticmethod 234 def _compress_into_codes(centroids: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: 235 codes = [] 236 batch_size = 2**29 // centroids.shape[0] 237 for batch in embeddings.split(batch_size): 238 indices = (centroids @ batch.transpose(-1, -2)).argmax(dim=0) 239 codes.append(indices) 240 return torch.cat(codes) 241
[docs] 242 def compress_into_codes(self, embeddings: torch.Tensor) -> torch.Tensor: 243 """Compress embeddings into codes using the centroids. 244 245 Args: 246 embeddings (torch.Tensor): The embeddings to compress. 247 Returns: 248 torch.Tensor: The compressed codes. 249 """ 250 return self._compress_into_codes(self.centroids, embeddings)
251
[docs] 252 def compress(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 253 """Compress embeddings into codes and residuals. 254 255 Args: 256 embeddings (torch.Tensor): The embeddings to compress. 257 Returns: 258 Tuple[torch.Tensor, torch.Tensor]: A tuple containing the compressed codes and residuals. 259 """ 260 embeddings = embeddings.to(self.centroids.device) 261 codes = self.compress_into_codes(embeddings) 262 centroids = self.centroids[codes] 263 residuals = self.binarize(embeddings - centroids) 264 return codes, residuals
265
[docs] 266 def binarize(self, residuals: torch.Tensor) -> torch.Tensor: 267 """Binarize the residuals using the bucket cutoffs and weights. 268 269 Args: 270 residuals (torch.Tensor): The residuals to binarize. 271 Returns: 272 torch.Tensor: The binarized residuals. 273 """ 274 buckets = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8) 275 buckets_expanded = buckets.unsqueeze(-1).expand(*buckets.size(), self.index_config.n_bits) 276 bucket_bits = buckets_expanded >> self.arange_bits # divide by 2^bit for each bit position 277 bucket_binary = bucket_bits & 1 # apply mod 2 to binarize 278 279 residuals_packed = self._packbits(bucket_binary) 280 residuals_packed = residuals_packed.reshape(residuals.size(0), max(1, self.dim // 8 * self.index_config.n_bits)) 281 282 return residuals_packed
283
[docs] 284 def decompress(self, codes: PackedTensor, compressed_residuals: PackedTensor) -> PackedTensor: 285 """Decompress the codes and residuals into embeddings. 286 287 Args: 288 codes (PackedTensor): The packed tensor containing the codes. 289 compressed_residuals (PackedTensor): The packed tensor containing the compressed residuals. 290 Returns: 291 PackedTensor: The decompressed embeddings. 292 """ 293 centroids = self.centroids[codes] 294 residuals = self.reversed_bit_map[compressed_residuals.long().view(-1)].view_as(compressed_residuals) 295 residuals = self.decompression_lookup_table[residuals.long()] 296 residuals = residuals.view(residuals.shape[0], -1) 297 residuals = self.bucket_weights[residuals.long()] 298 embeddings = centroids + residuals 299 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) 300 return PackedTensor(embeddings, lengths=codes.lengths)