Source code for lightning_ir.retrieve.seismic.seismic_format
1"""SeismicFormatConverter class for converting embeddings to a seismic format."""
2
3import numpy as np
4import torch
5
6
[docs]
7class SeismicFormatConverter:
8 """Converter for embeddings to a seismic format."""
9
[docs]
10 @staticmethod
11 def convert_to_seismic_format(embeddings: torch.Tensor) -> bytes:
12 """Convert embeddings to a seismic format.
13
14 Args:
15 embeddings (torch.Tensor): The embeddings to convert.
16 Returns:
17 bytes: The converted embeddings in seismic format.
18 Raises:
19 ValueError: If the embeddings are not 2D.
20 """
21 if embeddings.ndim != 2:
22 raise ValueError("Expected 2D tensor")
23 batch_idcs, term_idcs = embeddings.nonzero(as_tuple=True)
24 lengths = torch.bincount(batch_idcs).tolist()
25 values = embeddings[(batch_idcs, term_idcs)]
26
27 out = b""
28 for t, v in zip(term_idcs.split(lengths), values.split(lengths)):
29 out += (len(t)).to_bytes(4, byteorder="little", signed=False)
30 out += t.numpy(force=True).astype(np.int32).tobytes()
31 out += v.numpy(force=True).astype(np.float32).tobytes()
32 return out