Source code for lightning_ir.retrieve.pytorch.sparse_indexer
1import array
2from pathlib import Path
3
4import torch
5
6from ...bi_encoder import BiEncoderModule, BiEncoderOutput
7from ...data import IndexBatch
8from ...models import SpladeConfig
9from ..base import IndexConfig, Indexer
10
11
[docs]
12class TorchSparseIndexer(Indexer):
[docs]
13 def __init__(
14 self,
15 index_dir: Path,
16 index_config: "TorchSparseIndexConfig",
17 module: BiEncoderModule,
18 verbose: bool = False,
19 ) -> None:
20 super().__init__(index_dir, index_config, module, verbose)
21 self.crow_indices = array.array("L")
22 self.crow_indices.append(0)
23 self.col_indices = array.array("L")
24 self.values = array.array("f")
25
26 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
27 doc_embeddings = output.doc_embeddings
28 if doc_embeddings is None:
29 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
30
31 if doc_embeddings.scoring_mask is None:
32 doc_lengths = torch.ones(
33 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32
34 )
35 embeddings = doc_embeddings.embeddings[:, 0]
36 else:
37 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
38 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
39 num_docs = len(index_batch.doc_ids)
40 self.doc_ids.extend(index_batch.doc_ids)
41
42 crow_indices, col_indices, values = self.to_sparse_csr(embeddings)
43 crow_indices = crow_indices[1:] # remove the first element which is always 0
44 crow_indices += self.crow_indices[-1]
45
46 self.crow_indices.extend(crow_indices.cpu().tolist())
47 self.col_indices.extend(col_indices.cpu().tolist())
48 self.values.extend(values.cpu().tolist())
49
50 self.doc_lengths.extend(doc_lengths.int().cpu().tolist())
51 self.num_embeddings += embeddings.shape[0]
52 self.num_docs += num_docs
53
54 @staticmethod
55 def to_sparse_csr(
56 embeddings: torch.Tensor,
57 ) -> torch.Tensor:
58 token_idcs, dim_idcs = torch.nonzero(embeddings, as_tuple=True)
59 crow_indices = (token_idcs + 1).bincount().cumsum(0)
60 values = embeddings[token_idcs, dim_idcs]
61 return crow_indices, dim_idcs, values
62
63 def to_gpu(self) -> None:
64 pass
65
66 def to_cpu(self) -> None:
67 pass
68
69 def save(self) -> None:
70 super().save()
71 index = torch.sparse_csr_tensor(
72 torch.frombuffer(self.crow_indices, dtype=torch.int64),
73 torch.frombuffer(self.col_indices, dtype=torch.int64),
74 torch.frombuffer(self.values, dtype=torch.float32),
75 torch.Size([self.num_embeddings, self.module.config.embedding_dim]),
76 )
77 torch.save(index, self.index_dir / "index.pt")
78
79
[docs]
80class TorchSparseIndexConfig(IndexConfig):
81 indexer_class = TorchSparseIndexer
82 SUPPORTED_MODELS = {SpladeConfig.model_type}