Source code for lightning_ir.retrieve.faiss.faiss_indexer

  1"""FAISS Indexer for Lightning IR Framework"""
  2
  3import warnings
  4from pathlib import Path
  5from typing import Type
  6
  7import torch
  8
  9from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 10from ...data import IndexBatch
 11from ...models import ColConfig, DprConfig
 12from ..base import IndexConfig, Indexer
 13
 14
[docs] 15class FaissIndexer(Indexer): 16 """Base class for FAISS indexers in the Lightning IR framework.""" 17 18 INDEX_FACTORY: str 19
[docs] 20 def __init__( 21 self, 22 index_dir: Path, 23 index_config: "FaissIndexConfig", 24 module: BiEncoderModule, 25 verbose: bool = False, 26 ) -> None: 27 """Initialize the FaissIndexer. 28 29 Args: 30 index_dir (Path): Directory where the index will be stored. 31 index_config (FaissIndexConfig): Configuration for the FAISS index. 32 module (BiEncoderModule): The BiEncoderModule to use for indexing. 33 verbose (bool): Whether to enable verbose output. Defaults to False. 34 Raises: 35 ValueError: If the similarity function is not supported. 36 """ 37 super().__init__(index_dir, index_config, module, verbose) 38 import faiss 39 40 similarity_function = self.module.config.similarity_function 41 if similarity_function in ("cosine", "dot"): 42 self.metric_type = faiss.METRIC_INNER_PRODUCT 43 else: 44 raise ValueError(f"similarity_function {similarity_function} unknown") 45 46 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict()) 47 if similarity_function == "cosine": 48 index_factory = "L2norm," + index_factory 49 self.index = faiss.index_factory(self.module.config.embedding_dim, index_factory, self.metric_type) 50 51 self.set_verbosity() 52 53 if torch.cuda.is_available(): 54 self.to_gpu()
55
[docs] 56 def to_gpu(self) -> None: 57 """Move the FAISS index to GPU.""" 58 pass
59
[docs] 60 def to_cpu(self) -> None: 61 """Move the FAISS index to CPU.""" 62 pass
63
[docs] 64 def set_verbosity(self, verbose: bool | None = None) -> None: 65 """Set the verbosity of the FAISS index. 66 67 Args: 68 verbose (bool | None): Whether to enable verbose output. If None, uses the index's current verbosity 69 setting. Defaults to None. 70 """ 71 self.index.verbose = self.verbose if verbose is None else verbose
72
[docs] 73 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 74 """Process embeddings before adding them to the FAISS index. 75 76 Args: 77 embeddings (torch.Tensor): The embeddings to process. 78 Returns: 79 torch.Tensor: The processed embeddings. 80 """ 81 return embeddings
82
[docs] 83 def save(self) -> None: 84 """Save the FAISS index to disk. 85 86 Raises: 87 ValueError: If the number of embeddings does not match the index's total number of entries. 88 """ 89 super().save() 90 import faiss 91 92 if self.num_embeddings != self.index.ntotal: 93 raise ValueError("number of embeddings does not match index.ntotal") 94 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"): 95 self.index = faiss.index_gpu_to_cpu(self.index) 96 97 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
98
[docs] 99 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 100 """Add embeddings to the FAISS index. 101 102 Args: 103 index_batch (IndexBatch): The batch containing document indices and embeddings. 104 output (BiEncoderOutput): The output from the bi-encoder module containing document embeddings. 105 Raises: 106 ValueError: If the document embeddings are not present in the output. 107 """ 108 doc_embeddings = output.doc_embeddings 109 if doc_embeddings is None: 110 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 111 if doc_embeddings.scoring_mask is None: 112 doc_lengths = torch.ones( 113 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 114 ) 115 embeddings = doc_embeddings.embeddings[:, 0] 116 else: 117 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 118 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 119 doc_ids = index_batch.doc_ids 120 embeddings = self.process_embeddings(embeddings) 121 122 if embeddings.shape[0]: 123 self.index.add(embeddings.float().cpu()) 124 125 self.num_embeddings += embeddings.shape[0] 126 self.num_docs += len(doc_ids) 127 128 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 129 self.doc_ids.extend(doc_ids)
130 131
[docs] 132class FaissFlatIndexer(FaissIndexer): 133 """FAISS Flat Indexer for exact nearest neighbor search using FAISS.""" 134 135 INDEX_FACTORY = "Flat" 136
[docs] 137 def __init__( 138 self, 139 index_dir: Path, 140 index_config: "FaissFlatIndexConfig", 141 module: BiEncoderModule, 142 verbose: bool = False, 143 ) -> None: 144 """Initialize the FaissFlatIndexer. 145 146 Args: 147 index_dir (Path): Directory where the index will be stored. 148 index_config (FaissFlatIndexConfig): Configuration for the FAISS flat index. 149 module (BiEncoderModule): The BiEncoderModule to use for indexing. 150 verbose (bool): Whether to enable verbose output. Defaults to False. 151 """ 152 super().__init__(index_dir, index_config, module, verbose) 153 self.index_config: FaissFlatIndexConfig
154
[docs] 155 def to_gpu(self) -> None: 156 """Move the FAISS flat index to GPU.""" 157 pass
158
[docs] 159 def to_cpu(self) -> None: 160 """Move the FAISS flat index to CPU.""" 161 pass
162 163 164class _FaissTrainIndexer(FaissIndexer): 165 """Base class for FAISS indexers that require training on embeddings before indexing.""" 166 167 INDEX_FACTORY = "" # class only acts as mixin 168 169 def __init__( 170 self, 171 index_dir: Path, 172 index_config: "_FaissTrainIndexConfig", 173 module: BiEncoderModule, 174 verbose: bool = False, 175 ) -> None: 176 """Initialize the _FaissTrainIndexer. 177 178 Args: 179 index_dir (Path): Directory where the index will be stored. 180 index_config (_FaissTrainIndexConfig): Configuration for the FAISS index that requires training. 181 module (BiEncoderModule): The BiEncoderModule to use for indexing. 182 verbose (bool): Whether to enable verbose output. Defaults to False. 183 Raises: 184 ValueError: If num_train_embeddings is not set in the index configuration. 185 """ 186 super().__init__(index_dir, index_config, module, verbose) 187 if index_config.num_train_embeddings is None: 188 raise ValueError("num_train_embeddings must be set") 189 self.num_train_embeddings = index_config.num_train_embeddings 190 191 self._train_embeddings: torch.Tensor | None = torch.full( 192 (self.num_train_embeddings, self.module.config.embedding_dim), torch.nan, dtype=torch.float32 193 ) 194 195 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 196 """Process embeddings before adding them to the FAISS index. 197 198 Args: 199 embeddings (torch.Tensor): The embeddings to process. 200 Returns: 201 torch.Tensor: The processed embeddings. 202 """ 203 embeddings = self._grab_train_embeddings(embeddings) 204 self._train() 205 return embeddings 206 207 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 208 if self._train_embeddings is not None: 209 # save training embeddings until num_train_embeddings is reached 210 # if num_train_embeddings overflows, save the remaining embeddings 211 start = self.num_embeddings 212 end = start + embeddings.shape[0] 213 end = min(self.num_train_embeddings, start + embeddings.shape[0]) 214 length = end - start 215 self._train_embeddings[start:end] = embeddings[:length] 216 self.num_embeddings += length 217 embeddings = embeddings[length:] 218 return embeddings 219 220 def _train(self, force: bool = False): 221 if self._train_embeddings is None: 222 return 223 if not force and self.num_embeddings < self.num_train_embeddings: 224 return 225 if torch.isnan(self._train_embeddings).any(): 226 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.") 227 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 228 self.index.train(self._train_embeddings) 229 if torch.cuda.is_available(): 230 self.to_cpu() 231 self.index.add(self._train_embeddings) 232 self._train_embeddings = None 233 self.set_verbosity(False) 234 235 def save(self) -> None: 236 if not self.index.is_trained: 237 self._train(force=True) 238 return super().save() 239 240
[docs] 241class FaissIVFIndexer(_FaissTrainIndexer): 242 """FAISS IVF Indexer for approximate nearest neighbor search using FAISS with Inverted File System (IVF).""" 243 244 INDEX_FACTORY = "IVF{num_centroids},Flat" 245
[docs] 246 def __init__( 247 self, 248 index_dir: Path, 249 index_config: "FaissIVFIndexConfig", 250 module: BiEncoderModule, 251 verbose: bool = False, 252 ) -> None: 253 """Initialize the FaissIVFIndexer. 254 255 Args: 256 index_dir (Path): Directory where the index will be stored. 257 index_config (FaissIVFIndexConfig): Configuration for the FAISS IVF index. 258 module (BiEncoderModule): The BiEncoderModule to use for indexing. 259 verbose (bool): Whether to enable verbose output. Defaults to False. 260 """ 261 # default faiss values 262 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43 263 max_points_per_centroid = 256 264 index_config.num_train_embeddings = ( 265 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid 266 ) 267 super().__init__(index_dir, index_config, module, verbose) 268 269 import faiss 270 271 ivf_index = faiss.extract_index_ivf(self.index) 272 if hasattr(ivf_index, "quantizer"): 273 quantizer = ivf_index.quantizer 274 if hasattr(faiss.downcast_index(quantizer), "hnsw"): 275 downcasted_quantizer = faiss.downcast_index(quantizer) 276 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
277
[docs] 278 def to_gpu(self) -> None: 279 """Move the FAISS IVF index to GPU.""" 280 import faiss 281 282 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu 283 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html 284 if faiss.get_num_gpus() == 0: 285 return 286 clustering_index = faiss.index_cpu_to_all_gpus( 287 faiss.IndexFlat(self.module.config.embedding_dim, self.metric_type) 288 ) 289 clustering_index.verbose = self.verbose 290 index_ivf = faiss.extract_index_ivf(self.index) 291 index_ivf.clustering_index = clustering_index
292
[docs] 293 def to_cpu(self) -> None: 294 """Move the FAISS IVF index to CPU.""" 295 import faiss 296 297 if faiss.get_num_gpus() == 0: 298 return 299 self.index = faiss.index_gpu_to_cpu(self.index) 300 301 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e 302 # move index to cpu but leave quantizer on gpu 303 index_ivf = faiss.extract_index_ivf(self.index) 304 quantizer = index_ivf.quantizer 305 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer) 306 index_ivf.quantizer = gpu_quantizer
307
[docs] 308 def set_verbosity(self, verbose: bool | None = None) -> None: 309 """Set the verbosity of the FAISS IVF index. 310 311 Args: 312 verbose (bool | None): Whether to enable verbose output. Defaults to None. 313 """ 314 import faiss 315 316 verbose = verbose if verbose is not None else self.verbose 317 index = faiss.extract_index_ivf(self.index) 318 for elem in (index, index.quantizer): 319 setattr(elem, "verbose", verbose)
320 321
[docs] 322class FaissPQIndexer(_FaissTrainIndexer): 323 """FAISS PQ Indexer for approximate nearest neighbor search using Product Quantization (PQ).""" 324 325 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}" 326
[docs] 327 def __init__( 328 self, 329 index_dir: Path, 330 index_config: "FaissPQIndexConfig", 331 module: BiEncoderModule, 332 verbose: bool = False, 333 ) -> None: 334 """Initialize the FaissPQIndexer. 335 336 Args: 337 index_dir (Path): Directory where the index will be stored. 338 index_config (FaissPQIndexConfig): Configuration for the FAISS PQ index. 339 module (BiEncoderModule): The BiEncoderModule to use for indexing. 340 verbose (bool): Whether to enable verbose output. Defaults to False. 341 """ 342 super().__init__(index_dir, index_config, module, verbose) 343 self.index_config: FaissPQIndexConfig
344
[docs] 345 def to_gpu(self) -> None: 346 """Move the FAISS PQ index to GPU.""" 347 pass
348
[docs] 349 def to_cpu(self) -> None: 350 """Move the FAISS PQ index to CPU.""" 351 pass
352 353
[docs] 354class FaissIVFPQIndexer(FaissIVFIndexer): 355 """FAISS IVFPQ Indexer for approximate nearest neighbor search using Inverted File System (IVF) with Product 356 Quantization (PQ).""" 357 358 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}" 359
[docs] 360 def __init__( 361 self, 362 index_dir: Path, 363 index_config: "FaissIVFPQIndexConfig", 364 module: BiEncoderModule, 365 verbose: bool = False, 366 ) -> None: 367 """Initialize the FaissIVFPQIndexer. 368 369 Args: 370 index_dir (Path): Directory where the index will be stored. 371 index_config (FaissIVFPQIndexConfig): Configuration for the FAISS IVFPQ index. 372 module (BiEncoderModule): The BiEncoderModule to use for indexing. 373 verbose (bool): Whether to enable verbose output. Defaults to False. 374 """ 375 import faiss 376 377 super().__init__(index_dir, index_config, module, verbose) 378 self.index_config: FaissIVFPQIndexConfig 379 380 index_ivf = faiss.extract_index_ivf(self.index) 381 index_ivf.make_direct_map()
382
[docs] 383 def set_verbosity(self, verbose: bool | None = None) -> None: 384 """Set the verbosity of the FAISS IVFPQ index. 385 386 Args: 387 verbose (bool | None): Whether to enable verbose output. Defaults to None. 388 """ 389 super().set_verbosity(verbose) 390 import faiss 391 392 verbose = verbose if verbose is not None else self.verbose 393 index_ivf_pq = faiss.downcast_index(self.index.index) 394 for elem in ( 395 index_ivf_pq.pq, 396 index_ivf_pq.quantizer, 397 ): 398 setattr(elem, "verbose", verbose)
399 400
[docs] 401class FaissIndexConfig(IndexConfig): 402 """Configuration class for FAISS indexers in the Lightning IR framework.""" 403 404 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type} 405 indexer_class: Type[Indexer] = FaissIndexer
406 407
[docs] 408class FaissFlatIndexConfig(FaissIndexConfig): 409 """Configuration class for FAISS flat indexers in the Lightning IR framework.""" 410 411 indexer_class = FaissFlatIndexer
412 413 414class _FaissTrainIndexConfig(FaissIndexConfig): 415 """Base configuration class for FAISS indexers that require training on embeddings before indexing.""" 416 417 indexer_class = _FaissTrainIndexer 418 419 def __init__(self, num_train_embeddings: int | None = None) -> None: 420 """Initialize the _FaissTrainIndexConfig. 421 422 Args: 423 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 424 be set later. Defaults to None. 425 """ 426 super().__init__() 427 self.num_train_embeddings = num_train_embeddings 428 429
[docs] 430class FaissIVFIndexConfig(_FaissTrainIndexConfig): 431 """Configuration class for FAISS IVF indexers in the Lightning IR framework.""" 432 433 indexer_class = FaissIVFIndexer 434
[docs] 435 def __init__( 436 self, 437 num_train_embeddings: int | None = None, 438 num_centroids: int = 262144, 439 ef_construction: int = 40, 440 ) -> None: 441 """Initialize the FaissIVFIndexConfig. 442 443 Args: 444 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will be 445 set later. Defaults to None. 446 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144. 447 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40. 448 """ 449 super().__init__(num_train_embeddings) 450 self.num_centroids = num_centroids 451 self.ef_construction = ef_construction
452 453
[docs] 454class FaissPQIndexConfig(_FaissTrainIndexConfig): 455 """Configuration class for FAISS PQ indexers in the Lightning IR framework.""" 456 457 indexer_class = FaissPQIndexer 458
[docs] 459 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None: 460 """Initialize the FaissPQIndexConfig. 461 462 Args: 463 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 464 be set later. Defaults to None. 465 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16. 466 n_bits (int): Number of bits for the PQ index. Defaults to 8. 467 """ 468 super().__init__(num_train_embeddings) 469 self.num_subquantizers = num_subquantizers 470 self.n_bits = n_bits
471 472
[docs] 473class FaissIVFPQIndexConfig(FaissIVFIndexConfig): 474 """Configuration class for FAISS IVFPQ indexers in the Lightning IR framework.""" 475 476 indexer_class = FaissIVFPQIndexer 477
[docs] 478 def __init__( 479 self, 480 num_train_embeddings: int | None = None, 481 num_centroids: int = 262144, 482 ef_construction: int = 40, 483 num_subquantizers: int = 16, 484 n_bits: int = 8, 485 ) -> None: 486 """Initialize the FaissIVFPQIndexConfig. 487 488 Args: 489 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 490 be set later. Defaults to None. 491 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144. 492 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40. 493 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16. 494 n_bits (int): Number of bits for the PQ index. Defaults to 8. 495 """ 496 super().__init__(num_train_embeddings, num_centroids, ef_construction) 497 self.num_subquantizers = num_subquantizers 498 self.n_bits = n_bits