Source code for lightning_ir.retrieve.pytorch.dense_indexer
1import array
2from pathlib import Path
3
4import torch
5
6from ...bi_encoder import BiEncoderModule, BiEncoderOutput
7from ...data import IndexBatch
8from ...models import ColConfig, DprConfig
9from ..base import IndexConfig, Indexer
10
11
[docs]
12class TorchDenseIndexer(Indexer):
[docs]
13 def __init__(
14 self,
15 index_dir: Path,
16 index_config: "TorchDenseIndexConfig",
17 module: BiEncoderModule,
18 verbose: bool = False,
19 ) -> None:
20 super().__init__(index_dir, index_config, module, verbose)
21 self.embeddings = array.array("f")
22
23 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
24 doc_embeddings = output.doc_embeddings
25 if doc_embeddings is None:
26 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
27
28 if doc_embeddings.scoring_mask is None:
29 doc_lengths = torch.ones(
30 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
31 )
32 embeddings = doc_embeddings.embeddings[:, 0]
33 else:
34 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
35 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
36 num_docs = len(index_batch.doc_ids)
37 self.doc_ids.extend(index_batch.doc_ids)
38 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
39 self.num_embeddings += embeddings.shape[0]
40 self.num_docs += num_docs
41 self.embeddings.extend(embeddings.cpu().view(-1).float().tolist())
42
43 def to_gpu(self) -> None:
44 pass
45
46 def to_cpu(self) -> None:
47 pass
48
49 def save(self) -> None:
50 super().save()
51 index = torch.frombuffer(self.embeddings, dtype=torch.float32).view(self.num_embeddings, -1)
52 torch.save(index, self.index_dir / "index.pt")
53
54
[docs]
55class TorchDenseIndexConfig(IndexConfig):
56 indexer_class = TorchDenseIndexer
57 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}