diff --git a/logix/logging/log_loader.py b/logix/logging/log_loader.py index c166159..4b03c02 100644 --- a/logix/logging/log_loader.py +++ b/logix/logging/log_loader.py @@ -1,12 +1,8 @@ from collections import OrderedDict -from functools import reduce - -import numpy as np -import torch from torch.utils.data import Dataset from logix.logging.log_loader_utils import ( - get_entry_metadata, + unflatten_tensor, get_flatten_item, get_mmap_data, get_mmap_metadata, @@ -49,29 +45,25 @@ def __getitem__(self, index): nested_dict = {} mmap = self.memmaps[chunk_idx] offset = entry["offset"] + flat_tensor = get_flatten_item( + mmap, offset, entry["block_size"], entry["dtype"] + ) if self.flatten: - return data_id, get_flatten_item( - mmap, offset, entry["block_size"], entry["dtype"] - ) - dtype = entry["dtype"] + return data_id, flat_tensor + start = 0 for i in range(len(entry["path"])): path = entry["path"][i] shape = tuple(entry["shape"][i]) - tensor = torch.from_numpy( - np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C") - ).clone() - + tensor, start = unflatten_tensor(flat_tensor, shape, start) current_level = nested_dict for key in path[:-1]: if key not in current_level: current_level[key] = {} current_level = current_level[key] current_level[path[-1]] = tensor - offset += reduce(lambda x, y: x * y, shape) * np.dtype(dtype).itemsize - assert ( - offset == entry["offset"] + entry["block_size"] * np.dtype(dtype).itemsize - ), f"the block_size does not match the shape for data_id: {entry['data_id']}" + entry["block_size"] == start + ), f"block_size does not match with the shape for data_id: {entry['data_id']}" return data_id, nested_dict def __len__(self): diff --git a/logix/logging/log_loader_utils.py b/logix/logging/log_loader_utils.py index 8991ca7..1682074 100644 --- a/logix/logging/log_loader_utils.py +++ b/logix/logging/log_loader_utils.py @@ -114,6 +114,13 @@ def get_flatten_item(mmap, offset, block_size, dtype="float32"): return torch.from_numpy(array).clone() +def unflatten_tensor(flat_tensor, shape, start): + num_elements = reduce(lambda x, y: x * y, shape) + end = start + num_elements + unflattened_tensor = flat_tensor[start:end].view(*shape) + return unflattened_tensor, end + + def _init_collate_structure(nested_dict): # Initialize the collate structure based on the first item if isinstance(nested_dict, dict):