diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index 7f6af2bc..c4394858 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -20,5 +20,5 @@ def test_sample_and_fit(cls, kwargs, inputs): params, param_props = model.initialize(key1) states, emissions = model.sample(params, key2, num_timesteps=NUM_TIMESTEPS, inputs=inputs) fitted_params, lps = model.fit_em(params, param_props, emissions, inputs=inputs, num_iters=3) - assert monotonically_increasing(lps) # fails on TPU + assert monotonically_increasing(lps) fitted_params, lps = model.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=3)