1"""FAISS Indexer for Lightning IR Framework"""
2
3import warnings
4from pathlib import Path
5from typing import Type
6
7import torch
8
9from ...bi_encoder import BiEncoderModule, BiEncoderOutput
10from ...data import IndexBatch
11from ...models import ColConfig, DprConfig
12from ..base import IndexConfig, Indexer
13
14
[docs]
15class FaissIndexer(Indexer):
16 """Base class for FAISS indexers in the Lightning IR framework."""
17
18 INDEX_FACTORY: str
19
[docs]
20 def __init__(
21 self,
22 index_dir: Path,
23 index_config: "FaissIndexConfig",
24 module: BiEncoderModule,
25 verbose: bool = False,
26 ) -> None:
27 """Initialize the FaissIndexer.
28
29 Args:
30 index_dir (Path): Directory where the index will be stored.
31 index_config (FaissIndexConfig): Configuration for the FAISS index.
32 module (BiEncoderModule): The BiEncoderModule to use for indexing.
33 verbose (bool): Whether to enable verbose output. Defaults to False.
34 Raises:
35 ValueError: If the similarity function is not supported.
36 """
37 super().__init__(index_dir, index_config, module, verbose)
38 import faiss
39
40 similarity_function = self.module.config.similarity_function
41 if similarity_function in ("cosine", "dot"):
42 self.metric_type = faiss.METRIC_INNER_PRODUCT
43 else:
44 raise ValueError(f"similarity_function {similarity_function} unknown")
45
46 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict())
47 if similarity_function == "cosine":
48 index_factory = "L2norm," + index_factory
49 self.index = faiss.index_factory(self.module.config.embedding_dim, index_factory, self.metric_type)
50
51 self.set_verbosity()
52
53 if torch.cuda.is_available():
54 self.to_gpu()
55
[docs]
56 def to_gpu(self) -> None:
57 """Move the FAISS index to GPU."""
58 pass
59
[docs]
60 def to_cpu(self) -> None:
61 """Move the FAISS index to CPU."""
62 pass
63
[docs]
64 def set_verbosity(self, verbose: bool | None = None) -> None:
65 """Set the verbosity of the FAISS index.
66
67 Args:
68 verbose (bool | None): Whether to enable verbose output. If None, uses the index's current verbosity
69 setting. Defaults to None.
70 """
71 self.index.verbose = self.verbose if verbose is None else verbose
72
[docs]
73 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
74 """Process embeddings before adding them to the FAISS index.
75
76 Args:
77 embeddings (torch.Tensor): The embeddings to process.
78 Returns:
79 torch.Tensor: The processed embeddings.
80 """
81 return embeddings
82
[docs]
83 def save(self) -> None:
84 """Save the FAISS index to disk.
85
86 Raises:
87 ValueError: If the number of embeddings does not match the index's total number of entries.
88 """
89 super().save()
90 import faiss
91
92 if self.num_embeddings != self.index.ntotal:
93 raise ValueError("number of embeddings does not match index.ntotal")
94 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"):
95 self.index = faiss.index_gpu_to_cpu(self.index)
96
97 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
98
[docs]
99 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
100 """Add embeddings to the FAISS index.
101
102 Args:
103 index_batch (IndexBatch): The batch containing document indices and embeddings.
104 output (BiEncoderOutput): The output from the bi-encoder module containing document embeddings.
105 Raises:
106 ValueError: If the document embeddings are not present in the output.
107 """
108 doc_embeddings = output.doc_embeddings
109 if doc_embeddings is None:
110 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
111 if doc_embeddings.scoring_mask is None:
112 doc_lengths = torch.ones(
113 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
114 )
115 embeddings = doc_embeddings.embeddings[:, 0]
116 else:
117 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
118 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
119 doc_ids = index_batch.doc_ids
120 embeddings = self.process_embeddings(embeddings)
121
122 if embeddings.shape[0]:
123 self.index.add(embeddings.float().cpu())
124
125 self.num_embeddings += embeddings.shape[0]
126 self.num_docs += len(doc_ids)
127
128 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
129 self.doc_ids.extend(doc_ids)
130
131
[docs]
132class FaissFlatIndexer(FaissIndexer):
133 """FAISS Flat Indexer for exact nearest neighbor search using FAISS."""
134
135 INDEX_FACTORY = "Flat"
136
[docs]
137 def __init__(
138 self,
139 index_dir: Path,
140 index_config: "FaissFlatIndexConfig",
141 module: BiEncoderModule,
142 verbose: bool = False,
143 ) -> None:
144 """Initialize the FaissFlatIndexer.
145
146 Args:
147 index_dir (Path): Directory where the index will be stored.
148 index_config (FaissFlatIndexConfig): Configuration for the FAISS flat index.
149 module (BiEncoderModule): The BiEncoderModule to use for indexing.
150 verbose (bool): Whether to enable verbose output. Defaults to False.
151 """
152 super().__init__(index_dir, index_config, module, verbose)
153 self.index_config: FaissFlatIndexConfig
154
[docs]
155 def to_gpu(self) -> None:
156 """Move the FAISS flat index to GPU."""
157 pass
158
[docs]
159 def to_cpu(self) -> None:
160 """Move the FAISS flat index to CPU."""
161 pass
162
163
164class _FaissTrainIndexer(FaissIndexer):
165 """Base class for FAISS indexers that require training on embeddings before indexing."""
166
167 INDEX_FACTORY = "" # class only acts as mixin
168
169 def __init__(
170 self,
171 index_dir: Path,
172 index_config: "_FaissTrainIndexConfig",
173 module: BiEncoderModule,
174 verbose: bool = False,
175 ) -> None:
176 """Initialize the _FaissTrainIndexer.
177
178 Args:
179 index_dir (Path): Directory where the index will be stored.
180 index_config (_FaissTrainIndexConfig): Configuration for the FAISS index that requires training.
181 module (BiEncoderModule): The BiEncoderModule to use for indexing.
182 verbose (bool): Whether to enable verbose output. Defaults to False.
183 Raises:
184 ValueError: If num_train_embeddings is not set in the index configuration.
185 """
186 super().__init__(index_dir, index_config, module, verbose)
187 if index_config.num_train_embeddings is None:
188 raise ValueError("num_train_embeddings must be set")
189 self.num_train_embeddings = index_config.num_train_embeddings
190
191 self._train_embeddings: torch.Tensor | None = torch.full(
192 (self.num_train_embeddings, self.module.config.embedding_dim), torch.nan, dtype=torch.float32
193 )
194
195 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
196 """Process embeddings before adding them to the FAISS index.
197
198 Args:
199 embeddings (torch.Tensor): The embeddings to process.
200 Returns:
201 torch.Tensor: The processed embeddings.
202 """
203 embeddings = self._grab_train_embeddings(embeddings)
204 self._train()
205 return embeddings
206
207 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
208 if self._train_embeddings is not None:
209 # save training embeddings until num_train_embeddings is reached
210 # if num_train_embeddings overflows, save the remaining embeddings
211 start = self.num_embeddings
212 end = start + embeddings.shape[0]
213 end = min(self.num_train_embeddings, start + embeddings.shape[0])
214 length = end - start
215 self._train_embeddings[start:end] = embeddings[:length]
216 self.num_embeddings += length
217 embeddings = embeddings[length:]
218 return embeddings
219
220 def _train(self, force: bool = False):
221 if self._train_embeddings is None:
222 return
223 if not force and self.num_embeddings < self.num_train_embeddings:
224 return
225 if torch.isnan(self._train_embeddings).any():
226 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.")
227 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
228 self.index.train(self._train_embeddings)
229 if torch.cuda.is_available():
230 self.to_cpu()
231 self.index.add(self._train_embeddings)
232 self._train_embeddings = None
233 self.set_verbosity(False)
234
235 def save(self) -> None:
236 if not self.index.is_trained:
237 self._train(force=True)
238 return super().save()
239
240
[docs]
241class FaissIVFIndexer(_FaissTrainIndexer):
242 """FAISS IVF Indexer for approximate nearest neighbor search using FAISS with Inverted File System (IVF)."""
243
244 INDEX_FACTORY = "IVF{num_centroids},Flat"
245
[docs]
246 def __init__(
247 self,
248 index_dir: Path,
249 index_config: "FaissIVFIndexConfig",
250 module: BiEncoderModule,
251 verbose: bool = False,
252 ) -> None:
253 """Initialize the FaissIVFIndexer.
254
255 Args:
256 index_dir (Path): Directory where the index will be stored.
257 index_config (FaissIVFIndexConfig): Configuration for the FAISS IVF index.
258 module (BiEncoderModule): The BiEncoderModule to use for indexing.
259 verbose (bool): Whether to enable verbose output. Defaults to False.
260 """
261 # default faiss values
262 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43
263 max_points_per_centroid = 256
264 index_config.num_train_embeddings = (
265 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid
266 )
267 super().__init__(index_dir, index_config, module, verbose)
268
269 import faiss
270
271 ivf_index = faiss.extract_index_ivf(self.index)
272 if hasattr(ivf_index, "quantizer"):
273 quantizer = ivf_index.quantizer
274 if hasattr(faiss.downcast_index(quantizer), "hnsw"):
275 downcasted_quantizer = faiss.downcast_index(quantizer)
276 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
277
[docs]
278 def to_gpu(self) -> None:
279 """Move the FAISS IVF index to GPU."""
280 import faiss
281
282 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu
283 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html
284 clustering_index = faiss.index_cpu_to_all_gpus(
285 faiss.IndexFlat(self.module.config.embedding_dim, self.metric_type)
286 )
287 clustering_index.verbose = self.verbose
288 index_ivf = faiss.extract_index_ivf(self.index)
289 index_ivf.clustering_index = clustering_index
290
[docs]
291 def to_cpu(self) -> None:
292 """Move the FAISS IVF index to CPU."""
293 import faiss
294
295 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") and hasattr(faiss, "index_cpu_to_gpu"):
296 self.index = faiss.index_gpu_to_cpu(self.index)
297
298 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e
299 # move index to cpu but leave quantizer on gpu
300 index_ivf = faiss.extract_index_ivf(self.index)
301 quantizer = index_ivf.quantizer
302 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer)
303 index_ivf.quantizer = gpu_quantizer
304
[docs]
305 def set_verbosity(self, verbose: bool | None = None) -> None:
306 """Set the verbosity of the FAISS IVF index.
307
308 Args:
309 verbose (bool | None): Whether to enable verbose output. Defaults to None.
310 """
311 import faiss
312
313 verbose = verbose if verbose is not None else self.verbose
314 index = faiss.extract_index_ivf(self.index)
315 for elem in (index, index.quantizer):
316 setattr(elem, "verbose", verbose)
317
318
[docs]
319class FaissPQIndexer(_FaissTrainIndexer):
320 """FAISS PQ Indexer for approximate nearest neighbor search using Product Quantization (PQ)."""
321
322 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}"
323
[docs]
324 def __init__(
325 self,
326 index_dir: Path,
327 index_config: "FaissPQIndexConfig",
328 module: BiEncoderModule,
329 verbose: bool = False,
330 ) -> None:
331 """Initialize the FaissPQIndexer.
332
333 Args:
334 index_dir (Path): Directory where the index will be stored.
335 index_config (FaissPQIndexConfig): Configuration for the FAISS PQ index.
336 module (BiEncoderModule): The BiEncoderModule to use for indexing.
337 verbose (bool): Whether to enable verbose output. Defaults to False.
338 """
339 super().__init__(index_dir, index_config, module, verbose)
340 self.index_config: FaissPQIndexConfig
341
[docs]
342 def to_gpu(self) -> None:
343 """Move the FAISS PQ index to GPU."""
344 pass
345
[docs]
346 def to_cpu(self) -> None:
347 """Move the FAISS PQ index to CPU."""
348 pass
349
350
[docs]
351class FaissIVFPQIndexer(FaissIVFIndexer):
352 """FAISS IVFPQ Indexer for approximate nearest neighbor search using Inverted File System (IVF) with Product
353 Quantization (PQ)."""
354
355 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}"
356
[docs]
357 def __init__(
358 self,
359 index_dir: Path,
360 index_config: "FaissIVFPQIndexConfig",
361 module: BiEncoderModule,
362 verbose: bool = False,
363 ) -> None:
364 """Initialize the FaissIVFPQIndexer.
365
366 Args:
367 index_dir (Path): Directory where the index will be stored.
368 index_config (FaissIVFPQIndexConfig): Configuration for the FAISS IVFPQ index.
369 module (BiEncoderModule): The BiEncoderModule to use for indexing.
370 verbose (bool): Whether to enable verbose output. Defaults to False.
371 """
372 import faiss
373
374 super().__init__(index_dir, index_config, module, verbose)
375 self.index_config: FaissIVFPQIndexConfig
376
377 index_ivf = faiss.extract_index_ivf(self.index)
378 index_ivf.make_direct_map()
379
[docs]
380 def set_verbosity(self, verbose: bool | None = None) -> None:
381 """Set the verbosity of the FAISS IVFPQ index.
382
383 Args:
384 verbose (bool | None): Whether to enable verbose output. Defaults to None.
385 """
386 super().set_verbosity(verbose)
387 import faiss
388
389 verbose = verbose if verbose is not None else self.verbose
390 index_ivf_pq = faiss.downcast_index(self.index.index)
391 for elem in (
392 index_ivf_pq.pq,
393 index_ivf_pq.quantizer,
394 ):
395 setattr(elem, "verbose", verbose)
396
397
[docs]
398class FaissIndexConfig(IndexConfig):
399 """Configuration class for FAISS indexers in the Lightning IR framework."""
400
401 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}
402 indexer_class: Type[Indexer] = FaissIndexer
403
404
[docs]
405class FaissFlatIndexConfig(FaissIndexConfig):
406 """Configuration class for FAISS flat indexers in the Lightning IR framework."""
407
408 indexer_class = FaissFlatIndexer
409
410
411class _FaissTrainIndexConfig(FaissIndexConfig):
412 """Base configuration class for FAISS indexers that require training on embeddings before indexing."""
413
414 indexer_class = _FaissTrainIndexer
415
416 def __init__(self, num_train_embeddings: int | None = None) -> None:
417 """Initialize the _FaissTrainIndexConfig.
418
419 Args:
420 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
421 be set later. Defaults to None.
422 """
423 super().__init__()
424 self.num_train_embeddings = num_train_embeddings
425
426
[docs]
427class FaissIVFIndexConfig(_FaissTrainIndexConfig):
428 """Configuration class for FAISS IVF indexers in the Lightning IR framework."""
429
430 indexer_class = FaissIVFIndexer
431
[docs]
432 def __init__(
433 self,
434 num_train_embeddings: int | None = None,
435 num_centroids: int = 262144,
436 ef_construction: int = 40,
437 ) -> None:
438 """Initialize the FaissIVFIndexConfig.
439
440 Args:
441 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will be
442 set later. Defaults to None.
443 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144.
444 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40.
445 """
446 super().__init__(num_train_embeddings)
447 self.num_centroids = num_centroids
448 self.ef_construction = ef_construction
449
450
[docs]
451class FaissPQIndexConfig(_FaissTrainIndexConfig):
452 """Configuration class for FAISS PQ indexers in the Lightning IR framework."""
453
454 indexer_class = FaissPQIndexer
455
[docs]
456 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None:
457 """Initialize the FaissPQIndexConfig.
458
459 Args:
460 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
461 be set later. Defaults to None.
462 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16.
463 n_bits (int): Number of bits for the PQ index. Defaults to 8.
464 """
465 super().__init__(num_train_embeddings)
466 self.num_subquantizers = num_subquantizers
467 self.n_bits = n_bits
468
469
[docs]
470class FaissIVFPQIndexConfig(FaissIVFIndexConfig):
471 """Configuration class for FAISS IVFPQ indexers in the Lightning IR framework."""
472
473 indexer_class = FaissIVFPQIndexer
474
[docs]
475 def __init__(
476 self,
477 num_train_embeddings: int | None = None,
478 num_centroids: int = 262144,
479 ef_construction: int = 40,
480 num_subquantizers: int = 16,
481 n_bits: int = 8,
482 ) -> None:
483 """Initialize the FaissIVFPQIndexConfig.
484
485 Args:
486 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
487 be set later. Defaults to None.
488 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144.
489 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40.
490 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16.
491 n_bits (int): Number of bits for the PQ index. Defaults to 8.
492 """
493 super().__init__(num_train_embeddings, num_centroids, ef_construction)
494 self.num_subquantizers = num_subquantizers
495 self.n_bits = n_bits