Source code for lightning_ir.retrieve.faiss.faiss_indexer

  1"""FAISS Indexer for Lightning IR Framework"""
  2
  3import warnings
  4from pathlib import Path
  5
  6import torch
  7
  8from ...bi_encoder import BiEncoderModule, BiEncoderOutput
  9from ...data import IndexBatch
 10from ...models import ColConfig, DprConfig
 11from ..base import IndexConfig, Indexer
 12
 13
[docs] 14class FaissIndexer(Indexer): 15 """Base class for FAISS indexers in the Lightning IR framework.""" 16 17 INDEX_FACTORY: str 18
[docs] 19 def __init__( 20 self, 21 index_dir: Path, 22 index_config: "FaissIndexConfig", 23 module: BiEncoderModule, 24 verbose: bool = False, 25 ) -> None: 26 """Initialize the FaissIndexer. 27 28 Args: 29 index_dir (Path): Directory where the index will be stored. 30 index_config (FaissIndexConfig): Configuration for the FAISS index. 31 module (BiEncoderModule): The BiEncoderModule to use for indexing. 32 verbose (bool): Whether to enable verbose output. Defaults to False. 33 Raises: 34 ValueError: If the similarity function is not supported. 35 """ 36 super().__init__(index_dir, index_config, module, verbose) 37 import faiss 38 39 similarity_function = self.module.config.similarity_function 40 if similarity_function in ("cosine", "dot"): 41 self.metric_type = faiss.METRIC_INNER_PRODUCT 42 else: 43 raise ValueError(f"similarity_function {similarity_function} unknown") 44 45 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict()) 46 if similarity_function == "cosine": 47 index_factory = "L2norm," + index_factory 48 self.index = faiss.index_factory(self.module.config.embedding_dim, index_factory, self.metric_type) 49 50 self.set_verbosity() 51 52 if torch.cuda.is_available(): 53 self.to_gpu()
54
[docs] 55 def to_gpu(self) -> None: 56 """Move the FAISS index to GPU.""" 57 pass
58
[docs] 59 def to_cpu(self) -> None: 60 """Move the FAISS index to CPU.""" 61 pass
62
[docs] 63 def set_verbosity(self, verbose: bool | None = None) -> None: 64 """set the verbosity of the FAISS index. 65 66 Args: 67 verbose (bool | None): Whether to enable verbose output. If None, uses the index's current verbosity 68 setting. Defaults to None. 69 """ 70 self.index.verbose = self.verbose if verbose is None else verbose
71
[docs] 72 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 73 """Process embeddings before adding them to the FAISS index. 74 75 Args: 76 embeddings (torch.Tensor): The embeddings to process. 77 Returns: 78 torch.Tensor: The processed embeddings. 79 """ 80 return embeddings
81
[docs] 82 def save(self) -> None: 83 """Save the FAISS index to disk. 84 85 Raises: 86 ValueError: If the number of embeddings does not match the index's total number of entries. 87 """ 88 super().save() 89 import faiss 90 91 if self.num_embeddings != self.index.ntotal: 92 raise ValueError("number of embeddings does not match index.ntotal") 93 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"): 94 self.index = faiss.index_gpu_to_cpu(self.index) 95 96 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
97
[docs] 98 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 99 """Add embeddings to the FAISS index. 100 101 Args: 102 index_batch (IndexBatch): The batch containing document indices and embeddings. 103 output (BiEncoderOutput): The output from the bi-encoder module containing document embeddings. 104 Raises: 105 ValueError: If the document embeddings are not present in the output. 106 """ 107 doc_embeddings = output.doc_embeddings 108 if doc_embeddings is None: 109 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 110 if doc_embeddings.scoring_mask is None: 111 doc_lengths = torch.ones( 112 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 113 ) 114 embeddings = doc_embeddings.embeddings[:, 0] 115 else: 116 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 117 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 118 doc_ids = index_batch.doc_ids 119 embeddings = self.process_embeddings(embeddings) 120 121 if embeddings.shape[0]: 122 self.index.add(embeddings.float().cpu()) 123 124 self.num_embeddings += embeddings.shape[0] 125 self.num_docs += len(doc_ids) 126 127 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 128 self.doc_ids.extend(doc_ids)
129 130
[docs] 131class FaissFlatIndexer(FaissIndexer): 132 """FAISS Flat Indexer for exact nearest neighbor search using FAISS.""" 133 134 INDEX_FACTORY = "Flat" 135
[docs] 136 def __init__( 137 self, 138 index_dir: Path, 139 index_config: "FaissFlatIndexConfig", 140 module: BiEncoderModule, 141 verbose: bool = False, 142 ) -> None: 143 """Initialize the FaissFlatIndexer. 144 145 Args: 146 index_dir (Path): Directory where the index will be stored. 147 index_config (FaissFlatIndexConfig): Configuration for the FAISS flat index. 148 module (BiEncoderModule): The BiEncoderModule to use for indexing. 149 verbose (bool): Whether to enable verbose output. Defaults to False. 150 """ 151 super().__init__(index_dir, index_config, module, verbose) 152 self.index_config: FaissFlatIndexConfig
153
[docs] 154 def to_gpu(self) -> None: 155 """Move the FAISS flat index to GPU.""" 156 pass
157
[docs] 158 def to_cpu(self) -> None: 159 """Move the FAISS flat index to CPU.""" 160 pass
161 162 163class _FaissTrainIndexer(FaissIndexer): 164 """Base class for FAISS indexers that require training on embeddings before indexing.""" 165 166 INDEX_FACTORY = "" # class only acts as mixin 167 168 def __init__( 169 self, 170 index_dir: Path, 171 index_config: "_FaissTrainIndexConfig", 172 module: BiEncoderModule, 173 verbose: bool = False, 174 ) -> None: 175 """Initialize the _FaissTrainIndexer. 176 177 Args: 178 index_dir (Path): Directory where the index will be stored. 179 index_config (_FaissTrainIndexConfig): Configuration for the FAISS index that requires training. 180 module (BiEncoderModule): The BiEncoderModule to use for indexing. 181 verbose (bool): Whether to enable verbose output. Defaults to False. 182 Raises: 183 ValueError: If num_train_embeddings is not set in the index configuration. 184 """ 185 super().__init__(index_dir, index_config, module, verbose) 186 if index_config.num_train_embeddings is None: 187 raise ValueError("num_train_embeddings must be set") 188 self.num_train_embeddings = index_config.num_train_embeddings 189 190 self._train_embeddings: torch.Tensor | None = torch.full( 191 (self.num_train_embeddings, self.module.config.embedding_dim), torch.nan, dtype=torch.float32 192 ) 193 194 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 195 """Process embeddings before adding them to the FAISS index. 196 197 Args: 198 embeddings (torch.Tensor): The embeddings to process. 199 Returns: 200 torch.Tensor: The processed embeddings. 201 """ 202 embeddings = self._grab_train_embeddings(embeddings) 203 self._train() 204 return embeddings 205 206 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 207 if self._train_embeddings is not None: 208 # save training embeddings until num_train_embeddings is reached 209 # if num_train_embeddings overflows, save the remaining embeddings 210 start = self.num_embeddings 211 end = start + embeddings.shape[0] 212 end = min(self.num_train_embeddings, start + embeddings.shape[0]) 213 length = end - start 214 self._train_embeddings[start:end] = embeddings[:length] 215 self.num_embeddings += length 216 embeddings = embeddings[length:] 217 return embeddings 218 219 def _train(self, force: bool = False): 220 if self._train_embeddings is None: 221 return 222 if not force and self.num_embeddings < self.num_train_embeddings: 223 return 224 if torch.isnan(self._train_embeddings).any(): 225 warnings.warn( 226 "Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.", 227 stacklevel=2, 228 ) 229 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 230 self.index.train(self._train_embeddings) 231 if torch.cuda.is_available(): 232 self.to_cpu() 233 self.index.add(self._train_embeddings) 234 self._train_embeddings = None 235 self.set_verbosity(False) 236 237 def save(self) -> None: 238 if not self.index.is_trained: 239 self._train(force=True) 240 return super().save() 241 242
[docs] 243class FaissIVFIndexer(_FaissTrainIndexer): 244 """FAISS IVF Indexer for approximate nearest neighbor search using FAISS with Inverted File System (IVF).""" 245 246 INDEX_FACTORY = "IVF{num_centroids},Flat" 247
[docs] 248 def __init__( 249 self, 250 index_dir: Path, 251 index_config: "FaissIVFIndexConfig", 252 module: BiEncoderModule, 253 verbose: bool = False, 254 ) -> None: 255 """Initialize the FaissIVFIndexer. 256 257 Args: 258 index_dir (Path): Directory where the index will be stored. 259 index_config (FaissIVFIndexConfig): Configuration for the FAISS IVF index. 260 module (BiEncoderModule): The BiEncoderModule to use for indexing. 261 verbose (bool): Whether to enable verbose output. Defaults to False. 262 """ 263 # default faiss values 264 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43 265 max_points_per_centroid = 256 266 index_config.num_train_embeddings = ( 267 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid 268 ) 269 super().__init__(index_dir, index_config, module, verbose) 270 271 import faiss 272 273 ivf_index = faiss.extract_index_ivf(self.index) 274 if hasattr(ivf_index, "quantizer"): 275 quantizer = ivf_index.quantizer 276 if hasattr(faiss.downcast_index(quantizer), "hnsw"): 277 downcasted_quantizer = faiss.downcast_index(quantizer) 278 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
279
[docs] 280 def to_gpu(self) -> None: 281 """Move the FAISS IVF index to GPU.""" 282 import faiss 283 284 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu 285 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html 286 if faiss.get_num_gpus() == 0: 287 return 288 clustering_index = faiss.index_cpu_to_all_gpus( 289 faiss.IndexFlat(self.module.config.embedding_dim, self.metric_type) 290 ) 291 clustering_index.verbose = self.verbose 292 index_ivf = faiss.extract_index_ivf(self.index) 293 index_ivf.clustering_index = clustering_index
294
[docs] 295 def to_cpu(self) -> None: 296 """Move the FAISS IVF index to CPU.""" 297 import faiss 298 299 if faiss.get_num_gpus() == 0: 300 return 301 self.index = faiss.index_gpu_to_cpu(self.index) 302 303 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e 304 # move index to cpu but leave quantizer on gpu 305 index_ivf = faiss.extract_index_ivf(self.index) 306 quantizer = index_ivf.quantizer 307 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer) 308 index_ivf.quantizer = gpu_quantizer
309
[docs] 310 def set_verbosity(self, verbose: bool | None = None) -> None: 311 """set the verbosity of the FAISS IVF index. 312 313 Args: 314 verbose (bool | None): Whether to enable verbose output. Defaults to None. 315 """ 316 import faiss 317 318 verbose = verbose if verbose is not None else self.verbose 319 index = faiss.extract_index_ivf(self.index) 320 for elem in (index, index.quantizer): 321 elem.verbose = verbose
322 323
[docs] 324class FaissPQIndexer(_FaissTrainIndexer): 325 """FAISS PQ Indexer for approximate nearest neighbor search using Product Quantization (PQ).""" 326 327 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}" 328
[docs] 329 def __init__( 330 self, 331 index_dir: Path, 332 index_config: "FaissPQIndexConfig", 333 module: BiEncoderModule, 334 verbose: bool = False, 335 ) -> None: 336 """Initialize the FaissPQIndexer. 337 338 Args: 339 index_dir (Path): Directory where the index will be stored. 340 index_config (FaissPQIndexConfig): Configuration for the FAISS PQ index. 341 module (BiEncoderModule): The BiEncoderModule to use for indexing. 342 verbose (bool): Whether to enable verbose output. Defaults to False. 343 """ 344 super().__init__(index_dir, index_config, module, verbose) 345 self.index_config: FaissPQIndexConfig
346
[docs] 347 def to_gpu(self) -> None: 348 """Move the FAISS PQ index to GPU.""" 349 pass
350
[docs] 351 def to_cpu(self) -> None: 352 """Move the FAISS PQ index to CPU.""" 353 pass
354 355
[docs] 356class FaissIVFPQIndexer(FaissIVFIndexer): 357 """FAISS IVFPQ Indexer for approximate nearest neighbor search using Inverted File System (IVF) with Product 358 Quantization (PQ).""" 359 360 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}" 361
[docs] 362 def __init__( 363 self, 364 index_dir: Path, 365 index_config: "FaissIVFPQIndexConfig", 366 module: BiEncoderModule, 367 verbose: bool = False, 368 ) -> None: 369 """Initialize the FaissIVFPQIndexer. 370 371 Args: 372 index_dir (Path): Directory where the index will be stored. 373 index_config (FaissIVFPQIndexConfig): Configuration for the FAISS IVFPQ index. 374 module (BiEncoderModule): The BiEncoderModule to use for indexing. 375 verbose (bool): Whether to enable verbose output. Defaults to False. 376 """ 377 import faiss 378 379 super().__init__(index_dir, index_config, module, verbose) 380 self.index_config: FaissIVFPQIndexConfig 381 382 index_ivf = faiss.extract_index_ivf(self.index) 383 index_ivf.make_direct_map()
384
[docs] 385 def set_verbosity(self, verbose: bool | None = None) -> None: 386 """set the verbosity of the FAISS IVFPQ index. 387 388 Args: 389 verbose (bool | None): Whether to enable verbose output. Defaults to None. 390 """ 391 super().set_verbosity(verbose) 392 import faiss 393 394 verbose = verbose if verbose is not None else self.verbose 395 index_ivf_pq = faiss.downcast_index(self.index.index) 396 for elem in ( 397 index_ivf_pq.pq, 398 index_ivf_pq.quantizer, 399 ): 400 elem.verbose = verbose
401 402
[docs] 403class FaissIndexConfig(IndexConfig): 404 """Configuration class for FAISS indexers in the Lightning IR framework.""" 405 406 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type} 407 indexer_class: type[Indexer] = FaissIndexer
408 409
[docs] 410class FaissFlatIndexConfig(FaissIndexConfig): 411 """Configuration class for FAISS flat indexers in the Lightning IR framework.""" 412 413 indexer_class = FaissFlatIndexer
414 415 416class _FaissTrainIndexConfig(FaissIndexConfig): 417 """Base configuration class for FAISS indexers that require training on embeddings before indexing.""" 418 419 indexer_class = _FaissTrainIndexer 420 421 def __init__(self, num_train_embeddings: int | None = None) -> None: 422 """Initialize the _FaissTrainIndexConfig. 423 424 Args: 425 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 426 be set later. Defaults to None. 427 """ 428 super().__init__() 429 self.num_train_embeddings = num_train_embeddings 430 431
[docs] 432class FaissIVFIndexConfig(_FaissTrainIndexConfig): 433 """Configuration class for FAISS IVF indexers in the Lightning IR framework.""" 434 435 indexer_class = FaissIVFIndexer 436
[docs] 437 def __init__( 438 self, 439 num_train_embeddings: int | None = None, 440 num_centroids: int = 262144, 441 ef_construction: int = 40, 442 ) -> None: 443 """Initialize the FaissIVFIndexConfig. 444 445 Args: 446 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will be 447 set later. Defaults to None. 448 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144. 449 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40. 450 """ 451 super().__init__(num_train_embeddings) 452 self.num_centroids = num_centroids 453 self.ef_construction = ef_construction
454 455
[docs] 456class FaissPQIndexConfig(_FaissTrainIndexConfig): 457 """Configuration class for FAISS PQ indexers in the Lightning IR framework.""" 458 459 indexer_class = FaissPQIndexer 460
[docs] 461 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None: 462 """Initialize the FaissPQIndexConfig. 463 464 Args: 465 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 466 be set later. Defaults to None. 467 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16. 468 n_bits (int): Number of bits for the PQ index. Defaults to 8. 469 """ 470 super().__init__(num_train_embeddings) 471 self.num_subquantizers = num_subquantizers 472 self.n_bits = n_bits
473 474
[docs] 475class FaissIVFPQIndexConfig(FaissIVFIndexConfig): 476 """Configuration class for FAISS IVFPQ indexers in the Lightning IR framework.""" 477 478 indexer_class = FaissIVFPQIndexer 479
[docs] 480 def __init__( 481 self, 482 num_train_embeddings: int | None = None, 483 num_centroids: int = 262144, 484 ef_construction: int = 40, 485 num_subquantizers: int = 16, 486 n_bits: int = 8, 487 ) -> None: 488 """Initialize the FaissIVFPQIndexConfig. 489 490 Args: 491 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 492 be set later. Defaults to None. 493 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144. 494 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40. 495 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16. 496 n_bits (int): Number of bits for the PQ index. Defaults to 8. 497 """ 498 super().__init__(num_train_embeddings, num_centroids, ef_construction) 499 self.num_subquantizers = num_subquantizers 500 self.n_bits = n_bits