diff --git a/mala/network/trainer.py b/mala/network/trainer.py index bc3cfc544..0fafb67be 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -636,8 +636,8 @@ def __process_mini_batch(self, network, input_data, target_data): if self.parameters._configuration["gpu"]: if self.parameters.use_graphs and self.train_graph is None: printout("Capturing CUDA graph for training.", min_verbosity=2) - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) + s = torch.cuda.Stream(self.parameters._configuration["device"]) + s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"])) # Warmup for graphs with torch.cuda.stream(s): for _ in range(20): @@ -651,7 +651,7 @@ def __process_mini_batch(self, network, input_data, target_data): self.gradscaler.scale(loss).backward() else: loss.backward() - torch.cuda.current_stream().wait_stream(s) + torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s) # Create static entry point tensors to graph self.static_input_data = torch.empty_like(input_data) @@ -754,15 +754,15 @@ def __validate_network(self, network, data_set_type, validation_type): if self.parameters.use_graphs and self.validation_graph is None: printout("Capturing CUDA graph for validation.", min_verbosity=2) - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) + s = torch.cuda.Stream(self.parameters._configuration["device"]) + s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"])) # Warmup for graphs with torch.cuda.stream(s): for _ in range(20): with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision): prediction = network(x) loss = network.calculate_loss(prediction, y) - torch.cuda.current_stream().wait_stream(s) + torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s) # Create static entry point tensors to graph self.static_input_validation = torch.empty_like(x)