diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 98dc291b8..0fafb67be 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -279,7 +279,7 @@ def train_network(self): self.data.training_data_sets[0].shuffle() if self.parameters._configuration["gpu"]: - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) tsample = time.time() t0 = time.time() batchid = 0 @@ -309,7 +309,7 @@ def train_network(self): training_loss_sum += loss if batchid != 0 and (batchid + 1) % self.parameters.training_report_frequency == 0: - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) sample_time = time.time() - tsample avg_sample_time = sample_time / self.parameters.training_report_frequency avg_sample_tput = self.parameters.training_report_frequency * inputs.shape[0] / sample_time @@ -319,14 +319,14 @@ def train_network(self): min_verbosity=2) tsample = time.time() batchid += 1 - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) t1 = time.time() printout(f"training time: {t1 - t0}", min_verbosity=2) training_loss = training_loss_sum.item() / batchid # Calculate the validation loss. and output it. - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) else: batchid = 0 for loader in self.training_data_loaders: @@ -375,14 +375,14 @@ def train_network(self): self.tensor_board.close() if self.parameters._configuration["gpu"]: - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) # Mix the DataSets up (this function only does something # in the lazy loading case). if self.parameters.use_shuffling_for_samplers: self.data.mix_datasets() if self.parameters._configuration["gpu"]: - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) # If a scheduler is used, update it. if self.scheduler is not None: @@ -636,8 +636,8 @@ def __process_mini_batch(self, network, input_data, target_data): if self.parameters._configuration["gpu"]: if self.parameters.use_graphs and self.train_graph is None: printout("Capturing CUDA graph for training.", min_verbosity=2) - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) + s = torch.cuda.Stream(self.parameters._configuration["device"]) + s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"])) # Warmup for graphs with torch.cuda.stream(s): for _ in range(20): @@ -651,7 +651,7 @@ def __process_mini_batch(self, network, input_data, target_data): self.gradscaler.scale(loss).backward() else: loss.backward() - torch.cuda.current_stream().wait_stream(s) + torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s) # Create static entry point tensors to graph self.static_input_data = torch.empty_like(input_data) @@ -742,7 +742,7 @@ def __validate_network(self, network, data_set_type, validation_type): with torch.no_grad(): if self.parameters._configuration["gpu"]: report_freq = self.parameters.training_report_frequency - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) tsample = time.time() batchid = 0 for loader in data_loaders: @@ -754,15 +754,15 @@ def __validate_network(self, network, data_set_type, validation_type): if self.parameters.use_graphs and self.validation_graph is None: printout("Capturing CUDA graph for validation.", min_verbosity=2) - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) + s = torch.cuda.Stream(self.parameters._configuration["device"]) + s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"])) # Warmup for graphs with torch.cuda.stream(s): for _ in range(20): with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision): prediction = network(x) loss = network.calculate_loss(prediction, y) - torch.cuda.current_stream().wait_stream(s) + torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s) # Create static entry point tensors to graph self.static_input_validation = torch.empty_like(x) @@ -786,7 +786,7 @@ def __validate_network(self, network, data_set_type, validation_type): loss = network.calculate_loss(prediction, y) validation_loss_sum += loss if batchid != 0 and (batchid + 1) % report_freq == 0: - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) sample_time = time.time() - tsample avg_sample_time = sample_time / report_freq avg_sample_tput = report_freq * x.shape[0] / sample_time @@ -796,7 +796,7 @@ def __validate_network(self, network, data_set_type, validation_type): min_verbosity=2) tsample = time.time() batchid += 1 - torch.cuda.synchronize() + torch.cuda.synchronize(self.parameters._configuration["device"]) else: batchid = 0 for loader in data_loaders: