diff --git a/_modules/mala/network/trainer.html b/_modules/mala/network/trainer.html index 1356fe1c..4b2f5934 100644 --- a/_modules/mala/network/trainer.html +++ b/_modules/mala/network/trainer.html @@ -769,47 +769,207 @@

Source code for mala.network.trainer

                         )
                     loader_id += 1
             else:
-                with torch.no_grad():
-                    for snapshot_number in trange(
-                        offset_snapshots,
-                        number_of_snapshots + offset_snapshots,
-                        desc="Validation",
-                        disable=self.parameters_full.verbosity < 2,
-                    ):
-                        # Get optimal batch size and number of batches per snapshotss
-                        grid_size = (
-                            self.data.parameters.snapshot_directories_list[
-                                snapshot_number
-                            ].grid_size
-                        )
+                # If only the LDOS is in the validation metrics (as is the
+                # case for, e.g., distributed network trainings), we can
+                # use a faster (or at least better parallelizing) code
 
-                        optimal_batch_size = self._correct_batch_size(
-                            grid_size, self.parameters.mini_batch_size
-                        )
-                        number_of_batches_per_snapshot = int(
-                            grid_size / optimal_batch_size
+                if (
+                    len(self.parameters.validation_metrics) == 1
+                    and self.parameters.validation_metrics[0] == "ldos"
+                ):
+
+                    errors[data_set_type]["ldos"] = (
+                        self.__calculate_validation_error_ldos_only(
+                            data_loaders
                         )
+                    )
 
-                        actual_outputs, predicted_outputs = (
-                            self._forward_entire_snapshot(
+                else:
+                    with torch.no_grad():
+                        for snapshot_number in trange(
+                            offset_snapshots,
+                            number_of_snapshots + offset_snapshots,
+                            desc="Validation",
+                            disable=self.parameters_full.verbosity < 2,
+                        ):
+                            # Get optimal batch size and number of batches per snapshotss
+                            grid_size = (
+                                self.data.parameters.snapshot_directories_list[
+                                    snapshot_number
+                                ].grid_size
+                            )
+
+                            optimal_batch_size = self._correct_batch_size(
+                                grid_size, self.parameters.mini_batch_size
+                            )
+                            number_of_batches_per_snapshot = int(
+                                grid_size / optimal_batch_size
+                            )
+
+                            actual_outputs, predicted_outputs = (
+                                self._forward_entire_snapshot(
+                                    snapshot_number,
+                                    data_sets[0],
+                                    data_set_type[0:2],
+                                    number_of_batches_per_snapshot,
+                                    optimal_batch_size,
+                                )
+                            )
+                            calculated_errors = self._calculate_errors(
+                                actual_outputs,
+                                predicted_outputs,
+                                metrics,
                                 snapshot_number,
-                                data_sets[0],
-                                data_set_type[0:2],
-                                number_of_batches_per_snapshot,
-                                optimal_batch_size,
                             )
+                            for metric in metrics:
+                                errors[data_set_type][metric].append(
+                                    calculated_errors[metric]
+                                )
+        return errors
+
+    def __calculate_validation_error_ldos_only(self, data_loaders):
+        validation_loss_sum = torch.zeros(
+            1, device=self.parameters._configuration["device"]
+        )
+        with torch.no_grad():
+            if self.parameters._configuration["gpu"]:
+                report_freq = self.parameters.training_log_interval
+                torch.cuda.synchronize(
+                    self.parameters._configuration["device"]
+                )
+                tsample = time.time()
+                batchid = 0
+                for loader in data_loaders:
+                    for x, y in loader:
+                        x = x.to(
+                            self.parameters._configuration["device"],
+                            non_blocking=True,
                         )
-                        calculated_errors = self._calculate_errors(
-                            actual_outputs,
-                            predicted_outputs,
-                            metrics,
-                            snapshot_number,
+                        y = y.to(
+                            self.parameters._configuration["device"],
+                            non_blocking=True,
                         )
-                        for metric in metrics:
-                            errors[data_set_type][metric].append(
-                                calculated_errors[metric]
+
+                        if (
+                            self.parameters.use_graphs
+                            and self._validation_graph is None
+                        ):
+                            printout(
+                                "Capturing CUDA graph for validation.",
+                                min_verbosity=2,
                             )
-        return errors
+                            s = torch.cuda.Stream(
+                                self.parameters._configuration["device"]
+                            )
+                            s.wait_stream(
+                                torch.cuda.current_stream(
+                                    self.parameters._configuration["device"]
+                                )
+                            )
+                            # Warmup for graphs
+                            with torch.cuda.stream(s):
+                                for _ in range(20):
+                                    with torch.cuda.amp.autocast(
+                                        enabled=self.parameters.use_mixed_precision
+                                    ):
+                                        prediction = self.network(x)
+                                        if self.parameters_full.use_ddp:
+                                            loss = self.network.module.calculate_loss(
+                                                prediction, y
+                                            )
+                                        else:
+                                            loss = self.network.calculate_loss(
+                                                prediction, y
+                                            )
+                            torch.cuda.current_stream(
+                                self.parameters._configuration["device"]
+                            ).wait_stream(s)
+
+                            # Create static entry point tensors to graph
+                            self.static_input_validation = torch.empty_like(x)
+                            self.static_target_validation = torch.empty_like(y)
+
+                            # Capture graph
+                            self._validation_graph = torch.cuda.CUDAGraph()
+                            with torch.cuda.graph(self._validation_graph):
+                                with torch.cuda.amp.autocast(
+                                    enabled=self.parameters.use_mixed_precision
+                                ):
+                                    self.static_prediction_validation = (
+                                        self.network(
+                                            self.static_input_validation
+                                        )
+                                    )
+                                    if self.parameters_full.use_ddp:
+                                        self.static_loss_validation = self.network.module.calculate_loss(
+                                            self.static_prediction_validation,
+                                            self.static_target_validation,
+                                        )
+                                    else:
+                                        self.static_loss_validation = self.network.calculate_loss(
+                                            self.static_prediction_validation,
+                                            self.static_target_validation,
+                                        )
+
+                        if self._validation_graph:
+                            self.static_input_validation.copy_(x)
+                            self.static_target_validation.copy_(y)
+                            self._validation_graph.replay()
+                            validation_loss_sum += self.static_loss_validation
+                        else:
+                            with torch.cuda.amp.autocast(
+                                enabled=self.parameters.use_mixed_precision
+                            ):
+                                prediction = self.network(x)
+                                if self.parameters_full.use_ddp:
+                                    loss = self.network.module.calculate_loss(
+                                        prediction, y
+                                    )
+                                else:
+                                    loss = self.network.calculate_loss(
+                                        prediction, y
+                                    )
+                                validation_loss_sum += loss
+                        if batchid != 0 and (batchid + 1) % report_freq == 0:
+                            torch.cuda.synchronize(
+                                self.parameters._configuration["device"]
+                            )
+                            sample_time = time.time() - tsample
+                            avg_sample_time = sample_time / report_freq
+                            avg_sample_tput = (
+                                report_freq * x.shape[0] / sample_time
+                            )
+                            printout(
+                                f"batch {batchid + 1}, "  # /{total_samples}, "
+                                f"validation avg time: {avg_sample_time} "
+                                f"validation avg throughput: {avg_sample_tput}",
+                                min_verbosity=2,
+                            )
+                            tsample = time.time()
+                        batchid += 1
+                torch.cuda.synchronize(
+                    self.parameters._configuration["device"]
+                )
+            else:
+                batchid = 0
+                for loader in data_loaders:
+                    for x, y in loader:
+                        x = x.to(self.parameters._configuration["device"])
+                        y = y.to(self.parameters._configuration["device"])
+                        prediction = self.network(x)
+                        if self.parameters_full.use_ddp:
+                            validation_loss_sum += (
+                                self.network.module.calculate_loss(
+                                    prediction, y
+                                ).item()
+                            )
+                        else:
+                            validation_loss_sum += self.network.calculate_loss(
+                                prediction, y
+                            ).item()
+                        batchid += 1
+
+        return validation_loss_sum.item() / batchid
 
     def __prepare_to_train(self, optimizer_dict):
         """Prepare everything for training."""
diff --git a/objects.inv b/objects.inv
index 29dd60c1..be43d64c 100644
Binary files a/objects.inv and b/objects.inv differ