Source code for lightning_ir.retrieve.pytorch.dense_indexer
1import array
2from pathlib import Path
3
4import torch
5
6from ...bi_encoder import BiEncoderModule, BiEncoderOutput
7from ...data import IndexBatch
8from ...models import ColConfig, DprConfig
9from ..base import IndexConfig, Indexer
10
11
[docs]
12class TorchDenseIndexer(Indexer):
13 """Indexer for dense embeddings using PyTorch."""
14
[docs]
15 def __init__(
16 self,
17 index_dir: Path,
18 index_config: "TorchDenseIndexConfig",
19 module: BiEncoderModule,
20 verbose: bool = False,
21 ) -> None:
22 """Initialize the TorchDenseIndexer.
23
24 Args:
25 index_dir (Path): Directory to store the index.
26 index_config (TorchDenseIndexConfig): Configuration for the dense index.
27 module (BiEncoderModule): Bi-encoder module to use for indexing.
28 verbose (bool): Whether to print verbose output. Defaults to False.
29 """
30 super().__init__(index_dir, index_config, module, verbose)
31 self.embeddings = array.array("f")
32
[docs]
33 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
34 """Add embeddings from the output to the index.
35
36 Args:
37 index_batch (IndexBatch): Batch containing the index data.
38 output (BiEncoderOutput): Output from the Bi-encoder model containing embeddings.
39 Raises:
40 ValueError: If output does not contain document embeddings.
41 """
42 doc_embeddings = output.doc_embeddings
43 if doc_embeddings is None:
44 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
45
46 if doc_embeddings.scoring_mask is None:
47 doc_lengths = torch.ones(
48 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
49 )
50 embeddings = doc_embeddings.embeddings[:, 0]
51 else:
52 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
53 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
54 num_docs = len(index_batch.doc_ids)
55 self.doc_ids.extend(index_batch.doc_ids)
56 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
57 self.num_embeddings += embeddings.shape[0]
58 self.num_docs += num_docs
59 self.embeddings.extend(embeddings.cpu().view(-1).float().tolist())
60
[docs]
61 def to_gpu(self) -> None:
62 """Convert the index to GPU format."""
63 pass
64
[docs]
65 def to_cpu(self) -> None:
66 """Convert the index to CPU format."""
67 pass
68
[docs]
69 def save(self) -> None:
70 """Save the index to the specified directory."""
71 super().save()
72 index = torch.frombuffer(self.embeddings, dtype=torch.float32).view(self.num_embeddings, -1)
73 torch.save(index, self.index_dir / "index.pt")
74
75
[docs]
76class TorchDenseIndexConfig(IndexConfig):
77 """Configuration for the TorchDenseIndexer."""
78
79 indexer_class = TorchDenseIndexer
80 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}