Skip to content

Commit

Permalink
Checkpointing works now as well
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Jan 7, 2025
1 parent 88a5edb commit 0e4ddfe
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 29 deletions.
81 changes: 53 additions & 28 deletions mala/datahandling/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch.utils.data import TensorDataset

from mala.common.parallelizer import printout, barrier
from mala.common.parallelizer import printout, barrier, get_rank
from mala.common.parameters import Parameters, DEFAULT_NP_DATA_DTYPE
from mala.datahandling.data_handler_base import DataHandlerBase
from mala.datahandling.data_scaler import DataScaler
Expand Down Expand Up @@ -170,18 +170,20 @@ def clear_data(self):
self.output_data_scaler.reset()
super(DataHandler, self).clear_data()

def delete_temporary_data(self):
def _delete_temporary_data(self):
"""
Delete temporary data files.
These may have been created during a training or testing process
when using atomic positions for on-the-fly calculation of descriptors
rather than precomputed data files.
"""
for snapshot in self.parameters.snapshot_directories_list:
if snapshot.temporary_input_file is not None:
if os.path.isfile(snapshot.temporary_input_file):
os.remove(snapshot.temporary_input_file)
if get_rank() == 0:
for snapshot in self.parameters.snapshot_directories_list:
if snapshot.temporary_input_file is not None:
if os.path.isfile(snapshot.temporary_input_file):
os.remove(snapshot.temporary_input_file)
barrier()

# Preparing data
######################
Expand Down Expand Up @@ -241,16 +243,8 @@ def prepare_data(self, reparametrize_scaler=True):
printout("Initializing the data scalers.", min_verbosity=1)
self.__parametrize_scalers()
printout("Data scalers initialized.", min_verbosity=0)
elif (
self.parameters.use_lazy_loading is False
and self.nr_training_data != 0
):
printout(
"Data scalers already initilized, loading data to RAM.",
min_verbosity=0,
)
self.__load_data("training", "inputs")
self.__load_data("training", "outputs")
elif self.nr_training_data != 0:
self.__parametrized_load_training_data()

# Build Datasets.
printout("Build datasets.", min_verbosity=1)
Expand All @@ -267,6 +261,11 @@ def prepare_data(self, reparametrize_scaler=True):
# allows for parallel I/O.
barrier()

# In the RAM case, there is no reason not to delete all temporary files
# now.
if self.parameters.use_lazy_loading is False:
self._delete_temporary_data()

def prepare_for_testing(self):
"""
Prepare DataHandler for usage within Tester class.
Expand Down Expand Up @@ -351,19 +350,24 @@ def get_snapshot_calculation_output(self, snapshot_number):
].calculation_output

def calculate_temporary_inputs(self, snapshot: Snapshot):
snapshot.temporary_input_file = tempfile.NamedTemporaryFile(
delete=False,
prefix=snapshot.input_npy_file.split(".")[0],
suffix=".in.npy",
dir=snapshot.input_npy_directory,
).name
tmp, grid = self.descriptor_calculator.calculate_from_json(
os.path.join(
snapshot.input_npy_directory,
snapshot.input_npy_file,
if snapshot.temporary_input_file is not None:
if not os.path.isfile(snapshot.temporary_input_file):
snapshot.temporary_input_file = None

if snapshot.temporary_input_file is None:
snapshot.temporary_input_file = tempfile.NamedTemporaryFile(
delete=False,
prefix=snapshot.input_npy_file.split(".")[0],
suffix=".in.npy",
dir=snapshot.input_npy_directory,
).name
tmp, grid = self.descriptor_calculator.calculate_from_json(
os.path.join(
snapshot.input_npy_directory,
snapshot.input_npy_file,
)
)
)
np.save(snapshot.temporary_input_file, tmp)
np.save(snapshot.temporary_input_file, tmp)

# Debugging
######################
Expand Down Expand Up @@ -1014,6 +1018,27 @@ def __parametrize_scalers(self):

printout("Output scaler parametrized.", min_verbosity=1)

def __parametrized_load_training_data(self):
if self.parameters.use_lazy_loading:
printout(
"Data scalers already initilized, preparing input data.",
min_verbosity=0,
)
for snapshot in self.parameters.snapshot_directories_list:
# Data scaling is only performed on the training data sets.
if (
snapshot.snapshot_function == "tr"
and snapshot.snapshot_type == "json+numpy"
):
self.calculate_temporary_inputs(snapshot)
else:
printout(
"Data scalers already initilized, loading data to RAM.",
min_verbosity=0,
)
self.__load_data("training", "inputs")
self.__load_data("training", "outputs")

def __raw_numpy_to_converted_numpy(
self, numpy_array, data_type="in", units=None
):
Expand Down
2 changes: 1 addition & 1 deletion mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def train_network(self):
self.final_validation_loss = vloss

# Cleaning up temporary data files.
self.data.delete_temporary_data()
self.data._delete_temporary_data()

# Clean-up for pre-fetching lazy loading.
if self.data.parameters.use_lazy_loading_prefetch:
Expand Down

0 comments on commit 0e4ddfe

Please sign in to comment.