Source code for lightning_ir.retrieve.base.packed_tensor

 1"""PackedTensor class for handling tensors with variable segment lengths."""
 2
 3from typing import Sequence, Tuple
 4
 5import torch
 6
 7
[docs] 8class PackedTensor(torch.Tensor): 9 """A tensor that contains a sequence of tensors with varying lengths.""" 10 11 def __new__(cls, *args, lengths: Sequence[int] | None = None, **kwargs) -> "PackedTensor": 12 """Create a new PackedTensor instance. 13 14 Args: 15 lengths (Sequence[int] | None): A sequence of lengths for each segment in the tensor. If provided, the 16 tensor must be created with a total length equal to the sum of these lengths. Defaults to None. 17 Returns: 18 PackedTensor: A new instance of PackedTensor. 19 Raises: 20 ValueError: If lengths is None. 21 """ 22 if lengths is None: 23 raise ValueError("lengths must be provided") 24 return super().__new__(cls, *args, **kwargs) 25
[docs] 26 def __init__(self, *args, lengths: Sequence[int] | None = None, **kwargs) -> None: 27 """Initialize the PackedTensor instance. 28 29 Args: 30 lengths (Sequence[int] | None): A sequence of lengths for each segment in the tensor. If provided, the 31 tensor must be created with a total length equal to the sum of these lengths. Defaults to None. 32 Raises: 33 ValueError: If lengths is None. 34 ValueError: If the sum of lengths does not equal the length of the tensor. 35 """ 36 if lengths is None: 37 raise ValueError("lengths must be provided") 38 if sum(lengths) != len(self): 39 raise ValueError("Sum of lengths must equal the length of the tensor") 40 self.lengths = list(lengths) 41 self._segmented_tensor: Tuple[torch.Tensor, ...] | None = None
42 43 @property 44 def segmented_tensor(self) -> Tuple[torch.Tensor, ...]: 45 """Get the segmented tensor, which is a tuple of tensors split according to the specified lengths. 46 47 Returns: 48 Tuple[torch.Tensor, ...]: A tuple of tensors, each corresponding to a segment defined by the lengths. 49 """ 50 if self._segmented_tensor is None: 51 self._segmented_tensor = torch.split(self, self.lengths) 52 return self._segmented_tensor 53
[docs] 54 def lookup( 55 self, packed_idcs: torch.Tensor, idcs_lengths: Sequence[int] | int, unique: bool = False 56 ) -> "PackedTensor": 57 """Lookup segments in the packed tensor based on provided indices. 58 59 Args: 60 packed_idcs (torch.Tensor): A tensor containing indices to lookup in the packed tensor. 61 idcs_lengths (Sequence[int] | int): Lengths of the indices for each segment. If a single integer is 62 provided, it is assumed that all segments have the same length. 63 unique (bool): If True, returns only unique values from the segments. Defaults to False. 64 Returns: 65 PackedTensor: A new PackedTensor containing the concatenated segments corresponding to the provided indices. 66 """ 67 output_tensors = [] 68 lengths = [] 69 for lookup_idcs in torch.split(packed_idcs, idcs_lengths): 70 intermediate_tensors = [] 71 for idx in lookup_idcs: 72 intermediate_tensors.append(self.segmented_tensor[idx]) 73 74 cat_tensors = torch.cat(intermediate_tensors) 75 if unique: 76 cat_tensors = torch.unique(cat_tensors) 77 lengths.append(cat_tensors.shape[0]) 78 output_tensors.append(cat_tensors) 79 80 return PackedTensor(torch.cat(output_tensors), lengths=lengths)
81
[docs] 82 def to_padded_tensor(self, pad_value: int = 0) -> torch.Tensor: 83 """Convert the packed tensor to a padded tensor. 84 85 Args: 86 pad_value (int): The value to use for padding. Defaults to 0. 87 Returns: 88 torch.Tensor: A padded tensor where each segment is padded to the length of the longest segment 89 in the packed tensor. 90 """ 91 return torch.nn.utils.rnn.pad_sequence(self.segmented_tensor, batch_first=True, padding_value=pad_value)