1"""FAISS Indexer for Lightning IR Framework"""
2
3import warnings
4from pathlib import Path
5
6import torch
7
8from ...bi_encoder import BiEncoderModule, BiEncoderOutput
9from ...data import IndexBatch
10from ...models import ColConfig, DprConfig
11from ..base import IndexConfig, Indexer
12
13
[docs]
14class FaissIndexer(Indexer):
15 """Base class for FAISS indexers in the Lightning IR framework."""
16
17 INDEX_FACTORY: str
18
[docs]
19 def __init__(
20 self,
21 index_dir: Path,
22 index_config: "FaissIndexConfig",
23 module: BiEncoderModule,
24 verbose: bool = False,
25 ) -> None:
26 """Initialize the FaissIndexer.
27
28 Args:
29 index_dir (Path): Directory where the index will be stored.
30 index_config (FaissIndexConfig): Configuration for the FAISS index.
31 module (BiEncoderModule): The BiEncoderModule to use for indexing.
32 verbose (bool): Whether to enable verbose output. Defaults to False.
33 Raises:
34 ValueError: If the similarity function is not supported.
35 """
36 super().__init__(index_dir, index_config, module, verbose)
37 import faiss
38
39 similarity_function = self.module.config.similarity_function
40 if similarity_function in ("cosine", "dot"):
41 self.metric_type = faiss.METRIC_INNER_PRODUCT
42 else:
43 raise ValueError(f"similarity_function {similarity_function} unknown")
44
45 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict())
46 if similarity_function == "cosine":
47 index_factory = "L2norm," + index_factory
48 self.index = faiss.index_factory(self.module.config.embedding_dim, index_factory, self.metric_type)
49
50 self.set_verbosity()
51
52 if torch.cuda.is_available():
53 self.to_gpu()
54
[docs]
55 def to_gpu(self) -> None:
56 """Move the FAISS index to GPU."""
57 pass
58
[docs]
59 def to_cpu(self) -> None:
60 """Move the FAISS index to CPU."""
61 pass
62
[docs]
63 def set_verbosity(self, verbose: bool | None = None) -> None:
64 """set the verbosity of the FAISS index.
65
66 Args:
67 verbose (bool | None): Whether to enable verbose output. If None, uses the index's current verbosity
68 setting. Defaults to None.
69 """
70 self.index.verbose = self.verbose if verbose is None else verbose
71
[docs]
72 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
73 """Process embeddings before adding them to the FAISS index.
74
75 Args:
76 embeddings (torch.Tensor): The embeddings to process.
77 Returns:
78 torch.Tensor: The processed embeddings.
79 """
80 return embeddings
81
[docs]
82 def save(self) -> None:
83 """Save the FAISS index to disk.
84
85 Raises:
86 ValueError: If the number of embeddings does not match the index's total number of entries.
87 """
88 super().save()
89 import faiss
90
91 if self.num_embeddings != self.index.ntotal:
92 raise ValueError("number of embeddings does not match index.ntotal")
93 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"):
94 self.index = faiss.index_gpu_to_cpu(self.index)
95
96 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
97
[docs]
98 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
99 """Add embeddings to the FAISS index.
100
101 Args:
102 index_batch (IndexBatch): The batch containing document indices and embeddings.
103 output (BiEncoderOutput): The output from the bi-encoder module containing document embeddings.
104 Raises:
105 ValueError: If the document embeddings are not present in the output.
106 """
107 doc_embeddings = output.doc_embeddings
108 if doc_embeddings is None:
109 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
110 if doc_embeddings.scoring_mask is None:
111 doc_lengths = torch.ones(
112 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
113 )
114 embeddings = doc_embeddings.embeddings[:, 0]
115 else:
116 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
117 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
118 doc_ids = index_batch.doc_ids
119 embeddings = self.process_embeddings(embeddings)
120
121 if embeddings.shape[0]:
122 self.index.add(embeddings.float().cpu())
123
124 self.num_embeddings += embeddings.shape[0]
125 self.num_docs += len(doc_ids)
126
127 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
128 self.doc_ids.extend(doc_ids)
129
130
[docs]
131class FaissFlatIndexer(FaissIndexer):
132 """FAISS Flat Indexer for exact nearest neighbor search using FAISS."""
133
134 INDEX_FACTORY = "Flat"
135
[docs]
136 def __init__(
137 self,
138 index_dir: Path,
139 index_config: "FaissFlatIndexConfig",
140 module: BiEncoderModule,
141 verbose: bool = False,
142 ) -> None:
143 """Initialize the FaissFlatIndexer.
144
145 Args:
146 index_dir (Path): Directory where the index will be stored.
147 index_config (FaissFlatIndexConfig): Configuration for the FAISS flat index.
148 module (BiEncoderModule): The BiEncoderModule to use for indexing.
149 verbose (bool): Whether to enable verbose output. Defaults to False.
150 """
151 super().__init__(index_dir, index_config, module, verbose)
152 self.index_config: FaissFlatIndexConfig
153
[docs]
154 def to_gpu(self) -> None:
155 """Move the FAISS flat index to GPU."""
156 pass
157
[docs]
158 def to_cpu(self) -> None:
159 """Move the FAISS flat index to CPU."""
160 pass
161
162
163class _FaissTrainIndexer(FaissIndexer):
164 """Base class for FAISS indexers that require training on embeddings before indexing."""
165
166 INDEX_FACTORY = "" # class only acts as mixin
167
168 def __init__(
169 self,
170 index_dir: Path,
171 index_config: "_FaissTrainIndexConfig",
172 module: BiEncoderModule,
173 verbose: bool = False,
174 ) -> None:
175 """Initialize the _FaissTrainIndexer.
176
177 Args:
178 index_dir (Path): Directory where the index will be stored.
179 index_config (_FaissTrainIndexConfig): Configuration for the FAISS index that requires training.
180 module (BiEncoderModule): The BiEncoderModule to use for indexing.
181 verbose (bool): Whether to enable verbose output. Defaults to False.
182 Raises:
183 ValueError: If num_train_embeddings is not set in the index configuration.
184 """
185 super().__init__(index_dir, index_config, module, verbose)
186 if index_config.num_train_embeddings is None:
187 raise ValueError("num_train_embeddings must be set")
188 self.num_train_embeddings = index_config.num_train_embeddings
189
190 self._train_embeddings: torch.Tensor | None = torch.full(
191 (self.num_train_embeddings, self.module.config.embedding_dim), torch.nan, dtype=torch.float32
192 )
193
194 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
195 """Process embeddings before adding them to the FAISS index.
196
197 Args:
198 embeddings (torch.Tensor): The embeddings to process.
199 Returns:
200 torch.Tensor: The processed embeddings.
201 """
202 embeddings = self._grab_train_embeddings(embeddings)
203 self._train()
204 return embeddings
205
206 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
207 if self._train_embeddings is not None:
208 # save training embeddings until num_train_embeddings is reached
209 # if num_train_embeddings overflows, save the remaining embeddings
210 start = self.num_embeddings
211 end = start + embeddings.shape[0]
212 end = min(self.num_train_embeddings, start + embeddings.shape[0])
213 length = end - start
214 self._train_embeddings[start:end] = embeddings[:length]
215 self.num_embeddings += length
216 embeddings = embeddings[length:]
217 return embeddings
218
219 def _train(self, force: bool = False):
220 if self._train_embeddings is None:
221 return
222 if not force and self.num_embeddings < self.num_train_embeddings:
223 return
224 if torch.isnan(self._train_embeddings).any():
225 warnings.warn(
226 "Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.",
227 stacklevel=2,
228 )
229 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
230 self.index.train(self._train_embeddings)
231 if torch.cuda.is_available():
232 self.to_cpu()
233 self.index.add(self._train_embeddings)
234 self._train_embeddings = None
235 self.set_verbosity(False)
236
237 def save(self) -> None:
238 if not self.index.is_trained:
239 self._train(force=True)
240 return super().save()
241
242
[docs]
243class FaissIVFIndexer(_FaissTrainIndexer):
244 """FAISS IVF Indexer for approximate nearest neighbor search using FAISS with Inverted File System (IVF)."""
245
246 INDEX_FACTORY = "IVF{num_centroids},Flat"
247
[docs]
248 def __init__(
249 self,
250 index_dir: Path,
251 index_config: "FaissIVFIndexConfig",
252 module: BiEncoderModule,
253 verbose: bool = False,
254 ) -> None:
255 """Initialize the FaissIVFIndexer.
256
257 Args:
258 index_dir (Path): Directory where the index will be stored.
259 index_config (FaissIVFIndexConfig): Configuration for the FAISS IVF index.
260 module (BiEncoderModule): The BiEncoderModule to use for indexing.
261 verbose (bool): Whether to enable verbose output. Defaults to False.
262 """
263 # default faiss values
264 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43
265 max_points_per_centroid = 256
266 index_config.num_train_embeddings = (
267 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid
268 )
269 super().__init__(index_dir, index_config, module, verbose)
270
271 import faiss
272
273 ivf_index = faiss.extract_index_ivf(self.index)
274 if hasattr(ivf_index, "quantizer"):
275 quantizer = ivf_index.quantizer
276 if hasattr(faiss.downcast_index(quantizer), "hnsw"):
277 downcasted_quantizer = faiss.downcast_index(quantizer)
278 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
279
[docs]
280 def to_gpu(self) -> None:
281 """Move the FAISS IVF index to GPU."""
282 import faiss
283
284 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu
285 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html
286 if faiss.get_num_gpus() == 0:
287 return
288 clustering_index = faiss.index_cpu_to_all_gpus(
289 faiss.IndexFlat(self.module.config.embedding_dim, self.metric_type)
290 )
291 clustering_index.verbose = self.verbose
292 index_ivf = faiss.extract_index_ivf(self.index)
293 index_ivf.clustering_index = clustering_index
294
[docs]
295 def to_cpu(self) -> None:
296 """Move the FAISS IVF index to CPU."""
297 import faiss
298
299 if faiss.get_num_gpus() == 0:
300 return
301 self.index = faiss.index_gpu_to_cpu(self.index)
302
303 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e
304 # move index to cpu but leave quantizer on gpu
305 index_ivf = faiss.extract_index_ivf(self.index)
306 quantizer = index_ivf.quantizer
307 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer)
308 index_ivf.quantizer = gpu_quantizer
309
[docs]
310 def set_verbosity(self, verbose: bool | None = None) -> None:
311 """set the verbosity of the FAISS IVF index.
312
313 Args:
314 verbose (bool | None): Whether to enable verbose output. Defaults to None.
315 """
316 import faiss
317
318 verbose = verbose if verbose is not None else self.verbose
319 index = faiss.extract_index_ivf(self.index)
320 for elem in (index, index.quantizer):
321 elem.verbose = verbose
322
323
[docs]
324class FaissPQIndexer(_FaissTrainIndexer):
325 """FAISS PQ Indexer for approximate nearest neighbor search using Product Quantization (PQ)."""
326
327 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}"
328
[docs]
329 def __init__(
330 self,
331 index_dir: Path,
332 index_config: "FaissPQIndexConfig",
333 module: BiEncoderModule,
334 verbose: bool = False,
335 ) -> None:
336 """Initialize the FaissPQIndexer.
337
338 Args:
339 index_dir (Path): Directory where the index will be stored.
340 index_config (FaissPQIndexConfig): Configuration for the FAISS PQ index.
341 module (BiEncoderModule): The BiEncoderModule to use for indexing.
342 verbose (bool): Whether to enable verbose output. Defaults to False.
343 """
344 super().__init__(index_dir, index_config, module, verbose)
345 self.index_config: FaissPQIndexConfig
346
[docs]
347 def to_gpu(self) -> None:
348 """Move the FAISS PQ index to GPU."""
349 pass
350
[docs]
351 def to_cpu(self) -> None:
352 """Move the FAISS PQ index to CPU."""
353 pass
354
355
[docs]
356class FaissIVFPQIndexer(FaissIVFIndexer):
357 """FAISS IVFPQ Indexer for approximate nearest neighbor search using Inverted File System (IVF) with Product
358 Quantization (PQ)."""
359
360 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}"
361
[docs]
362 def __init__(
363 self,
364 index_dir: Path,
365 index_config: "FaissIVFPQIndexConfig",
366 module: BiEncoderModule,
367 verbose: bool = False,
368 ) -> None:
369 """Initialize the FaissIVFPQIndexer.
370
371 Args:
372 index_dir (Path): Directory where the index will be stored.
373 index_config (FaissIVFPQIndexConfig): Configuration for the FAISS IVFPQ index.
374 module (BiEncoderModule): The BiEncoderModule to use for indexing.
375 verbose (bool): Whether to enable verbose output. Defaults to False.
376 """
377 import faiss
378
379 super().__init__(index_dir, index_config, module, verbose)
380 self.index_config: FaissIVFPQIndexConfig
381
382 index_ivf = faiss.extract_index_ivf(self.index)
383 index_ivf.make_direct_map()
384
[docs]
385 def set_verbosity(self, verbose: bool | None = None) -> None:
386 """set the verbosity of the FAISS IVFPQ index.
387
388 Args:
389 verbose (bool | None): Whether to enable verbose output. Defaults to None.
390 """
391 super().set_verbosity(verbose)
392 import faiss
393
394 verbose = verbose if verbose is not None else self.verbose
395 index_ivf_pq = faiss.downcast_index(self.index.index)
396 for elem in (
397 index_ivf_pq.pq,
398 index_ivf_pq.quantizer,
399 ):
400 elem.verbose = verbose
401
402
[docs]
403class FaissIndexConfig(IndexConfig):
404 """Configuration class for FAISS indexers in the Lightning IR framework."""
405
406 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}
407 indexer_class: type[Indexer] = FaissIndexer
408
409
[docs]
410class FaissFlatIndexConfig(FaissIndexConfig):
411 """Configuration class for FAISS flat indexers in the Lightning IR framework."""
412
413 indexer_class = FaissFlatIndexer
414
415
416class _FaissTrainIndexConfig(FaissIndexConfig):
417 """Base configuration class for FAISS indexers that require training on embeddings before indexing."""
418
419 indexer_class = _FaissTrainIndexer
420
421 def __init__(self, num_train_embeddings: int | None = None) -> None:
422 """Initialize the _FaissTrainIndexConfig.
423
424 Args:
425 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
426 be set later. Defaults to None.
427 """
428 super().__init__()
429 self.num_train_embeddings = num_train_embeddings
430
431
[docs]
432class FaissIVFIndexConfig(_FaissTrainIndexConfig):
433 """Configuration class for FAISS IVF indexers in the Lightning IR framework."""
434
435 indexer_class = FaissIVFIndexer
436
[docs]
437 def __init__(
438 self,
439 num_train_embeddings: int | None = None,
440 num_centroids: int = 262144,
441 ef_construction: int = 40,
442 ) -> None:
443 """Initialize the FaissIVFIndexConfig.
444
445 Args:
446 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will be
447 set later. Defaults to None.
448 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144.
449 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40.
450 """
451 super().__init__(num_train_embeddings)
452 self.num_centroids = num_centroids
453 self.ef_construction = ef_construction
454
455
[docs]
456class FaissPQIndexConfig(_FaissTrainIndexConfig):
457 """Configuration class for FAISS PQ indexers in the Lightning IR framework."""
458
459 indexer_class = FaissPQIndexer
460
[docs]
461 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None:
462 """Initialize the FaissPQIndexConfig.
463
464 Args:
465 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
466 be set later. Defaults to None.
467 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16.
468 n_bits (int): Number of bits for the PQ index. Defaults to 8.
469 """
470 super().__init__(num_train_embeddings)
471 self.num_subquantizers = num_subquantizers
472 self.n_bits = n_bits
473
474
[docs]
475class FaissIVFPQIndexConfig(FaissIVFIndexConfig):
476 """Configuration class for FAISS IVFPQ indexers in the Lightning IR framework."""
477
478 indexer_class = FaissIVFPQIndexer
479
[docs]
480 def __init__(
481 self,
482 num_train_embeddings: int | None = None,
483 num_centroids: int = 262144,
484 ef_construction: int = 40,
485 num_subquantizers: int = 16,
486 n_bits: int = 8,
487 ) -> None:
488 """Initialize the FaissIVFPQIndexConfig.
489
490 Args:
491 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
492 be set later. Defaults to None.
493 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144.
494 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40.
495 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16.
496 n_bits (int): Number of bits for the PQ index. Defaults to 8.
497 """
498 super().__init__(num_train_embeddings, num_centroids, ef_construction)
499 self.num_subquantizers = num_subquantizers
500 self.n_bits = n_bits