1"""FAISS Indexer for Lightning IR Framework"""
2
3import warnings
4from pathlib import Path
5from typing import Type
6
7import torch
8
9from ...bi_encoder import BiEncoderModule, BiEncoderOutput
10from ...data import IndexBatch
11from ...models import ColConfig, DprConfig
12from ..base import IndexConfig, Indexer
13
14
[docs]
15class FaissIndexer(Indexer):
16 """Base class for FAISS indexers in the Lightning IR framework."""
17
18 INDEX_FACTORY: str
19
[docs]
20 def __init__(
21 self,
22 index_dir: Path,
23 index_config: "FaissIndexConfig",
24 module: BiEncoderModule,
25 verbose: bool = False,
26 ) -> None:
27 """Initialize the FaissIndexer.
28
29 Args:
30 index_dir (Path): Directory where the index will be stored.
31 index_config (FaissIndexConfig): Configuration for the FAISS index.
32 module (BiEncoderModule): The BiEncoderModule to use for indexing.
33 verbose (bool): Whether to enable verbose output. Defaults to False.
34 Raises:
35 ValueError: If the similarity function is not supported.
36 """
37 super().__init__(index_dir, index_config, module, verbose)
38 import faiss
39
40 similarity_function = self.module.config.similarity_function
41 if similarity_function in ("cosine", "dot"):
42 self.metric_type = faiss.METRIC_INNER_PRODUCT
43 else:
44 raise ValueError(f"similarity_function {similarity_function} unknown")
45
46 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict())
47 if similarity_function == "cosine":
48 index_factory = "L2norm," + index_factory
49 self.index = faiss.index_factory(self.module.config.embedding_dim, index_factory, self.metric_type)
50
51 self.set_verbosity()
52
53 if torch.cuda.is_available():
54 self.to_gpu()
55
[docs]
56 def to_gpu(self) -> None:
57 """Move the FAISS index to GPU."""
58 pass
59
[docs]
60 def to_cpu(self) -> None:
61 """Move the FAISS index to CPU."""
62 pass
63
[docs]
64 def set_verbosity(self, verbose: bool | None = None) -> None:
65 """Set the verbosity of the FAISS index.
66
67 Args:
68 verbose (bool | None): Whether to enable verbose output. If None, uses the index's current verbosity
69 setting. Defaults to None.
70 """
71 self.index.verbose = self.verbose if verbose is None else verbose
72
[docs]
73 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
74 """Process embeddings before adding them to the FAISS index.
75
76 Args:
77 embeddings (torch.Tensor): The embeddings to process.
78 Returns:
79 torch.Tensor: The processed embeddings.
80 """
81 return embeddings
82
[docs]
83 def save(self) -> None:
84 """Save the FAISS index to disk.
85
86 Raises:
87 ValueError: If the number of embeddings does not match the index's total number of entries.
88 """
89 super().save()
90 import faiss
91
92 if self.num_embeddings != self.index.ntotal:
93 raise ValueError("number of embeddings does not match index.ntotal")
94 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"):
95 self.index = faiss.index_gpu_to_cpu(self.index)
96
97 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
98
[docs]
99 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
100 """Add embeddings to the FAISS index.
101
102 Args:
103 index_batch (IndexBatch): The batch containing document indices and embeddings.
104 output (BiEncoderOutput): The output from the bi-encoder module containing document embeddings.
105 Raises:
106 ValueError: If the document embeddings are not present in the output.
107 """
108 doc_embeddings = output.doc_embeddings
109 if doc_embeddings is None:
110 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
111 if doc_embeddings.scoring_mask is None:
112 doc_lengths = torch.ones(
113 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
114 )
115 embeddings = doc_embeddings.embeddings[:, 0]
116 else:
117 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
118 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
119 doc_ids = index_batch.doc_ids
120 embeddings = self.process_embeddings(embeddings)
121
122 if embeddings.shape[0]:
123 self.index.add(embeddings.float().cpu())
124
125 self.num_embeddings += embeddings.shape[0]
126 self.num_docs += len(doc_ids)
127
128 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
129 self.doc_ids.extend(doc_ids)
130
131
[docs]
132class FaissFlatIndexer(FaissIndexer):
133 """FAISS Flat Indexer for exact nearest neighbor search using FAISS."""
134
135 INDEX_FACTORY = "Flat"
136
[docs]
137 def __init__(
138 self,
139 index_dir: Path,
140 index_config: "FaissFlatIndexConfig",
141 module: BiEncoderModule,
142 verbose: bool = False,
143 ) -> None:
144 """Initialize the FaissFlatIndexer.
145
146 Args:
147 index_dir (Path): Directory where the index will be stored.
148 index_config (FaissFlatIndexConfig): Configuration for the FAISS flat index.
149 module (BiEncoderModule): The BiEncoderModule to use for indexing.
150 verbose (bool): Whether to enable verbose output. Defaults to False.
151 """
152 super().__init__(index_dir, index_config, module, verbose)
153 self.index_config: FaissFlatIndexConfig
154
[docs]
155 def to_gpu(self) -> None:
156 """Move the FAISS flat index to GPU."""
157 pass
158
[docs]
159 def to_cpu(self) -> None:
160 """Move the FAISS flat index to CPU."""
161 pass
162
163
164class _FaissTrainIndexer(FaissIndexer):
165 """Base class for FAISS indexers that require training on embeddings before indexing."""
166
167 INDEX_FACTORY = "" # class only acts as mixin
168
169 def __init__(
170 self,
171 index_dir: Path,
172 index_config: "_FaissTrainIndexConfig",
173 module: BiEncoderModule,
174 verbose: bool = False,
175 ) -> None:
176 """Initialize the _FaissTrainIndexer.
177
178 Args:
179 index_dir (Path): Directory where the index will be stored.
180 index_config (_FaissTrainIndexConfig): Configuration for the FAISS index that requires training.
181 module (BiEncoderModule): The BiEncoderModule to use for indexing.
182 verbose (bool): Whether to enable verbose output. Defaults to False.
183 Raises:
184 ValueError: If num_train_embeddings is not set in the index configuration.
185 """
186 super().__init__(index_dir, index_config, module, verbose)
187 if index_config.num_train_embeddings is None:
188 raise ValueError("num_train_embeddings must be set")
189 self.num_train_embeddings = index_config.num_train_embeddings
190
191 self._train_embeddings: torch.Tensor | None = torch.full(
192 (self.num_train_embeddings, self.module.config.embedding_dim), torch.nan, dtype=torch.float32
193 )
194
195 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
196 """Process embeddings before adding them to the FAISS index.
197
198 Args:
199 embeddings (torch.Tensor): The embeddings to process.
200 Returns:
201 torch.Tensor: The processed embeddings.
202 """
203 embeddings = self._grab_train_embeddings(embeddings)
204 self._train()
205 return embeddings
206
207 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
208 if self._train_embeddings is not None:
209 # save training embeddings until num_train_embeddings is reached
210 # if num_train_embeddings overflows, save the remaining embeddings
211 start = self.num_embeddings
212 end = start + embeddings.shape[0]
213 end = min(self.num_train_embeddings, start + embeddings.shape[0])
214 length = end - start
215 self._train_embeddings[start:end] = embeddings[:length]
216 self.num_embeddings += length
217 embeddings = embeddings[length:]
218 return embeddings
219
220 def _train(self, force: bool = False):
221 if self._train_embeddings is None:
222 return
223 if not force and self.num_embeddings < self.num_train_embeddings:
224 return
225 if torch.isnan(self._train_embeddings).any():
226 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.")
227 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
228 self.index.train(self._train_embeddings)
229 if torch.cuda.is_available():
230 self.to_cpu()
231 self.index.add(self._train_embeddings)
232 self._train_embeddings = None
233 self.set_verbosity(False)
234
235 def save(self) -> None:
236 if not self.index.is_trained:
237 self._train(force=True)
238 return super().save()
239
240
[docs]
241class FaissIVFIndexer(_FaissTrainIndexer):
242 """FAISS IVF Indexer for approximate nearest neighbor search using FAISS with Inverted File System (IVF)."""
243
244 INDEX_FACTORY = "IVF{num_centroids},Flat"
245
[docs]
246 def __init__(
247 self,
248 index_dir: Path,
249 index_config: "FaissIVFIndexConfig",
250 module: BiEncoderModule,
251 verbose: bool = False,
252 ) -> None:
253 """Initialize the FaissIVFIndexer.
254
255 Args:
256 index_dir (Path): Directory where the index will be stored.
257 index_config (FaissIVFIndexConfig): Configuration for the FAISS IVF index.
258 module (BiEncoderModule): The BiEncoderModule to use for indexing.
259 verbose (bool): Whether to enable verbose output. Defaults to False.
260 """
261 # default faiss values
262 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43
263 max_points_per_centroid = 256
264 index_config.num_train_embeddings = (
265 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid
266 )
267 super().__init__(index_dir, index_config, module, verbose)
268
269 import faiss
270
271 ivf_index = faiss.extract_index_ivf(self.index)
272 if hasattr(ivf_index, "quantizer"):
273 quantizer = ivf_index.quantizer
274 if hasattr(faiss.downcast_index(quantizer), "hnsw"):
275 downcasted_quantizer = faiss.downcast_index(quantizer)
276 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
277
[docs]
278 def to_gpu(self) -> None:
279 """Move the FAISS IVF index to GPU."""
280 import faiss
281
282 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu
283 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html
284 if faiss.get_num_gpus() == 0:
285 return
286 clustering_index = faiss.index_cpu_to_all_gpus(
287 faiss.IndexFlat(self.module.config.embedding_dim, self.metric_type)
288 )
289 clustering_index.verbose = self.verbose
290 index_ivf = faiss.extract_index_ivf(self.index)
291 index_ivf.clustering_index = clustering_index
292
[docs]
293 def to_cpu(self) -> None:
294 """Move the FAISS IVF index to CPU."""
295 import faiss
296
297 if faiss.get_num_gpus() == 0:
298 return
299 self.index = faiss.index_gpu_to_cpu(self.index)
300
301 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e
302 # move index to cpu but leave quantizer on gpu
303 index_ivf = faiss.extract_index_ivf(self.index)
304 quantizer = index_ivf.quantizer
305 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer)
306 index_ivf.quantizer = gpu_quantizer
307
[docs]
308 def set_verbosity(self, verbose: bool | None = None) -> None:
309 """Set the verbosity of the FAISS IVF index.
310
311 Args:
312 verbose (bool | None): Whether to enable verbose output. Defaults to None.
313 """
314 import faiss
315
316 verbose = verbose if verbose is not None else self.verbose
317 index = faiss.extract_index_ivf(self.index)
318 for elem in (index, index.quantizer):
319 setattr(elem, "verbose", verbose)
320
321
[docs]
322class FaissPQIndexer(_FaissTrainIndexer):
323 """FAISS PQ Indexer for approximate nearest neighbor search using Product Quantization (PQ)."""
324
325 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}"
326
[docs]
327 def __init__(
328 self,
329 index_dir: Path,
330 index_config: "FaissPQIndexConfig",
331 module: BiEncoderModule,
332 verbose: bool = False,
333 ) -> None:
334 """Initialize the FaissPQIndexer.
335
336 Args:
337 index_dir (Path): Directory where the index will be stored.
338 index_config (FaissPQIndexConfig): Configuration for the FAISS PQ index.
339 module (BiEncoderModule): The BiEncoderModule to use for indexing.
340 verbose (bool): Whether to enable verbose output. Defaults to False.
341 """
342 super().__init__(index_dir, index_config, module, verbose)
343 self.index_config: FaissPQIndexConfig
344
[docs]
345 def to_gpu(self) -> None:
346 """Move the FAISS PQ index to GPU."""
347 pass
348
[docs]
349 def to_cpu(self) -> None:
350 """Move the FAISS PQ index to CPU."""
351 pass
352
353
[docs]
354class FaissIVFPQIndexer(FaissIVFIndexer):
355 """FAISS IVFPQ Indexer for approximate nearest neighbor search using Inverted File System (IVF) with Product
356 Quantization (PQ)."""
357
358 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}"
359
[docs]
360 def __init__(
361 self,
362 index_dir: Path,
363 index_config: "FaissIVFPQIndexConfig",
364 module: BiEncoderModule,
365 verbose: bool = False,
366 ) -> None:
367 """Initialize the FaissIVFPQIndexer.
368
369 Args:
370 index_dir (Path): Directory where the index will be stored.
371 index_config (FaissIVFPQIndexConfig): Configuration for the FAISS IVFPQ index.
372 module (BiEncoderModule): The BiEncoderModule to use for indexing.
373 verbose (bool): Whether to enable verbose output. Defaults to False.
374 """
375 import faiss
376
377 super().__init__(index_dir, index_config, module, verbose)
378 self.index_config: FaissIVFPQIndexConfig
379
380 index_ivf = faiss.extract_index_ivf(self.index)
381 index_ivf.make_direct_map()
382
[docs]
383 def set_verbosity(self, verbose: bool | None = None) -> None:
384 """Set the verbosity of the FAISS IVFPQ index.
385
386 Args:
387 verbose (bool | None): Whether to enable verbose output. Defaults to None.
388 """
389 super().set_verbosity(verbose)
390 import faiss
391
392 verbose = verbose if verbose is not None else self.verbose
393 index_ivf_pq = faiss.downcast_index(self.index.index)
394 for elem in (
395 index_ivf_pq.pq,
396 index_ivf_pq.quantizer,
397 ):
398 setattr(elem, "verbose", verbose)
399
400
[docs]
401class FaissIndexConfig(IndexConfig):
402 """Configuration class for FAISS indexers in the Lightning IR framework."""
403
404 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}
405 indexer_class: Type[Indexer] = FaissIndexer
406
407
[docs]
408class FaissFlatIndexConfig(FaissIndexConfig):
409 """Configuration class for FAISS flat indexers in the Lightning IR framework."""
410
411 indexer_class = FaissFlatIndexer
412
413
414class _FaissTrainIndexConfig(FaissIndexConfig):
415 """Base configuration class for FAISS indexers that require training on embeddings before indexing."""
416
417 indexer_class = _FaissTrainIndexer
418
419 def __init__(self, num_train_embeddings: int | None = None) -> None:
420 """Initialize the _FaissTrainIndexConfig.
421
422 Args:
423 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
424 be set later. Defaults to None.
425 """
426 super().__init__()
427 self.num_train_embeddings = num_train_embeddings
428
429
[docs]
430class FaissIVFIndexConfig(_FaissTrainIndexConfig):
431 """Configuration class for FAISS IVF indexers in the Lightning IR framework."""
432
433 indexer_class = FaissIVFIndexer
434
[docs]
435 def __init__(
436 self,
437 num_train_embeddings: int | None = None,
438 num_centroids: int = 262144,
439 ef_construction: int = 40,
440 ) -> None:
441 """Initialize the FaissIVFIndexConfig.
442
443 Args:
444 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will be
445 set later. Defaults to None.
446 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144.
447 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40.
448 """
449 super().__init__(num_train_embeddings)
450 self.num_centroids = num_centroids
451 self.ef_construction = ef_construction
452
453
[docs]
454class FaissPQIndexConfig(_FaissTrainIndexConfig):
455 """Configuration class for FAISS PQ indexers in the Lightning IR framework."""
456
457 indexer_class = FaissPQIndexer
458
[docs]
459 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None:
460 """Initialize the FaissPQIndexConfig.
461
462 Args:
463 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
464 be set later. Defaults to None.
465 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16.
466 n_bits (int): Number of bits for the PQ index. Defaults to 8.
467 """
468 super().__init__(num_train_embeddings)
469 self.num_subquantizers = num_subquantizers
470 self.n_bits = n_bits
471
472
[docs]
473class FaissIVFPQIndexConfig(FaissIVFIndexConfig):
474 """Configuration class for FAISS IVFPQ indexers in the Lightning IR framework."""
475
476 indexer_class = FaissIVFPQIndexer
477
[docs]
478 def __init__(
479 self,
480 num_train_embeddings: int | None = None,
481 num_centroids: int = 262144,
482 ef_construction: int = 40,
483 num_subquantizers: int = 16,
484 n_bits: int = 8,
485 ) -> None:
486 """Initialize the FaissIVFPQIndexConfig.
487
488 Args:
489 num_train_embeddings (int | None): Number of embeddings to use for training the index. If None, it will
490 be set later. Defaults to None.
491 num_centroids (int): Number of centroids for the IVF index. Defaults to 262144.
492 ef_construction (int): The size of the dynamic list used during construction. Defaults to 40.
493 num_subquantizers (int): Number of subquantizers for the PQ index. Defaults to 16.
494 n_bits (int): Number of bits for the PQ index. Defaults to 8.
495 """
496 super().__init__(num_train_embeddings, num_centroids, ef_construction)
497 self.num_subquantizers = num_subquantizers
498 self.n_bits = n_bits