diff --git a/benchmarks/roman_pots/train_dense_neural_network.py b/benchmarks/roman_pots/train_dense_neural_network.py index b74dcbec..3956b82d 100644 --- a/benchmarks/roman_pots/train_dense_neural_network.py +++ b/benchmarks/roman_pots/train_dense_neural_network.py @@ -45,11 +45,11 @@ def standardize(x): standardized_tensor = (x - mean) / std return standardized_tensor, mean, std -def train_model(input_tensor, target_tensor, model): +def train_model(input_tensor, target_tensor, model, num_epochs, learning_rate): # Define the loss function and optimizer criterion = torch.nn.HuberLoss(reduction='mean', delta=1.0) - optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Create a learning rate scheduler scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,'min',patience=100,cooldown=100,factor=0.5,threshold=1e-4,verbose=True)