Skip to content

Commit

Permalink
Merge pull request #11 from sangkeun00/no-for-loop
Browse files Browse the repository at this point in the history
Eliminating entry for loop to look for `data_id`.
  • Loading branch information
eatpk authored Nov 19, 2023
2 parents a2a3100 + 7464229 commit beaf0d3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 39 deletions.
65 changes: 33 additions & 32 deletions analog/storage/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def _flush_serialized(self) -> str:
if len(self.buffer) == 0:
return self.log_dir
buffer_list = [(k, v) for k, v in self.buffer.items()]
self.mmap_handler.write(buffer_list, self.file_prefix + f"{self.push_count}.mmap")
self.mmap_handler.write(
buffer_list, self.file_prefix + f"{self.push_count}.mmap"
)

self.push_count += 1
del buffer_list
Expand Down Expand Up @@ -201,18 +203,17 @@ def build_log_dataloader(self, batch_size=16, num_workers=0):

class DefaultLogDataset(Dataset):
def __init__(self, mmap_handler):
self.schemas = []
self.memmaps = []
self.data_id_to_chunk = OrderedDict()
self.mmap_handler = mmap_handler

# Find all chunk indices
self.chunk_indices = self._find_chunk_indices(self.mmap_handler.get_path())

# Add schemas and mmap files for all indices
# Add metadata and mmap files for all indices
for chunk_index in self.chunk_indices:
mmap_filename = f"log_chunk_{chunk_index}.mmap"
self._add_schema_and_mmap(mmap_filename, chunk_index)
self._add_metadata_and_mmap(mmap_filename, chunk_index)

def _find_chunk_indices(self, directory):
chunk_indices = []
Expand All @@ -224,45 +225,45 @@ def _find_chunk_indices(self, directory):
chunk_indices.append(int(chunk_index))
return sorted(chunk_indices)

def _add_schema_and_mmap(
self, mmap_filename, chunk_index
):
def _add_metadata_and_mmap(self, mmap_filename, chunk_index):
# Load the memmap file
mmap, schema = self.mmap_handler.read(mmap_filename)
mmap, metadata = self.mmap_handler.read(mmap_filename)
self.memmaps.append(mmap)
self.schemas.append(schema)

# Update the mapping from data_id to chunk
for entry in schema:
self.data_id_to_chunk[entry["data_id"]] = chunk_index
for entry in metadata:
data_id = entry["data_id"]

if data_id in self.data_id_to_chunk:
# Append to the existing list for this data_id
self.data_id_to_chunk[data_id][1].append(entry)
continue
self.data_id_to_chunk[data_id] = (chunk_index, [entry])

def __getitem__(self, index):
data_id = list(self.data_id_to_chunk.keys())[index]
chunk_idx = self.data_id_to_chunk[data_id]
chunk_idx, entries = self.data_id_to_chunk[data_id]

nested_dict = {}

mmap = self.memmaps[chunk_idx]
schema = self.schemas[chunk_idx]
for entry in schema:
if entry["data_id"] == data_id:
# Read the data and put it into the nested dictionary
path = entry["path"]
offset = entry["offset"]
shape = tuple(entry["shape"])
dtype = np.dtype(entry["dtype"])

array = np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C")
tensor = torch.Tensor(array)

# Place the tensor in the correct location within the nested dictionary
current_level = nested_dict
for key in path[:-1]:
if key not in current_level:
current_level[key] = {}
current_level = current_level[key]
current_level[path[-1]] = tensor

for entry in entries:
# Read the data and put it into the nested dictionary
path = entry["path"]
offset = entry["offset"]
shape = tuple(entry["shape"])
dtype = np.dtype(entry["dtype"])

array = np.ndarray(shape, dtype, buffer=mmap, offset=offset, order="C")
tensor = torch.Tensor(array)

# Place the tensor in the correct location within the nested dictionary
current_level = nested_dict
for key in path[:-1]:
if key not in current_level:
current_level[key] = {}
current_level = current_level[key]
current_level[path[-1]] = tensor
return data_id, nested_dict

def __len__(self):
Expand Down
20 changes: 13 additions & 7 deletions analog/storage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def extract_arrays(obj, base_path=()):


class MemoryMapHandler:
def __init__(self, log_dir, mmap_dtype='uint8'):
def __init__(self, log_dir, mmap_dtype="uint8"):
"""
Args:
save_path (str): The directory of the path to write and read the binaries and the schema.
save_path (str): The directory of the path to write and read the binaries and the metadata.
mmap_dtype: The data type that will be used to save the binary into the memory map.
"""
self.save_path = log_dir
Expand All @@ -47,16 +47,20 @@ def write(self, data_buffer, filename):
mmap_filename = os.path.join(self.save_path, filename)
metadata_filename = os.path.join(self.save_path, file_root + "_metadata.json")

total_size = sum(arr.nbytes for _, d in data_buffer for _, arr in extract_arrays(d))
mmap = np.memmap(mmap_filename, dtype=self.mmap_dtype, mode="w+", shape=(total_size,))
total_size = sum(
arr.nbytes for _, d in data_buffer for _, arr in extract_arrays(d)
)
mmap = np.memmap(
mmap_filename, dtype=self.mmap_dtype, mode="w+", shape=(total_size,)
)

metadata = []
offset = 0

for data_id, nested_dict in data_buffer:
for path, arr in extract_arrays(nested_dict):
bytes = arr.nbytes
mmap[offset: offset + bytes] = arr.ravel().view(self.mmap_dtype)
mmap[offset : offset + bytes] = arr.ravel().view(self.mmap_dtype)
metadata.append(
{
"data_id": data_id,
Expand All @@ -77,7 +81,7 @@ def write(self, data_buffer, filename):

def read(self, filename):
"""
read reads the file by chunk index, it will return the data_buffer with schema.
read reads the file by chunk index, it will return the data_buffer with metadata.
Arg:
filename (str): filename for the path to mmap.
Returns:
Expand All @@ -88,7 +92,9 @@ def read(self, filename):
if file_ext == "":
filename += ".mmap"

mmap = np.memmap(os.path.join(self.save_path, filename), dtype=self.mmap_dtype, mode="r")
mmap = np.memmap(
os.path.join(self.save_path, filename), dtype=self.mmap_dtype, mode="r"
)
metadata = self.read_metafile(file_root + "_metadata.json")
return mmap, metadata

Expand Down

0 comments on commit beaf0d3

Please sign in to comment.