Skip to content

Commit

Permalink
Merge pull request #506 from franzpoeschel/fix-hardcoded-iteration-fi…
Browse files Browse the repository at this point in the history
…lename

Remove hardcoded iteration number from data shuffler
  • Loading branch information
RandomDefaultUser authored May 30, 2024
2 parents 82c62bf + 845c021 commit 9f538c7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
3 changes: 2 additions & 1 deletion mala/common/physical_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ def write_to_openpmd_file(
import openpmd_api as io

if isinstance(path, str):
file_name = os.path.basename(path)
directory, file_name = os.path.split(path)
path = os.path.join(directory, file_name.replace("*", "%T"))
file_ending = file_name.split(".")[-1]
if file_name == file_ending:
path += ".h5"
Expand Down
14 changes: 9 additions & 5 deletions mala/datahandling/data_shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def __shuffle_numpy(
)

# Do the actual shuffling.
target_name_openpmd = os.path.join(target_save_path,
save_name.replace("*", "%T"))
descriptor_name_openpmd = os.path.join(descriptor_save_path,
save_name.replace("*", "%T"))
for i in range(0, number_of_new_snapshots):
new_descriptors = np.zeros(
(int(np.prod(shuffle_dimensions)), self.input_dimension),
Expand Down Expand Up @@ -209,7 +213,7 @@ def __shuffle_numpy(
shuffle_dimensions
)
self.descriptor_calculator.write_to_openpmd_file(
descriptor_name + ".in." + file_ending,
descriptor_name_openpmd + ".in." + file_ending,
new_descriptors,
additional_attributes={
"global_shuffling_seed": self.parameters.shuffling_seed,
Expand All @@ -219,7 +223,7 @@ def __shuffle_numpy(
internal_iteration_number=i,
)
self.target_calculator.write_to_openpmd_file(
target_name + ".out." + file_ending,
target_name_openpmd + ".out." + file_ending,
array=new_targets,
additional_attributes={
"global_shuffling_seed": self.parameters.shuffling_seed,
Expand Down Expand Up @@ -359,12 +363,12 @@ def from_chunk_i(i, n, dset, slice_dimension=0):
import json

# Do the actual shuffling.
name_prefix = os.path.join(
dot.save_path, save_name.replace("*", "%T")
)
for i in range(my_items_start, my_items_end):
# We check above that in the non-numpy case, OpenPMD will work.
dot.calculator.grid_dimensions = list(shuffle_dimensions)
name_prefix = os.path.join(
dot.save_path, save_name.replace("*", str(i))
)
# do NOT open with MPI
shuffled_snapshot_series = io.Series(
name_prefix + dot.name_infix + file_ending,
Expand Down

0 comments on commit 9f538c7

Please sign in to comment.