Skip to content

Commit

Permalink
Remove hardcoded iteration number from data shuffler
Browse files Browse the repository at this point in the history
  • Loading branch information
franzpoeschel committed Feb 21, 2024
1 parent 59f67ff commit cc794f6
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mala/datahandling/data_shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def __shuffle_numpy(self, number_of_new_snapshots, shuffle_dimensions,
mmap_mode="r"))

# Do the actual shuffling.
target_name_openpmd = os.path.join(target_save_path,
save_name.replace("*", "%T"))
descriptor_name_openpmd = os.path.join(descriptor_save_path,
save_name.replace("*", "%T"))
for i in range(0, number_of_new_snapshots):
new_descriptors = np.zeros((int(np.prod(shuffle_dimensions)),
self.input_dimension),
Expand Down Expand Up @@ -150,13 +154,13 @@ def __shuffle_numpy(self, number_of_new_snapshots, shuffle_dimensions,
self.target_calculator.grid_dimensions = \
list(shuffle_dimensions)
self.descriptor_calculator.\
write_to_openpmd_file(descriptor_name+".in."+file_ending,
write_to_openpmd_file(descriptor_name_openpmd+".in."+file_ending,
new_descriptors,
additional_attributes={"global_shuffling_seed": self.parameters.shuffling_seed,
"local_shuffling_seed": i*self.parameters.shuffling_seed},
internal_iteration_number=i)
self.target_calculator.\
write_to_openpmd_file(target_name+".out."+file_ending,
write_to_openpmd_file(target_name_openpmd+".out."+file_ending,
array=new_targets,
additional_attributes={"global_shuffling_seed": self.parameters.shuffling_seed,
"local_shuffling_seed": i*self.parameters.shuffling_seed},
Expand Down Expand Up @@ -265,11 +269,11 @@ def from_chunk_i(i, n, dset, slice_dimension=0):
import json

# Do the actual shuffling.
name_prefix = os.path.join(dot.save_path,
save_name.replace("*", "%T"))
for i in range(my_items_start, my_items_end):
# We check above that in the non-numpy case, OpenPMD will work.
dot.calculator.grid_dimensions = list(shuffle_dimensions)
name_prefix = os.path.join(dot.save_path,
save_name.replace("*", str(i)))
# do NOT open with MPI
shuffled_snapshot_series = io.Series(
name_prefix + dot.name_infix + file_ending,
Expand Down

0 comments on commit cc794f6

Please sign in to comment.