Skip to content

Commit

Permalink
feat: Add Dataset_TensorDict_concatenated class for loading slices fr…
Browse files Browse the repository at this point in the history
…om arrays within a TensorDict
  • Loading branch information
RichieHakim committed Aug 30, 2024
1 parent c108f11 commit d915980
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.data import Dataset
import numpy as np
from tqdm.auto import tqdm
import tensordict

from . import indexing
from . import misc
Expand Down Expand Up @@ -608,6 +609,65 @@ def close(self):
self.X.close()


class Dataset_TensorDict_concatenated(torch.utils.data.Dataset):
"""
Dataset class for loading slices from arrays within a TensorDict.\n
Input is a TensorDict containing arrays with similar shapes.\n
Output is a Dataset where the queried index pulls slices from the
concatenated first dimension indices of all the input arrays.\n
RH 2024
Args:
tensor_dict (TensorDict):
A TensorDict with the following organization: \n
* One hierarchical level of fields.
* Each field is a tensor.
* Each tensor may have different first dimension sizes, but all
other dimensions must be the same.
* Example:
TensorDict(
fields={
'array_1': tensor_1 (shape=[X1, M, N, P, ...]),
'array_2': tensor_2 (shape=[X2, M, N, P, ...]),
...
},
batch_size=torch.Size([]),
)
"""
def __init__(self, tensor_dict: tensordict.TensorDict):
super(Dataset_TensorDict_concatenated, self).__init__()

assert isinstance(tensor_dict, tensordict.TensorDict), "tensor_dict must be a TensorDict."
self.tensor_dict = tensor_dict

## Check that all arrays have the same shape except for the first dimension
shapes = [arr.shape for arr in tensor_dict.values()]
check_shape = lambda shape1, shape2: shape1[1:] == shape2[1:] if len(shape1) > 1 else shape1[0] == shape2[0]
assert all([check_shape(shape, shapes[0]) for shape in shapes]), "All arrays must have the same shape except for the first dimension."
self.n_samples = sum([shape[0] for shape in shapes])
self.shape = [self.n_samples] + list(shapes[0][1:])
self.fields = list(tensor_dict.keys())

## Create an index to field mapping
### Use a binary search to find the field for a given index using the cumsum of the first dimensions
self.cumsum = torch.cumsum(torch.as_tensor([0] + [shape[0] for shape in shapes], dtype=torch.int64), dim=0)
self.idx_to_fieldIdx = lambda idx: torch.searchsorted(self.cumsum, idx, side='right') - 1

def __len__(self):
return self.n_samples

def __getitem__(self, idx):
if (idx < 0):
idx = self.n_samples + idx
elif (idx >= self.n_samples):
raise IndexError(f"Index {idx} is out of bounds for dataset of length {self.n_samples}.")

fieldIdx = self.idx_to_fieldIdx(idx)
field = self.fields[fieldIdx]
idx_field = idx - self.cumsum[fieldIdx]
return self.tensor_dict[field][idx_field]


class BatchRandomSampler(torch.utils.data.Sampler):
"""
Creates a sampler similar to torch.utils.data.BatchSampler, but allows for
Expand Down Expand Up @@ -820,6 +880,21 @@ def send_to_device(batch, device, non_blocking, pin_memory):
##################################################################


numpy_to_torch_dtype_dict = {
np.bool_ : torch.bool,
np.uint8 : torch.uint8,
np.int8 : torch.int8,
np.int16 : torch.int16,
np.int32 : torch.int32,
np.int64 : torch.int64,
np.float16 : torch.float16,
np.float32 : torch.float32,
np.float64 : torch.float64,
np.complex64 : torch.complex64,
np.complex128 : torch.complex128
}


@misc.wrapper_flexible_args(['dim', 'axis'])
@misc.wrapper_flexible_args(['keepdim', 'keepdims'])
def nanvar(
Expand Down

0 comments on commit d915980

Please sign in to comment.