Skip to content

Commit

Permalink
speed op rosenborck test
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Jan 4, 2025
1 parent f2852bc commit 35e7242
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/algorithms/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@

@pytest.mark.parametrize('enforce_bounds_on_x1', [True, False])
@pytest.mark.parametrize(
('optimizer', 'optimizer_kwargs'), [(adam, {'lr': 0.02, 'max_iter': 10000}), (lbfgs, {'lr': 1.0})]
('optimizer', 'optimizer_kwargs'),
[(adam, {'lr': 0.02, 'max_iter': 2000, 'betas': (0.8, 0.999)}), (lbfgs, {'lr': 1.0, 'max_iter': 20})],
)
def test_optimizers_rosenbrock(optimizer, enforce_bounds_on_x1, optimizer_kwargs):
# use Rosenbrock function as test case with 2D test data
a, b = 1.0, 100.0
rosen_brock = Rosenbrock(a, b)

# initial point of optimization
x1 = torch.tensor([a / 3.14])
x2 = torch.tensor([3.14])
x1 = torch.tensor([a / 1.23])
x2 = torch.tensor([1.23])
x1.grad = torch.tensor([2.78])
x2.grad = torch.tensor([-1.0])
params_init = [x1, x2]
Expand All @@ -45,7 +46,7 @@ def test_optimizers_rosenbrock(optimizer, enforce_bounds_on_x1, optimizer_kwargs
params_result = constrain_op(*params_result)

# obtained solution should match analytical
torch.testing.assert_close(torch.tensor(params_result), analytical_solution)
torch.testing.assert_close(torch.tensor(params_result), analytical_solution, rtol=1e-4, atol=0)

for p, before, grad_before in zip(params_init, params_init_before, params_init_grad_before, strict=True):
assert p == before, 'the initial parameter should not have changed during optimization'
Expand Down

0 comments on commit 35e7242

Please sign in to comment.