Source code for lightning_ir.retrieve.seismic.seismic_indexer

  1import os
  2from pathlib import Path
  3
  4import numpy as np
  5import torch
  6
  7try:
  8    _seismic_available = True
  9    import seismic
 10    from seismic import SeismicDataset, SeismicIndex
 11
 12    STRING_TYPE = seismic.get_seismic_string()
 13except ImportError:
 14    STRING_TYPE = None
 15    _seismic_available = False
 16    SeismicIndex = SeismicDataset = None
 17
 18
 19from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 20from ...data import IndexBatch
 21from ...models import SpladeConfig
 22from ..base import IndexConfig, Indexer
 23
 24
[docs] 25class SeismicIndexer(Indexer):
[docs] 26 def __init__( 27 self, 28 index_dir: Path, 29 index_config: "SeismicIndexConfig", 30 module: BiEncoderModule, 31 verbose: bool = False, 32 ) -> None: 33 super().__init__(index_dir, index_config, module, verbose) 34 if _seismic_available is False: 35 raise ImportError( 36 "Please install the seismic package to use the SeismicIndexer. " 37 "Instructions can be found at " 38 "https://github.com/TusKANNy/seismic?tab=readme-ov-file#using-the-python-interface" 39 ) 40 self.index_config: SeismicIndexConfig 41 assert SeismicDataset is not None 42 self.seismic_dataset = SeismicDataset()
43 44 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 45 doc_embeddings = output.doc_embeddings 46 if doc_embeddings is None: 47 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 48 49 if doc_embeddings.scoring_mask is None: 50 doc_lengths = torch.ones( 51 doc_embeddings.embeddings.shape[0], device=doc_embeddings.device, dtype=torch.int32 52 ) 53 embeddings = doc_embeddings.embeddings[:, 0] 54 else: 55 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 56 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 57 58 num_docs = len(index_batch.doc_ids) 59 self.doc_ids.extend(index_batch.doc_ids) 60 self.doc_lengths.extend(doc_lengths.int().cpu().tolist()) 61 self.num_embeddings += embeddings.shape[0] 62 self.num_docs += num_docs 63 64 for idx, doc_id in enumerate(index_batch.doc_ids): 65 non_zero = embeddings[idx].nonzero().view(-1) 66 values = embeddings[idx][non_zero].float().numpy(force=True) 67 tokens = np.array(self.module.tokenizer.convert_ids_to_tokens(non_zero), dtype="U30") 68 self.seismic_dataset.add_document(doc_id, tokens, values) 69 70 def save(self) -> None: 71 super().save() 72 73 assert SeismicIndex is not None 74 index = SeismicIndex.build_from_dataset( 75 self.seismic_dataset, 76 n_postings=self.index_config.num_postings, 77 centroid_fraction=self.index_config.centroid_fraction, 78 min_cluster_size=self.index_config.min_cluster_size, 79 summary_energy=self.index_config.summary_energy, 80 nknn=self.index_config.num_k_nearest_neighbors, 81 batched_indexing=self.index_config.batch_size, 82 num_threads=self.index_config.num_threads, 83 ) 84 index.save(str(self.index_dir) + os.path.sep)
85 86
[docs] 87class SeismicIndexConfig(IndexConfig): 88 indexer_class = SeismicIndexer 89 SUPPORTED_MODELS = {SpladeConfig.model_type} 90
[docs] 91 def __init__( 92 self, 93 num_postings: int = 3_500, 94 centroid_fraction: float = 0.1, 95 min_cluster_size: int = 2, 96 summary_energy: float = 0.4, 97 num_k_nearest_neighbors: int = 0, 98 batch_size: int | None = None, 99 num_threads: int = 0, 100 ) -> None: 101 super().__init__() 102 self.num_postings = num_postings 103 self.centroid_fraction = centroid_fraction 104 self.summary_energy = summary_energy 105 self.min_cluster_size = min_cluster_size 106 self.num_k_nearest_neighbors = num_k_nearest_neighbors 107 self.batch_size = batch_size 108 self.num_threads = num_threads