Skip to content

Commit

Permalink
Merge pull request #603 from RandomDefaultUser/fix_train_graphs
Browse files Browse the repository at this point in the history
Adapt batch size in case of GPU graphs
  • Loading branch information
RandomDefaultUser authored Nov 14, 2024
2 parents b3d117e + 7263abb commit db6ecbd
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
2 changes: 1 addition & 1 deletion mala/network/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _forward_snap_descriptors(
# Only predict if there is something to predict.
# Elsewise, we just wait at the barrier down below.
if local_data_size > 0:
optimal_batch_size = self._correct_batch_size_for_testing(
optimal_batch_size = self._correct_batch_size(
local_data_size, self.parameters.mini_batch_size
)
if optimal_batch_size != self.parameters.mini_batch_size:
Expand Down
2 changes: 1 addition & 1 deletion mala/network/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def _forward_entire_snapshot(
return actual_outputs, predicted_outputs

@staticmethod
def _correct_batch_size_for_testing(datasize, batchsize):
def _correct_batch_size(datasize, batchsize):
"""
Get the correct batch size for testing.
Expand Down
14 changes: 9 additions & 5 deletions mala/network/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ def test_snapshot(self, snapshot_number, data_type="te"):
snapshot_number,
)
return results

def get_energy_targets_and_predictions(self, snapshot_number, data_type="te"):

def get_energy_targets_and_predictions(
self, snapshot_number, data_type="te"
):
"""
Get the energy targets and predictions for a single snapshot.
Expand All @@ -145,8 +147,10 @@ def get_energy_targets_and_predictions(self, snapshot_number, data_type="te"):
actual_outputs, predicted_outputs = self.predict_targets(
snapshot_number, data_type=data_type
)

energy_metrics = [metric for metric in self.observables_to_test if "energy" in metric]

energy_metrics = [
metric for metric in self.observables_to_test if "energy" in metric
]
targets, predictions = self._calculate_energy_targets_and_predictions(
actual_outputs,
predicted_outputs,
Expand Down Expand Up @@ -219,7 +223,7 @@ def __prepare_to_test(self, snapshot_number):
break
test_snapshot += 1

optimal_batch_size = self._correct_batch_size_for_testing(
optimal_batch_size = self._correct_batch_size(
grid_size, self.parameters.mini_batch_size
)
if optimal_batch_size != self.parameters.mini_batch_size:
Expand Down
30 changes: 23 additions & 7 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,13 +435,15 @@ def train_network(self):
)
batchid += 1
total_batch_id += 1

dataset_fractions = ["validation"]
if self.parameters.validate_on_training_data:
dataset_fractions.append("train")
validation_metrics = ["ldos"]
if (epoch != 0 and
(epoch - 1) % self.parameters.validate_every_n_epochs == 0):
if (
epoch != 0
and (epoch - 1) % self.parameters.validate_every_n_epochs == 0
):
validation_metrics = self.parameters.validation_metrics
errors = self._validate_network(
dataset_fractions, validation_metrics
Expand Down Expand Up @@ -678,10 +680,8 @@ def _validate_network(self, data_set_fractions, metrics):
].grid_size
)

optimal_batch_size = (
self._correct_batch_size_for_testing(
grid_size, self.parameters.mini_batch_size
)
optimal_batch_size = self._correct_batch_size(
grid_size, self.parameters.mini_batch_size
)
number_of_batches_per_snapshot = int(
grid_size / optimal_batch_size
Expand Down Expand Up @@ -828,6 +828,22 @@ def __prepare_to_train(self, optimizer_dict):
):
do_shuffle = False

# To use graphs, our batch size has to be an even divisor of the data
# set size.
if self.parameters.use_graphs:
optimal_batch_size = self._correct_batch_size(
self.data.nr_training_data, self.parameters.mini_batch_size
)
if optimal_batch_size != self.parameters.mini_batch_size:
printout(
"Had to readjust batch size from",
self.parameters.mini_batch_size,
"to",
optimal_batch_size,
min_verbosity=0,
)
self.parameters.mini_batch_size = optimal_batch_size

# Prepare data loaders.(look into mini-batch size)
if isinstance(self.data.training_data_sets[0], FastTensorDataset):
# Not shuffling in loader.
Expand Down

0 comments on commit db6ecbd

Please sign in to comment.