Source code for lightning_ir.retrieve.faiss.faiss_indexer

  1import warnings
  2from pathlib import Path
  3from typing import Type
  4
  5import torch
  6
  7from ...bi_encoder import BiEncoderModule, BiEncoderOutput
  8from ...data import IndexBatch
  9from ...models import ColConfig, DprConfig
 10from ..base import IndexConfig, Indexer
 11
 12
[docs] 13class FaissIndexer(Indexer): 14 INDEX_FACTORY: str 15
[docs] 16 def __init__( 17 self, 18 index_dir: Path, 19 index_config: "FaissIndexConfig", 20 module: BiEncoderModule, 21 verbose: bool = False, 22 ) -> None: 23 super().__init__(index_dir, index_config, module, verbose) 24 import faiss 25 26 similarity_function = self.module.config.similarity_function 27 if similarity_function in ("cosine", "dot"): 28 self.metric_type = faiss.METRIC_INNER_PRODUCT 29 else: 30 raise ValueError(f"similarity_function {similarity_function} unknown") 31 32 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict()) 33 if similarity_function == "cosine": 34 index_factory = "L2norm," + index_factory 35 self.index = faiss.index_factory(self.module.config.embedding_dim, index_factory, self.metric_type) 36 37 self.set_verbosity() 38 39 if torch.cuda.is_available(): 40 self.to_gpu()
41 42 def to_gpu(self) -> None: 43 pass 44 45 def to_cpu(self) -> None: 46 pass 47 48 def set_verbosity(self, verbose: bool | None = None) -> None: 49 self.index.verbose = self.verbose if verbose is None else verbose 50 51 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 52 return embeddings 53 54 def save(self) -> None: 55 super().save() 56 import faiss 57 58 if self.num_embeddings != self.index.ntotal: 59 raise ValueError("number of embeddings does not match index.ntotal") 60 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"): 61 self.index = faiss.index_gpu_to_cpu(self.index) 62 63 faiss.write_index(self.index, str(self.index_dir / "index.faiss")) 64 65 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 66 doc_embeddings = output.doc_embeddings 67 if doc_embeddings is None: 68 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 69 if doc_embeddings.scoring_mask is None: 70 doc_lengths = torch.ones( 71 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 72 ) 73 embeddings = doc_embeddings.embeddings[:, 0] 74 else: 75 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 76 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 77 doc_ids = index_batch.doc_ids 78 embeddings = self.process_embeddings(embeddings) 79 80 if embeddings.shape[0]: 81 self.index.add(embeddings.float().cpu()) 82 83 self.num_embeddings += embeddings.shape[0] 84 self.num_docs += len(doc_ids) 85 86 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 87 self.doc_ids.extend(doc_ids)
88 89
[docs] 90class FaissFlatIndexer(FaissIndexer): 91 INDEX_FACTORY = "Flat" 92
[docs] 93 def __init__( 94 self, 95 index_dir: Path, 96 index_config: "FaissFlatIndexConfig", 97 module: BiEncoderModule, 98 verbose: bool = False, 99 ) -> None: 100 super().__init__(index_dir, index_config, module, verbose) 101 self.index_config: FaissFlatIndexConfig
102 103 def to_gpu(self) -> None: 104 pass 105 106 def to_cpu(self) -> None: 107 pass
108 109 110class _FaissTrainIndexer(FaissIndexer): 111 112 INDEX_FACTORY = "" # class only acts as mixin 113 114 def __init__( 115 self, 116 index_dir: Path, 117 index_config: "_FaissTrainIndexConfig", 118 module: BiEncoderModule, 119 verbose: bool = False, 120 ) -> None: 121 super().__init__(index_dir, index_config, module, verbose) 122 if index_config.num_train_embeddings is None: 123 raise ValueError("num_train_embeddings must be set") 124 self.num_train_embeddings = index_config.num_train_embeddings 125 126 self._train_embeddings: torch.Tensor | None = torch.full( 127 (self.num_train_embeddings, self.module.config.embedding_dim), torch.nan, dtype=torch.float32 128 ) 129 130 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 131 embeddings = self._grab_train_embeddings(embeddings) 132 self._train() 133 return embeddings 134 135 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 136 if self._train_embeddings is not None: 137 # save training embeddings until num_train_embeddings is reached 138 # if num_train_embeddings overflows, save the remaining embeddings 139 start = self.num_embeddings 140 end = start + embeddings.shape[0] 141 end = min(self.num_train_embeddings, start + embeddings.shape[0]) 142 length = end - start 143 self._train_embeddings[start:end] = embeddings[:length] 144 self.num_embeddings += length 145 embeddings = embeddings[length:] 146 return embeddings 147 148 def _train(self, force: bool = False): 149 if self._train_embeddings is None: 150 return 151 if not force and self.num_embeddings < self.num_train_embeddings: 152 return 153 if torch.isnan(self._train_embeddings).any(): 154 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.") 155 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 156 self.index.train(self._train_embeddings) 157 if torch.cuda.is_available(): 158 self.to_cpu() 159 self.index.add(self._train_embeddings) 160 self._train_embeddings = None 161 self.set_verbosity(False) 162 163 def save(self) -> None: 164 if not self.index.is_trained: 165 self._train(force=True) 166 return super().save() 167 168
[docs] 169class FaissIVFIndexer(_FaissTrainIndexer): 170 INDEX_FACTORY = "IVF{num_centroids},Flat" 171
[docs] 172 def __init__( 173 self, 174 index_dir: Path, 175 index_config: "FaissIVFIndexConfig", 176 module: BiEncoderModule, 177 verbose: bool = False, 178 ) -> None: 179 # default faiss values 180 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43 181 max_points_per_centroid = 256 182 index_config.num_train_embeddings = ( 183 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid 184 ) 185 super().__init__(index_dir, index_config, module, verbose) 186 187 import faiss 188 189 ivf_index = faiss.extract_index_ivf(self.index) 190 if hasattr(ivf_index, "quantizer"): 191 quantizer = ivf_index.quantizer 192 if hasattr(faiss.downcast_index(quantizer), "hnsw"): 193 downcasted_quantizer = faiss.downcast_index(quantizer) 194 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
195 196 def to_gpu(self) -> None: 197 import faiss 198 199 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu 200 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html 201 clustering_index = faiss.index_cpu_to_all_gpus( 202 faiss.IndexFlat(self.module.config.embedding_dim, self.metric_type) 203 ) 204 clustering_index.verbose = self.verbose 205 index_ivf = faiss.extract_index_ivf(self.index) 206 index_ivf.clustering_index = clustering_index 207 208 def to_cpu(self) -> None: 209 import faiss 210 211 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") and hasattr(faiss, "index_cpu_to_gpu"): 212 self.index = faiss.index_gpu_to_cpu(self.index) 213 214 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e 215 # move index to cpu but leave quantizer on gpu 216 index_ivf = faiss.extract_index_ivf(self.index) 217 quantizer = index_ivf.quantizer 218 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer) 219 index_ivf.quantizer = gpu_quantizer 220 221 def set_verbosity(self, verbose: bool | None = None) -> None: 222 import faiss 223 224 verbose = verbose if verbose is not None else self.verbose 225 index = faiss.extract_index_ivf(self.index) 226 for elem in (index, index.quantizer): 227 setattr(elem, "verbose", verbose)
228 229
[docs] 230class FaissPQIndexer(_FaissTrainIndexer): 231 232 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}" 233
[docs] 234 def __init__( 235 self, 236 index_dir: Path, 237 index_config: "FaissPQIndexConfig", 238 module: BiEncoderModule, 239 verbose: bool = False, 240 ) -> None: 241 super().__init__(index_dir, index_config, module, verbose) 242 self.index_config: FaissPQIndexConfig
243 244 def to_gpu(self) -> None: 245 pass 246 247 def to_cpu(self) -> None: 248 pass
249 250
[docs] 251class FaissIVFPQIndexer(FaissIVFIndexer): 252 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}" 253
[docs] 254 def __init__( 255 self, 256 index_dir: Path, 257 index_config: "FaissIVFPQIndexConfig", 258 module: BiEncoderModule, 259 verbose: bool = False, 260 ) -> None: 261 import faiss 262 263 super().__init__(index_dir, index_config, module, verbose) 264 self.index_config: FaissIVFPQIndexConfig 265 266 index_ivf = faiss.extract_index_ivf(self.index) 267 index_ivf.make_direct_map()
268 269 def set_verbosity(self, verbose: bool | None = None) -> None: 270 super().set_verbosity(verbose) 271 import faiss 272 273 verbose = verbose if verbose is not None else self.verbose 274 index_ivf_pq = faiss.downcast_index(self.index.index) 275 for elem in ( 276 index_ivf_pq.pq, 277 index_ivf_pq.quantizer, 278 ): 279 setattr(elem, "verbose", verbose)
280 281
[docs] 282class FaissIndexConfig(IndexConfig): 283 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type} 284 indexer_class: Type[Indexer] = FaissIndexer
285 286
[docs] 287class FaissFlatIndexConfig(FaissIndexConfig): 288 indexer_class = FaissFlatIndexer
289 290 291class _FaissTrainIndexConfig(FaissIndexConfig): 292 293 indexer_class = _FaissTrainIndexer 294 295 def __init__(self, num_train_embeddings: int | None = None) -> None: 296 super().__init__() 297 self.num_train_embeddings = num_train_embeddings 298 299
[docs] 300class FaissIVFIndexConfig(_FaissTrainIndexConfig): 301 indexer_class = FaissIVFIndexer 302
[docs] 303 def __init__( 304 self, 305 num_train_embeddings: int | None = None, 306 num_centroids: int = 262144, 307 ef_construction: int = 40, 308 ) -> None: 309 super().__init__(num_train_embeddings) 310 self.num_centroids = num_centroids 311 self.ef_construction = ef_construction
312 313
[docs] 314class FaissPQIndexConfig(_FaissTrainIndexConfig): 315 indexer_class = FaissPQIndexer 316
[docs] 317 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None: 318 super().__init__(num_train_embeddings) 319 self.num_subquantizers = num_subquantizers 320 self.n_bits = n_bits
321 322
[docs] 323class FaissIVFPQIndexConfig(FaissIVFIndexConfig): 324 indexer_class = FaissIVFPQIndexer 325
[docs] 326 def __init__( 327 self, 328 num_train_embeddings: int | None = None, 329 num_centroids: int = 262144, 330 ef_construction: int = 40, 331 num_subquantizers: int = 16, 332 n_bits: int = 8, 333 ) -> None: 334 super().__init__(num_train_embeddings, num_centroids, ef_construction) 335 self.num_subquantizers = num_subquantizers 336 self.n_bits = n_bits