Skip to content

Commit

Permalink
Change hasattr check to module
Browse files Browse the repository at this point in the history
  • Loading branch information
nerkulec committed Apr 25, 2024
1 parent 65d11b1 commit e4f2eed
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
6 changes: 3 additions & 3 deletions mala/network/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def save_run(
optimizer_file = run_name + ".optimizer.pth"

self.parameters_full.save(os.path.join(save_path, params_file))
if hasattr(self.network, "save_network"):
self.network.save_network(os.path.join(save_path, model_file))
else:
if hasattr(self.network, "module"):
self.network.module.save_network(
os.path.join(save_path, model_file)
)
else:
self.network.save_network(os.path.join(save_path, model_file))
self.data.input_data_scaler.save(os.path.join(save_path, iscaler_file))
self.data.output_data_scaler.save(
os.path.join(save_path, oscaler_file)
Expand Down
48 changes: 24 additions & 24 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,12 @@ def __process_mini_batch(self, network, input_data, target_data):
loss = network.calculate_loss(
prediction, target_data
)
if hasattr(network, "calculate_loss"):
loss = network.calculate_loss(
if hasattr(network, "module"):
loss = network.module.calculate_loss(
prediction, target_data
)
else:
loss = network.module.calculate_loss(
loss = network.calculate_loss(
prediction, target_data
)

Expand Down Expand Up @@ -846,12 +846,12 @@ def __process_mini_batch(self, network, input_data, target_data):
self.static_prediction, self.static_target_data
)

if hasattr(network, "calculate_loss"):
self.static_loss = network.calculate_loss(
if hasattr(network, "module"):
self.static_loss = network.module.calculate_loss(
self.static_prediction, self.static_target_data
)
else:
self.static_loss = network.module.calculate_loss(
self.static_loss = network.calculate_loss(
self.static_prediction, self.static_target_data
)

Expand Down Expand Up @@ -879,12 +879,12 @@ def __process_mini_batch(self, network, input_data, target_data):
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("loss")
if hasattr(network, "calculate_loss"):
loss = network.calculate_loss(prediction, target_data)
else:
if hasattr(network, "module"):
loss = network.module.calculate_loss(
prediction, target_data
)
else:
loss = network.calculate_loss(prediction, target_data)
# loss
torch.cuda.nvtx.range_pop()

Expand All @@ -907,10 +907,10 @@ def __process_mini_batch(self, network, input_data, target_data):
return loss
else:
prediction = network(input_data)
if hasattr(network, "calculate_loss"):
loss = network.calculate_loss(prediction, target_data)
else:
if hasattr(network, "module"):
loss = network.module.calculate_loss(prediction, target_data)
else:
loss = network.calculate_loss(prediction, target_data)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
Expand Down Expand Up @@ -987,13 +987,13 @@ def __validate_network(self, network, data_set_type, validation_type):
):
prediction = network(x)
if hasattr(
network, "calculate_loss"
network, "module"
):
loss = network.calculate_loss(
loss = network.module.calculate_loss(
prediction, y
)
else:
loss = network.module.calculate_loss(
loss = network.calculate_loss(
prediction, y
)
torch.cuda.current_stream(
Expand Down Expand Up @@ -1023,13 +1023,13 @@ def __validate_network(self, network, data_set_type, validation_type):
self.static_prediction_validation,
self.static_target_validation,
)
if hasattr(network, "calculate_loss"):
self.static_loss_validation = network.calculate_loss(
if hasattr(network, "module"):
self.static_loss_validation = network.module.calculate_loss(
self.static_prediction_validation,
self.static_target_validation,
)
else:
self.static_loss_validation = network.module.calculate_loss(
self.static_loss_validation = network.calculate_loss(
self.static_prediction_validation,
self.static_target_validation,
)
Expand Down Expand Up @@ -1058,12 +1058,12 @@ def __validate_network(self, network, data_set_type, validation_type):
enabled=self.parameters.use_mixed_precision
):
prediction = network(x)
if hasattr(network, "calculate_loss"):
loss = network.calculate_loss(
if hasattr(network, "module"):
loss = network.module.calculate_loss(
prediction, y
)
else:
loss = network.module.calculate_loss(
loss = network.calculate_loss(
prediction, y
)
validation_loss_sum += loss
Expand Down Expand Up @@ -1098,12 +1098,12 @@ def __validate_network(self, network, data_set_type, validation_type):
y = y.to(self.parameters._configuration["device"])
prediction = network(x)

if hasattr(network, "calculate_loss"):
loss = network.calculate_loss(prediction, y)
else:
if hasattr(network, "module"):
loss = network.module.calculate_loss(
prediction, y
)
else:
loss = network.calculate_loss(prediction, y)

validation_loss_sum += loss.item()
batchid += 1
Expand Down

0 comments on commit e4f2eed

Please sign in to comment.