Source code for lightning_ir.retrieve.seismic.seismic_format

 1"""SeismicFormatConverter class for converting embeddings to a seismic format."""
 2
 3import numpy as np
 4import torch
 5
 6
[docs] 7class SeismicFormatConverter: 8 """Converter for embeddings to a seismic format.""" 9
[docs] 10 @staticmethod 11 def convert_to_seismic_format(embeddings: torch.Tensor) -> bytes: 12 """Convert embeddings to a seismic format. 13 14 Args: 15 embeddings (torch.Tensor): The embeddings to convert. 16 Returns: 17 bytes: The converted embeddings in seismic format. 18 Raises: 19 ValueError: If the embeddings are not 2D. 20 """ 21 if embeddings.ndim != 2: 22 raise ValueError("Expected 2D tensor") 23 batch_idcs, term_idcs = embeddings.nonzero(as_tuple=True) 24 lengths = torch.bincount(batch_idcs).tolist() 25 values = embeddings[(batch_idcs, term_idcs)] 26 27 out = b"" 28 for t, v in zip(term_idcs.split(lengths), values.split(lengths)): 29 out += (len(t)).to_bytes(4, byteorder="little", signed=False) 30 out += t.numpy(force=True).astype(np.int32).tobytes() 31 out += v.numpy(force=True).astype(np.float32).tobytes() 32 return out