diff --git a/mala/datahandling/data_shuffler.py b/mala/datahandling/data_shuffler.py index 6569198d..7de400d4 100644 --- a/mala/datahandling/data_shuffler.py +++ b/mala/datahandling/data_shuffler.py @@ -310,7 +310,6 @@ def get_extent(from_, to_): f"__shuffle_openpmd: Internal indexing error extent={extent} and y={y} must have the same length. This is a bug." ) - print(f"CALLED: {x} - {y}") # Recursive bottom cases # 1. No items left if y < x: @@ -351,8 +350,6 @@ def __resolve_flattened_index_into_ndim(idx: int, ndim_extent: list[int]): strides = [current_stride] + strides current_stride *= ext # sic!, the last one gets ignored - print("Strides", strides) - def worker(inner_idx, inner_strides): if not inner_strides: if inner_idx != 0: @@ -377,11 +374,15 @@ def __load_chunk_1D(mesh, arr, offset, extent): mesh.shape, start_idx, end_idx ) ) + # print(f"\n\nLOADING {offset}\t+{extent}\tFROM {mesh.shape}") current_offset = 0 # offset within arr for nd_offset, nd_extent in blocks_to_load: - flat_extent = prod(nd_extent) + flat_extent = np.prod(nd_extent) + # print( + # f"\t{nd_offset}\t-{nd_extent}\t->[{current_offset}:{current_offset + flat_extent}]" + # ) mesh.load_chunk( - arr[current_offset : current_offset + nd_extent], + arr[current_offset : current_offset + flat_extent], nd_offset, nd_extent, ) @@ -468,23 +469,12 @@ def __shuffle_openpmd( # This gets the offset and extent of the i'th such slice. # The extent is given as in openPMD, i.e. the size of the block # (not its upper coordinate). - def from_chunk_i(i, n, dset, slice_dimension=0): + def from_chunk_i(i, n, dset): if isinstance(dset, io.Dataset): dset = dset.extent - dset = list(dset) - offset = [0 for _ in dset] - extent = dset - extent_dim_0 = dset[slice_dimension] - if extent_dim_0 % n != 0: - raise Exception( - "Dataset {} cannot be split into {} chunks on dimension {}.".format( - dset, n, slice_dimension - ) - ) - single_chunk_len = extent_dim_0 // n - offset[slice_dimension] = i * single_chunk_len - extent[slice_dimension] = single_chunk_len - return offset, extent + flat_extent = np.prod(dset) + one_slice_extent = flat_extent // n + return i * one_slice_extent, one_slice_extent import json @@ -545,9 +535,10 @@ def from_chunk_i(i, n, dset, slice_dimension=0): i, number_of_new_snapshots, extent_in ) to_chunk_offset = to_chunk_extent - to_chunk_extent = to_chunk_offset + np.prod(from_chunk_extent) + to_chunk_extent = to_chunk_offset + from_chunk_extent for dimension in range(len(mesh_in)): - mesh_in[str(dimension)].load_chunk( + DataShuffler.__load_chunk_1D( + mesh_in[str(dimension)], new_array[dimension, to_chunk_offset:to_chunk_extent], from_chunk_offset, from_chunk_extent, @@ -651,24 +642,6 @@ def shuffle_snapshots( if number_of_shuffled_snapshots is None: number_of_shuffled_snapshots = self.nr_snapshots - # Currently, the openPMD interface is not feature-complete. - if snapshot_type == "openpmd" and np.any( - np.array( - [ - snapshot.grid_dimension[0] % number_of_shuffled_snapshots - for snapshot in self.parameters.snapshot_directories_list - ] - ) - != 0 - ): - raise ValueError( - "Shuffling from OpenPMD files currently only " - "supported if first dimension of all snapshots " - "can evenly be divided by number of snapshots. " - "Please select a different number of shuffled " - "snapshots or use the numpy interface. " - ) - shuffled_gridsizes = snapshot_size_list // number_of_shuffled_snapshots if np.any( diff --git a/test/shuffling_test.py b/test/shuffling_test.py index 2ac09801..40e34544 100644 --- a/test/shuffling_test.py +++ b/test/shuffling_test.py @@ -327,7 +327,9 @@ def test_training_openpmd(self): new_loss = test_trainer.final_validation_loss assert old_loss > new_loss - def test_arbitrary_number_snapshots(self): + def worker_arbitrary_number_snapshots( + self, ext_in, ext_out, snapshot_type + ): parameters = mala.Parameters() # This ensures reproducibility of the created data sets. @@ -337,20 +339,46 @@ def test_arbitrary_number_snapshots(self): for i in range(5): data_shuffler.add_snapshot( - "Be_snapshot0.in.npy", + f"Be_snapshot0.in.{ext_in}", data_path, - "Be_snapshot0.out.npy", + f"Be_snapshot0.out.{ext_in}", data_path, + snapshot_type=snapshot_type, ) data_shuffler.shuffle_snapshots( complete_save_path=".", - save_name="Be_shuffled*", + save_name=f"Be_shuffled*{ext_out}", number_of_shuffled_snapshots=5, ) - for i in range(4): + + def test_arbitrary_number_snapshots(self): + self.worker_arbitrary_number_snapshots("npy", "", "numpy") + for i in range(5): bispectrum = np.load("Be_shuffled" + str(i) + ".in.npy") ldos = np.load("Be_shuffled" + str(i) + ".out.npy") assert not np.any(np.where(np.all(ldos == 0, axis=-1).squeeze())) assert not np.any( np.where(np.all(bispectrum == 0, axis=-1).squeeze()) ) + self.worker_arbitrary_number_snapshots("h5", ".h5", "openpmd") + import openpmd_api as opmd + + bispectrum_series = opmd.Series( + "Be_shuffled%T.in.h5", opmd.Access.read_only + ) + ldos_series = opmd.Series( + "Be_shuffled%T.out.h5", opmd.Access.read_only + ) + for i in range(5): + for name, series in [("Bispectrum", bispectrum_series), ("LDOS", ldos_series)]: + loaded_array = [ + component.load_chunk().squeeze() + for _, component in series.iterations[i] + .meshes[name] + .items() + ] + series.flush() + loaded_array = np.array(loaded_array) + assert not np.any( + np.where(np.all(loaded_array == 0, axis=0).squeeze()) + )