Source code for lightning_ir.retrieve.pytorch.sparse_indexer

  1"""Torch-based Sparse Indexer for Lightning IR Framework"""
  2
  3import array
  4from pathlib import Path
  5
  6import torch
  7
  8from ...bi_encoder import BiEncoderModule, BiEncoderOutput
  9from ...data import IndexBatch
 10from ...models import SpladeConfig
 11from ..base import IndexConfig, Indexer
 12
 13
[docs] 14class TorchSparseIndexer(Indexer): 15 """Sparse indexer for bi-encoder models using PyTorch.""" 16
[docs] 17 def __init__( 18 self, 19 index_dir: Path, 20 index_config: "TorchSparseIndexConfig", 21 module: BiEncoderModule, 22 verbose: bool = False, 23 ) -> None: 24 """Initialize the TorchSparseIndexer. 25 26 Args: 27 index_dir (Path): Directory to store the index. 28 index_config (TorchSparseIndexConfig): Configuration for the sparse index. 29 module (BiEncoderModule): The bi-encoder module to use for indexing. 30 verbose (bool): Whether to print verbose output. Defaults to False. 31 """ 32 super().__init__(index_dir, index_config, module, verbose) 33 self.crow_indices = array.array("L") 34 self.crow_indices.append(0) 35 self.col_indices = array.array("L") 36 self.values = array.array("f")
37
[docs] 38 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 39 """Add embeddings to the sparse index. 40 41 Args: 42 index_batch (IndexBatch): The batch containing the embeddings to index. 43 output (BiEncoderOutput): The output from the bi-encoder model containing embeddings. 44 Raises: 45 ValueError: If doc_embeddings are not present in the output. 46 """ 47 doc_embeddings = output.doc_embeddings 48 if doc_embeddings is None: 49 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 50 51 if doc_embeddings.scoring_mask is None: 52 doc_lengths = torch.ones( 53 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 54 ) 55 embeddings = doc_embeddings.embeddings[:, 0] 56 else: 57 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 58 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 59 num_docs = len(index_batch.doc_ids) 60 self.doc_ids.extend(index_batch.doc_ids) 61 62 crow_indices, col_indices, values = self.to_sparse_csr(embeddings) 63 crow_indices = crow_indices[1:] # remove the first element which is always 0 64 crow_indices += self.crow_indices[-1] 65 66 self.crow_indices.extend(crow_indices.cpu().tolist()) 67 self.col_indices.extend(col_indices.cpu().tolist()) 68 self.values.extend(values.cpu().tolist()) 69 70 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 71 self.num_embeddings += embeddings.shape[0] 72 self.num_docs += num_docs
73
[docs] 74 @staticmethod 75 def to_sparse_csr( 76 embeddings: torch.Tensor, 77 ) -> torch.Tensor: 78 """Convert embeddings to sparse CSR format. 79 80 Args: 81 embeddings (torch.Tensor): The embeddings tensor to convert. 82 Returns: 83 Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Crow indices, column indices, and values of the sparse 84 matrix. 85 """ 86 token_idcs, dim_idcs = torch.nonzero(embeddings, as_tuple=True) 87 crow_indices = (token_idcs + 1).bincount().cumsum(0) 88 values = embeddings[token_idcs, dim_idcs] 89 return crow_indices, dim_idcs, values
90
[docs] 91 def to_gpu(self) -> None: 92 """Move the index to GPU if available.""" 93 pass
94
[docs] 95 def to_cpu(self) -> None: 96 """Move the index to CPU.""" 97 pass
98
[docs] 99 def save(self) -> None: 100 """Save the sparse index to disk.""" 101 super().save() 102 index = torch.sparse_csr_tensor( 103 torch.frombuffer(self.crow_indices, dtype=torch.int64), 104 torch.frombuffer(self.col_indices, dtype=torch.int64), 105 torch.frombuffer(self.values, dtype=torch.float32), 106 torch.Size([self.num_embeddings, self.module.config.embedding_dim]), 107 ) 108 torch.save(index, self.index_dir / "index.pt")
109 110
[docs] 111class TorchSparseIndexConfig(IndexConfig): 112 """Configuration for the Torch-based sparse indexer.""" 113 114 indexer_class = TorchSparseIndexer 115 SUPPORTED_MODELS = {SpladeConfig.model_type}