diff --git a/mala/network/runner.py b/mala/network/runner.py index f7e0be697..83d97fc60 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -88,12 +88,12 @@ def save_run( optimizer_file = run_name + ".optimizer.pth" self.parameters_full.save(os.path.join(save_path, params_file)) - if hasattr(self.network, "save_network"): - self.network.save_network(os.path.join(save_path, model_file)) - else: + if hasattr(self.network, "module"): self.network.module.save_network( os.path.join(save_path, model_file) ) + else: + self.network.save_network(os.path.join(save_path, model_file)) self.data.input_data_scaler.save(os.path.join(save_path, iscaler_file)) self.data.output_data_scaler.save( os.path.join(save_path, oscaler_file) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index 01632a380..3221041f6 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -810,12 +810,12 @@ def __process_mini_batch(self, network, input_data, target_data): loss = network.calculate_loss( prediction, target_data ) - if hasattr(network, "calculate_loss"): - loss = network.calculate_loss( + if hasattr(network, "module"): + loss = network.module.calculate_loss( prediction, target_data ) else: - loss = network.module.calculate_loss( + loss = network.calculate_loss( prediction, target_data ) @@ -846,12 +846,12 @@ def __process_mini_batch(self, network, input_data, target_data): self.static_prediction, self.static_target_data ) - if hasattr(network, "calculate_loss"): - self.static_loss = network.calculate_loss( + if hasattr(network, "module"): + self.static_loss = network.module.calculate_loss( self.static_prediction, self.static_target_data ) else: - self.static_loss = network.module.calculate_loss( + self.static_loss = network.calculate_loss( self.static_prediction, self.static_target_data ) @@ -879,12 +879,12 @@ def __process_mini_batch(self, network, input_data, target_data): torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push("loss") - if hasattr(network, "calculate_loss"): - loss = network.calculate_loss(prediction, target_data) - else: + if hasattr(network, "module"): loss = network.module.calculate_loss( prediction, target_data ) + else: + loss = network.calculate_loss(prediction, target_data) # loss torch.cuda.nvtx.range_pop() @@ -907,10 +907,10 @@ def __process_mini_batch(self, network, input_data, target_data): return loss else: prediction = network(input_data) - if hasattr(network, "calculate_loss"): - loss = network.calculate_loss(prediction, target_data) - else: + if hasattr(network, "module"): loss = network.module.calculate_loss(prediction, target_data) + else: + loss = network.calculate_loss(prediction, target_data) loss.backward() self.optimizer.step() self.optimizer.zero_grad() @@ -987,13 +987,13 @@ def __validate_network(self, network, data_set_type, validation_type): ): prediction = network(x) if hasattr( - network, "calculate_loss" + network, "module" ): - loss = network.calculate_loss( + loss = network.module.calculate_loss( prediction, y ) else: - loss = network.module.calculate_loss( + loss = network.calculate_loss( prediction, y ) torch.cuda.current_stream( @@ -1023,13 +1023,13 @@ def __validate_network(self, network, data_set_type, validation_type): self.static_prediction_validation, self.static_target_validation, ) - if hasattr(network, "calculate_loss"): - self.static_loss_validation = network.calculate_loss( + if hasattr(network, "module"): + self.static_loss_validation = network.module.calculate_loss( self.static_prediction_validation, self.static_target_validation, ) else: - self.static_loss_validation = network.module.calculate_loss( + self.static_loss_validation = network.calculate_loss( self.static_prediction_validation, self.static_target_validation, ) @@ -1058,12 +1058,12 @@ def __validate_network(self, network, data_set_type, validation_type): enabled=self.parameters.use_mixed_precision ): prediction = network(x) - if hasattr(network, "calculate_loss"): - loss = network.calculate_loss( + if hasattr(network, "module"): + loss = network.module.calculate_loss( prediction, y ) else: - loss = network.module.calculate_loss( + loss = network.calculate_loss( prediction, y ) validation_loss_sum += loss @@ -1098,12 +1098,12 @@ def __validate_network(self, network, data_set_type, validation_type): y = y.to(self.parameters._configuration["device"]) prediction = network(x) - if hasattr(network, "calculate_loss"): - loss = network.calculate_loss(prediction, y) - else: + if hasattr(network, "module"): loss = network.module.calculate_loss( prediction, y ) + else: + loss = network.calculate_loss(prediction, y) validation_loss_sum += loss.item() batchid += 1