diff --git a/_modules/mala/network/trainer.html b/_modules/mala/network/trainer.html index 1356fe1c..4b2f5934 100644 --- a/_modules/mala/network/trainer.html +++ b/_modules/mala/network/trainer.html @@ -769,47 +769,207 @@
)
loader_id += 1
else:
- with torch.no_grad():
- for snapshot_number in trange(
- offset_snapshots,
- number_of_snapshots + offset_snapshots,
- desc="Validation",
- disable=self.parameters_full.verbosity < 2,
- ):
- # Get optimal batch size and number of batches per snapshotss
- grid_size = (
- self.data.parameters.snapshot_directories_list[
- snapshot_number
- ].grid_size
- )
+ # If only the LDOS is in the validation metrics (as is the
+ # case for, e.g., distributed network trainings), we can
+ # use a faster (or at least better parallelizing) code
- optimal_batch_size = self._correct_batch_size(
- grid_size, self.parameters.mini_batch_size
- )
- number_of_batches_per_snapshot = int(
- grid_size / optimal_batch_size
+ if (
+ len(self.parameters.validation_metrics) == 1
+ and self.parameters.validation_metrics[0] == "ldos"
+ ):
+
+ errors[data_set_type]["ldos"] = (
+ self.__calculate_validation_error_ldos_only(
+ data_loaders
)
+ )
- actual_outputs, predicted_outputs = (
- self._forward_entire_snapshot(
+ else:
+ with torch.no_grad():
+ for snapshot_number in trange(
+ offset_snapshots,
+ number_of_snapshots + offset_snapshots,
+ desc="Validation",
+ disable=self.parameters_full.verbosity < 2,
+ ):
+ # Get optimal batch size and number of batches per snapshotss
+ grid_size = (
+ self.data.parameters.snapshot_directories_list[
+ snapshot_number
+ ].grid_size
+ )
+
+ optimal_batch_size = self._correct_batch_size(
+ grid_size, self.parameters.mini_batch_size
+ )
+ number_of_batches_per_snapshot = int(
+ grid_size / optimal_batch_size
+ )
+
+ actual_outputs, predicted_outputs = (
+ self._forward_entire_snapshot(
+ snapshot_number,
+ data_sets[0],
+ data_set_type[0:2],
+ number_of_batches_per_snapshot,
+ optimal_batch_size,
+ )
+ )
+ calculated_errors = self._calculate_errors(
+ actual_outputs,
+ predicted_outputs,
+ metrics,
snapshot_number,
- data_sets[0],
- data_set_type[0:2],
- number_of_batches_per_snapshot,
- optimal_batch_size,
)
+ for metric in metrics:
+ errors[data_set_type][metric].append(
+ calculated_errors[metric]
+ )
+ return errors
+
+ def __calculate_validation_error_ldos_only(self, data_loaders):
+ validation_loss_sum = torch.zeros(
+ 1, device=self.parameters._configuration["device"]
+ )
+ with torch.no_grad():
+ if self.parameters._configuration["gpu"]:
+ report_freq = self.parameters.training_log_interval
+ torch.cuda.synchronize(
+ self.parameters._configuration["device"]
+ )
+ tsample = time.time()
+ batchid = 0
+ for loader in data_loaders:
+ for x, y in loader:
+ x = x.to(
+ self.parameters._configuration["device"],
+ non_blocking=True,
)
- calculated_errors = self._calculate_errors(
- actual_outputs,
- predicted_outputs,
- metrics,
- snapshot_number,
+ y = y.to(
+ self.parameters._configuration["device"],
+ non_blocking=True,
)
- for metric in metrics:
- errors[data_set_type][metric].append(
- calculated_errors[metric]
+
+ if (
+ self.parameters.use_graphs
+ and self._validation_graph is None
+ ):
+ printout(
+ "Capturing CUDA graph for validation.",
+ min_verbosity=2,
)
- return errors
+ s = torch.cuda.Stream(
+ self.parameters._configuration["device"]
+ )
+ s.wait_stream(
+ torch.cuda.current_stream(
+ self.parameters._configuration["device"]
+ )
+ )
+ # Warmup for graphs
+ with torch.cuda.stream(s):
+ for _ in range(20):
+ with torch.cuda.amp.autocast(
+ enabled=self.parameters.use_mixed_precision
+ ):
+ prediction = self.network(x)
+ if self.parameters_full.use_ddp:
+ loss = self.network.module.calculate_loss(
+ prediction, y
+ )
+ else:
+ loss = self.network.calculate_loss(
+ prediction, y
+ )
+ torch.cuda.current_stream(
+ self.parameters._configuration["device"]
+ ).wait_stream(s)
+
+ # Create static entry point tensors to graph
+ self.static_input_validation = torch.empty_like(x)
+ self.static_target_validation = torch.empty_like(y)
+
+ # Capture graph
+ self._validation_graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(self._validation_graph):
+ with torch.cuda.amp.autocast(
+ enabled=self.parameters.use_mixed_precision
+ ):
+ self.static_prediction_validation = (
+ self.network(
+ self.static_input_validation
+ )
+ )
+ if self.parameters_full.use_ddp:
+ self.static_loss_validation = self.network.module.calculate_loss(
+ self.static_prediction_validation,
+ self.static_target_validation,
+ )
+ else:
+ self.static_loss_validation = self.network.calculate_loss(
+ self.static_prediction_validation,
+ self.static_target_validation,
+ )
+
+ if self._validation_graph:
+ self.static_input_validation.copy_(x)
+ self.static_target_validation.copy_(y)
+ self._validation_graph.replay()
+ validation_loss_sum += self.static_loss_validation
+ else:
+ with torch.cuda.amp.autocast(
+ enabled=self.parameters.use_mixed_precision
+ ):
+ prediction = self.network(x)
+ if self.parameters_full.use_ddp:
+ loss = self.network.module.calculate_loss(
+ prediction, y
+ )
+ else:
+ loss = self.network.calculate_loss(
+ prediction, y
+ )
+ validation_loss_sum += loss
+ if batchid != 0 and (batchid + 1) % report_freq == 0:
+ torch.cuda.synchronize(
+ self.parameters._configuration["device"]
+ )
+ sample_time = time.time() - tsample
+ avg_sample_time = sample_time / report_freq
+ avg_sample_tput = (
+ report_freq * x.shape[0] / sample_time
+ )
+ printout(
+ f"batch {batchid + 1}, " # /{total_samples}, "
+ f"validation avg time: {avg_sample_time} "
+ f"validation avg throughput: {avg_sample_tput}",
+ min_verbosity=2,
+ )
+ tsample = time.time()
+ batchid += 1
+ torch.cuda.synchronize(
+ self.parameters._configuration["device"]
+ )
+ else:
+ batchid = 0
+ for loader in data_loaders:
+ for x, y in loader:
+ x = x.to(self.parameters._configuration["device"])
+ y = y.to(self.parameters._configuration["device"])
+ prediction = self.network(x)
+ if self.parameters_full.use_ddp:
+ validation_loss_sum += (
+ self.network.module.calculate_loss(
+ prediction, y
+ ).item()
+ )
+ else:
+ validation_loss_sum += self.network.calculate_loss(
+ prediction, y
+ ).item()
+ batchid += 1
+
+ return validation_loss_sum.item() / batchid
def __prepare_to_train(self, optimizer_dict):
"""Prepare everything for training."""
diff --git a/objects.inv b/objects.inv
index 29dd60c1..be43d64c 100644
Binary files a/objects.inv and b/objects.inv differ