Source code for lightning_ir.retrieve.faiss.faiss_indexer

  1"""FAISS Indexer for Lightning IR Framework"""
  2
  3import warnings
  4from pathlib import Path
  5from typing import Type
  6
  7import torch
  8
  9from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 10from ...data import IndexBatch
 11from ...models import ColConfig, DprConfig
 12from ..base import IndexConfig, Indexer
 13
 14
[docs] 15class FaissIndexer(Indexer): 16 """Base class for FAISS indexers in the Lightning IR framework.""" 17 18 INDEX_FACTORY: str 19
[docs] 20 def __init__( 21 self, 22 index_dir: Path, 23 index_config: "FaissIndexConfig", 24 module: BiEncoderModule, 25 verbose: bool = False, 26 ) -> None: 27 """Initialize the FaissIndexer. 28 29 Args: 30 index_dir (Path): Directory where the index will be stored. 31 index_config (FaissIndexConfig): Configuration for the FAISS index. 32 module (BiEncoderModule): The BiEncoderModule to use for indexing. 33 verbose (bool): Whether to enable verbose output. Defaults to False. 34 Raises: 35 ValueError: If the similarity function is not supported. 36 """ 37 super().__init__(index_dir, index_config, module, verbose) 38 import faiss 39 40 similarity_function = self.module.config.similarity_function 41 if similarity_function in ("cosine", "dot"): 42 self.metric_type = faiss.METRIC_INNER_PRODUCT 43 else: 44 raise ValueError(f"similarity_function {similarity_function} unknown") 45 46 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict()) 47 if similarity_function == "cosine": 48 index_factory = "L2norm," + index_factory 49 self.index = faiss.index_factory(self.module.config.embedding_dim, index_factory, self.metric_type) 50 51 self.set_verbosity() 52 53 if torch.cuda.is_available(): 54 self.to_gpu()
55
[docs] 56 def to_gpu(self) -> None: 57 """Move the FAISS index to GPU.""" 58 pass
59
[docs] 60 def to_cpu(self) -> None: 61 """Move the FAISS index to CPU.""" 62 pass
63
[docs] 64 def set_verbosity(self, verbose: bool | None = None) -> None: 65 """Set the verbosity of the FAISS index. 66 67 Args: 68 verbose (bool | None): Whether to enable verbose output. If None, uses the index's current verbosity 69 setting. Defaults to None. 70 """ 71 self.index.verbose = self.verbose if verbose is None else verbose
72
[docs] 73 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 74 """Process embeddings before adding them to the FAISS index. 75 76 Args: 77 embeddings (torch.Tensor): The embeddings to process. 78 Returns: 79 torch.Tensor: The processed embeddings. 80 """ 81 return embeddings
82
[docs] 83 def save(self) -> None: 84 """Save the FAISS index to disk. 85 86 Raises: 87 ValueError: If the number of embeddings does not match the index's total number of entries. 88 """ 89 super().save() 90 import faiss 91 92 if self.num_embeddings != self.index.ntotal: 93 raise ValueError("number of embeddings does not match index.ntotal") 94 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"): 95 self.index = faiss.index_gpu_to_cpu(self.index) 96 97 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
98
[docs] 99 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 100 """Add embeddings to the FAISS index. 101 102 Args: 103 index_batch (IndexBatch): The batch containing document indices and embeddings. 104 output (BiEncoderOutput): The output from the bi-encoder module containing document embeddings. 105 Raises: 106 ValueError: If the document embeddings are not present in the output. 107 """ 108 doc_embeddings = output.doc_embeddings 109 if doc_embeddings is None: 110 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 111 if doc_embeddings.scoring_mask is None: 112 doc_lengths = torch.ones( 113 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 114 ) 115 embeddings = doc_embeddings.embeddings[:, 0] 116 else: 117 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 118 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 119 doc_ids = index_batch.doc_ids 120 embeddings = self.process_embeddings(embeddings) 121 122 if embeddings.shape[0]: 123 self.index.add(embeddings.float().cpu()) 124 125 self.num_embeddings += embeddings.shape[0] 126 self.num_docs += len(doc_ids) 127 128 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 129 self.doc_ids.extend(doc_ids)
130 131
[docs] 132class FaissFlatIndexer(FaissIndexer): 133 """FAISS Flat Indexer for exact nearest neighbor search using FAISS.""" 134 135 INDEX_FACTORY = "Flat" 136
[docs] 137 def __init__( 138 self, 139 index_dir: Path, 140 index_config: "FaissFlatIndexConfig", 141 module: BiEncoderModule, 142 verbose: bool = False, 143 ) -> None: 144 """Initialize the FaissFlatIndexer. 145 146 Args: 147 index_dir (Path): Directory where the index will be stored. 148 index_config (FaissFlatIndexConfig): Configuration for the FAISS flat index. 149 module (BiEncoderModule): The BiEncoderModule to use for indexing. 150 verbose (bool): Whether to enable verbose output. Defaults to False. 151 """ 152 super().__init__(index_dir, index_config, module, verbose) 153 self.index_config: FaissFlatIndexConfig
154
[docs] 155 def to_gpu(self) -> None: 156 """Move the FAISS flat index to GPU.""" 157 pass
158
[docs] 159 def to_cpu(self) -> None: 160 """Move the FAISS flat index to CPU.""" 161 pass
162 163 164class _FaissTrainIndexer(FaissIndexer): 165 """Base class for FAISS indexers that require training on embeddings before indexing.""" 166 167 INDEX_FACTORY = "" # class only acts as mixin 168 169 def __init__( 170 self, 171 index_dir: Path, 172 index_config: "_FaissTrainIndexConfig", 173 module: BiEncoderModule, 174 verbose: bool = False, 175 ) -> None: 176 """Initialize the _FaissTrainIndexer. 177 178 Args: 179 index_dir (Path): Directory where the index will be stored. 180 index_config (_FaissTrainIndexConfig): Configuration for the FAISS index that requires training. 181 module (BiEncoderModule): The BiEncoderModule to use for indexing. 182 verbose (bool): Whether to enable verbose output. Defaults to False. 183 Raises: 184 ValueError: If num_train_embeddings is not set in the index configuration. 185 """ 186 super().__init__(index_dir, index_config, module, verbose) 187 if index_config.num_train_embeddings is None: 188 raise ValueError("num_train_embeddings must be set") 189 self.num_train_embeddings = index_config.num_train_embeddings 190 191 self._train_embeddings: torch.Tensor | None = torch.full( 192 (self.num_train_embeddings, self.module.config.embedding_dim), torch.nan, dtype=torch.float32 193 ) 194 195 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 196 """Process embeddings before adding them to the FAISS index. 197 198 Args: 199 embeddings (torch.Tensor): The embeddings to process. 200 Returns: 201 torch.Tensor: The processed embeddings. 202 """ 203 embeddings = self._grab_train_embeddings(embeddings) 204 self._train() 205 return embeddings 206 207 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 208 if self._train_embeddings is not None: 209 # save training embeddings until num_train_embeddings is reached 210 # if num_train_embeddings overflows, save the remaining embeddings 211 start = self.num_embeddings 212 end = start + embeddings.shape[0] 213 end = min(self.num_train_embeddings, start + embeddings.shape[0]) 214 length = end - start 215 self._train_embeddings[start:end] = embeddings[:length] 216 self.num_embeddings += length 217 embeddings = embeddings[length:] 218 return embeddings 219 220 def _train(self, force: bool = False): 221 if self._train_embeddings is None: 222 return 223 if not force and self.num_embeddings < self.num_train_embeddings: 224 return 225 if torch.isnan(self._train_embeddings).any(): 226 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.") 227 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 228 self.index.train(self._train_embeddings) 229 if torch.cuda.is_available(): 230 self.to_cpu() 231 self.index.add(self._train_embeddings) 232 self._train_embeddings = None 233 self.set_verbosity(False) 234 235 def save(self) -> None: 236 if not self.index.is_trained: 237 self._train(force=True) 238 return super().save() 239 240
[docs] 241class FaissIVFIndexer(_FaissTrainIndexer): 242 """FAISS IVF Indexer for approximate nearest neighbor search using FAISS with Inverted File System (IVF).""" 243 244 INDEX_FACTORY = "IVF{num_centroids},Flat" 245
[docs] 246 def __init__( 247 self, 248 index_dir: Path, 249 index_config: "FaissIVFIndexConfig", 250 module: BiEncoderModule, 251 verbose: bool = False, 252 ) -> None: 253 """Initialize the FaissIVFIndexer. 254 255 Args: 256 index_dir (Path): Directory where the index will be stored. 257 index_config (FaissIVFIndexConfig): Configuration for the FAISS IVF index. 258 module (BiEncoderModule): The BiEncoderModule to use for indexing. 259 verbose (bool): Whether to enable verbose output. Defaults to False. 260 """ 261 # default faiss values 262 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43 263 max_points_per_centroid = 256 264 index_config.num_train_embeddings = ( 265 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid 266 ) 267 super().__init__(index_dir, index_config, module, verbose) 268 269 import faiss 270 271 ivf_index = faiss.extract_index_ivf(self.index) 272 if hasattr(ivf_index, "quantizer"): 273 quantizer = ivf_index.quantizer 274 if hasattr(faiss.downcast_index(quantizer), "hnsw"): 275 downcasted_quantizer = faiss.downcast_index(quantizer) 276 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
277
[docs] 278 def to_gpu(self) -> None: 279 """Move the FAISS IVF index to GPU.""" 280 import faiss 281 282 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu 283 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html 284 clustering_index = faiss.index_cpu_to_all_gpus( 285 faiss.IndexFlat(self.module.config.embedding_dim, self.metric_type) 286 ) 287 clustering_index.verbose = self.verbose 288 index_ivf = faiss.extract_index_ivf(self.index) 289 index_ivf.clustering_index = clustering_index
290
[docs] 291 def to_cpu(self) -> None: 292 """Move the FAISS IVF index to CPU.""" 293 import faiss 294 295 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") and hasattr(faiss, "index_cpu_to_gpu"): 296 self.index = faiss.index_gpu_to_cpu(self.index) 297 298 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e 299 # move index to cpu but leave quantizer on gpu 300 index_ivf = faiss.extract_index_ivf(self.index) 301 quantizer = index_ivf.quantizer 302 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer) 303 index_ivf.quantizer = gpu_quantizer
304
[docs] 305 def set_verbosity(self, verbose: bool | None = None) -> None: 306 """Set the verbosity of the FAISS IVF index. 307 308 Args: 309 verbose (bool | None): Whether to enable verbose output. Defaults to None. 310 """ 311 import faiss 312 313 verbose = verbose if verbose is not None else self.verbose 314 index = faiss.extract_index_ivf(self.index) 315 for elem in (index, index.quantizer): 316 setattr(elem, "verbose", verbose)
317 318
[docs] 319class FaissPQIndexer(_FaissTrainIndexer): 320 """FAISS PQ Indexer for approximate nearest neighbor search using Product Quantization (PQ).""" 321 322 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}" 323
[docs] 324 def __init__( 325 self, 326 index_dir: Path, 327 index_config: "FaissPQIndexConfig", 328 module: BiEncoderModule, 329 verbose: bool = False, 330 ) -> None: 331 """Initialize the FaissPQIndexer. 332 333 Args: 334 index_dir (Path): Directory where the index will be stored. 335 index_config (FaissPQIndexConfig): Configuration for the FAISS PQ index. 336 module (BiEncoderModule): The BiEncoderModule to use for indexing. 337 verbose (bool): Whether to enable verbose output. Defaults to False. 338 """ 339 super().__init__(index_dir, index_config, module, verbose) 340 self.index_config: FaissPQIndexConfig
341
[docs] 342 def to_gpu(self) -> None: 343 """Move the FAISS PQ index to GPU.""" 344 pass
345
[docs] 346 def to_cpu(self) -> None: 347 """Move the FAISS PQ index to CPU.""" 348 pass
349 350
[docs] 351class FaissIVFPQIndexer(FaissIVFIndexer): 352 """FAISS IVFPQ Indexer for approximate nearest neighbor search using Inverted File System (IVF) with Product 353 Quantization (PQ).""" 354 355 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}" 356
[docs] 357 def __init__( 358 self, 359 index_dir: Path, 360 index_config: "FaissIVFPQIndexConfig", 361 module: BiEncoderModule, 362 verbose: bool = False, 363 ) -> None: 364 """Initialize the FaissIVFPQIndexer. 365 366 Args: 367 index_dir (Path): Directory where the index will be stored. 368 index_config (FaissIVFPQIndexConfig): Configuration for the FAISS IVFPQ index. 369 module (BiEncoderModule): The BiEncoderModule to use for indexing. 370 verbose (bool): Whether to enable verbose output. Defaults to False. 371 """ 372 import faiss 373 374 super().__init__(index_dir, index_config, module, verbose) 375 self.index_config: FaissIVFPQIndexConfig 376 377 index_ivf = faiss.extract_index_ivf(self.index) 378 index_ivf.make_direct_map()
379
[docs] 380 def set_verbosity(self, verbose: bool | None = None) -> None: 381 """Set the verbosity of the FAISS IVFPQ index. 382 383 Args: 384 verbose (bool | None): Whether to enable verbose output. Defaults to None. 385 """ 386 super().set_verbosity(verbose) 387 import faiss 388 389 verbose = verbose if verbose is not None else self.verbose 390 index_ivf_pq = faiss.downcast_index(self.index.index) 391 for elem in ( 392 index_ivf_pq.pq, 393 index_ivf_pq.quantizer, 394 ): 395 setattr(elem, "verbose", verbose)
396 397
[docs] 398class FaissIndexConfig(IndexConfig): 399 """Configuration class for FAISS indexers in the Lightning IR framework.""" 400 401 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type} 402 indexer_class: Type[Indexer] = FaissIndexer
403 404
[docs] 405class FaissFlatIndexConfig(FaissIndexConfig): 406 """Configuration class for FAISS flat indexers in the Lightning IR framework.""" 407 408 indexer_class = FaissFlatIndexer
409 410 411class _FaissTrainIndexConfig(FaissIndexConfig): 412 """Base configuration class for FAISS indexers that require training on embeddings before indexing.""" 413 414 indexer_class = _FaissTrainIndexer 415 416 def __init__(self, num_train_embeddings: int | None = None) -> None: 417 """Initialize the _FaissTrainIndexConfig. 418 419 Args: 420 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 421 be set later. Defaults to None. 422 """ 423 super().__init__() 424 self.num_train_embeddings = num_train_embeddings 425 426
[docs] 427class FaissIVFIndexConfig(_FaissTrainIndexConfig): 428 """Configuration class for FAISS IVF indexers in the Lightning IR framework.""" 429 430 indexer_class = FaissIVFIndexer 431
[docs] 432 def __init__( 433 self, 434 num_train_embeddings: int | None = None, 435 num_centroids: int = 262144, 436 ef_construction: int = 40, 437 ) -> None: 438 """Initialize the FaissIVFIndexConfig. 439 440 Args: 441 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will be 442 set later. Defaults to None. 443 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144. 444 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40. 445 """ 446 super().__init__(num_train_embeddings) 447 self.num_centroids = num_centroids 448 self.ef_construction = ef_construction
449 450
[docs] 451class FaissPQIndexConfig(_FaissTrainIndexConfig): 452 """Configuration class for FAISS PQ indexers in the Lightning IR framework.""" 453 454 indexer_class = FaissPQIndexer 455
[docs] 456 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None: 457 """Initialize the FaissPQIndexConfig. 458 459 Args: 460 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 461 be set later. Defaults to None. 462 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16. 463 n_bits (int): Number of bits for the PQ index. Defaults to 8. 464 """ 465 super().__init__(num_train_embeddings) 466 self.num_subquantizers = num_subquantizers 467 self.n_bits = n_bits
468 469
[docs] 470class FaissIVFPQIndexConfig(FaissIVFIndexConfig): 471 """Configuration class for FAISS IVFPQ indexers in the Lightning IR framework.""" 472 473 indexer_class = FaissIVFPQIndexer 474
[docs] 475 def __init__( 476 self, 477 num_train_embeddings: int | None = None, 478 num_centroids: int = 262144, 479 ef_construction: int = 40, 480 num_subquantizers: int = 16, 481 n_bits: int = 8, 482 ) -> None: 483 """Initialize the FaissIVFPQIndexConfig. 484 485 Args: 486 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will 487 be set later. Defaults to None. 488 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144. 489 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40. 490 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16. 491 n_bits (int): Number of bits for the PQ index. Defaults to 8. 492 """ 493 super().__init__(num_train_embeddings, num_centroids, ef_construction) 494 self.num_subquantizers = num_subquantizers 495 self.n_bits = n_bits