Source code for lightning_ir.retrieve.pytorch.dense_indexer

 1import array
 2from pathlib import Path
 3
 4import torch
 5
 6from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 7from ...data import IndexBatch
 8from ...models import ColConfig, DprConfig
 9from ..base import IndexConfig, Indexer
10
11
[docs] 12class TorchDenseIndexer(Indexer):
[docs] 13 def __init__( 14 self, 15 index_dir: Path, 16 index_config: "TorchDenseIndexConfig", 17 module: BiEncoderModule, 18 verbose: bool = False, 19 ) -> None: 20 super().__init__(index_dir, index_config, module, verbose) 21 self.embeddings = array.array("f")
22 23 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 24 doc_embeddings = output.doc_embeddings 25 if doc_embeddings is None: 26 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 27 28 if doc_embeddings.scoring_mask is None: 29 doc_lengths = torch.ones( 30 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 31 ) 32 embeddings = doc_embeddings.embeddings[:, 0] 33 else: 34 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 35 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 36 num_docs = len(index_batch.doc_ids) 37 self.doc_ids.extend(index_batch.doc_ids) 38 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 39 self.num_embeddings += embeddings.shape[0] 40 self.num_docs += num_docs 41 self.embeddings.extend(embeddings.cpu().view(-1).float().tolist()) 42 43 def to_gpu(self) -> None: 44 pass 45 46 def to_cpu(self) -> None: 47 pass 48 49 def save(self) -> None: 50 super().save() 51 index = torch.frombuffer(self.embeddings, dtype=torch.float32).view(self.num_embeddings, -1) 52 torch.save(index, self.index_dir / "index.pt")
53 54
[docs] 55class TorchDenseIndexConfig(IndexConfig): 56 indexer_class = TorchDenseIndexer 57 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}