Skip to content

Commit

Permalink
Remove Flax optimisers equivalence tests since this module has been r…
Browse files Browse the repository at this point in the history
…eplaced with Optax.

PiperOrigin-RevId: 561147610
  • Loading branch information
hbq1 authored and OptaxDev committed Aug 29, 2023
1 parent 1fa1fe0 commit 4c04ca5
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 84 deletions.
83 changes: 0 additions & 83 deletions optax/_src/equivalence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from absl.testing import parameterized

import chex
from flax import optim
from jax.example_libraries import optimizers
import jax.numpy as jnp

Expand Down Expand Up @@ -90,87 +89,5 @@ def step(updates, state):
chex.assert_trees_all_close(jax_params, optax_params, rtol=rtol)


class FlaxOptimizersEquivalenceTest(chex.TestCase):

def setUp(self):
super().setUp()
self.init_params = (
jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.]))
self.per_step_updates = (
jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.]))

@parameterized.named_parameters(
('sgd',
alias.sgd(LR),
optim.GradientDescent(LR)),
('momentum',
alias.sgd(LR, momentum=0.9),
optim.Momentum(LR, beta=0.9)), # Different names.
('nesterov_momentum',
alias.sgd(LR, momentum=0.9, nesterov=True),
optim.Momentum(LR, beta=0.9, nesterov=True)),
('rmsprop',
alias.rmsprop(LR),
optim.RMSProp(LR)),
('centered_rmsprop',
alias.rmsprop(LR, centered=True),
optim.RMSProp(LR, centered=True)),
('adam',
alias.adam(LR),
optim.Adam(LR)),
('adam_w',
alias.adamw(LR, weight_decay=1e-4),
optim.Adam(LR, weight_decay=1e-4)), # Different name.
('adagrad',
alias.adagrad(LR, initial_accumulator_value=0.), # Different default!
optim.Adagrad(LR)),
('lamb',
alias.lamb(LR),
optim.LAMB(LR)),
('lars',
alias.lars(
LR, weight_decay=.5, trust_coefficient=0.003,
momentum=0.9, eps=1e-3),
optim.LARS(
LR, weight_decay=.5, trust_coefficient=0.003,
beta=0.9, eps=1e-3)),
('adafactor',
alias.adafactor(
learning_rate=LR / 10.,
factored=True,
multiply_by_parameter_scale=True,
clipping_threshold=1.0,
decay_rate=0.8,
min_dim_size_to_factor=2),
optim.Adafactor(
learning_rate=LR / 10.,
factored=True,
multiply_by_parameter_scale=True,
clipping_threshold=1.0,
decay_rate=0.8,
min_dim_size_to_factor=2)),
)
def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer):

# flax/optim
flax_params = self.init_params
flax_optimizer = flax_optimizer.create(flax_params)
for _ in range(STEPS):
flax_optimizer = flax_optimizer.apply_gradient(
self.per_step_updates)
flax_params = flax_optimizer.target

# optax
optax_params = self.init_params
state = optax_optimizer.init(optax_params)
for _ in range(STEPS):
updates, state = optax_optimizer.update(
self.per_step_updates, state, optax_params)
optax_params = update.apply_updates(optax_params, updates)

# Check equivalence.
chex.assert_trees_all_close(flax_params, optax_params, rtol=2e-4)


if __name__ == '__main__':
absltest.main()
1 change: 0 additions & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
dm-haiku>=0.0.3
dm-tree>=0.1.7
flax==0.5.3

0 comments on commit 4c04ca5

Please sign in to comment.