Source code for lightning_ir.retrieve.seismic.seismic_format
1import numpy as np
2import torch
3
4
[docs]
5class SeismicFormatConverter:
6
7 @staticmethod
8 def convert_to_seismic_format(embeddings: torch.Tensor) -> bytes:
9 if embeddings.ndim != 2:
10 raise ValueError("Expected 2D tensor")
11 batch_idcs, term_idcs = embeddings.nonzero(as_tuple=True)
12 lengths = torch.bincount(batch_idcs).tolist()
13 values = embeddings[(batch_idcs, term_idcs)]
14
15 out = b""
16 for t, v in zip(term_idcs.split(lengths), values.split(lengths)):
17 out += (len(t)).to_bytes(4, byteorder="little", signed=False)
18 out += t.numpy(force=True).astype(np.int32).tobytes()
19 out += v.numpy(force=True).astype(np.float32).tobytes()
20 return out