Source code for lightning_ir.retrieve.base.packed_tensor

 1from typing import Sequence, Tuple
 2
 3import torch
 4
 5
[docs] 6class PackedTensor(torch.Tensor): 7 8 def __new__(cls, *args, lengths: Sequence[int] | None = None, **kwargs) -> "PackedTensor": 9 if lengths is None: 10 raise ValueError("lengths must be provided") 11 return super().__new__(cls, *args, **kwargs) 12
[docs] 13 def __init__(self, *args, lengths: Sequence[int] | None = None, **kwargs) -> None: 14 if lengths is None: 15 raise ValueError("lengths must be provided") 16 if sum(lengths) != len(self): 17 raise ValueError("Sum of lengths must equal the length of the tensor") 18 self.lengths = list(lengths) 19 self._segmented_tensor: Tuple[torch.Tensor, ...] | None = None
20 21 @property 22 def segmented_tensor(self) -> Tuple[torch.Tensor, ...]: 23 if self._segmented_tensor is None: 24 self._segmented_tensor = torch.split(self, self.lengths) 25 return self._segmented_tensor 26 27 def lookup( 28 self, packed_idcs: torch.Tensor, idcs_lengths: Sequence[int] | int, unique: bool = False 29 ) -> "PackedTensor": 30 output_tensors = [] 31 lengths = [] 32 for lookup_idcs in torch.split(packed_idcs, idcs_lengths): 33 intermediate_tensors = [] 34 for idx in lookup_idcs: 35 intermediate_tensors.append(self.segmented_tensor[idx]) 36 37 cat_tensors = torch.cat(intermediate_tensors) 38 if unique: 39 cat_tensors = torch.unique(cat_tensors) 40 lengths.append(cat_tensors.shape[0]) 41 output_tensors.append(cat_tensors) 42 43 return PackedTensor(torch.cat(output_tensors), lengths=lengths) 44
[docs] 45 def to_padded_tensor(self, pad_value: int = 0) -> torch.Tensor: 46 return torch.nn.utils.rnn.pad_sequence(self.segmented_tensor, batch_first=True, padding_value=pad_value)