Skip to content

Commit

Permalink
Seems to work?
Browse files Browse the repository at this point in the history
  • Loading branch information
franzpoeschel committed Dec 5, 2024
1 parent 3612774 commit 3a48b47
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 45 deletions.
53 changes: 13 additions & 40 deletions mala/datahandling/data_shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def get_extent(from_, to_):
f"__shuffle_openpmd: Internal indexing error extent={extent} and y={y} must have the same length. This is a bug."
)

print(f"CALLED: {x} - {y}")
# Recursive bottom cases
# 1. No items left
if y < x:
Expand Down Expand Up @@ -351,8 +350,6 @@ def __resolve_flattened_index_into_ndim(idx: int, ndim_extent: list[int]):
strides = [current_stride] + strides
current_stride *= ext # sic!, the last one gets ignored

print("Strides", strides)

def worker(inner_idx, inner_strides):
if not inner_strides:
if inner_idx != 0:
Expand All @@ -377,11 +374,15 @@ def __load_chunk_1D(mesh, arr, offset, extent):
mesh.shape, start_idx, end_idx
)
)
# print(f"\n\nLOADING {offset}\t+{extent}\tFROM {mesh.shape}")
current_offset = 0 # offset within arr
for nd_offset, nd_extent in blocks_to_load:
flat_extent = prod(nd_extent)
flat_extent = np.prod(nd_extent)
# print(
# f"\t{nd_offset}\t-{nd_extent}\t->[{current_offset}:{current_offset + flat_extent}]"
# )
mesh.load_chunk(
arr[current_offset : current_offset + nd_extent],
arr[current_offset : current_offset + flat_extent],
nd_offset,
nd_extent,
)
Expand Down Expand Up @@ -468,23 +469,12 @@ def __shuffle_openpmd(
# This gets the offset and extent of the i'th such slice.
# The extent is given as in openPMD, i.e. the size of the block
# (not its upper coordinate).
def from_chunk_i(i, n, dset, slice_dimension=0):
def from_chunk_i(i, n, dset):
if isinstance(dset, io.Dataset):
dset = dset.extent
dset = list(dset)
offset = [0 for _ in dset]
extent = dset
extent_dim_0 = dset[slice_dimension]
if extent_dim_0 % n != 0:
raise Exception(
"Dataset {} cannot be split into {} chunks on dimension {}.".format(
dset, n, slice_dimension
)
)
single_chunk_len = extent_dim_0 // n
offset[slice_dimension] = i * single_chunk_len
extent[slice_dimension] = single_chunk_len
return offset, extent
flat_extent = np.prod(dset)
one_slice_extent = flat_extent // n
return i * one_slice_extent, one_slice_extent

import json

Expand Down Expand Up @@ -545,9 +535,10 @@ def from_chunk_i(i, n, dset, slice_dimension=0):
i, number_of_new_snapshots, extent_in
)
to_chunk_offset = to_chunk_extent
to_chunk_extent = to_chunk_offset + np.prod(from_chunk_extent)
to_chunk_extent = to_chunk_offset + from_chunk_extent
for dimension in range(len(mesh_in)):
mesh_in[str(dimension)].load_chunk(
DataShuffler.__load_chunk_1D(
mesh_in[str(dimension)],
new_array[dimension, to_chunk_offset:to_chunk_extent],
from_chunk_offset,
from_chunk_extent,
Expand Down Expand Up @@ -651,24 +642,6 @@ def shuffle_snapshots(
if number_of_shuffled_snapshots is None:
number_of_shuffled_snapshots = self.nr_snapshots

# Currently, the openPMD interface is not feature-complete.
if snapshot_type == "openpmd" and np.any(
np.array(
[
snapshot.grid_dimension[0] % number_of_shuffled_snapshots
for snapshot in self.parameters.snapshot_directories_list
]
)
!= 0
):
raise ValueError(
"Shuffling from OpenPMD files currently only "
"supported if first dimension of all snapshots "
"can evenly be divided by number of snapshots. "
"Please select a different number of shuffled "
"snapshots or use the numpy interface. "
)

shuffled_gridsizes = snapshot_size_list // number_of_shuffled_snapshots

if np.any(
Expand Down
38 changes: 33 additions & 5 deletions test/shuffling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def test_training_openpmd(self):
new_loss = test_trainer.final_validation_loss
assert old_loss > new_loss

def test_arbitrary_number_snapshots(self):
def worker_arbitrary_number_snapshots(
self, ext_in, ext_out, snapshot_type
):
parameters = mala.Parameters()

# This ensures reproducibility of the created data sets.
Expand All @@ -337,20 +339,46 @@ def test_arbitrary_number_snapshots(self):

for i in range(5):
data_shuffler.add_snapshot(
"Be_snapshot0.in.npy",
f"Be_snapshot0.in.{ext_in}",
data_path,
"Be_snapshot0.out.npy",
f"Be_snapshot0.out.{ext_in}",
data_path,
snapshot_type=snapshot_type,
)
data_shuffler.shuffle_snapshots(
complete_save_path=".",
save_name="Be_shuffled*",
save_name=f"Be_shuffled*{ext_out}",
number_of_shuffled_snapshots=5,
)
for i in range(4):

def test_arbitrary_number_snapshots(self):
self.worker_arbitrary_number_snapshots("npy", "", "numpy")
for i in range(5):
bispectrum = np.load("Be_shuffled" + str(i) + ".in.npy")
ldos = np.load("Be_shuffled" + str(i) + ".out.npy")
assert not np.any(np.where(np.all(ldos == 0, axis=-1).squeeze()))
assert not np.any(
np.where(np.all(bispectrum == 0, axis=-1).squeeze())
)
self.worker_arbitrary_number_snapshots("h5", ".h5", "openpmd")
import openpmd_api as opmd

bispectrum_series = opmd.Series(
"Be_shuffled%T.in.h5", opmd.Access.read_only
)
ldos_series = opmd.Series(
"Be_shuffled%T.out.h5", opmd.Access.read_only
)
for i in range(5):
for name, series in [("Bispectrum", bispectrum_series), ("LDOS", ldos_series)]:
loaded_array = [
component.load_chunk().squeeze()
for _, component in series.iterations[i]
.meshes[name]
.items()
]
series.flush()
loaded_array = np.array(loaded_array)
assert not np.any(
np.where(np.all(loaded_array == 0, axis=0).squeeze())
)

0 comments on commit 3a48b47

Please sign in to comment.