Source code for lightning_ir.retrieve.base.packed_tensor
1"""PackedTensor class for handling tensors with variable segment lengths."""
2
3from typing import Sequence, Tuple
4
5import torch
6
7
[docs]
8class PackedTensor(torch.Tensor):
9 """A tensor that contains a sequence of tensors with varying lengths."""
10
11 def __new__(cls, *args, lengths: Sequence[int] | None = None, **kwargs) -> "PackedTensor":
12 """Create a new PackedTensor instance.
13
14 Args:
15 lengths (Sequence[int] | None): A sequence of lengths for each segment in the tensor. If provided, the
16 tensor must be created with a total length equal to the sum of these lengths. Defaults to None.
17 Returns:
18 PackedTensor: A new instance of PackedTensor.
19 Raises:
20 ValueError: If lengths is None.
21 """
22 if lengths is None:
23 raise ValueError("lengths must be provided")
24 return super().__new__(cls, *args, **kwargs)
25
[docs]
26 def __init__(self, *args, lengths: Sequence[int] | None = None, **kwargs) -> None:
27 """Initialize the PackedTensor instance.
28
29 Args:
30 lengths (Sequence[int] | None): A sequence of lengths for each segment in the tensor. If provided, the
31 tensor must be created with a total length equal to the sum of these lengths. Defaults to None.
32 Raises:
33 ValueError: If lengths is None.
34 ValueError: If the sum of lengths does not equal the length of the tensor.
35 """
36 if lengths is None:
37 raise ValueError("lengths must be provided")
38 if sum(lengths) != len(self):
39 raise ValueError("Sum of lengths must equal the length of the tensor")
40 self.lengths = list(lengths)
41 self._segmented_tensor: Tuple[torch.Tensor, ...] | None = None
42
43 @property
44 def segmented_tensor(self) -> Tuple[torch.Tensor, ...]:
45 """Get the segmented tensor, which is a tuple of tensors split according to the specified lengths.
46
47 Returns:
48 Tuple[torch.Tensor, ...]: A tuple of tensors, each corresponding to a segment defined by the lengths.
49 """
50 if self._segmented_tensor is None:
51 self._segmented_tensor = torch.split(self, self.lengths)
52 return self._segmented_tensor
53
[docs]
54 def lookup(
55 self, packed_idcs: torch.Tensor, idcs_lengths: Sequence[int] | int, unique: bool = False
56 ) -> "PackedTensor":
57 """Lookup segments in the packed tensor based on provided indices.
58
59 Args:
60 packed_idcs (torch.Tensor): A tensor containing indices to lookup in the packed tensor.
61 idcs_lengths (Sequence[int] | int): Lengths of the indices for each segment. If a single integer is
62 provided, it is assumed that all segments have the same length.
63 unique (bool): If True, returns only unique values from the segments. Defaults to False.
64 Returns:
65 PackedTensor: A new PackedTensor containing the concatenated segments corresponding to the provided indices.
66 """
67 output_tensors = []
68 lengths = []
69 for lookup_idcs in torch.split(packed_idcs, idcs_lengths):
70 intermediate_tensors = []
71 for idx in lookup_idcs:
72 intermediate_tensors.append(self.segmented_tensor[idx])
73
74 cat_tensors = torch.cat(intermediate_tensors)
75 if unique:
76 cat_tensors = torch.unique(cat_tensors)
77 lengths.append(cat_tensors.shape[0])
78 output_tensors.append(cat_tensors)
79
80 return PackedTensor(torch.cat(output_tensors), lengths=lengths)
81
[docs]
82 def to_padded_tensor(self, pad_value: int = 0) -> torch.Tensor:
83 """Convert the packed tensor to a padded tensor.
84
85 Args:
86 pad_value (int): The value to use for padding. Defaults to 0.
87 Returns:
88 torch.Tensor: A padded tensor where each segment is padded to the length of the longest segment
89 in the packed tensor.
90 """
91 return torch.nn.utils.rnn.pad_sequence(self.segmented_tensor, batch_first=True, padding_value=pad_value)