diff --git a/_modules/mala/network/trainer.html b/_modules/mala/network/trainer.html index 9f8a0ad6c..f07fefab3 100644 --- a/_modules/mala/network/trainer.html +++ b/_modules/mala/network/trainer.html @@ -355,7 +355,7 @@

Source code for mala.network.trainer

                 self.data.training_data_sets[0].shuffle()
 
             if self.parameters._configuration["gpu"]:
-                torch.cuda.synchronize()
+                torch.cuda.synchronize(self.parameters._configuration["device"])
                 tsample = time.time()
                 t0 = time.time()
                 batchid = 0
@@ -385,7 +385,7 @@ 

Source code for mala.network.trainer

                         training_loss_sum += loss
 
                         if batchid != 0 and (batchid + 1) % self.parameters.training_report_frequency == 0:
-                            torch.cuda.synchronize()
+                            torch.cuda.synchronize(self.parameters._configuration["device"])
                             sample_time = time.time() - tsample
                             avg_sample_time = sample_time / self.parameters.training_report_frequency
                             avg_sample_tput = self.parameters.training_report_frequency * inputs.shape[0] / sample_time
@@ -395,14 +395,14 @@ 

Source code for mala.network.trainer

                                      min_verbosity=2)
                             tsample = time.time()
                         batchid += 1
-                torch.cuda.synchronize()
+                torch.cuda.synchronize(self.parameters._configuration["device"])
                 t1 = time.time()
                 printout(f"training time: {t1 - t0}", min_verbosity=2)
 
                 training_loss = training_loss_sum.item() / batchid
 
                 # Calculate the validation loss. and output it.
-                torch.cuda.synchronize()
+                torch.cuda.synchronize(self.parameters._configuration["device"])
             else:
                 batchid = 0
                 for loader in self.training_data_loaders:
@@ -451,14 +451,14 @@ 

Source code for mala.network.trainer

                 self.tensor_board.close()
 
             if self.parameters._configuration["gpu"]:
-                torch.cuda.synchronize()
+                torch.cuda.synchronize(self.parameters._configuration["device"])
 
             # Mix the DataSets up (this function only does something
             # in the lazy loading case).
             if self.parameters.use_shuffling_for_samplers:
                 self.data.mix_datasets()
             if self.parameters._configuration["gpu"]:
-                torch.cuda.synchronize()
+                torch.cuda.synchronize(self.parameters._configuration["device"])
 
             # If a scheduler is used, update it.
             if self.scheduler is not None:
@@ -712,8 +712,8 @@ 

Source code for mala.network.trainer

         if self.parameters._configuration["gpu"]:
             if self.parameters.use_graphs and self.train_graph is None:
                 printout("Capturing CUDA graph for training.", min_verbosity=2)
-                s = torch.cuda.Stream()
-                s.wait_stream(torch.cuda.current_stream())
+                s = torch.cuda.Stream(self.parameters._configuration["device"])
+                s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
                 # Warmup for graphs
                 with torch.cuda.stream(s):
                     for _ in range(20):
@@ -727,7 +727,7 @@ 

Source code for mala.network.trainer

                             self.gradscaler.scale(loss).backward()
                         else:
                             loss.backward()
-                torch.cuda.current_stream().wait_stream(s)
+                torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)
 
                 # Create static entry point tensors to graph
                 self.static_input_data = torch.empty_like(input_data)
@@ -818,7 +818,7 @@ 

Source code for mala.network.trainer

             with torch.no_grad():
                 if self.parameters._configuration["gpu"]:
                     report_freq = self.parameters.training_report_frequency
-                    torch.cuda.synchronize()
+                    torch.cuda.synchronize(self.parameters._configuration["device"])
                     tsample = time.time()
                     batchid = 0
                     for loader in data_loaders:
@@ -830,15 +830,15 @@ 

Source code for mala.network.trainer

 
                             if self.parameters.use_graphs and self.validation_graph is None:
                                 printout("Capturing CUDA graph for validation.", min_verbosity=2)
-                                s = torch.cuda.Stream()
-                                s.wait_stream(torch.cuda.current_stream())
+                                s = torch.cuda.Stream(self.parameters._configuration["device"])
+                                s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
                                 # Warmup for graphs
                                 with torch.cuda.stream(s):
                                     for _ in range(20):
                                         with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision):
                                             prediction = network(x)
                                             loss = network.calculate_loss(prediction, y)
-                                torch.cuda.current_stream().wait_stream(s)
+                                torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)
 
                                 # Create static entry point tensors to graph
                                 self.static_input_validation = torch.empty_like(x)
@@ -862,7 +862,7 @@ 

Source code for mala.network.trainer

                                     loss = network.calculate_loss(prediction, y)
                                     validation_loss_sum += loss
                             if batchid != 0 and (batchid + 1) % report_freq == 0:
-                                torch.cuda.synchronize()
+                                torch.cuda.synchronize(self.parameters._configuration["device"])
                                 sample_time = time.time() - tsample
                                 avg_sample_time = sample_time / report_freq
                                 avg_sample_tput = report_freq * x.shape[0] / sample_time
@@ -872,7 +872,7 @@ 

Source code for mala.network.trainer

                                          min_verbosity=2)
                                 tsample = time.time()
                             batchid += 1
-                    torch.cuda.synchronize()
+                    torch.cuda.synchronize(self.parameters._configuration["device"])
                 else:
                     batchid = 0
                     for loader in data_loaders:
diff --git a/objects.inv b/objects.inv
index 1245cdc0e..0f7efe040 100644
Binary files a/objects.inv and b/objects.inv differ