diff --git a/src/surrogate_model.py b/src/surrogate_model.py index 970e07a..3e2275b 100644 --- a/src/surrogate_model.py +++ b/src/surrogate_model.py @@ -210,7 +210,7 @@ def conv_batchnorm_relu( final_model.compile( loss=masked_mae, - optimizer="adam", + optimizer=tf.keras.optimizers.Adam() # metrics=[mape], ) return final_model