1import warnings
2from pathlib import Path
3from typing import Type
4
5import torch
6
7from ...bi_encoder import BiEncoderModule, BiEncoderOutput
8from ...data import IndexBatch
9from ...models import ColConfig, DprConfig
10from ..base import IndexConfig, Indexer
11
12
[docs]
13class FaissIndexer(Indexer):
14 INDEX_FACTORY: str
15
[docs]
16 def __init__(
17 self,
18 index_dir: Path,
19 index_config: "FaissIndexConfig",
20 module: BiEncoderModule,
21 verbose: bool = False,
22 ) -> None:
23 super().__init__(index_dir, index_config, module, verbose)
24 import faiss
25
26 similarity_function = self.module.config.similarity_function
27 if similarity_function in ("cosine", "dot"):
28 self.metric_type = faiss.METRIC_INNER_PRODUCT
29 else:
30 raise ValueError(f"similarity_function {similarity_function} unknown")
31
32 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict())
33 if similarity_function == "cosine":
34 index_factory = "L2norm," + index_factory
35 self.index = faiss.index_factory(self.module.config.embedding_dim, index_factory, self.metric_type)
36
37 self.set_verbosity()
38
39 if torch.cuda.is_available():
40 self.to_gpu()
41
42 def to_gpu(self) -> None:
43 pass
44
45 def to_cpu(self) -> None:
46 pass
47
48 def set_verbosity(self, verbose: bool | None = None) -> None:
49 self.index.verbose = self.verbose if verbose is None else verbose
50
51 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
52 return embeddings
53
54 def save(self) -> None:
55 super().save()
56 import faiss
57
58 if self.num_embeddings != self.index.ntotal:
59 raise ValueError("number of embeddings does not match index.ntotal")
60 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"):
61 self.index = faiss.index_gpu_to_cpu(self.index)
62
63 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
64
65 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
66 doc_embeddings = output.doc_embeddings
67 if doc_embeddings is None:
68 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
69 if doc_embeddings.scoring_mask is None:
70 doc_lengths = torch.ones(
71 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
72 )
73 embeddings = doc_embeddings.embeddings[:, 0]
74 else:
75 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
76 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
77 doc_ids = index_batch.doc_ids
78 embeddings = self.process_embeddings(embeddings)
79
80 if embeddings.shape[0]:
81 self.index.add(embeddings.float().cpu())
82
83 self.num_embeddings += embeddings.shape[0]
84 self.num_docs += len(doc_ids)
85
86 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
87 self.doc_ids.extend(doc_ids)
88
89
[docs]
90class FaissFlatIndexer(FaissIndexer):
91 INDEX_FACTORY = "Flat"
92
[docs]
93 def __init__(
94 self,
95 index_dir: Path,
96 index_config: "FaissFlatIndexConfig",
97 module: BiEncoderModule,
98 verbose: bool = False,
99 ) -> None:
100 super().__init__(index_dir, index_config, module, verbose)
101 self.index_config: FaissFlatIndexConfig
102
103 def to_gpu(self) -> None:
104 pass
105
106 def to_cpu(self) -> None:
107 pass
108
109
110class _FaissTrainIndexer(FaissIndexer):
111
112 INDEX_FACTORY = "" # class only acts as mixin
113
114 def __init__(
115 self,
116 index_dir: Path,
117 index_config: "_FaissTrainIndexConfig",
118 module: BiEncoderModule,
119 verbose: bool = False,
120 ) -> None:
121 super().__init__(index_dir, index_config, module, verbose)
122 if index_config.num_train_embeddings is None:
123 raise ValueError("num_train_embeddings must be set")
124 self.num_train_embeddings = index_config.num_train_embeddings
125
126 self._train_embeddings: torch.Tensor | None = torch.full(
127 (self.num_train_embeddings, self.module.config.embedding_dim), torch.nan, dtype=torch.float32
128 )
129
130 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
131 embeddings = self._grab_train_embeddings(embeddings)
132 self._train()
133 return embeddings
134
135 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
136 if self._train_embeddings is not None:
137 # save training embeddings until num_train_embeddings is reached
138 # if num_train_embeddings overflows, save the remaining embeddings
139 start = self.num_embeddings
140 end = start + embeddings.shape[0]
141 end = min(self.num_train_embeddings, start + embeddings.shape[0])
142 length = end - start
143 self._train_embeddings[start:end] = embeddings[:length]
144 self.num_embeddings += length
145 embeddings = embeddings[length:]
146 return embeddings
147
148 def _train(self, force: bool = False):
149 if self._train_embeddings is None:
150 return
151 if not force and self.num_embeddings < self.num_train_embeddings:
152 return
153 if torch.isnan(self._train_embeddings).any():
154 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.")
155 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
156 self.index.train(self._train_embeddings)
157 if torch.cuda.is_available():
158 self.to_cpu()
159 self.index.add(self._train_embeddings)
160 self._train_embeddings = None
161 self.set_verbosity(False)
162
163 def save(self) -> None:
164 if not self.index.is_trained:
165 self._train(force=True)
166 return super().save()
167
168
[docs]
169class FaissIVFIndexer(_FaissTrainIndexer):
170 INDEX_FACTORY = "IVF{num_centroids},Flat"
171
[docs]
172 def __init__(
173 self,
174 index_dir: Path,
175 index_config: "FaissIVFIndexConfig",
176 module: BiEncoderModule,
177 verbose: bool = False,
178 ) -> None:
179 # default faiss values
180 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43
181 max_points_per_centroid = 256
182 index_config.num_train_embeddings = (
183 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid
184 )
185 super().__init__(index_dir, index_config, module, verbose)
186
187 import faiss
188
189 ivf_index = faiss.extract_index_ivf(self.index)
190 if hasattr(ivf_index, "quantizer"):
191 quantizer = ivf_index.quantizer
192 if hasattr(faiss.downcast_index(quantizer), "hnsw"):
193 downcasted_quantizer = faiss.downcast_index(quantizer)
194 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
195
196 def to_gpu(self) -> None:
197 import faiss
198
199 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu
200 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html
201 clustering_index = faiss.index_cpu_to_all_gpus(
202 faiss.IndexFlat(self.module.config.embedding_dim, self.metric_type)
203 )
204 clustering_index.verbose = self.verbose
205 index_ivf = faiss.extract_index_ivf(self.index)
206 index_ivf.clustering_index = clustering_index
207
208 def to_cpu(self) -> None:
209 import faiss
210
211 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") and hasattr(faiss, "index_cpu_to_gpu"):
212 self.index = faiss.index_gpu_to_cpu(self.index)
213
214 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e
215 # move index to cpu but leave quantizer on gpu
216 index_ivf = faiss.extract_index_ivf(self.index)
217 quantizer = index_ivf.quantizer
218 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer)
219 index_ivf.quantizer = gpu_quantizer
220
221 def set_verbosity(self, verbose: bool | None = None) -> None:
222 import faiss
223
224 verbose = verbose if verbose is not None else self.verbose
225 index = faiss.extract_index_ivf(self.index)
226 for elem in (index, index.quantizer):
227 setattr(elem, "verbose", verbose)
228
229
[docs]
230class FaissPQIndexer(_FaissTrainIndexer):
231
232 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}"
233
[docs]
234 def __init__(
235 self,
236 index_dir: Path,
237 index_config: "FaissPQIndexConfig",
238 module: BiEncoderModule,
239 verbose: bool = False,
240 ) -> None:
241 super().__init__(index_dir, index_config, module, verbose)
242 self.index_config: FaissPQIndexConfig
243
244 def to_gpu(self) -> None:
245 pass
246
247 def to_cpu(self) -> None:
248 pass
249
250
[docs]
251class FaissIVFPQIndexer(FaissIVFIndexer):
252 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}"
253
[docs]
254 def __init__(
255 self,
256 index_dir: Path,
257 index_config: "FaissIVFPQIndexConfig",
258 module: BiEncoderModule,
259 verbose: bool = False,
260 ) -> None:
261 import faiss
262
263 super().__init__(index_dir, index_config, module, verbose)
264 self.index_config: FaissIVFPQIndexConfig
265
266 index_ivf = faiss.extract_index_ivf(self.index)
267 index_ivf.make_direct_map()
268
269 def set_verbosity(self, verbose: bool | None = None) -> None:
270 super().set_verbosity(verbose)
271 import faiss
272
273 verbose = verbose if verbose is not None else self.verbose
274 index_ivf_pq = faiss.downcast_index(self.index.index)
275 for elem in (
276 index_ivf_pq.pq,
277 index_ivf_pq.quantizer,
278 ):
279 setattr(elem, "verbose", verbose)
280
281
[docs]
282class FaissIndexConfig(IndexConfig):
283 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}
284 indexer_class: Type[Indexer] = FaissIndexer
285
286
[docs]
287class FaissFlatIndexConfig(FaissIndexConfig):
288 indexer_class = FaissFlatIndexer
289
290
291class _FaissTrainIndexConfig(FaissIndexConfig):
292
293 indexer_class = _FaissTrainIndexer
294
295 def __init__(self, num_train_embeddings: int | None = None) -> None:
296 super().__init__()
297 self.num_train_embeddings = num_train_embeddings
298
299
[docs]
300class FaissIVFIndexConfig(_FaissTrainIndexConfig):
301 indexer_class = FaissIVFIndexer
302
[docs]
303 def __init__(
304 self,
305 num_train_embeddings: int | None = None,
306 num_centroids: int = 262144,
307 ef_construction: int = 40,
308 ) -> None:
309 super().__init__(num_train_embeddings)
310 self.num_centroids = num_centroids
311 self.ef_construction = ef_construction
312
313
[docs]
314class FaissPQIndexConfig(_FaissTrainIndexConfig):
315 indexer_class = FaissPQIndexer
316
[docs]
317 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None:
318 super().__init__(num_train_embeddings)
319 self.num_subquantizers = num_subquantizers
320 self.n_bits = n_bits
321
322
[docs]
323class FaissIVFPQIndexConfig(FaissIVFIndexConfig):
324 indexer_class = FaissIVFPQIndexer
325
[docs]
326 def __init__(
327 self,
328 num_train_embeddings: int | None = None,
329 num_centroids: int = 262144,
330 ef_construction: int = 40,
331 num_subquantizers: int = 16,
332 n_bits: int = 8,
333 ) -> None:
334 super().__init__(num_train_embeddings, num_centroids, ef_construction)
335 self.num_subquantizers = num_subquantizers
336 self.n_bits = n_bits