From d915980eda4904fe7a7674bc930acf05846be91d Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Fri, 30 Aug 2024 10:16:18 -0400 Subject: [PATCH] feat: Add Dataset_TensorDict_concatenated class for loading slices from arrays within a TensorDict --- bnpm/torch_helpers.py | 75 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/bnpm/torch_helpers.py b/bnpm/torch_helpers.py index d6da368..d9237fa 100644 --- a/bnpm/torch_helpers.py +++ b/bnpm/torch_helpers.py @@ -11,6 +11,7 @@ from torch.utils.data import Dataset import numpy as np from tqdm.auto import tqdm +import tensordict from . import indexing from . import misc @@ -608,6 +609,65 @@ def close(self): self.X.close() +class Dataset_TensorDict_concatenated(torch.utils.data.Dataset): + """ + Dataset class for loading slices from arrays within a TensorDict.\n + Input is a TensorDict containing arrays with similar shapes.\n + Output is a Dataset where the queried index pulls slices from the + concatenated first dimension indices of all the input arrays.\n + RH 2024 + + Args: + tensor_dict (TensorDict): + A TensorDict with the following organization: \n + * One hierarchical level of fields. + * Each field is a tensor. + * Each tensor may have different first dimension sizes, but all + other dimensions must be the same. + * Example: + TensorDict( + fields={ + 'array_1': tensor_1 (shape=[X1, M, N, P, ...]), + 'array_2': tensor_2 (shape=[X2, M, N, P, ...]), + ... + }, + batch_size=torch.Size([]), + ) + """ + def __init__(self, tensor_dict: tensordict.TensorDict): + super(Dataset_TensorDict_concatenated, self).__init__() + + assert isinstance(tensor_dict, tensordict.TensorDict), "tensor_dict must be a TensorDict." + self.tensor_dict = tensor_dict + + ## Check that all arrays have the same shape except for the first dimension + shapes = [arr.shape for arr in tensor_dict.values()] + check_shape = lambda shape1, shape2: shape1[1:] == shape2[1:] if len(shape1) > 1 else shape1[0] == shape2[0] + assert all([check_shape(shape, shapes[0]) for shape in shapes]), "All arrays must have the same shape except for the first dimension." + self.n_samples = sum([shape[0] for shape in shapes]) + self.shape = [self.n_samples] + list(shapes[0][1:]) + self.fields = list(tensor_dict.keys()) + + ## Create an index to field mapping + ### Use a binary search to find the field for a given index using the cumsum of the first dimensions + self.cumsum = torch.cumsum(torch.as_tensor([0] + [shape[0] for shape in shapes], dtype=torch.int64), dim=0) + self.idx_to_fieldIdx = lambda idx: torch.searchsorted(self.cumsum, idx, side='right') - 1 + + def __len__(self): + return self.n_samples + + def __getitem__(self, idx): + if (idx < 0): + idx = self.n_samples + idx + elif (idx >= self.n_samples): + raise IndexError(f"Index {idx} is out of bounds for dataset of length {self.n_samples}.") + + fieldIdx = self.idx_to_fieldIdx(idx) + field = self.fields[fieldIdx] + idx_field = idx - self.cumsum[fieldIdx] + return self.tensor_dict[field][idx_field] + + class BatchRandomSampler(torch.utils.data.Sampler): """ Creates a sampler similar to torch.utils.data.BatchSampler, but allows for @@ -820,6 +880,21 @@ def send_to_device(batch, device, non_blocking, pin_memory): ################################################################## +numpy_to_torch_dtype_dict = { + np.bool_ : torch.bool, + np.uint8 : torch.uint8, + np.int8 : torch.int8, + np.int16 : torch.int16, + np.int32 : torch.int32, + np.int64 : torch.int64, + np.float16 : torch.float16, + np.float32 : torch.float32, + np.float64 : torch.float64, + np.complex64 : torch.complex64, + np.complex128 : torch.complex128 +} + + @misc.wrapper_flexible_args(['dim', 'axis']) @misc.wrapper_flexible_args(['keepdim', 'keepdims']) def nanvar(