Source code for lightning_ir.retrieve.pytorch.dense_indexer

 1import array
 2from pathlib import Path
 3
 4import torch
 5
 6from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 7from ...data import IndexBatch
 8from ...models import ColConfig, DprConfig
 9from ..base import IndexConfig, Indexer
10
11
[docs] 12class TorchDenseIndexer(Indexer): 13 """Indexer for dense embeddings using PyTorch.""" 14
[docs] 15 def __init__( 16 self, 17 index_dir: Path, 18 index_config: "TorchDenseIndexConfig", 19 module: BiEncoderModule, 20 verbose: bool = False, 21 ) -> None: 22 """Initialize the TorchDenseIndexer. 23 24 Args: 25 index_dir (Path): Directory to store the index. 26 index_config (TorchDenseIndexConfig): Configuration for the dense index. 27 module (BiEncoderModule): Bi-encoder module to use for indexing. 28 verbose (bool): Whether to print verbose output. Defaults to False. 29 """ 30 super().__init__(index_dir, index_config, module, verbose) 31 self.embeddings = array.array("f")
32
[docs] 33 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 34 """Add embeddings from the output to the index. 35 36 Args: 37 index_batch (IndexBatch): Batch containing the index data. 38 output (BiEncoderOutput): Output from the Bi-encoder model containing embeddings. 39 Raises: 40 ValueError: If output does not contain document embeddings. 41 """ 42 doc_embeddings = output.doc_embeddings 43 if doc_embeddings is None: 44 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 45 46 if doc_embeddings.scoring_mask is None: 47 doc_lengths = torch.ones( 48 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 49 ) 50 embeddings = doc_embeddings.embeddings[:, 0] 51 else: 52 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 53 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 54 num_docs = len(index_batch.doc_ids) 55 self.doc_ids.extend(index_batch.doc_ids) 56 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 57 self.num_embeddings += embeddings.shape[0] 58 self.num_docs += num_docs 59 self.embeddings.extend(embeddings.cpu().view(-1).float().tolist())
60
[docs] 61 def to_gpu(self) -> None: 62 """Convert the index to GPU format.""" 63 pass
64
[docs] 65 def to_cpu(self) -> None: 66 """Convert the index to CPU format.""" 67 pass
68
[docs] 69 def save(self) -> None: 70 """Save the index to the specified directory.""" 71 super().save() 72 index = torch.frombuffer(self.embeddings, dtype=torch.float32).view(self.num_embeddings, -1) 73 torch.save(index, self.index_dir / "index.pt")
74 75
[docs] 76class TorchDenseIndexConfig(IndexConfig): 77 """Configuration for the TorchDenseIndexer.""" 78 79 indexer_class = TorchDenseIndexer 80 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}