From cc794f66233e3439da176d8656f8e3696d9a9811 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20P=C3=B6schel?= Date: Wed, 21 Feb 2024 18:59:42 +0100 Subject: [PATCH] Remove hardcoded iteration number from data shuffler --- mala/datahandling/data_shuffler.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mala/datahandling/data_shuffler.py b/mala/datahandling/data_shuffler.py index 0a655c00f..1a72e8549 100644 --- a/mala/datahandling/data_shuffler.py +++ b/mala/datahandling/data_shuffler.py @@ -94,6 +94,10 @@ def __shuffle_numpy(self, number_of_new_snapshots, shuffle_dimensions, mmap_mode="r")) # Do the actual shuffling. + target_name_openpmd = os.path.join(target_save_path, + save_name.replace("*", "%T")) + descriptor_name_openpmd = os.path.join(descriptor_save_path, + save_name.replace("*", "%T")) for i in range(0, number_of_new_snapshots): new_descriptors = np.zeros((int(np.prod(shuffle_dimensions)), self.input_dimension), @@ -150,13 +154,13 @@ def __shuffle_numpy(self, number_of_new_snapshots, shuffle_dimensions, self.target_calculator.grid_dimensions = \ list(shuffle_dimensions) self.descriptor_calculator.\ - write_to_openpmd_file(descriptor_name+".in."+file_ending, + write_to_openpmd_file(descriptor_name_openpmd+".in."+file_ending, new_descriptors, additional_attributes={"global_shuffling_seed": self.parameters.shuffling_seed, "local_shuffling_seed": i*self.parameters.shuffling_seed}, internal_iteration_number=i) self.target_calculator.\ - write_to_openpmd_file(target_name+".out."+file_ending, + write_to_openpmd_file(target_name_openpmd+".out."+file_ending, array=new_targets, additional_attributes={"global_shuffling_seed": self.parameters.shuffling_seed, "local_shuffling_seed": i*self.parameters.shuffling_seed}, @@ -265,11 +269,11 @@ def from_chunk_i(i, n, dset, slice_dimension=0): import json # Do the actual shuffling. + name_prefix = os.path.join(dot.save_path, + save_name.replace("*", "%T")) for i in range(my_items_start, my_items_end): # We check above that in the non-numpy case, OpenPMD will work. dot.calculator.grid_dimensions = list(shuffle_dimensions) - name_prefix = os.path.join(dot.save_path, - save_name.replace("*", str(i))) # do NOT open with MPI shuffled_snapshot_series = io.Series( name_prefix + dot.name_infix + file_ending,