Skip to content

Commit

Permalink
Lazy Loading training works now
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Jan 6, 2025
1 parent e154d2a commit 88a5edb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
51 changes: 35 additions & 16 deletions mala/datahandling/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,21 @@ def get_snapshot_calculation_output(self, snapshot_number):
snapshot_number
].calculation_output

def calculate_temporary_inputs(self, snapshot: Snapshot):
snapshot.temporary_input_file = tempfile.NamedTemporaryFile(
delete=False,
prefix=snapshot.input_npy_file.split(".")[0],
suffix=".in.npy",
dir=snapshot.input_npy_directory,
).name
tmp, grid = self.descriptor_calculator.calculate_from_json(
os.path.join(
snapshot.input_npy_directory,
snapshot.input_npy_file,
)
)
np.save(snapshot.temporary_input_file, tmp)

# Debugging
######################

Expand Down Expand Up @@ -613,22 +628,8 @@ def __load_data(self, function, data_type):
# If the input for the descriptors is actually a JSON
# file then we need to calculate the descriptors.
if snapshot.snapshot_type == "json+numpy":
snapshot.temporary_input_file = (
tempfile.NamedTemporaryFile(
delete=False,
prefix=snapshot.input_npy_file.split(".")[0],
suffix=".in.npy",
dir=snapshot.input_npy_directory,
).name
)
descriptors, grid = (
self.descriptor_calculator.calculate_from_json(
file
)
)
np.save(snapshot.temporary_input_file, descriptors)
self.calculate_temporary_inputs(snapshot)
file = snapshot.temporary_input_file

else:
file = os.path.join(
snapshot.output_npy_directory,
Expand Down Expand Up @@ -753,11 +754,20 @@ def __build_datasets(self):
self.training_data_sets[0].add_snapshot_to_dataset(
snapshot
)
# For training snapshots, temporary files (if needed) have
# already been built during parametrization, for all other
# snapshot types, this has to be done here.
if snapshot.snapshot_function == "va":
if snapshot.snapshot_type == "json+numpy":
self.calculate_temporary_inputs(snapshot)

self.validation_data_sets[0].add_snapshot_to_dataset(
snapshot
)
if snapshot.snapshot_function == "te":
if snapshot.snapshot_type == "json+numpy":
self.calculate_temporary_inputs(snapshot)

self.test_data_sets[0].add_snapshot_to_dataset(snapshot)

# I don't think we need to mix them here. We can use the standard
Expand Down Expand Up @@ -915,6 +925,12 @@ def __parametrize_scalers(self):
)
)
)
elif snapshot.snapshot_type == "json+numpy":
self.calculate_temporary_inputs(snapshot)
tmp = self.descriptor_calculator.read_from_numpy_file(
snapshot.temporary_input_file,
units=snapshot.input_units,
)
else:
raise Exception("Unknown snapshot file type.")

Expand Down Expand Up @@ -956,7 +972,10 @@ def __parametrize_scalers(self):
for snapshot in self.parameters.snapshot_directories_list:
# Data scaling is only performed on the training data sets.
if snapshot.snapshot_function == "tr":
if snapshot.snapshot_type == "numpy":
if (
snapshot.snapshot_type == "numpy"
or snapshot.snapshot_type == "json+numpy"
):
tmp = self.target_calculator.read_from_numpy_file(
os.path.join(
snapshot.output_npy_directory,
Expand Down
12 changes: 12 additions & 0 deletions mala/datahandling/lazy_load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ def get_new_data(self, file_index):
),
units=self._snapshot_list[file_index].output_units,
)
elif self._snapshot_list[file_index].snapshot_type == "json+numpy":
self.input_data = self._descriptor_calculator.read_from_numpy_file(
self._snapshot_list[file_index].temporary_input_file,
units=self._snapshot_list[file_index].input_units,
)
self.output_data = self._target_calculator.read_from_numpy_file(
os.path.join(
self._snapshot_list[file_index].output_npy_directory,
self._snapshot_list[file_index].output_npy_file,
),
units=self._snapshot_list[file_index].output_units,
)

elif self._snapshot_list[file_index].snapshot_type == "openpmd":
self.input_data = (
Expand Down

0 comments on commit 88a5edb

Please sign in to comment.