Source code for lightning_ir.retrieve.seismic.seismic_format

 1import numpy as np
 2import torch
 3
 4
[docs] 5class SeismicFormatConverter: 6 7 @staticmethod 8 def convert_to_seismic_format(embeddings: torch.Tensor) -> bytes: 9 if embeddings.ndim != 2: 10 raise ValueError("Expected 2D tensor") 11 batch_idcs, term_idcs = embeddings.nonzero(as_tuple=True) 12 lengths = torch.bincount(batch_idcs).tolist() 13 values = embeddings[(batch_idcs, term_idcs)] 14 15 out = b"" 16 for t, v in zip(term_idcs.split(lengths), values.split(lengths)): 17 out += (len(t)).to_bytes(4, byteorder="little", signed=False) 18 out += t.numpy(force=True).astype(np.int32).tobytes() 19 out += v.numpy(force=True).astype(np.float32).tobytes() 20 return out