Skip to content

Commit

Permalink
Merge pull request #1110 from carlosgmartin:improve_tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707682384
  • Loading branch information
OptaxDev committed Dec 18, 2024
2 parents 5ded749 + 70a9241 commit ee883b4
Show file tree
Hide file tree
Showing 39 changed files with 730 additions and 719 deletions.
17 changes: 7 additions & 10 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,22 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.

# pylint: disable=g-bad-import-order
# pylint: disable=g-import-not-at-top
import inspect
import os
import sys


def _add_annotations_import(path):
"""Appends a future annotations import to the file at the given path."""
with open(path) as f:
with open(path, encoding='utf-8') as f:
contents = f.read()
if contents.startswith('from __future__ import annotations'):
# If we run sphinx multiple times then we will append the future import
# multiple times too.
return

assert contents.startswith('#'), (path, contents.split('\n')[0])
with open(path, 'w') as f:
with open(path, 'w', encoding='utf-8') as f:
# NOTE: This is subtle and not unit tested, we're prefixing the first line
# in each Python file with this future import. It is important to prefix
# not insert a newline such that source code locations are accurate (we link
Expand All @@ -64,8 +62,10 @@ def _recursive_add_annotations_import():
sys.path.insert(0, os.path.abspath('../'))
sys.path.append(os.path.abspath('ext'))

# pylint: disable=g-import-not-at-top
import optax
from sphinxcontrib import katex
# pylint: enable=g-import-not-at-top

# -- Project information -----------------------------------------------------

Expand Down Expand Up @@ -251,13 +251,10 @@ def linkcode_resolve(domain, info):
return None

# TODO(slebedev): support tags after we release an initial version.
path = os.path.relpath(filename, start=os.path.dirname(optax.__file__))
return (
'https://github.com/google-deepmind/optax/tree/main/optax/%s#L%d#L%d'
% (
os.path.relpath(filename, start=os.path.dirname(optax.__file__)),
lineno,
lineno + len(source) - 1,
)
'https://github.com/google-deepmind/optax/tree/main/optax/'
f'{path}#L{lineno}#L{lineno + len(source) - 1}'
)


Expand Down
12 changes: 9 additions & 3 deletions docs/ext/coverage_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ class OptaxCoverageCheck(builders.Builder):
def get_outdated_docs(self) -> str:
return "coverage_check"

def write(self, *ignored: Any) -> None:
def write(self, *ignored: Any) -> None: # pylint: disable=overridden-final-method
pass

def finish(self) -> None:
documented_objects = frozenset(self.env.domaindata["py"]["objects"])
documented_objects = frozenset(self.env.domaindata["py"]["objects"]) # pytype: disable=attribute-error
undocumented_objects = set(optax_public_symbols()) - documented_objects
if undocumented_objects:
undocumented_objects = tuple(sorted(undocumented_objects))
Expand All @@ -78,7 +78,13 @@ def finish(self) -> None:
"forget to add an entry to `api.rst`?\n"
f"Undocumented symbols: {undocumented_objects}")

def get_target_uri(self, docname, typ=None):
raise NotImplementedError

def write_doc(self, docname, doctree):
raise NotImplementedError


def setup(app: application.Sphinx) -> Mapping[str, Any]:
app.add_builder(OptaxCoverageCheck)
return dict(version=optax.__version__, parallel_read_safe=True)
return {"version": optax.__version__, "parallel_read_safe": True}
23 changes: 12 additions & 11 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# pylint: disable=wrong-import-position
# pylint: disable=g-importing-member

import typing as _typing

from optax import assignment
from optax import contrib
from optax import losses
Expand Down Expand Up @@ -141,7 +143,14 @@
from optax._src.utils import scale_gradient
from optax._src.utils import value_and_grad_from_state

# TODO(mtthss): remove tree_utils aliases after updates.
# TODO(mtthss): remove contrib aliases from flat namespace once users updated.
# Deprecated modules
from optax.contrib import differentially_private_aggregate as _deprecated_differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState as _deprecated_DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd as _deprecated_dpsgd


# TODO(mtthss): remove aliases after updates.
adaptive_grad_clip = transforms.adaptive_grad_clip
AdaptiveGradClipState = EmptyState
clip = transforms.clip
Expand Down Expand Up @@ -192,7 +201,7 @@
update_moment = tree_utils.tree_update_moment
update_moment_per_elem_norm = tree_utils.tree_update_moment_per_elem_norm

# TODO(mtthss): remove schedules alises from flat namespaces after user updates.
# TODO(mtthss): remove schedules aliases from flat namespaces after user updates
constant_schedule = schedules.constant_schedule
cosine_decay_schedule = schedules.cosine_decay_schedule
cosine_onecycle_schedule = schedules.cosine_onecycle_schedule
Expand Down Expand Up @@ -235,13 +244,6 @@
squared_error = losses.squared_error
sigmoid_focal_loss = losses.sigmoid_focal_loss

# pylint: disable=g-import-not-at-top
# TODO(mtthss): remove contrib aliases from flat namespace once users updated.
# Deprecated modules
from optax.contrib import differentially_private_aggregate as _deprecated_differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState as _deprecated_DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd as _deprecated_dpsgd

_deprecations = {
# Added Apr 2024
"differentially_private_aggregate": (
Expand All @@ -268,9 +270,8 @@
_deprecated_dpsgd,
),
}
# pylint: disable=g-import-not-at-top
# pylint: disable=g-bad-import-order
import typing as _typing

if _typing.TYPE_CHECKING:
# pylint: disable=reimported
from optax.contrib import differentially_private_aggregate
Expand Down
171 changes: 88 additions & 83 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,44 +46,50 @@


_OPTIMIZERS_UNDER_TEST = (
dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1e-3, momentum=0.9)),
dict(opt_name='adadelta', opt_kwargs=dict(learning_rate=0.1)),
dict(opt_name='adafactor', opt_kwargs=dict(learning_rate=5e-3)),
dict(opt_name='adagrad', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='adam', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='adan', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)),
dict(
opt_name='lion',
opt_kwargs=dict(learning_rate=1e-2, weight_decay=1e-4),
),
dict(opt_name='nadam', opt_kwargs=dict(learning_rate=1e-2)),
dict(opt_name='nadamw', opt_kwargs=dict(learning_rate=1e-2)),
dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)),
dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1e-3)),
dict(
opt_name='optimistic_gradient_descent',
opt_kwargs=dict(learning_rate=2e-3, alpha=0.7, beta=0.1),
),
dict(
opt_name='optimistic_adam',
opt_kwargs=dict(learning_rate=2e-3),
),
dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=5e-3)),
dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=5e-3, momentum=0.9)),
dict(opt_name='sign_sgd', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='fromage', opt_kwargs=dict(learning_rate=5e-3)),
dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1e-2)),
dict(opt_name='radam', opt_kwargs=dict(learning_rate=5e-3)),
dict(opt_name='rprop', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='polyak_sgd', opt_kwargs=dict(max_learning_rate=1.0)),
{'opt_name': 'sgd', 'opt_kwargs': {'learning_rate': 1e-3, 'momentum': 0.9}},
{'opt_name': 'adadelta', 'opt_kwargs': {'learning_rate': 0.1}},
{'opt_name': 'adafactor', 'opt_kwargs': {'learning_rate': 5e-3}},
{'opt_name': 'adagrad', 'opt_kwargs': {'learning_rate': 1.0}},
{'opt_name': 'adam', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'adamw', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'adamax', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'adamaxw', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'adan', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'amsgrad', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'lars', 'opt_kwargs': {'learning_rate': 1.0}},
{'opt_name': 'lamb', 'opt_kwargs': {'learning_rate': 1e-3}},
{
'opt_name': 'lion',
'opt_kwargs': {'learning_rate': 1e-2, 'weight_decay': 1e-4},
},
{'opt_name': 'nadam', 'opt_kwargs': {'learning_rate': 1e-2}},
{'opt_name': 'nadamw', 'opt_kwargs': {'learning_rate': 1e-2}},
{
'opt_name': 'noisy_sgd',
'opt_kwargs': {'learning_rate': 1e-3, 'eta': 1e-4},
},
{'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1e-3}},
{
'opt_name': 'optimistic_gradient_descent',
'opt_kwargs': {'learning_rate': 2e-3, 'alpha': 0.7, 'beta': 0.1},
},
{
'opt_name': 'optimistic_adam',
'opt_kwargs': {'learning_rate': 2e-3},
},
{'opt_name': 'rmsprop', 'opt_kwargs': {'learning_rate': 5e-3}},
{
'opt_name': 'rmsprop',
'opt_kwargs': {'learning_rate': 5e-3, 'momentum': 0.9},
},
{'opt_name': 'sign_sgd', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'fromage', 'opt_kwargs': {'learning_rate': 5e-3}},
{'opt_name': 'adabelief', 'opt_kwargs': {'learning_rate': 1e-2}},
{'opt_name': 'radam', 'opt_kwargs': {'learning_rate': 5e-3}},
{'opt_name': 'rprop', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'sm3', 'opt_kwargs': {'learning_rate': 1.0}},
{'opt_name': 'yogi', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'polyak_sgd', 'opt_kwargs': {'max_learning_rate': 1.0}},
)


Expand Down Expand Up @@ -373,14 +379,13 @@ def _materialize_approx_inv_hessian(
rhos = jnp.roll(rhos, -k, axis=0)

id_mat = jnp.eye(d, d)
# pylint: disable=invalid-name
P = id_mat
p = id_mat
safe_dot = lambda x, y: jnp.dot(x, y, precision=jax.lax.Precision.HIGHEST)

for j in range(m):
V = id_mat - rhos[j] * jnp.outer(dus[j], dws[j])
P = safe_dot(V.T, safe_dot(P, V)) + rhos[j] * jnp.outer(dws[j], dws[j])
# pylint: enable=invalid-name
precond_mat = P
v = id_mat - rhos[j] * jnp.outer(dus[j], dws[j])
p = safe_dot(v.T, safe_dot(p, v)) + rhos[j] * jnp.outer(dws[j], dws[j])
precond_mat = p
return precond_mat


Expand Down Expand Up @@ -524,44 +529,44 @@ def zakharov(x, xnp):
answer = sum1 + sum2**2 + sum2**4
return answer

problems = dict(
rosenbrock=dict(
fun=lambda x: rosenbrock(x, jnp),
numpy_fun=lambda x: rosenbrock(x, np),
init=np.zeros(2),
minimum=0.0,
minimizer=np.ones(2),
),
himmelblau=dict(
fun=himmelblau,
numpy_fun=himmelblau,
init=np.ones(2),
minimum=0.0,
problems = {
'rosenbrock': {
'fun': lambda x: rosenbrock(x, jnp),
'numpy_fun': lambda x: rosenbrock(x, np),
'init': np.zeros(2),
'minimum': 0.0,
'minimizer': np.ones(2),
},
'himmelblau': {
'fun': himmelblau,
'numpy_fun': himmelblau,
'init': np.ones(2),
'minimum': 0.0,
# himmelblau has actually multiple minimizers, we simply consider one.
minimizer=np.array([3.0, 2.0]),
),
matyas=dict(
fun=matyas,
numpy_fun=matyas,
init=np.ones(2) * 6.0,
minimum=0.0,
minimizer=np.zeros(2),
),
eggholder=dict(
fun=lambda x: eggholder(x, jnp),
numpy_fun=lambda x: eggholder(x, np),
init=np.ones(2) * 6.0,
minimum=-959.6407,
minimizer=np.array([512.0, 404.22319]),
),
zakharov=dict(
fun=lambda x: zakharov(x, jnp),
numpy_fun=lambda x: zakharov(x, np),
init=np.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e3]),
minimum=0.0,
minimizer=np.zeros(6),
),
)
'minimizer': np.array([3.0, 2.0]),
},
'matyas': {
'fun': matyas,
'numpy_fun': matyas,
'init': np.ones(2) * 6.0,
'minimum': 0.0,
'minimizer': np.zeros(2),
},
'eggholder': {
'fun': lambda x: eggholder(x, jnp),
'numpy_fun': lambda x: eggholder(x, np),
'init': np.ones(2) * 6.0,
'minimum': -959.6407,
'minimizer': np.array([512.0, 404.22319]),
},
'zakharov': {
'fun': lambda x: zakharov(x, jnp),
'numpy_fun': lambda x: zakharov(x, np),
'init': np.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e3]),
'minimum': 0.0,
'minimizer': np.zeros(6),
},
}
return problems[name]


Expand Down Expand Up @@ -633,11 +638,11 @@ def test_preconditioning_by_lbfgs_on_trees(self, idx: int):
)

flat_dws = [
flatten_util.ravel_pytree(jax.tree.map(lambda dw: dw[i], dws))[0] # pylint: disable=cell-var-from-loop
flatten_util.ravel_pytree(jax.tree.map(lambda dw, i=i: dw[i], dws))[0]
for i in range(m)
]
flat_dus = [
flatten_util.ravel_pytree(jax.tree.map(lambda du: du[i], dus))[0] # pylint: disable=cell-var-from-loop
flatten_util.ravel_pytree(jax.tree.map(lambda du, i=i: du[i], dus))[0]
for i in range(m)
]
flat_dws, flat_dus = jnp.stack(flat_dws), jnp.stack(flat_dus)
Expand Down
2 changes: 0 additions & 2 deletions optax/_src/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,3 @@ def _getattr(name):
raise AttributeError(f"module {module!r} has no attribute {name!r}")

return _getattr


13 changes: 6 additions & 7 deletions optax/_src/factorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,12 @@ def _init(param):
v_col=jnp.zeros(vc_shape, dtype=dtype),
v=jnp.zeros((1,), dtype=dtype),
)
else:
return _UpdateResult(
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros((1,), dtype=dtype),
v_col=jnp.zeros((1,), dtype=dtype),
v=jnp.zeros(param.shape, dtype=dtype),
)
return _UpdateResult(
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros((1,), dtype=dtype),
v_col=jnp.zeros((1,), dtype=dtype),
v=jnp.zeros(param.shape, dtype=dtype),
)

return _to_state(jnp.zeros([], jnp.int32), jax.tree.map(_init, params))

Expand Down
Loading

0 comments on commit ee883b4

Please sign in to comment.