Source code for lightning_ir.retrieve.pytorch.sparse_indexer
1"""Torch-based Sparse Indexer for Lightning IR Framework"""
2
3import array
4from pathlib import Path
5
6import torch
7
8from ...bi_encoder import BiEncoderModule, BiEncoderOutput
9from ...data import IndexBatch
10from ...models import SpladeConfig
11from ..base import IndexConfig, Indexer
12
13
[docs]
14class TorchSparseIndexer(Indexer):
15 """Sparse indexer for bi-encoder models using PyTorch."""
16
[docs]
17 def __init__(
18 self,
19 index_dir: Path,
20 index_config: "TorchSparseIndexConfig",
21 module: BiEncoderModule,
22 verbose: bool = False,
23 ) -> None:
24 """Initialize the TorchSparseIndexer.
25
26 Args:
27 index_dir (Path): Directory to store the index.
28 index_config (TorchSparseIndexConfig): Configuration for the sparse index.
29 module (BiEncoderModule): The bi-encoder module to use for indexing.
30 verbose (bool): Whether to print verbose output. Defaults to False.
31 """
32 super().__init__(index_dir, index_config, module, verbose)
33 self.crow_indices = array.array("L")
34 self.crow_indices.append(0)
35 self.col_indices = array.array("L")
36 self.values = array.array("f")
37
[docs]
38 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
39 """Add embeddings to the sparse index.
40
41 Args:
42 index_batch (IndexBatch): The batch containing the embeddings to index.
43 output (BiEncoderOutput): The output from the bi-encoder model containing embeddings.
44 Raises:
45 ValueError: If doc_embeddings are not present in the output.
46 """
47 doc_embeddings = output.doc_embeddings
48 if doc_embeddings is None:
49 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
50
51 if doc_embeddings.scoring_mask is None:
52 doc_lengths = torch.ones(
53 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
54 )
55 embeddings = doc_embeddings.embeddings[:, 0]
56 else:
57 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
58 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
59 num_docs = len(index_batch.doc_ids)
60 self.doc_ids.extend(index_batch.doc_ids)
61
62 crow_indices, col_indices, values = self.to_sparse_csr(embeddings)
63 crow_indices = crow_indices[1:] # remove the first element which is always 0
64 crow_indices += self.crow_indices[-1]
65
66 self.crow_indices.extend(crow_indices.cpu().tolist())
67 self.col_indices.extend(col_indices.cpu().tolist())
68 self.values.extend(values.cpu().tolist())
69
70 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
71 self.num_embeddings += embeddings.shape[0]
72 self.num_docs += num_docs
73
[docs]
74 @staticmethod
75 def to_sparse_csr(
76 embeddings: torch.Tensor,
77 ) -> torch.Tensor:
78 """Convert embeddings to sparse CSR format.
79
80 Args:
81 embeddings (torch.Tensor): The embeddings tensor to convert.
82 Returns:
83 Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Crow indices, column indices, and values of the sparse
84 matrix.
85 """
86 token_idcs, dim_idcs = torch.nonzero(embeddings, as_tuple=True)
87 crow_indices = (token_idcs + 1).bincount().cumsum(0)
88 values = embeddings[token_idcs, dim_idcs]
89 return crow_indices, dim_idcs, values
90
[docs]
91 def to_gpu(self) -> None:
92 """Move the index to GPU if available."""
93 pass
94
[docs]
95 def to_cpu(self) -> None:
96 """Move the index to CPU."""
97 pass
98
[docs]
99 def save(self) -> None:
100 """Save the sparse index to disk."""
101 super().save()
102 index = torch.sparse_csr_tensor(
103 torch.frombuffer(self.crow_indices, dtype=torch.int64),
104 torch.frombuffer(self.col_indices, dtype=torch.int64),
105 torch.frombuffer(self.values, dtype=torch.float32),
106 torch.Size([self.num_embeddings, self.module.config.embedding_dim]),
107 )
108 torch.save(index, self.index_dir / "index.pt")
109
110
[docs]
111class TorchSparseIndexConfig(IndexConfig):
112 """Configuration for the Torch-based sparse indexer."""
113
114 indexer_class = TorchSparseIndexer
115 SUPPORTED_MODELS = {SpladeConfig.model_type}