Source code for lightning_ir.retrieve.base.packed_tensor
1from typing import Sequence, Tuple
2
3import torch
4
5
[docs]
6class PackedTensor(torch.Tensor):
7
8 def __new__(cls, *args, lengths: Sequence[int] | None = None, **kwargs) -> "PackedTensor":
9 if lengths is None:
10 raise ValueError("lengths must be provided")
11 return super().__new__(cls, *args, **kwargs)
12
[docs]
13 def __init__(self, *args, lengths: Sequence[int] | None = None, **kwargs) -> None:
14 if lengths is None:
15 raise ValueError("lengths must be provided")
16 if sum(lengths) != len(self):
17 raise ValueError("Sum of lengths must equal the length of the tensor")
18 self.lengths = list(lengths)
19 self._segmented_tensor: Tuple[torch.Tensor, ...] | None = None
20
21 @property
22 def segmented_tensor(self) -> Tuple[torch.Tensor, ...]:
23 if self._segmented_tensor is None:
24 self._segmented_tensor = torch.split(self, self.lengths)
25 return self._segmented_tensor
26
27 def lookup(
28 self, packed_idcs: torch.Tensor, idcs_lengths: Sequence[int] | int, unique: bool = False
29 ) -> "PackedTensor":
30 output_tensors = []
31 lengths = []
32 for lookup_idcs in torch.split(packed_idcs, idcs_lengths):
33 intermediate_tensors = []
34 for idx in lookup_idcs:
35 intermediate_tensors.append(self.segmented_tensor[idx])
36
37 cat_tensors = torch.cat(intermediate_tensors)
38 if unique:
39 cat_tensors = torch.unique(cat_tensors)
40 lengths.append(cat_tensors.shape[0])
41 output_tensors.append(cat_tensors)
42
43 return PackedTensor(torch.cat(output_tensors), lengths=lengths)
44
[docs]
45 def to_padded_tensor(self, pad_value: int = 0) -> torch.Tensor:
46 return torch.nn.utils.rnn.pad_sequence(self.segmented_tensor, batch_first=True, padding_value=pad_value)