Source code for lightning_ir.retrieve.pytorch.sparse_indexer

 1import array
 2from pathlib import Path
 3
 4import torch
 5
 6from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 7from ...data import IndexBatch
 8from ...models import SpladeConfig
 9from ..base import IndexConfig, Indexer
10
11
[docs] 12class TorchSparseIndexer(Indexer):
[docs] 13 def __init__( 14 self, 15 index_dir: Path, 16 index_config: "TorchSparseIndexConfig", 17 module: BiEncoderModule, 18 verbose: bool = False, 19 ) -> None: 20 super().__init__(index_dir, index_config, module, verbose) 21 self.crow_indices = array.array("L") 22 self.crow_indices.append(0) 23 self.col_indices = array.array("L") 24 self.values = array.array("f")
25 26 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 27 doc_embeddings = output.doc_embeddings 28 if doc_embeddings is None: 29 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 30 31 if doc_embeddings.scoring_mask is None: 32 doc_lengths = torch.ones( 33 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 34 ) 35 embeddings = doc_embeddings.embeddings[:, 0] 36 else: 37 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 38 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 39 num_docs = len(index_batch.doc_ids) 40 self.doc_ids.extend(index_batch.doc_ids) 41 42 crow_indices, col_indices, values = self.to_sparse_csr(embeddings) 43 crow_indices = crow_indices[1:] # remove the first element which is always 0 44 crow_indices += self.crow_indices[-1] 45 46 self.crow_indices.extend(crow_indices.cpu().tolist()) 47 self.col_indices.extend(col_indices.cpu().tolist()) 48 self.values.extend(values.cpu().tolist()) 49 50 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 51 self.num_embeddings += embeddings.shape[0] 52 self.num_docs += num_docs 53 54 @staticmethod 55 def to_sparse_csr( 56 embeddings: torch.Tensor, 57 ) -> torch.Tensor: 58 token_idcs, dim_idcs = torch.nonzero(embeddings, as_tuple=True) 59 crow_indices = (token_idcs + 1).bincount().cumsum(0) 60 values = embeddings[token_idcs, dim_idcs] 61 return crow_indices, dim_idcs, values 62 63 def to_gpu(self) -> None: 64 pass 65 66 def to_cpu(self) -> None: 67 pass 68 69 def save(self) -> None: 70 super().save() 71 index = torch.sparse_csr_tensor( 72 torch.frombuffer(self.crow_indices, dtype=torch.int64), 73 torch.frombuffer(self.col_indices, dtype=torch.int64), 74 torch.frombuffer(self.values, dtype=torch.float32), 75 torch.Size([self.num_embeddings, self.module.config.embedding_dim]), 76 ) 77 torch.save(index, self.index_dir / "index.pt")
78 79
[docs] 80class TorchSparseIndexConfig(IndexConfig): 81 indexer_class = TorchSparseIndexer 82 SUPPORTED_MODELS = {SpladeConfig.model_type}