diff --git a/mala/common/physical_data.py b/mala/common/physical_data.py index 26bb12675..e756e96d1 100644 --- a/mala/common/physical_data.py +++ b/mala/common/physical_data.py @@ -418,7 +418,8 @@ def write_to_openpmd_file( import openpmd_api as io if isinstance(path, str): - file_name = os.path.basename(path) + directory, file_name = os.path.split(path) + path = os.path.join(directory, file_name.replace("*", "%T")) file_ending = file_name.split(".")[-1] if file_name == file_ending: path += ".h5" diff --git a/mala/datahandling/data_shuffler.py b/mala/datahandling/data_shuffler.py index 935847276..62d6e11a3 100644 --- a/mala/datahandling/data_shuffler.py +++ b/mala/datahandling/data_shuffler.py @@ -131,6 +131,10 @@ def __shuffle_numpy( ) # Do the actual shuffling. + target_name_openpmd = os.path.join(target_save_path, + save_name.replace("*", "%T")) + descriptor_name_openpmd = os.path.join(descriptor_save_path, + save_name.replace("*", "%T")) for i in range(0, number_of_new_snapshots): new_descriptors = np.zeros( (int(np.prod(shuffle_dimensions)), self.input_dimension), @@ -209,7 +213,7 @@ def __shuffle_numpy( shuffle_dimensions ) self.descriptor_calculator.write_to_openpmd_file( - descriptor_name + ".in." + file_ending, + descriptor_name_openpmd + ".in." + file_ending, new_descriptors, additional_attributes={ "global_shuffling_seed": self.parameters.shuffling_seed, @@ -219,7 +223,7 @@ def __shuffle_numpy( internal_iteration_number=i, ) self.target_calculator.write_to_openpmd_file( - target_name + ".out." + file_ending, + target_name_openpmd + ".out." + file_ending, array=new_targets, additional_attributes={ "global_shuffling_seed": self.parameters.shuffling_seed, @@ -359,12 +363,12 @@ def from_chunk_i(i, n, dset, slice_dimension=0): import json # Do the actual shuffling. + name_prefix = os.path.join( + dot.save_path, save_name.replace("*", "%T") + ) for i in range(my_items_start, my_items_end): # We check above that in the non-numpy case, OpenPMD will work. dot.calculator.grid_dimensions = list(shuffle_dimensions) - name_prefix = os.path.join( - dot.save_path, save_name.replace("*", str(i)) - ) # do NOT open with MPI shuffled_snapshot_series = io.Series( name_prefix + dot.name_infix + file_ending,