From 5e4a792c656a7a314b3d2f4723a7ffee23083e94 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Thu, 17 Oct 2024 19:46:35 -0400 Subject: [PATCH 1/3] Changes to test.sh and pyproject.toml. --- .flake8 | 14 ++ .gitignore | 3 + .pylintrc | 400 ------------------------------------------------- docs/conf.py | 19 +-- pyproject.toml | 69 +++++++++ test.sh | 82 +++++----- 6 files changed, 135 insertions(+), 452 deletions(-) create mode 100644 .flake8 delete mode 100644 .pylintrc diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..7e9b93fd0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,14 @@ +[flake8] +select = + E9, + F63, + F7, + F82, + E225, + E251, +show-source = true +statistics = true +exclude = + build, + dist, + test_venv diff --git a/.gitignore b/.gitignore index 1c77aaa10..623ca9b8f 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ docs/sg_execution_times.rst .idea .vscode +# Running tests +test_venv/ +optax-*.whl diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index b26aeee4f..000000000 --- a/.pylintrc +++ /dev/null @@ -1,400 +0,0 @@ -# This Pylint rcfile contains a best-effort configuration to uphold the -# best-practices and style described in the Google Python style guide: -# https://google.github.io/styleguide/pyguide.html -# -# Its canonical open-source location is: -# https://google.github.io/styleguide/pylintrc - -[MAIN] - -# Files or directories to be skipped. They should be base names, not paths. -ignore=third_party - -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. -ignore-patterns= - -# Pickle collected data for later comparisons. -persistent=no - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Use multiple processes to speed up Pylint. -jobs=4 - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -confidence= - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -#enable= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=R, - abstract-method, - apply-builtin, - arguments-differ, - attribute-defined-outside-init, - backtick, - bad-option-value, - basestring-builtin, - buffer-builtin, - c-extension-no-member, - consider-using-enumerate, - cmp-builtin, - cmp-method, - coerce-builtin, - coerce-method, - delslice-method, - div-method, - eq-without-hash, - execfile-builtin, - file-builtin, - filter-builtin-not-iterating, - fixme, - getslice-method, - global-statement, - hex-method, - idiv-method, - implicit-str-concat, - import-error, - import-self, - import-star-module-level, - input-builtin, - intern-builtin, - invalid-str-codec, - locally-disabled, - long-builtin, - long-suffix, - map-builtin-not-iterating, - misplaced-comparison-constant, - missing-function-docstring, - metaclass-assignment, - next-method-called, - next-method-defined, - no-absolute-import, - no-init, # added - no-member, - no-name-in-module, - no-self-use, - nonzero-method, - oct-method, - old-division, - old-ne-operator, - old-octal-literal, - old-raise-syntax, - parameter-unpacking, - print-statement, - raising-string, - range-builtin-not-iterating, - raw_input-builtin, - rdiv-method, - reduce-builtin, - relative-import, - reload-builtin, - round-builtin, - setslice-method, - signature-differs, - standarderror-builtin, - suppressed-message, - sys-max-int, - trailing-newlines, - unichr-builtin, - unicode-builtin, - unnecessary-pass, - unpacking-in-except, - useless-else-on-loop, - useless-suppression, - using-cmp-argument, - wrong-import-order, - xrange-builtin, - zip-builtin-not-iterating, - - -[REPORTS] - -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages -reports=no - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - - -[BASIC] - -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl - -# Regular expression matching correct function names -function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ - -# Regular expression matching correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct constant names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ - -# Regular expression matching correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class attribute names -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct inline iteration names -inlinevar-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ - -# Regular expression matching correct module names -module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ - -# Regular expression matching correct method names -method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=12 - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=80 - -# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt -# lines made too long by directives to pytype. - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=(?x)( - ^\s*(\#\ )??$| - ^\s*(from\s+\S+\s+)?import\s+.+$) - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=yes - -# Maximum number of lines in a module -max-module-lines=99999 - -# String used as indentation unit. The internal Google style guide mandates 2 -# spaces. Google's externaly-published style guide says 4, consistent with -# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google -# projects (like TensorFlow). -indent-string=' ' - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=TODO - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=yes - - -[VARIABLES] - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_,_cb - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging,absl.logging,tensorflow.io.logging - - -[SIMILARITIES] - -# Minimum lines number of a similarity. -min-similarity-lines=4 - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - - -[SPELLING] - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[IMPORTS] - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub, - TERMIOS, - Bastion, - rexec, - sets - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant, absl - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls, - class_ - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs -disable=unnecessary-lambda-assignment,no-value-for-parameter,use-dict-literal diff --git a/docs/conf.py b/docs/conf.py index ed57f8484..7c485d2c0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,10 +30,13 @@ import os import sys +import optax +from sphinxcontrib import katex + def _add_annotations_import(path): """Appends a future annotations import to the file at the given path.""" - with open(path) as f: + with open(path, encoding='utf-8') as f: contents = f.read() if contents.startswith('from __future__ import annotations'): # If we run sphinx multiple times then we will append the future import @@ -41,7 +44,7 @@ def _add_annotations_import(path): return assert contents.startswith('#'), (path, contents.split('\n')[0]) - with open(path, 'w') as f: + with open(path, 'w', encoding='utf-8') as f: # NOTE: This is subtle and not unit tested, we're prefixing the first line # in each Python file with this future import. It is important to prefix # not insert a newline such that source code locations are accurate (we link @@ -64,9 +67,6 @@ def _recursive_add_annotations_import(): sys.path.insert(0, os.path.abspath('../')) sys.path.append(os.path.abspath('ext')) -import optax -from sphinxcontrib import katex - # -- Project information ----------------------------------------------------- project = 'Optax' @@ -251,13 +251,10 @@ def linkcode_resolve(domain, info): return None # TODO(slebedev): support tags after we release an initial version. + path = os.path.relpath(filename, start=os.path.dirname(optax.__file__)) return ( - 'https://github.com/google-deepmind/optax/tree/main/optax/%s#L%d#L%d' - % ( - os.path.relpath(filename, start=os.path.dirname(optax.__file__)), - lineno, - lineno + len(source) - 1, - ) + 'https://github.com/google-deepmind/optax/tree/main/optax/' + f'{path}#L{lineno}#L{lineno + len(source) - 1}' ) diff --git a/pyproject.toml b/pyproject.toml index 7bccb6db3..40ddd826b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,75 @@ dp-accounting = [ "scipy>=1.7.1" ] +dev = [ + "flake8", + "pylint", + "pytype", + "pytest", + "pytest-xdist", + "typing_extensions", +] + [tool.setuptools.packages.find] include = ["README.md", "LICENSE"] exclude = ["*_test.py"] + +[tool.pytype] +inputs = ["optax"] +disable = ["import-error"] +keep_going = true + +[tool.pytest.ini_options] +addopts = "--pyargs optax --numprocesses auto" + +[tool.pylint.main] +ignore-patterns = [ + ".*.pyi", +] +ignore = [ + "test_venv", +] + +[tool.pylint.messages_control] +disable = [ + "bad-indentation", + "invalid-name", + "missing-class-docstring", + "missing-function-docstring", + "unnecessary-lambda-assignment", + "no-member", + "use-dict-literal", + "too-many-locals", + "too-many-arguments", + "too-many-positional-arguments", + "unused-argument", + "unused-import", + "line-too-long", + "no-value-for-parameter", + "wrong-import-order", + "no-else-return", + "too-many-lines", + "missing-module-docstring", + "too-few-public-methods", + "abstract-method", + "used-before-assignment", + "too-many-statements", + "protected-access", + "too-many-instance-attributes", + "inconsistent-return-statements", + "trailing-newlines", + "consider-using-generator", + "no-else-raise", + "unknown-option-value", + "not-callable", + "consider-merging-isinstance", + "duplicate-code", + "import-error", + "overridden-final-method", +] + +[tool.pylint.similarities] +ignore-comments = true +ignore-docstrings = true +ignore-imports = true +min-similarity-lines = 4 diff --git a/test.sh b/test.sh index e6d258535..1134ada3a 100755 --- a/test.sh +++ b/test.sh @@ -13,79 +13,78 @@ # limitations under the License. # ============================================================================== +set -o errexit +set -o nounset +set -o pipefail + function cleanup { deactivate - rm -r "${TEMP_DIR}" } trap cleanup EXIT -REPO_DIR=$(pwd) -TEMP_DIR=$(mktemp --directory) +echo "Deleting test environment (if it exists)" +rm -rf test_venv -set -o errexit -set -o nounset -set -o pipefail +echo "Creating test environment" +python3 -m venv test_venv -# Install deps in a virtual env. -python3 -m venv "${TEMP_DIR}/test_venv" -source "${TEMP_DIR}/test_venv/bin/activate" +echo "Activating test environment" +source test_venv/bin/activate -# Install dependencies. -python3 -m pip install --quiet --upgrade pip setuptools wheel -python3 -m pip install --quiet --upgrade flake8 pytest-xdist pylint pylint-exit -python3 -m pip install --quiet --editable ".[test, examples]" +for requirement in "pip" "setuptools" "wheel" "build" ".[test]" ".[examples]" ".[dev]" ".[docs]" +do + echo "Installing" $requirement + python3 -m pip install -qU $requirement +done # Dp-accounting specifies exact minor versions as requirements which sometimes # become incompatible with other libraries optax needs. We therefore install # dependencies for dp-accounting manually. # TODO(b/239416992): Remove this workaround if dp-accounting switches to minimum # version requirements. -python3 -m pip install --quiet --editable ".[dp-accounting]" -python3 -m pip install --quiet --no-deps "dp-accounting>=0.1.1" +echo "Installing .[dp-accounting]" +python3 -m pip install -qU --editable ".[dp-accounting]" +python3 -m pip install -qU --no-deps "dp-accounting>=0.1.1" +echo "Installing requested JAX version" # Install the requested JAX version if [ -z "${JAX_VERSION-}" ]; then : # use version installed in requirements above elif [ "$JAX_VERSION" = "newest" ]; then - python3 -m pip install --quiet --upgrade jax jaxlib + python3 -m pip install -qU jax jaxlib else - python3 -m pip install --quiet "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" + python3 -m pip install -qU "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" fi # Ensure optax was not installed by one of the dependencies above, # since if it is, the tests below will be run against that version instead of # the branch build. -python3 -m pip uninstall --quiet --yes optax +echo "Uninstalling optax (if already installed)" +python3 -m pip uninstall -q --yes optax -# Lint with flake8. -python3 -m flake8 --select=E9,F63,F7,F82,E225,E251 --show-source --statistics +echo "Linting with flake8" +flake8 -# Lint with pylint. -PYLINT_ARGS="-efail -wfail -cfail -rfail" -# Append specific config lines. -# Lint modules and tests separately. -python3 -m pylint --rcfile=.pylintrc $(find optax -name '*.py' | grep -v 'test.py' | xargs) -d E1102 || pylint-exit $PYLINT_ARGS $? -# Disable protected-access warnings for tests. -python3 -m pylint --rcfile=.pylintrc $(find optax -name '*_test.py' | xargs) -d W0212,E1102 || pylint-exit $PYLINT_ARGS $? +echo "Linting with pylint" +pylint . --rcfile pyproject.toml -# Build the package. -python3 -m pip install --quiet build +echo "Building the package" python3 -m build -python3 -m pip wheel --no-deps dist/optax-*.tar.gz --wheel-dir "${TEMP_DIR}" -python3 -m pip install --quiet "${TEMP_DIR}/optax-"*.whl -# Check types with pytype. -python3 -m pip install --quiet pytype -pytype "optax" --keep-going --disable import-error +echo "Building wheel" +python3 -m pip wheel --no-deps dist/optax-*.tar.gz + +echo "Installing the wheel" +python3 -m pip install -qU optax-*.whl + +echo "Checking types with pytype" +pytype -# Run tests using pytest. -# Change directory to avoid importing the package from repo root. -cd "${TEMP_DIR}" -python3 -m pytest --numprocesses auto --pyargs optax -cd "${REPO_DIR}" +echo "Running tests with pytest" +pytest +echo "Building sphinx docs" # Build Sphinx docs. -python3 -m pip install --quiet --editable ".[docs]" # NOTE(vroulet) We have dependencies issues: # tensorflow > 2.13.1 requires ml-dtypes <= 0.3.2 # but jax requires ml-dtypes >= 0.4.0 @@ -95,9 +94,10 @@ python3 -m pip install --quiet --editable ".[docs]" # bug (which issues conflict warnings but runs fine). # A long term solution is probably to fully remove tensorflow from our # dependencies. -python3 -m pip install --upgrade --verbose typing_extensions cd docs +echo "make html" make html +echo "make doctest" make doctest # run doctests cd .. From 15299571c86899fb8ab81e315d11ef9268511dab Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Thu, 24 Oct 2024 14:53:49 -0400 Subject: [PATCH 2/3] Re-enable several pylint messages after editing repo. --- docs/conf.py | 9 +- docs/ext/coverage_check.py | 12 +- optax/__init__.py | 23 +- optax/_src/alias_test.py | 186 ++++++------ optax/_src/base.py | 7 +- optax/_src/deprecations.py | 2 - optax/_src/factorized.py | 13 +- optax/_src/factorized_test.py | 6 +- optax/_src/float64_test.py | 34 +-- optax/_src/linear_algebra.py | 3 +- optax/_src/linear_algebra_test.py | 15 +- optax/_src/linesearch_test.py | 283 +++++++++--------- optax/_src/numerics_test.py | 7 +- optax/_src/transform.py | 35 +-- optax/_src/transform_test.py | 6 +- optax/_src/update_test.py | 10 +- optax/_src/utils.py | 13 +- optax/_src/utils_test.py | 2 +- optax/assignment/__init__.py | 2 - optax/assignment/_hungarian_algorithm.py | 3 +- optax/assignment/_hungarian_algorithm_test.py | 3 +- optax/contrib/__init__.py | 2 - optax/contrib/_common_test.py | 135 ++++----- optax/contrib/_complex_valued.py | 6 +- optax/contrib/_reduce_on_plateau.py | 4 +- optax/contrib/_sam_test.py | 8 +- optax/contrib/_schedule_free_test.py | 6 +- optax/losses/__init__.py | 2 - optax/losses/_classification.py | 7 +- optax/losses/_classification_test.py | 219 +++++++------- optax/losses/_ranking_test.py | 6 +- optax/losses/_regression_test.py | 10 +- optax/monte_carlo/control_variates_test.py | 244 ++++++--------- .../stochastic_gradient_estimators.py | 3 +- .../stochastic_gradient_estimators_test.py | 139 ++++----- optax/perturbations/__init__.py | 3 - optax/perturbations/_make_pert.py | 12 +- optax/perturbations/_make_pert_test.py | 19 +- optax/projections/_projections_test.py | 16 +- optax/schedules/_inject_test.py | 4 +- optax/schedules/_schedule_test.py | 16 +- optax/second_order/_hessian.py | 6 +- optax/transforms/_accumulation.py | 12 +- optax/transforms/_accumulation_test.py | 28 +- optax/transforms/_adding.py | 2 +- optax/transforms/_adding_test.py | 36 +-- optax/transforms/_clipping_test.py | 3 +- optax/transforms/_masking.py | 5 +- optax/transforms/_masking_test.py | 16 +- optax/tree_utils/__init__.py | 2 - optax/tree_utils/_casting.py | 33 +- optax/tree_utils/_casting_test.py | 8 +- optax/tree_utils/_random_test.py | 11 +- optax/tree_utils/_state_utils.py | 99 +++--- optax/tree_utils/_state_utils_test.py | 69 +++-- optax/tree_utils/_tree_math.py | 17 +- optax/tree_utils/_tree_math_test.py | 48 +-- pyproject.toml | 20 +- 58 files changed, 939 insertions(+), 1011 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 7c485d2c0..bd923ece2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,14 +24,14 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -# pylint: disable=g-bad-import-order -# pylint: disable=g-import-not-at-top +# pylint: disable=invalid-name + import inspect import os import sys -import optax from sphinxcontrib import katex +import optax def _add_annotations_import(path): @@ -237,8 +237,7 @@ def linkcode_resolve(domain, info): obj = getattr(obj, attr) except AttributeError: return None - else: - obj = inspect.unwrap(obj) + obj = inspect.unwrap(obj) try: filename = inspect.getsourcefile(obj) diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py index 0f47e8bea..301d5cb7d 100644 --- a/docs/ext/coverage_check.py +++ b/docs/ext/coverage_check.py @@ -19,10 +19,10 @@ import types from typing import Any, Sequence, Tuple -import optax from sphinx import application from sphinx import builders from sphinx import errors +import optax def find_internal_python_modules( @@ -65,7 +65,7 @@ class OptaxCoverageCheck(builders.Builder): def get_outdated_docs(self) -> str: return "coverage_check" - def write(self, *ignored: Any) -> None: + def write(self, *ignored: Any) -> None: # pylint: disable=overridden-final-method pass def finish(self) -> None: @@ -78,7 +78,13 @@ def finish(self) -> None: "forget to add an entry to `api.rst`?\n" f"Undocumented symbols: {undocumented_objects}") + def get_target_uri(self, docname, typ=None): + raise NotImplementedError + + def write_doc(self, docname, doctree): + raise NotImplementedError + def setup(app: application.Sphinx) -> Mapping[str, Any]: app.add_builder(OptaxCoverageCheck) - return dict(version=optax.__version__, parallel_read_safe=True) + return {"version": optax.__version__, "parallel_read_safe": True} diff --git a/optax/__init__.py b/optax/__init__.py index 7316e017f..2051b8480 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -14,8 +14,7 @@ # ============================================================================== """Optax: composable gradient processing and optimization, in JAX.""" -# pylint: disable=wrong-import-position -# pylint: disable=g-importing-member +import typing as _typing from optax import assignment from optax import contrib @@ -185,6 +184,14 @@ from optax._src.wrappers import skip_large_updates from optax._src.wrappers import skip_not_finite +# TODO(mtthss): remove contrib aliases from flat namespace once users updated. +# Deprecated modules +from optax.contrib import differentially_private_aggregate as \ + _deprecated_differentially_private_aggregate +from optax.contrib import DifferentiallyPrivateAggregateState as \ + _deprecated_DifferentiallyPrivateAggregateState +from optax.contrib import dpsgd as _deprecated_dpsgd + # TODO(mtthss): remove tree_utils aliases after updates. tree_map_params = tree_utils.tree_map_params @@ -236,13 +243,6 @@ squared_error = losses.squared_error sigmoid_focal_loss = losses.sigmoid_focal_loss -# pylint: disable=g-import-not-at-top -# TODO(mtthss): remove contrib aliases from flat namespace once users updated. -# Deprecated modules -from optax.contrib import differentially_private_aggregate as _deprecated_differentially_private_aggregate -from optax.contrib import DifferentiallyPrivateAggregateState as _deprecated_DifferentiallyPrivateAggregateState -from optax.contrib import dpsgd as _deprecated_dpsgd - _deprecations = { # Added Apr 2024 "differentially_private_aggregate": ( @@ -269,8 +269,6 @@ _deprecated_dpsgd, ), } -# pylint: disable=g-bad-import-order -import typing as _typing if _typing.TYPE_CHECKING: # pylint: disable=reimported @@ -285,9 +283,6 @@ __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing -# pylint: enable=g-bad-import-order -# pylint: enable=g-import-not-at-top -# pylint: enable=g-importing-member __version__ = "0.2.4.dev" diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 0969386ba..22fedf59d 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -25,6 +25,10 @@ import jax.numpy as jnp import jax.random as jrd import numpy as np +import scipy.optimize as scipy_optimize +from sklearn import datasets +from sklearn import linear_model + from optax._src import alias from optax._src import base from optax._src import linesearch as _linesearch @@ -35,9 +39,6 @@ from optax.schedules import _inject from optax.transforms import _accumulation import optax.tree_utils as otu -import scipy.optimize as scipy_optimize -from sklearn import datasets -from sklearn import linear_model ############## @@ -46,44 +47,50 @@ _OPTIMIZERS_UNDER_TEST = ( - dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1e-3, momentum=0.9)), - dict(opt_name='adadelta', opt_kwargs=dict(learning_rate=0.1)), - dict(opt_name='adafactor', opt_kwargs=dict(learning_rate=5e-3)), - dict(opt_name='adagrad', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adam', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='adan', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)), - dict( - opt_name='lion', - opt_kwargs=dict(learning_rate=1e-2, weight_decay=1e-4), - ), - dict(opt_name='nadam', opt_kwargs=dict(learning_rate=1e-2)), - dict(opt_name='nadamw', opt_kwargs=dict(learning_rate=1e-2)), - dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), - dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1e-3)), - dict( - opt_name='optimistic_gradient_descent', - opt_kwargs=dict(learning_rate=2e-3, alpha=0.7, beta=0.1), - ), - dict( - opt_name='optimistic_adam', - opt_kwargs=dict(learning_rate=2e-3), - ), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=5e-3)), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=5e-3, momentum=0.9)), - dict(opt_name='sign_sgd', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='fromage', opt_kwargs=dict(learning_rate=5e-3)), - dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1e-2)), - dict(opt_name='radam', opt_kwargs=dict(learning_rate=5e-3)), - dict(opt_name='rprop', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='polyak_sgd', opt_kwargs=dict(max_learning_rate=1.0)), + {"opt_name": 'sgd', "opt_kwargs": {"learning_rate": 1e-3, "momentum": 0.9}}, + {"opt_name": 'adadelta', "opt_kwargs": {"learning_rate": 0.1}}, + {"opt_name": 'adafactor', "opt_kwargs": {"learning_rate": 5e-3}}, + {"opt_name": 'adagrad', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'adam', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'adamw', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'adamax', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'adamaxw', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'adan', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'amsgrad', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'lars', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'lamb', "opt_kwargs": {"learning_rate": 1e-3}}, + { + "opt_name": 'lion', + "opt_kwargs": {"learning_rate": 1e-2, "weight_decay": 1e-4}, + }, + {"opt_name": 'nadam', "opt_kwargs": {"learning_rate": 1e-2}}, + {"opt_name": 'nadamw', "opt_kwargs": {"learning_rate": 1e-2}}, + { + "opt_name": 'noisy_sgd', + "opt_kwargs": {"learning_rate": 1e-3, "eta": 1e-4}, + }, + {"opt_name": 'novograd', "opt_kwargs": {"learning_rate": 1e-3}}, + { + "opt_name": 'optimistic_gradient_descent', + "opt_kwargs": {"learning_rate": 2e-3, "alpha": 0.7, "beta": 0.1}, + }, + { + "opt_name": 'optimistic_adam', + "opt_kwargs": {"learning_rate": 2e-3}, + }, + {"opt_name": 'rmsprop', "opt_kwargs": {"learning_rate": 5e-3}}, + { + "opt_name": 'rmsprop', + "opt_kwargs": {"learning_rate": 5e-3, "momentum": 0.9}, + }, + {"opt_name": 'sign_sgd', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'fromage', "opt_kwargs": {"learning_rate": 5e-3}}, + {"opt_name": 'adabelief', "opt_kwargs": {"learning_rate": 1e-2}}, + {"opt_name": 'radam', "opt_kwargs": {"learning_rate": 5e-3}}, + {"opt_name": 'rprop', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'sm3', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'yogi', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'polyak_sgd', "opt_kwargs": {"max_learning_rate": 1.0}}, ) @@ -271,7 +278,8 @@ def test_preserve_dtype(self, opt_name, opt_kwargs, dtype): dtype = jnp.dtype(dtype) opt_factory = getattr(alias, opt_name) opt = opt_factory(**opt_kwargs) - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) params = jnp.array([1.0, 2.0], dtype=dtype) grads = jax.grad(fun)(params) @@ -295,7 +303,8 @@ def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype): base_opt = opt_factory(**opt_kwargs) opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4) - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) params = jnp.array([1.0, 2.0], dtype=dtype) grads = jax.grad(fun)(params) @@ -372,14 +381,13 @@ def _materialize_approx_inv_hessian( rhos = jnp.roll(rhos, -k, axis=0) id_mat = jnp.eye(d, d) - # pylint: disable=invalid-name - P = id_mat - safe_dot = lambda x, y: jnp.dot(x, y, precision=jax.lax.Precision.HIGHEST) + p = id_mat + def safe_dot(x, y): + return jnp.dot(x, y, precision=jax.lax.Precision.HIGHEST) for j in range(m): - V = id_mat - rhos[j] * jnp.outer(dus[j], dws[j]) - P = safe_dot(V.T, safe_dot(P, V)) + rhos[j] * jnp.outer(dws[j], dws[j]) - # pylint: enable=invalid-name - precond_mat = P + v = id_mat - rhos[j] * jnp.outer(dus[j], dws[j]) + p = safe_dot(v.T, safe_dot(p, v)) + rhos[j] * jnp.outer(dws[j], dws[j]) + precond_mat = p return precond_mat @@ -520,44 +528,44 @@ def zakharov(x, xnp): answer = sum1 + sum2**2 + sum2**4 return answer - problems = dict( - rosenbrock=dict( - fun=lambda x: rosenbrock(x, jnp), - numpy_fun=lambda x: rosenbrock(x, np), - init=np.zeros(2), - minimum=0.0, - minimizer=np.ones(2), - ), - himmelblau=dict( - fun=himmelblau, - numpy_fun=himmelblau, - init=np.ones(2), - minimum=0.0, + problems = { + "rosenbrock": { + "fun": lambda x: rosenbrock(x, jnp), + "numpy_fun": lambda x: rosenbrock(x, np), + "init": np.zeros(2), + "minimum": 0.0, + "minimizer": np.ones(2), + }, + "himmelblau": { + "fun": himmelblau, + "numpy_fun": himmelblau, + "init": np.ones(2), + "minimum": 0.0, # himmelblau has actually multiple minimizers, we simply consider one. - minimizer=np.array([3.0, 2.0]), - ), - matyas=dict( - fun=matyas, - numpy_fun=matyas, - init=np.ones(2) * 6.0, - minimum=0.0, - minimizer=np.zeros(2), - ), - eggholder=dict( - fun=lambda x: eggholder(x, jnp), - numpy_fun=lambda x: eggholder(x, np), - init=np.ones(2) * 6.0, - minimum=-959.6407, - minimizer=np.array([512.0, 404.22319]), - ), - zakharov=dict( - fun=lambda x: zakharov(x, jnp), - numpy_fun=lambda x: zakharov(x, np), - init=np.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e3]), - minimum=0.0, - minimizer=np.zeros(6), - ), - ) + "minimizer": np.array([3.0, 2.0]), + }, + "matyas": { + "fun": matyas, + "numpy_fun": matyas, + "init": np.ones(2) * 6.0, + "minimum": 0.0, + "minimizer": np.zeros(2), + }, + "eggholder": { + "fun": lambda x: eggholder(x, jnp), + "numpy_fun": lambda x: eggholder(x, np), + "init": np.ones(2) * 6.0, + "minimum": -959.6407, + "minimizer": np.array([512.0, 404.22319]), + }, + "zakharov": { + "fun": lambda x: zakharov(x, jnp), + "numpy_fun": lambda x: zakharov(x, np), + "init": np.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e3]), + "minimum": 0.0, + "minimizer": np.zeros(6), + }, + } return problems[name] @@ -629,11 +637,11 @@ def test_preconditioning_by_lbfgs_on_trees(self, idx: int): ) flat_dws = [ - flatten_util.ravel_pytree(jax.tree.map(lambda dw: dw[i], dws))[0] # pylint: disable=cell-var-from-loop + flatten_util.ravel_pytree(jax.tree.map(lambda dw, i=i: dw[i], dws))[0] for i in range(m) ] flat_dus = [ - flatten_util.ravel_pytree(jax.tree.map(lambda du: du[i], dus))[0] # pylint: disable=cell-var-from-loop + flatten_util.ravel_pytree(jax.tree.map(lambda du, i=i: du[i], dus))[0] for i in range(m) ] flat_dws, flat_dus = jnp.stack(flat_dws), jnp.stack(flat_dus) diff --git a/optax/_src/base.py b/optax/_src/base.py index ddac5a53b..8934c3bd5 100644 --- a/optax/_src/base.py +++ b/optax/_src/base.py @@ -313,9 +313,10 @@ def update_fn(updates, state, params=None): del state if params is not None: return jax.tree.map(f, updates, params), EmptyState() - else: - f_ = lambda u: f(u, None) - return jax.tree.map(f_, updates), EmptyState() + + def f_(u): + return f(u, None) + return jax.tree.map(f_, updates), EmptyState() return GradientTransformation(init_empty_state, update_fn) diff --git a/optax/_src/deprecations.py b/optax/_src/deprecations.py index 9b49876f1..9d5c7031a 100644 --- a/optax/_src/deprecations.py +++ b/optax/_src/deprecations.py @@ -54,5 +54,3 @@ def _getattr(name): raise AttributeError(f"module {module!r} has no attribute {name!r}") return _getattr - - diff --git a/optax/_src/factorized.py b/optax/_src/factorized.py index 6a513f298..63b0bbf30 100644 --- a/optax/_src/factorized.py +++ b/optax/_src/factorized.py @@ -142,13 +142,12 @@ def _init(param): v_col=jnp.zeros(vc_shape, dtype=dtype), v=jnp.zeros((1,), dtype=dtype), ) - else: - return _UpdateResult( - update=jnp.zeros((1,), dtype=dtype), - v_row=jnp.zeros((1,), dtype=dtype), - v_col=jnp.zeros((1,), dtype=dtype), - v=jnp.zeros(param.shape, dtype=dtype), - ) + return _UpdateResult( + update=jnp.zeros((1,), dtype=dtype), + v_row=jnp.zeros((1,), dtype=dtype), + v_col=jnp.zeros((1,), dtype=dtype), + v=jnp.zeros(param.shape, dtype=dtype), + ) return _to_state(jnp.zeros([], jnp.int32), jax.tree.map(_init, params)) diff --git a/optax/_src/factorized_test.py b/optax/_src/factorized_test.py index 1f57756be..2ddfd5e1d 100644 --- a/optax/_src/factorized_test.py +++ b/optax/_src/factorized_test.py @@ -53,7 +53,8 @@ def test_preserve_dtype(self, factorized_dims: bool, dtype: str): """Test that the optimizer returns updates of same dtype as params.""" dtype = jnp.dtype(dtype) opt = factorized.scale_by_factored_rms() - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) if factorized_dims: # The updates are factored only for large enough parameters @@ -77,7 +78,8 @@ def test_gradient_accumulation(self, factorized_dims, dtype): base_opt = factorized.scale_by_factored_rms() opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4) - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) if factorized_dims: # The updates are factored only for large enough parameters diff --git a/optax/_src/float64_test.py b/optax/_src/float64_test.py index dbcb0aa9a..3641014b4 100644 --- a/optax/_src/float64_test.py +++ b/optax/_src/float64_test.py @@ -28,37 +28,37 @@ ALL_MODULES = [ ('identity', base.identity, {}), - ('clip', clipping.clip, dict(max_delta=1.0)), - ('clip_by_global_norm', clipping.clip_by_global_norm, dict(max_norm=1.0)), - ('trace', transform.trace, dict(decay=0.5, nesterov=False)), - ('trace_with_nesterov', transform.trace, dict(decay=0.5, nesterov=True)), + ('clip', clipping.clip, {"max_delta": 1.0}), + ('clip_by_global_norm', clipping.clip_by_global_norm, {"max_norm": 1.0}), + ('trace', transform.trace, {"decay": 0.5, "nesterov": False}), + ('trace_with_nesterov', transform.trace, {"decay": 0.5, "nesterov": True}), ('scale_by_rss', transform.scale_by_rss, {}), ('scale_by_rms', transform.scale_by_rms, {}), ('scale_by_stddev', transform.scale_by_stddev, {}), ('adam', transform.scale_by_adam, {}), - ('scale', transform.scale, dict(step_size=3.0)), + ('scale', transform.scale, {"step_size": 3.0}), ( 'add_decayed_weights', transform.add_decayed_weights, - dict(weight_decay=0.1), + {"weight_decay": 0.1}, ), ( 'scale_by_schedule', transform.scale_by_schedule, - dict(step_size_fn=lambda x: x * 0.1), + {"step_size_fn": lambda x: x * 0.1}, ), ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), - ('add_noise', transform.add_noise, dict(eta=1.0, gamma=0.1, seed=42)), + ('add_noise', transform.add_noise, {"eta": 1.0, "gamma": 0.1, "seed": 42}), ('apply_every_k', transform.apply_every, {}), - ('adagrad', alias.adagrad, dict(learning_rate=0.1)), - ('adam', alias.adam, dict(learning_rate=0.1)), - ('adamw', alias.adamw, dict(learning_rate=0.1)), - ('fromage', alias.fromage, dict(learning_rate=0.1)), - ('lamb', alias.lamb, dict(learning_rate=0.1)), - ('noisy_sgd', alias.noisy_sgd, dict(learning_rate=0.1)), - ('rmsprop', alias.rmsprop, dict(learning_rate=0.1)), - ('sgd', alias.sgd, dict(learning_rate=0.1)), - ('sign_sgd', alias.sgd, dict(learning_rate=0.1)), + ('adagrad', alias.adagrad, {"learning_rate": 0.1}), + ('adam', alias.adam, {"learning_rate": 0.1}), + ('adamw', alias.adamw, {"learning_rate": 0.1}), + ('fromage', alias.fromage, {"learning_rate": 0.1}), + ('lamb', alias.lamb, {"learning_rate": 0.1}), + ('noisy_sgd', alias.noisy_sgd, {"learning_rate": 0.1}), + ('rmsprop', alias.rmsprop, {"learning_rate": 0.1}), + ('sgd', alias.sgd, {"learning_rate": 0.1}), + ('sign_sgd', alias.sgd, {"learning_rate": 0.1}), ] diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index f05d5b2be..f6d0e8cb8 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -102,7 +102,8 @@ def power_iteration( # v0 must be given as we don't know the underlying pytree structure. raise ValueError('v0 must be provided when `matrix` is a callable.') else: - mvp = lambda v: jnp.matmul(matrix, v, precision=precision) + def mvp(v): + return jnp.matmul(matrix, v, precision=precision) if v0 is None: if key is None: key = jax.random.PRNGKey(0) diff --git a/optax/_src/linear_algebra_test.py b/optax/_src/linear_algebra_test.py index a2da037e9..d9ee9cd97 100644 --- a/optax/_src/linear_algebra_test.py +++ b/optax/_src/linear_algebra_test.py @@ -24,9 +24,10 @@ import jax import jax.numpy as jnp import numpy as np +import scipy.stats + from optax._src import linear_algebra import optax.tree_utils as otu -import scipy.stats class MLP(nn.Module): @@ -46,10 +47,10 @@ class LinearAlgebraTest(chex.TestCase): def test_global_norm(self): flat_updates = jnp.array([2.0, 4.0, 3.0, 5.0], dtype=jnp.float32) - nested_updates = dict( - a=jnp.array([2.0, 4.0], dtype=jnp.float32), - b=jnp.array([3.0, 5.0], dtype=jnp.float32), - ) + nested_updates = { + "a": jnp.array([2.0, 4.0], dtype=jnp.float32), + "b": jnp.array([3.0, 5.0], dtype=jnp.float32), + } np.testing.assert_array_equal( jnp.sqrt(jnp.sum(flat_updates**2)), linear_algebra.global_norm(nested_updates), @@ -78,8 +79,8 @@ def test_power_iteration_cond_fun(self, dim=6): @chex.all_variants @parameterized.parameters( - dict(implicit=True), - dict(implicit=False), + {"implicit": True}, + {"implicit": False}, ) def test_power_iteration(self, implicit, dim=6, tol=1e-3, num_iters=100): """Test power_iteration by comparing to numpy.linalg.eigh.""" diff --git a/optax/_src/linesearch_test.py b/optax/_src/linesearch_test.py index b82ed77e8..e315bc9fc 100644 --- a/optax/_src/linesearch_test.py +++ b/optax/_src/linesearch_test.py @@ -28,14 +28,15 @@ import jax import jax.numpy as jnp import jax.random as jrd +import scipy.optimize as scipy_optimize + +import optax.tree_utils as otu from optax._src import alias from optax._src import base from optax._src import combine from optax._src import linesearch as _linesearch from optax._src import update from optax._src import utils -import optax.tree_utils as otu -import scipy.optimize as scipy_optimize def get_problem(name: str): @@ -70,16 +71,16 @@ def zakharov(x): sum2 = (0.5 * ii * x).sum() return sum1 + sum2**2 + sum2**4 - problems = dict( - polynomial=dict(fn=polynomial, input_shape=()), - exponential=dict(fn=exponential, input_shape=()), - sinusoidal=dict(fn=sinusoidal, input_shape=()), - rosenbrock=dict(fn=rosenbrock, input_shape=(16,)), - himmelblau=dict(fn=himmelblau, input_shape=(2,)), - matyas=dict(fn=matyas, input_shape=(2,)), - eggholder=dict(fn=eggholder, input_shape=(2,)), - zakharov=dict(fn=zakharov, input_shape=(6,)), - ) + problems = { + "polynomial": {"fn": polynomial, "input_shape": ()}, + "exponential": {"fn": exponential, "input_shape": ()}, + "sinusoidal": {"fn": sinusoidal, "input_shape": ()}, + "rosenbrock": {"fn": rosenbrock, "input_shape": (16,)}, + "himmelblau": {"fn": himmelblau, "input_shape": (2,)}, + "matyas": {"fn": matyas, "input_shape": (2,)}, + "eggholder": {"fn": eggholder, "input_shape": (2,)}, + "zakharov": {"fn": zakharov, "input_shape": (6,)}, + } return problems[name] @@ -108,7 +109,8 @@ def _check_decrease_conditions( @chex.all_variants() def test_linesearch_with_jax_variants(self): """Test backtracking linesearch with jax variants (jit etc...).""" - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) params = jnp.zeros(2) updates = -jax.grad(fun)(params) @@ -156,13 +158,13 @@ def test_linesearch( base_opt = alias.sgd(learning_rate=1.0) descent_dir = -jax.grad(fn)(init_params) - opt_args = dict( - max_backtracking_steps=50, - slope_rtol=slope_rtol, - increase_factor=increase_factor, - atol=atol, - rtol=rtol, - ) + opt_args = { + "max_backtracking_steps": 50, + "slope_rtol": slope_rtol, + "increase_factor": increase_factor, + "atol": atol, + "rtol": rtol, + } solver = combine.chain( base_opt, @@ -373,11 +375,11 @@ def _check_linesearch_conditions( potentially_failed = False slope_init = otu.tree_vdot(updates, grad_init) slope_final = otu.tree_vdot(updates, grad_final) - default_opt_args = dict( - slope_rtol=1e-4, - curv_rtol=0.9, - tol=0.0, - ) + default_opt_args = { + "slope_rtol": 1e-4, + "curv_rtol": 0.9, + "tol": 0.0, + } opt_args = default_opt_args | opt_args slope_rtol, curv_rtol, tol = ( opt_args['slope_rtol'], @@ -405,7 +407,8 @@ def _check_linesearch_conditions( @chex.all_variants() def test_linesearch_with_jax_variants(self): """Test zoom linesearch with jax variants (jit etc...).""" - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) params = jnp.zeros(2) updates = -jax.grad(fun)(params) @@ -458,13 +461,13 @@ def test_linesearch(self, problem_name: str, seed: int): # (non-negativity ensures that we keep a descent direction) init_updates = -precond_vec * jax.grad(fn)(init_params) - opt_args = dict( - max_linesearch_steps=30, - slope_rtol=slope_rtol, - curv_rtol=curv_rtol, - tol=tol, - max_learning_rate=None, - ) + opt_args = { + "max_linesearch_steps": 30, + "slope_rtol": slope_rtol, + "curv_rtol": curv_rtol, + "tol": tol, + "max_learning_rate": None, + } opt = _linesearch.scale_by_zoom_linesearch(**opt_args) final_params, final_state = _run_linesearch( @@ -497,30 +500,30 @@ def test_failure_descent_direction(self): # program if jax.default_backend() in ['tpu', 'gpu']: return - else: - # For this f and p, starting at a point on axis 0, the strong Wolfe - # condition 2 is met if and only if the step length s satisfies - # |x + s| <= c2 * |x| - def fn(w): - return jnp.dot(w, w) - - u = jnp.array([1.0, 0.0]) - w = 60 * u - - # Test that the line search fails for p not a descent direction - # For high maxiter, still finds a decrease error because of - # the approximate Wolfe condition so we reduced maxiter - opt = _linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=18, curv_rtol=0.5, verbose=True - ) - stdout = io.StringIO() - with contextlib.redirect_stdout(stdout): - _, state = _run_linesearch(opt, fn, w, u, stepsize_guess=1.0) - stepsize = otu.tree_get(state, 'learning_rate') - # Check that we were not able to make a step or an infinitesimal one - self.assertLess(stepsize, 1e-5) - self.assertIn(_linesearch.FLAG_NOT_A_DESCENT_DIRECTION, stdout.getvalue()) - self.assertIn(_linesearch.FLAG_NO_STEPSIZE_FOUND, stdout.getvalue()) + + # For this f and p, starting at a point on axis 0, the strong Wolfe + # condition 2 is met if and only if the step length s satisfies + # |x + s| <= c2 * |x| + def fn(w): + return jnp.dot(w, w) + + u = jnp.array([1.0, 0.0]) + w = 60 * u + + # Test that the line search fails for p not a descent direction + # For high maxiter, still finds a decrease error because of + # the approximate Wolfe condition so we reduced maxiter + opt = _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=18, curv_rtol=0.5, verbose=True + ) + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + _, state = _run_linesearch(opt, fn, w, u, stepsize_guess=1.0) + stepsize = otu.tree_get(state, 'learning_rate') + # Check that we were not able to make a step or an infinitesimal one + self.assertLess(stepsize, 1e-5) + self.assertIn(_linesearch.FLAG_NOT_A_DESCENT_DIRECTION, stdout.getvalue()) + self.assertIn(_linesearch.FLAG_NO_STEPSIZE_FOUND, stdout.getvalue()) def test_failure_too_small_max_stepsize(self): """Check failure when the max stepsize is too small.""" @@ -531,32 +534,31 @@ def test_failure_too_small_max_stepsize(self): # program if jax.default_backend() in ['tpu', 'gpu']: return - else: - - def fn(x): - return jnp.dot(x, x) - - u = jnp.array([1.0, 0.0]) - w = -60 * u - # Test that the line search fails if the maximum stepsize is too small - # Here, smallest s satisfying strong Wolfe conditions for c2=0.5 is 30 - opt = _linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=10, - curv_rtol=0.5, - verbose=True, - max_learning_rate=10.0, - ) - stdout = io.StringIO() - with contextlib.redirect_stdout(stdout): - _, state = _run_linesearch(opt, fn, w, u, stepsize_guess=1.0) - stepsize = otu.tree_get(state, 'learning_rate') - # Check that we still made a step - self.assertEqual(stepsize, 10.0) - self.assertIn(_linesearch.FLAG_INTERVAL_NOT_FOUND, stdout.getvalue()) - self.assertIn( - _linesearch.FLAG_CURVATURE_COND_NOT_SATISFIED, stdout.getvalue() - ) + def fn(x): + return jnp.dot(x, x) + + u = jnp.array([1.0, 0.0]) + w = -60 * u + + # Test that the line search fails if the maximum stepsize is too small + # Here, smallest s satisfying strong Wolfe conditions for c2=0.5 is 30 + opt = _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=10, + curv_rtol=0.5, + verbose=True, + max_learning_rate=10.0, + ) + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + _, state = _run_linesearch(opt, fn, w, u, stepsize_guess=1.0) + stepsize = otu.tree_get(state, 'learning_rate') + # Check that we still made a step + self.assertEqual(stepsize, 10.0) + self.assertIn(_linesearch.FLAG_INTERVAL_NOT_FOUND, stdout.getvalue()) + self.assertIn( + _linesearch.FLAG_CURVATURE_COND_NOT_SATISFIED, stdout.getvalue() + ) def test_failure_not_enough_iter(self): """Check failure for a very small number of iterations.""" @@ -567,47 +569,46 @@ def test_failure_not_enough_iter(self): # program if jax.default_backend() in ['tpu', 'gpu']: return - else: - def fn(x): - return jnp.dot(x, x) + def fn(x): + return jnp.dot(x, x) - u = jnp.array([1.0, 0.0]) - w = -60 * u + u = jnp.array([1.0, 0.0]) + w = -60 * u - curv_rtol = 0.5 - # s=30 will only be tried on the 6th iteration, so this fails because - # the maximum number of iterations is reached. - opt = _linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=5, curv_rtol=curv_rtol, verbose=True - ) - stdout = io.StringIO() - with contextlib.redirect_stdout(stdout): - _, final_state = _run_linesearch(opt, fn, w, u, stepsize_guess=1.0) - stepsize = otu.tree_get(final_state, 'learning_rate') - # Check that we still made a step - self.assertEqual(stepsize, 16.0) - decrease_error = otu.tree_get(final_state, 'decrease_error') - curvature_error = otu.tree_get(final_state, 'curvature_error') - success = (decrease_error <= 0.0) and (curvature_error <= 0.0) - self.assertFalse(success) - # Here the error should not be that we haven't had a descent direction - self.assertNotIn( - _linesearch.FLAG_NOT_A_DESCENT_DIRECTION, stdout.getvalue() - ) + curv_rtol = 0.5 + # s=30 will only be tried on the 6th iteration, so this fails because + # the maximum number of iterations is reached. + opt = _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=5, curv_rtol=curv_rtol, verbose=True + ) + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + _, final_state = _run_linesearch(opt, fn, w, u, stepsize_guess=1.0) + stepsize = otu.tree_get(final_state, 'learning_rate') + # Check that we still made a step + self.assertEqual(stepsize, 16.0) + decrease_error = otu.tree_get(final_state, 'decrease_error') + curvature_error = otu.tree_get(final_state, 'curvature_error') + success = (decrease_error <= 0.0) and (curvature_error <= 0.0) + self.assertFalse(success) + # Here the error should not be that we haven't had a descent direction + self.assertNotIn( + _linesearch.FLAG_NOT_A_DESCENT_DIRECTION, stdout.getvalue() + ) - # Check if it works normally - opt = _linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=30, curv_rtol=curv_rtol - ) - final_params, final_state = _run_linesearch( - opt, fn, w, u, stepsize_guess=1.0 - ) - s = otu.tree_get(final_state, 'learning_rate') - self._check_linesearch_conditions( - fn, w, u, final_params, final_state, dict(curv_rtol=curv_rtol) - ) - self.assertGreaterEqual(s, 30.0) + # Check if it works normally + opt = _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=30, curv_rtol=curv_rtol + ) + final_params, final_state = _run_linesearch( + opt, fn, w, u, stepsize_guess=1.0 + ) + s = otu.tree_get(final_state, 'learning_rate') + self._check_linesearch_conditions( + fn, w, u, final_params, final_state, {"curv_rtol": curv_rtol} + ) + self.assertGreaterEqual(s, 30.0) def test_failure_flat_fun(self): """Check failure for a very flat function.""" @@ -618,20 +619,19 @@ def test_failure_flat_fun(self): # program if jax.default_backend() in ['tpu', 'gpu']: return - else: - def fun_flat(x): - return jnp.exp(-1 / x**2) + def fun_flat(x): + return jnp.exp(-1 / x**2) - w = jnp.asarray(-0.2) - u = -jax.grad(fun_flat)(w) - opt = _linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=30, verbose=True - ) - stdout = io.StringIO() - with contextlib.redirect_stdout(stdout): - _, _ = _run_linesearch(opt, fun_flat, w, u, stepsize_guess=1.0) - self.assertIn(_linesearch.FLAG_INTERVAL_TOO_SMALL, stdout.getvalue()) + w = jnp.asarray(-0.2) + u = -jax.grad(fun_flat)(w) + opt = _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=30, verbose=True + ) + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + _, _ = _run_linesearch(opt, fun_flat, w, u, stepsize_guess=1.0) + self.assertIn(_linesearch.FLAG_INTERVAL_TOO_SMALL, stdout.getvalue()) def test_failure_inf_value(self): """Check behavior for inf/nan values.""" @@ -642,19 +642,18 @@ def test_failure_inf_value(self): # program if jax.default_backend() in ['tpu', 'gpu']: return - else: - def fun_inf(x): - return jnp.log(x) + def fun_inf(x): + return jnp.log(x) - w = jnp.asarray(1.0) - u = jnp.asarray(-2.0) - opt = _linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=30, verbose=True - ) - _, state = _run_linesearch(opt, fun_inf, w, u, stepsize_guess=1.0) - stepsize = otu.tree_get(state, 'learning_rate') - self.assertGreater(stepsize, 0.0) + w = jnp.asarray(1.0) + u = jnp.asarray(-2.0) + opt = _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=30, verbose=True + ) + _, state = _run_linesearch(opt, fun_inf, w, u, stepsize_guess=1.0) + stepsize = otu.tree_get(state, 'learning_rate') + self.assertGreater(stepsize, 0.0) def test_high_smaller_than_low(self): # See google/jax/issues/16236 diff --git a/optax/_src/numerics_test.py b/optax/_src/numerics_test.py index 18cd02811..a8c4a3fb4 100644 --- a/optax/_src/numerics_test.py +++ b/optax/_src/numerics_test.py @@ -28,8 +28,11 @@ _ALL_ORDS = [None, np.inf, -np.inf, "fro", "nuc", 0, 1, 2, -2, -2, -1.5, 1.5] -int32_array = lambda i: jnp.array(i, dtype=jnp.int32) -float32_array = lambda i: jnp.array(i, dtype=jnp.float32) +def int32_array(i): + return jnp.array(i, dtype=jnp.int32) + +def float32_array(i): + return jnp.array(i, dtype=jnp.float32) def _invalid_ord_axis_inputs(ord_axis_keepdims): diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 69b24bdc1..1a4be919a 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -124,8 +124,7 @@ def init_fn(params): nu = otu.tree_full_like(params, initial_scale) # second moment if bias_correction: return ScaleByRmsWithCountState(count=jnp.zeros([], jnp.int32), nu=nu) - else: - return ScaleByRmsState(nu=nu) + return ScaleByRmsState(nu=nu) def update_fn(updates, state, params=None): del params @@ -198,8 +197,7 @@ def init_fn(params): return ScaleByRStdDevWithCountState( count=jnp.zeros([], jnp.int32), mu=mu, nu=nu ) - else: - return ScaleByRStdDevState(mu=mu, nu=nu) + return ScaleByRStdDevState(mu=mu, nu=nu) def update_fn(updates, state, params=None): del params @@ -738,6 +736,9 @@ def scale_by_yogi( Returns: A :class:`optax.GradientTransformation` object. + + References: + [Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) # pylint: disable=line-too-long """ def init_fn(params): @@ -1076,8 +1077,7 @@ def update_fn(updates, state, params=None): def _subtract_mean(g): if len(g.shape) > 1: return g - g.mean(tuple(range(1, len(g.shape))), keepdims=True) - else: - return g + return g CentralState = base.EmptyState @@ -1147,29 +1147,26 @@ def _expanded_shape(shape, axis): def _new_accum(g, v): coeffs = ((1.0 - b2) if b2 != 1.0 else 1.0, b2) if g.ndim < 2: - return coeffs[0] * g**2 + coeffs[1] * v[0] - else: - return coeffs[0] * g**2 + coeffs[1] * functools.reduce(jnp.minimum, v) + return coeffs[0]*g**2 + coeffs[1]*v[0] + return coeffs[0]*g**2 + coeffs[1]*functools.reduce(jnp.minimum, v) def _new_mu(g, i): if g.ndim < 2: return g - else: - return jnp.max(g, axis=other_axes(i, g.ndim)) + return jnp.max(g, axis=other_axes(i, g.ndim)) def other_axes(idx, ndim): return list(range(idx)) + list(range(idx + 1, ndim)) def update_fn(updates, state, params=None): del params - mu = jax.tree.map( - lambda g, v: [ # pylint:disable=g-long-lambda - jnp.reshape(v[i], _expanded_shape(g.shape, i)) - for i in range(g.ndim) - ], - updates, - state.mu, - ) + + def f(g, v): + return [ + jnp.reshape(v[i], _expanded_shape(g.shape, i)) for i in range(g.ndim) + ] + + mu = jax.tree.map(f, updates, state.mu) accum = jax.tree.map(_new_accum, updates, mu) accum_inv_sqrt = jax.tree.map( lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), accum diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index 001c463a4..6ee24a992 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -180,7 +180,8 @@ def test_scale_by_polyak_l1_norm(self, tol=1e-10): """Polyak step-size on L1 norm.""" # for this objective, the Polyak step-size has an exact model and should # converge to the minimizer in one step - objective = lambda x: jnp.abs(x).sum() + def objective(x): + return jnp.abs(x).sum() init_params = jnp.array([1.0, -1.0]) polyak = transform.scale_by_polyak() @@ -197,7 +198,8 @@ def test_scale_by_polyak_l1_norm(self, tol=1e-10): def test_rms_match_adam(self): """Test scale_by_rms add_eps_in_sqrt=False matches scale_by_adam(b1=0).""" - fun = lambda x: otu.tree_l2_norm(x, squared=True) + def fun(x): + return otu.tree_l2_norm(x, squared=True) rms = transform.scale_by_rms( decay=0.999, eps_in_sqrt=False, bias_correction=True diff --git a/optax/_src/update_test.py b/optax/_src/update_test.py index c3bde6d50..d9a2f270c 100644 --- a/optax/_src/update_test.py +++ b/optax/_src/update_test.py @@ -77,11 +77,11 @@ def test_periodic_update(self): chex.assert_trees_all_close(params_2, new_params, atol=1e-10, rtol=1e-5) @parameterized.named_parameters( - dict(testcase_name='apply_updates', operation=update.apply_updates), - dict( - testcase_name='incremental_update', - operation=lambda x, y: update.incremental_update(x, y, 1), - ), + {"testcase_name": 'apply_updates', "operation": update.apply_updates}, + { + "testcase_name": 'incremental_update', + "operation": lambda x, y: update.incremental_update(x, y, 1), + }, ) def test_none_argument(self, operation): x = jnp.array([1.0, 2.0, 3.0]) diff --git a/optax/_src/utils.py b/optax/_src/utils.py index 3c59cad3f..0e03d9a11 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -24,16 +24,16 @@ from etils import epy import jax import jax.numpy as jnp + +with epy.lazy_imports(): + import jax.scipy.stats.norm as multivariate_normal + from optax import tree_utils as otu from optax._src import base from optax._src import linear_algebra from optax._src import numerics -with epy.lazy_imports(): - import jax.scipy.stats.norm as multivariate_normal # pylint: disable=g-import-not-at-top,ungrouped-imports - - def tile_second_to_last_dim(a: chex.Array) -> chex.Array: ones = jnp.ones_like(a) a = jnp.expand_dims(a, axis=-1) @@ -171,10 +171,9 @@ def scale_gradient(inputs: chex.ArrayTree, scale: float) -> chex.ArrayTree: # Special case scales of 1. and 0. for more efficiency. if scale == 1.0: return inputs - elif scale == 0.0: + if scale == 0.0: return jax.lax.stop_gradient(inputs) - else: - return _scale_gradient(inputs, scale) + return _scale_gradient(inputs, scale) def _extract_fns_kwargs( diff --git a/optax/_src/utils_test.py b/optax/_src/utils_test.py index 1b5c3dc42..1c96e6784 100644 --- a/optax/_src/utils_test.py +++ b/optax/_src/utils_test.py @@ -43,7 +43,7 @@ def fn(inputs): outputs = jax.tree.map(lambda x: x**2, outputs) return sum(jax.tree.leaves(outputs)) - inputs = dict(a=-1.0, b=dict(c=(2.0,), d=0.0)) + inputs = {"a": -1.0, "b": {"c": (2.0,), "d": 0.0}} grad = jax.grad(fn) grads = grad(inputs) diff --git a/optax/assignment/__init__.py b/optax/assignment/__init__.py index 9aecb2365..cc140c6ca 100644 --- a/optax/assignment/__init__.py +++ b/optax/assignment/__init__.py @@ -14,6 +14,4 @@ # ============================================================================== """The assignment sub-package.""" -# pylint:disable=g-importing-member - from optax.assignment._hungarian_algorithm import hungarian_algorithm diff --git a/optax/assignment/_hungarian_algorithm.py b/optax/assignment/_hungarian_algorithm.py index 7bac4a793..5e4b6a8d1 100644 --- a/optax/assignment/_hungarian_algorithm.py +++ b/optax/assignment/_hungarian_algorithm.py @@ -127,8 +127,7 @@ def hungarian_algorithm(cost_matrix): if transpose: i = col4row.argsort() return col4row[i], i - else: - return jnp.arange(cost_matrix.shape[0]), col4row + return jnp.arange(cost_matrix.shape[0]), col4row def _find_short_augpath_while_body_inner_for(it, val): diff --git a/optax/assignment/_hungarian_algorithm_test.py b/optax/assignment/_hungarian_algorithm_test.py index 7cf3aeb59..a0399e210 100644 --- a/optax/assignment/_hungarian_algorithm_test.py +++ b/optax/assignment/_hungarian_algorithm_test.py @@ -19,9 +19,10 @@ import jax import jax.numpy as jnp import jax.random as jrd -from optax.assignment import _hungarian_algorithm import scipy +from optax.assignment import _hungarian_algorithm + class HungarianAlgorithmTest(parameterized.TestCase): diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index a310cc23b..ed4f120ee 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -14,8 +14,6 @@ # ============================================================================== """Contributed optimizers in Optax.""" -# pylint: disable=g-importing-member - from optax.contrib._acprop import acprop from optax.contrib._acprop import scale_by_acprop from optax.contrib._cocob import cocob diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index b20118483..720c77c98 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -36,23 +36,23 @@ # Testing contributions coded as GradientTransformations _MAIN_OPTIMIZERS_UNDER_TEST = [ - dict(opt_name='acprop', opt_kwargs=dict(learning_rate=1e-3)), - dict(opt_name='cocob', opt_kwargs={}), - dict(opt_name='cocob', opt_kwargs=dict(weight_decay=1e-2)), - dict(opt_name='dadapt_adamw', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='dog', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='dowg', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)), - dict( - opt_name='schedule_free_sgd', - opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), - ), - dict( - opt_name='schedule_free_adamw', - opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), - ), + {"opt_name": 'acprop', "opt_kwargs": {"learning_rate": 1e-3}}, + {"opt_name": 'cocob', "opt_kwargs": {}}, + {"opt_name": 'cocob', "opt_kwargs": {"weight_decay": 1e-2}}, + {"opt_name": 'dadapt_adamw', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'dog', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'dowg', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'momo', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'momo_adam', "opt_kwargs": {"learning_rate": 1e-1}}, + {"opt_name": 'prodigy', "opt_kwargs": {"learning_rate": 1e-1}}, + { + "opt_name": 'schedule_free_sgd', + "opt_kwargs": {"learning_rate": 1e-2, "warmup_steps": 5000}, + }, + { + "opt_name": 'schedule_free_adamw', + "opt_kwargs": {"learning_rate": 1e-2, "warmup_steps": 5000}, + }, ] for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST: optimizer['wrapper_name'] = None @@ -61,59 +61,59 @@ # Testing contributions coded as wrappers # (just with sgd as we just want the behavior of the wrapper) _MAIN_OPTIMIZERS_UNDER_TEST += [ - dict( - opt_name='sgd', - opt_kwargs=dict(learning_rate=1e-1), - wrapper_name='mechanize', - wrapper_kwargs=dict(weight_decay=0.0), - ), - dict( - opt_name='sgd', - opt_kwargs=dict(learning_rate=1e-2), - wrapper_name='schedule_free', - wrapper_kwargs=dict(learning_rate=1e-2), - ), - dict( - opt_name='sgd', - opt_kwargs=dict(learning_rate=1e-3), - wrapper_name='reduce_on_plateau', - wrapper_kwargs={}, - ), + { + "opt_name": 'sgd', + "opt_kwargs": {"learning_rate": 1e-1}, + "wrapper_name": 'mechanize', + "wrapper_kwargs": {"weight_decay": 0.0}, + }, + { + "opt_name": 'sgd', + "opt_kwargs": {"learning_rate": 1e-2}, + "wrapper_name": 'schedule_free', + "wrapper_kwargs": {"learning_rate": 1e-2}, + }, + { + "opt_name": 'sgd', + "opt_kwargs": {"learning_rate": 1e-3}, + "wrapper_name": 'reduce_on_plateau', + "wrapper_kwargs": {}, + }, ] # Adding here instantiations of wrappers with any base optimizer _BASE_OPTIMIZERS = [ - dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name='adam', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adan', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='lion', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), - dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), - dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1.0)), - dict( - opt_name='optimistic_gradient_descent', - opt_kwargs=dict(learning_rate=1.0, alpha=0.7, beta=0.1), - ), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), + {"opt_name": 'sgd', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'sgd', "opt_kwargs": {"learning_rate": 1.0, "momentum": 0.9}}, + {"opt_name": 'adam', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'adamw', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'adamax', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'adamaxw', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'adan', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'amsgrad', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'lamb', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'lion', "opt_kwargs": {"learning_rate": 1.0, "b1": 0.99}}, + {"opt_name": 'noisy_sgd', "opt_kwargs": {"learning_rate": 1.0, "eta": 1e-4}}, + {"opt_name": 'novograd', "opt_kwargs": {"learning_rate": 1.0}}, + { + "opt_name": 'optimistic_gradient_descent', + "opt_kwargs": {"learning_rate": 1.0, "alpha": 0.7, "beta": 0.1}, + }, + {"opt_name": 'rmsprop', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'rmsprop', "opt_kwargs": {"learning_rate": 1.0, "momentum": 0.9}}, + {"opt_name": 'adabelief', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'radam', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'sm3', "opt_kwargs": {"learning_rate": 1.0}}, + {"opt_name": 'yogi', "opt_kwargs": {"learning_rate": 1.0, "b1": 0.99}}, ] # TODO(harshm): make LARS and Fromage work with mechanic. _OTHER_OPTIMIZERS_UNDER_TEST = [ - dict( - opt_name=base_opt['opt_name'], - opt_kwargs=base_opt['opt_kwargs'], - wrapper_name='mechanize', - wrapper_kwargs=dict(weight_decay=0.0), - ) + { + "opt_name": base_opt['opt_name'], + "opt_kwargs": base_opt['opt_kwargs'], + "wrapper_name": 'mechanize', + "wrapper_kwargs": {"weight_decay": 0.0}, + } for base_opt in _BASE_OPTIMIZERS ] @@ -135,8 +135,7 @@ def _get_opt_factory(opt_name): def _wrap_opt(opt, wrapper_name, wrapper_kwargs): if wrapper_name == 'reduce_on_plateau': return combine.chain(opt, contrib.reduce_on_plateau(**wrapper_kwargs)) - else: - return getattr(contrib, wrapper_name)(opt, **wrapper_kwargs) + return getattr(contrib, wrapper_name)(opt, **wrapper_kwargs) def _setup_parabola(dtype): @@ -311,7 +310,8 @@ def test_preserve_dtype( opt = _get_opt_factory(opt_name)(**opt_kwargs) if wrapper_name is not None: opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) params = jnp.array([1.0, 2.0], dtype=dtype) value, grads = jax.value_and_grad(fun)(params) @@ -341,7 +341,8 @@ def test_gradient_accumulation( opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) opt = _accumulation.MultiSteps(opt, every_k_schedule=4) - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) params = jnp.array([1.0, 2.0], dtype=dtype) value, grads = jax.value_and_grad(fun)(params) diff --git a/optax/contrib/_complex_valued.py b/optax/contrib/_complex_valued.py index 0674eff26..24b4b441e 100644 --- a/optax/contrib/_complex_valued.py +++ b/optax/contrib/_complex_valued.py @@ -56,8 +56,7 @@ def _complex_to_real_pair( """ if jnp.iscomplexobj(x): return SplitRealAndImaginaryArrays(x.real, x.imag) - else: - return x + return x def _real_pair_to_complex( @@ -75,8 +74,7 @@ def _real_pair_to_complex( """ if isinstance(x, SplitRealAndImaginaryArrays): return x.real + x.imaginary * 1j - else: - return x + return x class SplitRealAndImaginaryState(NamedTuple): diff --git a/optax/contrib/_reduce_on_plateau.py b/optax/contrib/_reduce_on_plateau.py index 8fcb15fc5..be554f5cb 100644 --- a/optax/contrib/_reduce_on_plateau.py +++ b/optax/contrib/_reduce_on_plateau.py @@ -86,12 +86,12 @@ def reduce_on_plateau( "Both rtol and atol must be non-negative, got " f"rtol = {rtol} and atol = {atol}." ) - elif rtol == 0.0 and atol == 0.0: + if rtol == 0.0 and atol == 0.0: raise ValueError( "At least one of rtol or atol must be positive, got " f"rtol = {rtol} and atol = {atol}." ) - elif rtol > 1.0: + if rtol > 1.0: raise ValueError( f"rtol must be less than or equal to 1.0, got rtol = {rtol}." ) diff --git a/optax/contrib/_sam_test.py b/optax/contrib/_sam_test.py index c3284cad3..04167fe26 100644 --- a/optax/contrib/_sam_test.py +++ b/optax/contrib/_sam_test.py @@ -27,11 +27,11 @@ from optax.tree_utils import _state_utils _BASE_OPTIMIZERS_UNDER_TEST = [ - dict(base_opt_name='sgd', base_opt_kwargs=dict(learning_rate=1e-3)), + {"base_opt_name": 'sgd', "base_opt_kwargs": {"learning_rate": 1e-3}}, ] _ADVERSARIAL_OPTIMIZERS_UNDER_TEST = [ - dict(adv_opt_name='sgd', adv_opt_kwargs=dict(learning_rate=1e-5)), - dict(adv_opt_name='adam', adv_opt_kwargs=dict(learning_rate=1e-4)), + {"adv_opt_name": 'sgd', "adv_opt_kwargs": {"learning_rate": 1e-5}}, + {"adv_opt_name": 'adam', "adv_opt_kwargs": {"learning_rate": 1e-4}}, ] @@ -79,7 +79,7 @@ def test_optimization( initial_params, final_params, get_updates = target(dtype) if opaque_mode: - update_kwargs = dict(grad_fn=lambda p, _: get_updates(p)) + update_kwargs = {"grad_fn": lambda p, _: get_updates(p)} else: update_kwargs = {} diff --git a/optax/contrib/_schedule_free_test.py b/optax/contrib/_schedule_free_test.py index 571616763..d2b80e48d 100644 --- a/optax/contrib/_schedule_free_test.py +++ b/optax/contrib/_schedule_free_test.py @@ -44,7 +44,8 @@ def test_learning_rate_zero(self): base_opt = alias.sgd(learning_rate=0.0, momentum=0.0) opt = _schedule_free.schedule_free(base_opt, learning_rate=0.0) initial_params = jnp.array([1.0, 2.0]) - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) @jax.jit def step(params, state): @@ -66,7 +67,8 @@ def step(params, state): def test_schedule_free_adamw(self): initial_params = jnp.array([1.0, 2.0]) - fun = lambda x: jnp.sum(x**2) + def fun(x): + return jnp.sum(x**2) def step(params, state, opt): updates = jax.grad(fun)(params) diff --git a/optax/losses/__init__.py b/optax/losses/__init__.py index 8d0a99bbc..6771f0866 100644 --- a/optax/losses/__init__.py +++ b/optax/losses/__init__.py @@ -14,8 +14,6 @@ # ============================================================================== """The losses sub-package.""" -# pylint:disable=g-importing-member - from optax.losses._classification import convex_kl_divergence from optax.losses._classification import ctc_loss from optax.losses._classification import ctc_loss_with_forward_probs diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 4c38d7626..087e1ae4f 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -789,12 +789,9 @@ def sigmoid_focal_loss( p_t = p * labels + (1 - p) * (1 - labels) loss = ce_loss * ((1 - p_t) ** gamma) - weighted = ( - lambda loss_arg: (alpha * labels + (1 - alpha) * (1 - labels)) * loss_arg - ) - not_weighted = lambda loss_arg: loss_arg + weighted = (alpha * labels + (1 - alpha) * (1 - labels)) * loss - loss = jax.lax.cond(alpha >= 0, weighted, not_weighted, loss) + loss = jnp.where(alpha >= 0, weighted, loss) return loss diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index a2665d2b5..e44dc4fd5 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -81,7 +81,7 @@ def test_gradient(self): order=1, ) - @parameterized.parameters(dict(size=5), dict(size=10)) + @parameterized.parameters({"size": 5}, {"size": 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.random.dirichlet(np.ones(size)) @@ -91,9 +91,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - dict(axis=0, shape=[4, 5, 6]), - dict(axis=1, shape=[4, 5, 6]), - dict(axis=2, shape=[4, 5, 6]), + {"axis": 0, "shape": [4, 5, 6]}, + {"axis": 1, "shape": [4, 5, 6]}, + {"axis": 2, "shape": [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -230,9 +230,9 @@ def test_gradient(self): ) @parameterized.parameters( - dict(axis=0, shape=[4, 5, 6]), - dict(axis=1, shape=[4, 5, 6]), - dict(axis=2, shape=[4, 5, 6]), + {"axis": 0, "shape": [4, 5, 6]}, + {"axis": 1, "shape": [4, 5, 6]}, + {"axis": 2, "shape": [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -251,58 +251,58 @@ def test_axis(self, shape, axis): class SigmoidCrossEntropyTest(parameterized.TestCase): @parameterized.parameters( - dict( - preds=np.array([-1e09, -1e-09]), - labels=np.array([1.0, 0.0]), - expected=5e08, - ), - dict( - preds=np.array([-1e09, -1e-09]), - labels=np.array([0.0, 1.0]), - expected=0.3465736, - ), - dict( - preds=np.array([1e09, 1e-09]), - labels=np.array([1.0, 0.0]), - expected=0.3465736, - ), - dict( - preds=np.array([1e09, 1e-09]), - labels=np.array([0.0, 1.0]), - expected=5e08, - ), - dict( - preds=np.array([-1e09, 1e-09]), - labels=np.array([1.0, 0.0]), - expected=5e08, - ), - dict( - preds=np.array([-1e09, 1e-09]), - labels=np.array([0.0, 1.0]), - expected=0.3465736, - ), - dict( - preds=np.array([1e09, -1e-09]), - labels=np.array([1.0, 0.0]), - expected=0.3465736, - ), - dict( - preds=np.array([1e09, -1e-09]), - labels=np.array([0.0, 1.0]), - expected=5e08, - ), - dict( - preds=np.array([0.0, 0.0]), - labels=np.array([1.0, 0.0]), - expected=0.6931472, - ), - dict( - preds=np.array([0.0, 0.0]), - labels=np.array([0.0, 1.0]), - expected=0.6931472, - ), + { + "preds": np.array([-1e09, -1e-09]), + "labels": np.array([1.0, 0.0]), + "expected": 5e08, + }, + { + "preds": np.array([-1e09, -1e-09]), + "labels": np.array([0.0, 1.0]), + "expected": 0.3465736, + }, + { + "preds": np.array([1e09, 1e-09]), + "labels": np.array([1.0, 0.0]), + "expected": 0.3465736, + }, + { + "preds": np.array([1e09, 1e-09]), + "labels": np.array([0.0, 1.0]), + "expected": 5e08, + }, + { + "preds": np.array([-1e09, 1e-09]), + "labels": np.array([1.0, 0.0]), + "expected": 5e08, + }, + { + "preds": np.array([-1e09, 1e-09]), + "labels": np.array([0.0, 1.0]), + "expected": 0.3465736, + }, + { + "preds": np.array([1e09, -1e-09]), + "labels": np.array([1.0, 0.0]), + "expected": 0.3465736, + }, + { + "preds": np.array([1e09, -1e-09]), + "labels": np.array([0.0, 1.0]), + "expected": 5e08, + }, + { + "preds": np.array([0.0, 0.0]), + "labels": np.array([1.0, 0.0]), + "expected": 0.6931472, + }, + { + "preds": np.array([0.0, 0.0]), + "labels": np.array([0.0, 1.0]), + "expected": 0.6931472, + }, ) - def testSigmoidCrossEntropy(self, preds, labels, expected): + def test_sigmoid_cross_entropy(self, preds, labels, expected): tested = jnp.mean( _classification.sigmoid_binary_cross_entropy(preds, labels) ) @@ -323,14 +323,14 @@ def setUp(self): @chex.all_variants @parameterized.parameters( - dict(eps=2, expected=4.5317), - dict(eps=1, expected=3.7153), - dict(eps=-1, expected=2.0827), - dict(eps=0, expected=2.8990), - dict(eps=-0.5, expected=2.4908), - dict(eps=1.15, expected=3.8378), - dict(eps=1.214, expected=3.8900), - dict(eps=5.45, expected=7.3480), + {"eps": 2, "expected": 4.5317}, + {"eps": 1, "expected": 3.7153}, + {"eps": -1, "expected": 2.0827}, + {"eps": 0, "expected": 2.8990}, + {"eps": -0.5, "expected": 2.4908}, + {"eps": 1.15, "expected": 3.8378}, + {"eps": 1.214, "expected": 3.8900}, + {"eps": 5.45, "expected": 7.3480}, ) def test_scalar(self, eps, expected): np.testing.assert_allclose( @@ -343,13 +343,13 @@ def test_scalar(self, eps, expected): @chex.all_variants @parameterized.parameters( - dict(eps=2, expected=np.array([0.4823, 1.2567])), - dict(eps=1, expected=np.array([0.3261, 1.0407])), - dict(eps=0, expected=np.array([0.1698, 0.8247])), - dict(eps=-0.5, expected=np.array([0.0917, 0.7168])), - dict(eps=1.15, expected=np.array([0.3495, 1.0731])), - dict(eps=1.214, expected=np.array([0.3595, 1.0870])), - dict(eps=5.45, expected=np.array([1.0211, 2.0018])), + {"eps": 2, "expected": np.array([0.4823, 1.2567])}, + {"eps": 1, "expected": np.array([0.3261, 1.0407])}, + {"eps": 0, "expected": np.array([0.1698, 0.8247])}, + {"eps": -0.5, "expected": np.array([0.0917, 0.7168])}, + {"eps": 1.15, "expected": np.array([0.3495, 1.0731])}, + {"eps": 1.214, "expected": np.array([0.3595, 1.0870])}, + {"eps": 5.45, "expected": np.array([1.0211, 2.0018])}, ) def test_batched(self, eps, expected): np.testing.assert_allclose( @@ -362,28 +362,28 @@ def test_batched(self, eps, expected): @chex.all_variants @parameterized.parameters( - dict( - logits=np.array( + { + "logits": np.array( [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0], [0.134, 1.234, 3.235]] ), - labels=np.array( + "labels": np.array( [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2], [0.34, 0.33, 0.33]] ), - ), - dict( - logits=np.array([[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]), - labels=np.array([[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]), - ), - dict( - logits=np.array( + }, + { + "logits": np.array([[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]), + "labels": np.array([[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]), + }, + { + "logits": np.array( [[4.0, 2.0, 1.0, 0.134, 1.3515], [0.0, 5.0, 1.0, 0.5215, 5.616]] ), - labels=np.array( + "labels": np.array( [[0.5, 0.0, 0.0, 0.0, 0.5], [0.0, 0.12, 0.2, 0.56, 0.12]] ), - ), - dict(logits=np.array([1.89, 2.39]), labels=np.array([0.34, 0.66])), - dict(logits=np.array([0.314]), labels=np.array([1.0])), + }, + {"logits": np.array([1.89, 2.39]), "labels": np.array([0.34, 0.66])}, + {"logits": np.array([0.314]), "labels": np.array([1.0])}, ) def test_equals_to_cross_entropy_when_eps0(self, logits, labels): np.testing.assert_allclose( @@ -394,7 +394,7 @@ def test_equals_to_cross_entropy_when_eps0(self, logits, labels): atol=1e-4, ) - @parameterized.parameters(dict(size=5), dict(size=10)) + @parameterized.parameters({"size": 5}, {"size": 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.random.dirichlet(np.ones(size)) @@ -404,9 +404,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - dict(axis=0, shape=[4, 5, 6]), - dict(axis=1, shape=[4, 5, 6]), - dict(axis=2, shape=[4, 5, 6]), + {"axis": 0, "shape": [4, 5, 6]}, + {"axis": 1, "shape": [4, 5, 6]}, + {"axis": 2, "shape": [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -482,10 +482,9 @@ def reference_impl(label, logit): scores = -(2 * label - 1) * logit if scores <= -1.0: return 0.0 - elif scores >= 1.0: + if scores >= 1.0: return scores - else: - return (scores + 1.0) ** 2 / 4 + return (scores + 1.0) ** 2 / 4 expected = reference_impl(label, score) result = _classification.sparsemax_loss( @@ -501,10 +500,9 @@ def reference_impl(label, logit): scores = -(2 * label - 1) * logit if scores <= -1.0: return 0.0 - elif scores >= 1.0: + if scores >= 1.0: return scores - else: - return (scores + 1.0) ** 2 / 4 + return (scores + 1.0) ** 2 / 4 expected = jnp.asarray([ reference_impl(labels[0], scores[0]), @@ -569,7 +567,7 @@ def test_batched(self): atol=1e-4, ) - @parameterized.parameters(dict(size=5), dict(size=10)) + @parameterized.parameters({"size": 5}, {"size": 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.random.dirichlet(np.ones(size)) @@ -579,9 +577,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - dict(axis=0, shape=[4, 5, 6]), - dict(axis=1, shape=[4, 5, 6]), - dict(axis=2, shape=[4, 5, 6]), + {"axis": 0, "shape": [4, 5, 6]}, + {"axis": 1, "shape": [4, 5, 6]}, + {"axis": 2, "shape": [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -675,7 +673,7 @@ def test_batched(self): atol=1e-4, ) - @parameterized.parameters(dict(size=5), dict(size=10)) + @parameterized.parameters({"size": 5}, {"size": 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.random.dirichlet(np.ones(size)) @@ -685,9 +683,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - dict(axis=0, shape=[4, 5, 6]), - dict(axis=1, shape=[4, 5, 6]), - dict(axis=2, shape=[4, 5, 6]), + {"axis": 0, "shape": [4, 5, 6]}, + {"axis": 1, "shape": [4, 5, 6]}, + {"axis": 2, "shape": [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -735,7 +733,7 @@ def test_batched(self): atol=1e-4, ) - @parameterized.parameters(dict(size=5), dict(size=10)) + @parameterized.parameters({"size": 5}, {"size": 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.log(np.random.dirichlet(np.ones(size))) @@ -746,9 +744,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - dict(axis=0, shape=[4, 5, 6]), - dict(axis=1, shape=[4, 5, 6]), - dict(axis=2, shape=[4, 5, 6]), + {"axis": 0, "shape": [4, 5, 6]}, + {"axis": 1, "shape": [4, 5, 6]}, + {"axis": 2, "shape": [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -934,7 +932,8 @@ def setUp(self): self.ts = np.array([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) self._rtol = 5e-3 if jax.default_backend() != 'cpu' else 1e-6 - logit = lambda x: jnp.log(x / (1.0 - x)) + def logit(x): + return jnp.log(x / (1.0 - x)) self.large_ys = logit(jnp.array([0.9, 0.98, 0.3, 0.99])) self.small_ys = logit(jnp.array([0.1, 0.02, 0.09, 0.15])) self.ones_ts = jnp.array([1.0, 1.0, 1.0, 1.0]) diff --git a/optax/losses/_ranking_test.py b/optax/losses/_ranking_test.py index 1765b958a..2368c79cc 100644 --- a/optax/losses/_ranking_test.py +++ b/optax/losses/_ranking_test.py @@ -28,8 +28,10 @@ # Export symbols from math for conciser test value definitions. exp = math.exp log = math.log -logloss = lambda x: log(1.0 + exp(-x)) -sigmoid = lambda x: 1.0 / (1.0 + exp(-x)) +def logloss(x): + return log(1.0 + exp(-x)) +def sigmoid(x): + return 1.0 / (1.0 + exp(-x)) class RankingLossesTest(parameterized.TestCase): diff --git a/optax/losses/_regression_test.py b/optax/losses/_regression_test.py index 08cf6b8a0..4bd8b0415 100644 --- a/optax/losses/_regression_test.py +++ b/optax/losses/_regression_test.py @@ -185,7 +185,7 @@ def test_batched_similarity(self): atol=1e-4, ) - @parameterized.parameters(dict(size=5), dict(size=10)) + @parameterized.parameters({"size": 5}, {"size": 10}) def test_mask_distance(self, size): preds = np.random.normal(size=size) targets = np.random.normal(size=size) @@ -194,7 +194,7 @@ def test_mask_distance(self, size): y = _regression.cosine_distance(preds, targets, where=mask) np.testing.assert_allclose(x, y, atol=1e-4) - @parameterized.parameters(dict(size=5), dict(size=10)) + @parameterized.parameters({"size": 5}, {"size": 10}) def test_mask_similarity(self, size): preds = np.random.normal(size=size) targets = np.random.normal(size=size) @@ -204,9 +204,9 @@ def test_mask_similarity(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - dict(axis=0, shape=[4, 5, 6]), - dict(axis=1, shape=[4, 5, 6]), - dict(axis=2, shape=[4, 5, 6]), + {"axis": 0, "shape": [4, 5, 6]}, + {"axis": 1, "shape": [4, 5, 6]}, + {"axis": 2, "shape": [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) diff --git a/optax/monte_carlo/control_variates_test.py b/optax/monte_carlo/control_variates_test.py index cb34ab5ce..e59aec84b 100644 --- a/optax/monte_carlo/control_variates_test.py +++ b/optax/monte_carlo/control_variates_test.py @@ -74,7 +74,7 @@ class DeltaControlVariateTest(chex.TestCase): @chex.all_variants @parameterized.parameters([(1.0, 0.5)]) - def testQuadraticFunction(self, effective_mean, effective_log_scale): + def test_quadratic_function(self, effective_mean, effective_log_scale): data_dims = 20 num_samples = 10**6 rng = jax.random.PRNGKey(1) @@ -87,7 +87,8 @@ def testQuadraticFunction(self, effective_mean, effective_log_scale): dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) - function = lambda x: jnp.sum(x**2) + def function(x): + return jnp.sum(x**2) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) @@ -100,7 +101,7 @@ def testQuadraticFunction(self, effective_mean, effective_log_scale): @chex.all_variants @parameterized.parameters([(1.0, 1.0)]) - def testPolynomialFunction(self, effective_mean, effective_log_scale): + def test_polynomial_function(self, effective_mean, effective_log_scale): data_dims = 10 num_samples = 10**3 @@ -113,7 +114,8 @@ def testPolynomialFunction(self, effective_mean, effective_log_scale): dist = utils.multi_normal(*params) rng = jax.random.PRNGKey(1) dist_samples = dist.sample((num_samples,), rng) - function = lambda x: jnp.sum(x**5) + def function(x): + return jnp.sum(x**5) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) @@ -123,7 +125,7 @@ def testPolynomialFunction(self, effective_mean, effective_log_scale): _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3) @chex.all_variants - def testNonPolynomialFunction(self): + def test_non_polynomial_function(self): data_dims = 10 num_samples = 10**3 @@ -134,7 +136,8 @@ def testNonPolynomialFunction(self): rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) - function = lambda x: jnp.sum(jnp.log(x**2)) + def function(x): + return jnp.sum(jnp.log(x**2)) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) @@ -154,8 +157,8 @@ class MovingAverageBaselineTest(chex.TestCase): @chex.all_variants @parameterized.parameters([(1.0, 0.5, 0.9), (1.0, 0.5, 0.99)]) - def testLinearFunction(self, effective_mean, effective_log_scale, decay): - weights = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) + def test_linear_function(self, effective_mean, effective_log_scale, decay): + weights = jnp.array([1., 2., 3.], dtype=jnp.float32) num_samples = 10**4 data_dims = len(weights) @@ -165,7 +168,8 @@ def testLinearFunction(self, effective_mean, effective_log_scale, decay): ) params = [mean, log_scale] - function = lambda x: jnp.sum(weights * x) + def function(x): + return jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) @@ -204,10 +208,10 @@ def testLinearFunction(self, effective_mean, effective_log_scale, decay): @chex.all_variants @parameterized.parameters([(1.0, 0.5, 0.9), (1.0, 0.5, 0.99)]) - def testLinearFunctionWithHeuristic( + def test_linear_function_with_heuristic( self, effective_mean, effective_log_scale, decay ): - weights = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) + weights = jnp.array([1., 2., 3.], dtype=jnp.float32) num_samples = 10**5 data_dims = len(weights) @@ -217,7 +221,8 @@ def testLinearFunctionWithHeuristic( ) params = [mean, log_scale] - function = lambda x: jnp.sum(weights * x) + def function(x): + return jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) @@ -259,10 +264,10 @@ def testLinearFunctionWithHeuristic( ) @parameterized.parameters([(1.0, 0.5, 0.9), (1.0, 0.5, 0.99)]) - def testLinearFunctionZeroDebias( + def test_linear_function_zero_debias( self, effective_mean, effective_log_scale, decay ): - weights = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) + weights = jnp.array([1., 2., 3.], dtype=jnp.float32) num_samples = 10**5 data_dims = len(weights) @@ -272,7 +277,8 @@ def testLinearFunctionZeroDebias( ) params = [mean, log_scale] - function = lambda x: jnp.sum(weights * x) + def function(x): + return jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) @@ -306,35 +312,22 @@ class DeltaMethodAnalyticalExpectedGrads(chex.TestCase): @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ( - '_score_function_jacobians', - 1.0, - 1.0, - sge.score_function_jacobians, - ), - ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), - ( - '_measure_valued_jacobians', - 1.0, - 1.0, - sge.measure_valued_jacobians, - ), - ], - [ - ('estimate_cv_coeffs', True), - ('no_estimate_cv_coeffs', False), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), + ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), + ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), + ], [ + ('estimate_cv_coeffs', True), + ('no_estimate_cv_coeffs', False), + ], + named=True) ) - def testQuadraticFunction( - self, - effective_mean, - effective_log_scale, - grad_estimator, - estimate_cv_coeffs, + def test_quadratic_function( + self, + effective_mean, + effective_log_scale, + grad_estimator, + estimate_cv_coeffs, ): data_dims = 3 num_samples = 10**3 @@ -345,7 +338,8 @@ def testQuadraticFunction( ) params = [mean, log_scale] - function = lambda x: jnp.sum(x**2) + def function(x): + return jnp.sum(x**2) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( @@ -389,30 +383,17 @@ def testQuadraticFunction( @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ( - '_score_function_jacobians', - 1.0, - 1.0, - sge.score_function_jacobians, - ), - ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), - ( - '_measure_valued_jacobians', - 1.0, - 1.0, - sge.measure_valued_jacobians, - ), - ], - [ - ('estimate_cv_coeffs', True), - ('no_estimate_cv_coeffs', False), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), + ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), + ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), + ], [ + ('estimate_cv_coeffs', True), + ('no_estimate_cv_coeffs', False), + ], + named=True), ) - def testCubicFunction( + def test_cubic_function( self, effective_mean, effective_log_scale, @@ -428,7 +409,8 @@ def testCubicFunction( ) params = [mean, log_scale] - function = lambda x: jnp.sum(x**3) + def function(x): + return jnp.sum(x**3) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( @@ -476,30 +458,17 @@ def testCubicFunction( @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ( - '_score_function_jacobians', - 1.0, - 1.0, - sge.score_function_jacobians, - ), - ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), - ( - '_measure_valued_jacobians', - 1.0, - 1.0, - sge.measure_valued_jacobians, - ), - ], - [ - ('estimate_cv_coeffs', True), - ('no_estimate_cv_coeffs', False), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), + ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), + ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), + ], [ + ('estimate_cv_coeffs', True), + ('no_estimate_cv_coeffs', False), + ], + named=True), ) - def testForthPowerFunction( + def test_forth_power_function( self, effective_mean, effective_log_scale, @@ -515,7 +484,8 @@ def testForthPowerFunction( ) params = [mean, log_scale] - function = lambda x: jnp.sum(x**4) + def function(x): + return jnp.sum(x**4) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( @@ -573,34 +543,27 @@ class ConsistencyWithStandardEstimators(chex.TestCase): @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ('_score_function_jacobians', 1, 1, sge.score_function_jacobians), - ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians), - ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians), - ], - [ - ( - 'control_delta_method', - 10**5, - control_variates.control_delta_method, - ), - ( - 'moving_avg_baseline', - 10**6, - control_variates.moving_avg_baseline, - ), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', 1, 1, sge.score_function_jacobians), + ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians), + ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians), + ], [ + ( + 'control_delta_method', + 10**5, + control_variates.control_delta_method + ), + ('moving_avg_baseline', 10**6, control_variates.moving_avg_baseline), + ], + named=True), ) - def testWeightedLinearFunction( - self, - effective_mean, - effective_log_scale, - grad_estimator, - num_samples, - control_variate_from_function, + def test_weighted_linear_function( + self, + effective_mean, + effective_log_scale, + grad_estimator, + num_samples, + control_variate_from_function, ): """Check that the gradients are consistent between estimators.""" weights = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) @@ -612,7 +575,8 @@ def testWeightedLinearFunction( ) params = [mean, log_scale] - function = lambda x: jnp.sum(weights * x) + def function(x): + return jnp.sum(weights * x) rng = jax.random.PRNGKey(1) cv_rng, ge_rng = jax.random.split(rng) @@ -659,32 +623,19 @@ def testWeightedLinearFunction( @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ( - '_score_function_jacobians', - 1, - 1, - sge.score_function_jacobians, - 10**5, - ), - ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5), - ( - '_measure_valued_jacobians', - 1, - 1, - sge.measure_valued_jacobians, - 10**5, - ), - ], - [ - ('control_delta_method', control_variates.control_delta_method), - ('moving_avg_baseline', control_variates.moving_avg_baseline), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', 1, 1, sge.score_function_jacobians, + 10**5), + ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5), + ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians, + 10**5), + ], [ + ('control_delta_method', control_variates.control_delta_method), + ('moving_avg_baseline', control_variates.moving_avg_baseline), + ], + named=True), ) - def testNonPolynomialFunction( + def test_non_polynomial_function( self, effective_mean, effective_log_scale, @@ -701,7 +652,8 @@ def testNonPolynomialFunction( ) params = [mean, log_scale] - function = lambda x: jnp.log(jnp.sum(x**2)) + def function(x): + return jnp.log(jnp.sum(x**2)) rng = jax.random.PRNGKey(1) cv_rng, ge_rng = jax.random.split(rng) diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index 541b697a3..18dd543bb 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -84,7 +84,8 @@ def score_function_jacobians( def surrogate(params): dist = dist_builder(*params) - one_sample_surrogate_fn = lambda x: function(x) * dist.log_prob(x) + def one_sample_surrogate_fn(x): + return function(x) * dist.log_prob(x) samples = jax.lax.stop_gradient(dist.sample((num_samples,), seed=rng)) # We vmap the function application over samples - this ensures that the # function we use does not have to be vectorized itself. diff --git a/optax/monte_carlo/stochastic_gradient_estimators_test.py b/optax/monte_carlo/stochastic_gradient_estimators_test.py index b71dde808..0fb756731 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators_test.py +++ b/optax/monte_carlo/stochastic_gradient_estimators_test.py @@ -76,21 +76,18 @@ class GradientEstimatorsTest(chex.TestCase): @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], - [ - ('0.1', 0.1), - ('0.5', 0.5), - ('0.9', 0.9), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('0.1', 0.1), + ('0.5', 0.5), + ('0.9', 0.9), + ], + named=True), ) - def testConstantFunction(self, estimator, constant): + def test_constant_function(self, estimator, constant): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] @@ -125,21 +122,18 @@ def testConstantFunction(self, estimator, constant): @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], - [ - ('0.5_-1.', 0.5, -1.0), - ('0.7_0.0)', 0.7, 0.0), - ('0.8_0.1', 0.8, 0.1), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('0.5_-1.', 0.5, -1.), + ('0.7_0.0)', 0.7, 0.0), + ('0.8_0.1', 0.8, 0.1), + ], + named=True), ) - def testLinearFunction(self, estimator, effective_mean, effective_log_scale): + def test_linear_function(self, estimator, effective_mean, effective_log_scale): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] rng = jax.random.PRNGKey(1) @@ -166,19 +160,16 @@ def testLinearFunction(self, estimator, effective_mean, effective_log_scale): @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], - [ - ('1.0_0.3', 1.0, 0.3), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('1.0_0.3', 1.0, 0.3), + ], + named=True), ) - def testQuadraticFunction( + def test_quadratic_function( self, estimator, effective_mean, effective_log_scale ): data_dims = 3 @@ -213,21 +204,18 @@ def testQuadraticFunction( @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], - [ - ('case_1', [1.0, 2.0, 3.0], [-1.0, 0.3, -2.0], [1.0, 1.0, 1.0]), - ('case_2', [1.0, 2.0, 3.0], [-1.0, 0.3, -2.0], [4.0, 2.0, 3.0]), - ('case_3', [1.0, 2.0, 3.0], [0.1, 0.2, 0.1], [10.0, 5.0, 1.0]), - ], - named=True, - ) + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]), + ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]), + ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [10., 5., 1.]), + ], + named=True), ) - def testWeightedLinear( + def test_weighted_linear( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] @@ -239,7 +227,8 @@ def testWeightedLinear( data_dims = len(effective_mean) - function = lambda x: jnp.sum(x * weights) + def function(x): + return jnp.sum(x * weights) jacobians = _estimator_variant(self.variant, estimator)( function, [mean, log_scale], utils.multi_normal, rng, num_samples ) @@ -260,21 +249,17 @@ def testWeightedLinear( @chex.all_variants @parameterized.named_parameters( - chex.params_product( - [ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], - [ - ('case_1', [1.0, 2.0, 3.0], [-1.0, 0.3, -2.0], [1.0, 1.0, 1.0]), - ('case_2', [1.0, 2.0, 3.0], [-1.0, 0.3, -2.0], [4.0, 2.0, 3.0]), - ('case_3', [1.0, 2.0, 3.0], [0.1, 0.2, 0.1], [3.0, 5.0, 1.0]), - ], - named=True, - ) - ) - def testWeightedQuadratic( + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]), + ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]), + ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [3., 5., 1.]), + ], + named=True)) + def test_weighted_quadratic( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] @@ -286,7 +271,8 @@ def testWeightedQuadratic( data_dims = len(effective_mean) - function = lambda x: jnp.sum(x * weights) ** 2 + def function(x): + return jnp.sum(x * weights) ** 2 jacobians = _estimator_variant(self.variant, estimator)( function, [mean, log_scale], utils.multi_normal, rng, num_samples ) @@ -333,10 +319,8 @@ def testWeightedQuadratic( ('coupling', True), ('nocoupling', False), ], - named=True, - ) - ) - def testNonPolynomialFunctionConsistencyWithPathwise( + named=True)) + def test_non_polynomial_function_consistency_with_pathwise( self, effective_mean, effective_log_scale, function, coupling ): num_samples = 10**5 @@ -401,11 +385,12 @@ class MeasuredValuedEstimatorsTest(chex.TestCase): @chex.all_variants @parameterized.parameters([True, False]) - def testRaisesErrorForNonGaussian(self, coupling): + def test_raises_error_for_non_gaussian(self, coupling): num_samples = 10**5 rng = jax.random.PRNGKey(1) - function = lambda x: jnp.sum(x) ** 2 + def function(x): + return jnp.sum(x) ** 2 mean = jnp.array(0, dtype=jnp.float32) log_scale = jnp.array(0.0, dtype=jnp.float32) diff --git a/optax/perturbations/__init__.py b/optax/perturbations/__init__.py index 521eefa2d..e64130af2 100644 --- a/optax/perturbations/__init__.py +++ b/optax/perturbations/__init__.py @@ -14,9 +14,6 @@ # ============================================================================== """The perturbations sub-package.""" -# pylint: disable=g-importing-member - from optax.perturbations._make_pert import Gumbel from optax.perturbations._make_pert import make_perturbed_fun from optax.perturbations._make_pert import Normal - diff --git a/optax/perturbations/_make_pert.py b/optax/perturbations/_make_pert.py index ae711f5d2..0978603a5 100644 --- a/optax/perturbations/_make_pert.py +++ b/optax/perturbations/_make_pert.py @@ -77,7 +77,8 @@ def _tree_mean_across(trees: Sequence[chex.ArrayTree]) -> chex.ArrayTree: ... ) {'first': [3, 4], 'last': 5} """ - mean_fun = lambda x: sum(x) / len(trees) + def mean_fun(x): + return sum(x) / len(trees) return jtu.tree_map(lambda *leaves: mean_fun(leaves), *trees) @@ -87,7 +88,8 @@ def _tree_vmap( ) -> chex.ArrayTree: """Applies a function to a list of trees, akin to a vmap.""" tree_def_in = jtu.tree_structure(trees[0]) - has_in_structure = lambda x: jtu.tree_structure(x) == tree_def_in + def has_in_structure(x): + return jtu.tree_structure(x) == tree_def_in return jtu.tree_map(fun, trees, is_leaf=has_in_structure) @@ -172,11 +174,13 @@ def fun_perturb_jvp( The jacobian vector product. """ outputs_pert, samples = _compute_residuals(inputs, rng) - array_sum_log_prob_func = lambda x: jnp.sum(noise.log_prob(x)) + def array_sum_log_prob_func(x): + return jnp.sum(noise.log_prob(x)) array_grad_log_prob_func = jax.grad(array_sum_log_prob_func) # computes [grad log_prob(Z_1), ... , grad log_prob(Z_num_samples)] tree_sum_log_probs = jtu.tree_map(array_grad_log_prob_func, samples) - fun_dot_prod = lambda z: jax.tree_util.tree_map(jnp.dot, z, tangent) + def fun_dot_prod(z): + return jax.tree_util.tree_map(jnp.dot, z, tangent) list_tree_dot_prods = _tree_vmap(fun_dot_prod, tree_sum_log_probs) # computes [, .. , ] list_dot_prods = _tree_vmap( diff --git a/optax/perturbations/_make_pert_test.py b/optax/perturbations/_make_pert_test.py index 394776b32..911cc650a 100644 --- a/optax/perturbations/_make_pert_test.py +++ b/optax/perturbations/_make_pert_test.py @@ -33,8 +33,8 @@ def one_hot_argmax(inputs: jnp.ndarray) -> jnp.ndarray: flat_one_hot = jax.nn.one_hot(jnp.argmax(inputs_flat), inputs_flat.shape[0]) return jnp.reshape(flat_one_hot, inputs.shape) - -argmax_tree = lambda x: jtu.tree_map(one_hot_argmax, x) +def argmax_tree(x): + return jtu.tree_map(one_hot_argmax, x) class MakePertTest(absltest.TestCase): @@ -67,12 +67,10 @@ def setUp(self): example_tree = [] for i in range(2): - example_tree.append( - dict( - weights=jnp.ones(weight_shapes[i]), - biases=jnp.ones(biases_shapes[i]), - ) - ) + example_tree.append({ + "weights": jnp.ones(weight_shapes[i]), + "biases": jnp.ones(biases_shapes[i]), + }) self.example_tree = example_tree self.element = jnp.array([1.0, 2.0, 3.0, 4.0]) @@ -92,7 +90,8 @@ def test_pert_close_array(self): argmax_tree, self.num_samples, self.sigma ) expected = pert_argmax_fun(self.array_a_jax, self.rng_jax) - softmax_fun = lambda x: jax.nn.softmax(x / self.sigma) + def softmax_fun(x): + return jax.nn.softmax(x / self.sigma) got = jtu.tree_map(softmax_fun, self.array_a_jax) np.testing.assert_array_almost_equal(expected, got, decimal=1) pert_argmax_fun_small = _make_pert.make_perturbed_fun( @@ -134,7 +133,7 @@ def apply_element_tree(tree): apply_tree = jtu.Partial(apply_both, tree) leaves, _ = jtu.tree_flatten(tree) return_tree = jtu.tree_map(apply_tree, self.element_tree) - return_tree.append(sum([jnp.sum(leaf) for leaf in leaves])) + return_tree.append(sum(jnp.sum(leaf) for leaf in leaves)) return return_tree tree_out = apply_element_tree(self.example_tree) diff --git a/optax/projections/_projections_test.py b/optax/projections/_projections_test.py index 0fa8c0447..2c79a08fc 100644 --- a/optax/projections/_projections_test.py +++ b/optax/projections/_projections_test.py @@ -39,12 +39,16 @@ def setUp(self): array_1d = jnp.array([0.5, 2.1, -3.5]) array_2d = jnp.array([[0.5, 2.1, -3.5], [1.0, 2.0, 3.0]]) tree = (array_1d, array_1d) - self.data = dict(array_1d=array_1d, array_2d=array_2d, tree=tree) - self.fns = dict( - l1=(proj.projection_l1_ball, otu.tree_l1_norm), - l2=(proj.projection_l2_ball, otu.tree_l2_norm), - linf=(proj.projection_linf_ball, otu.tree_linf_norm), - ) + self.data = { + "array_1d": array_1d, + "array_2d": array_2d, + "tree": tree, + } + self.fns = { + "l1": (proj.projection_l1_ball, otu.tree_l1_norm), + "l2": (proj.projection_l2_ball, otu.tree_l2_norm), + "linf": (proj.projection_linf_ball, otu.tree_linf_norm), + } def test_projection_non_negative(self): with self.subTest('with an array'): diff --git a/optax/schedules/_inject_test.py b/optax/schedules/_inject_test.py index 726d22de0..08d6993e5 100644 --- a/optax/schedules/_inject_test.py +++ b/optax/schedules/_inject_test.py @@ -222,7 +222,7 @@ def test_wrap_stateless_schedule(self): my_schedule(count), my_wrapped_schedule(state), atol=0.0 ) count = count + 1 - extra_args = dict(loss=jnp.ones([], dtype=jnp.float32)) + extra_args = {"loss": jnp.ones([], dtype=jnp.float32)} state = my_wrapped_schedule.update(state, **extra_args) np.testing.assert_allclose(count, state, atol=0.0) @@ -240,7 +240,7 @@ def test_inject_stateful_hyperparams(self): ) state = self.variant(tx.init)(params) - extra_args = dict(addendum=0.3 * jnp.ones((), dtype=jnp.float32)) + extra_args = {"addendum": 0.3 * jnp.ones((), dtype=jnp.float32)} _, state = self.variant(tx.update)( grads, state, params=params, **extra_args ) diff --git a/optax/schedules/_schedule_test.py b/optax/schedules/_schedule_test.py index 4dd243d02..baa46e73d 100644 --- a/optax/schedules/_schedule_test.py +++ b/optax/schedules/_schedule_test.py @@ -552,15 +552,13 @@ def test_limits(self, lr0, lr1, lr2): """Check cosine schedule decay for the entire training schedule.""" lr_kwargs = [] for step, lr in zip([2e3, 3e3, 5e3], [lr0, lr1, lr2]): - lr_kwargs += [ - dict( - decay_steps=int(step), - peak_value=lr, - init_value=0, - end_value=0.0, - warmup_steps=500, - ) - ] + lr_kwargs += [{ + "decay_steps": int(step), + "peak_value": lr, + "init_value": 0, + "end_value": 0.0, + "warmup_steps": 500, + }] schedule_fn = self.variant(_schedule.sgdr_schedule(lr_kwargs)) np.testing.assert_allclose(lr0, schedule_fn(500)) np.testing.assert_allclose(lr1, schedule_fn(2500)) diff --git a/optax/second_order/_hessian.py b/optax/second_order/_hessian.py index 943f34e1a..31c266e5a 100644 --- a/optax/second_order/_hessian.py +++ b/optax/second_order/_hessian.py @@ -52,7 +52,8 @@ def hvp( evaluated at `(params, inputs, targets)`. """ _, unravel_fn = flatten_util.ravel_pytree(params) - loss_fn = lambda p: loss(p, inputs, targets) + def loss_fn(p): + return loss(p, inputs, targets) return jax.jvp(jax.grad(loss_fn), [params], [unravel_fn(v)])[1] @@ -75,5 +76,6 @@ def hessian_diag( evaluated at `(params, inputs, targets)`. """ vs = jnp.eye(_ravel(params).size) - comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets))) + def comp(v): + return jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets))) return jax.vmap(comp)(vs) diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py index 07c70ef0c..72c2982bd 100644 --- a/optax/transforms/_accumulation.py +++ b/optax/transforms/_accumulation.py @@ -65,7 +65,8 @@ def init_fn(params): def update_fn(updates, state, params=None): del params - f = lambda g, t: g + decay * t + def f(g, t): + return g + decay * t new_trace = jax.tree.map(f, updates, state.trace) updates = jax.tree.map(f, updates, new_trace) if nesterov else new_trace new_trace = otu.tree_cast(new_trace, accumulator_dtype) @@ -176,9 +177,10 @@ def skip_not_finite( ] num_not_finite = jnp.sum(jnp.array(all_is_finite)) should_skip = num_not_finite > 0 - return should_skip, dict( - should_skip=should_skip, num_not_finite=num_not_finite - ) + return should_skip, { + "should_skip": should_skip, + "num_not_finite": num_not_finite, + } def skip_large_updates( @@ -208,7 +210,7 @@ def skip_large_updates( ) # This will also return True if `norm_sq` is NaN. should_skip = jnp.logical_not(norm_sq < max_squared_norm) - return should_skip, dict(should_skip=should_skip, norm_squared=norm_sq) + return should_skip, {"should_skip": should_skip, "norm_squared": norm_sq} class MultiStepsState(NamedTuple): diff --git a/optax/transforms/_accumulation_test.py b/optax/transforms/_accumulation_test.py index a7fc17180..93cceb7a9 100644 --- a/optax/transforms/_accumulation_test.py +++ b/optax/transforms/_accumulation_test.py @@ -215,9 +215,9 @@ def test_multi_steps_every_k_schedule(self): alias.sgd(1e-4), lambda grad_step: jnp.where(grad_step < 2, 1, 3) ) opt_init, opt_update = ms_opt.gradient_transformation() - params = dict(a=jnp.zeros([])) + params = {"a": jnp.zeros([])} opt_state = opt_init(params) - grad = dict(a=jnp.zeros([])) + grad = {"a": jnp.zeros([])} self.assertFalse(ms_opt.has_updated(opt_state)) # First two steps have 1 mini-step per update. for _ in range(2): @@ -239,9 +239,9 @@ def test_multi_steps_zero_nans(self): every_k_schedule=2, ) opt_init, opt_update = ms_opt.gradient_transformation() - params = dict(a=jnp.zeros([])) + params = {"a": jnp.zeros([])} opt_state = opt_init(params) - grad = dict(a=jnp.zeros([])) + grad = {"a": jnp.zeros([])} opt_update(grad, opt_state, params) def test_multi_steps_computes_mean(self): @@ -250,9 +250,9 @@ def test_multi_steps_computes_mean(self): transform.scale(1.0), k_steps, use_grad_mean=True ) opt_init, opt_update = ms_opt.gradient_transformation() - params = dict(a=jnp.zeros([])) + params = {"a": jnp.zeros([])} opt_state = opt_init(params) - grads = [dict(a=jnp.ones([]) * i) for i in [1, 2, 3, 4]] + grads = [{"a": jnp.ones([]) * i} for i in [1, 2, 3, 4]] self.assertFalse(ms_opt.has_updated(opt_state)) # First 3 steps don't update. @@ -275,39 +275,37 @@ def test_multi_steps_skip_not_finite(self): opt_init, opt_update = ms_opt.gradient_transformation() opt_init = jax.jit(opt_init) opt_update = jax.jit(opt_update) - params = dict(a=jnp.zeros([])) + params = {"a": jnp.zeros([])} opt_state = opt_init(params) with self.subTest('test_good_updates'): - updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) + updates, opt_state = opt_update({"a": jnp.ones([])}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 1) params = update.apply_updates(params, updates) - updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) + updates, opt_state = opt_update({"a": jnp.ones([])}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) params = update.apply_updates(params, updates) np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_inf_updates'): updates, opt_state = opt_update( - dict(a=jnp.array(float('inf'))), opt_state, params - ) + {"a": jnp.array(float('inf'))}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step params = update.apply_updates(params, updates) np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_nan_updates'): updates, opt_state = opt_update( - dict(a=jnp.full([], float('nan'))), opt_state, params - ) + {"a": jnp.full([], float('nan'))}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step params = update.apply_updates(params, updates) np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_final_good_updates'): - updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) + updates, opt_state = opt_update({"a": jnp.ones([])}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 1) params = update.apply_updates(params, updates) - updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) + updates, opt_state = opt_update({"a": jnp.ones([])}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) params = update.apply_updates(params, updates) np.testing.assert_array_equal( diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index cfc5c93fe..5fd759aab 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -94,7 +94,7 @@ def init_fn(params): count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed) ) - def update_fn(updates, state, params=None): # pylint: disable=missing-docstring + def update_fn(updates, state, params=None): del params count_inc = numerics.safe_increment(state.count) standard_deviation = jnp.sqrt(eta / count_inc**gamma) diff --git a/optax/transforms/_adding_test.py b/optax/transforms/_adding_test.py index a1cc3e924..61305d2be 100644 --- a/optax/transforms/_adding_test.py +++ b/optax/transforms/_adding_test.py @@ -35,31 +35,31 @@ def test_add_decayed_weights(self): # Define a transform that add decayed weights. # We can define a mask either as a pytree, or as a function that # returns the pytree. Below we define the pytree directly. - mask = (True, dict(a=True, b=False)) + mask = (True, {"a": True, "b": False}) tx = _adding.add_decayed_weights(0.1, mask=mask) # Define input updates and weights. updates = ( jnp.zeros((2,), dtype=jnp.float32), - dict( - a=jnp.zeros((2,), dtype=jnp.float32), - b=jnp.zeros((2,), dtype=jnp.float32), - ), + { + "a": jnp.zeros((2,), dtype=jnp.float32), + "b": jnp.zeros((2,), dtype=jnp.float32), + }, ) weights = ( jnp.ones((2,), dtype=jnp.float32), - dict( - a=jnp.ones((2,), dtype=jnp.float32), - b=jnp.ones((2,), dtype=jnp.float32), - ), + { + "a": jnp.ones((2,), dtype=jnp.float32), + "b": jnp.ones((2,), dtype=jnp.float32), + }, ) # This mask means that we will add decayed weights to the first two # terms in the input updates, but not to the last element. expected_tx_updates = ( - 0.1 * jnp.ones((2,), dtype=jnp.float32), - dict( - a=0.1 * jnp.ones((2,), dtype=jnp.float32), - b=jnp.zeros((2,), dtype=jnp.float32), - ), + 0.1*jnp.ones((2,), dtype=jnp.float32), + { + "a": 0.1*jnp.ones((2,), dtype=jnp.float32), + "b": jnp.zeros((2,), dtype=jnp.float32), + }, ) # Apply transform state = tx.init(weights) @@ -99,10 +99,10 @@ def test_add_noise_has_correct_variance_scaling(self): def test_none_argument(self): weights = ( jnp.ones((2,), dtype=jnp.float32), - dict( - a=jnp.ones((2,), dtype=jnp.float32), - b=jnp.ones((2,), dtype=jnp.float32), - ), + { + "a": jnp.ones((2,), dtype=jnp.float32), + "b": jnp.ones((2,), dtype=jnp.float32), + }, ) tf = _adding.add_decayed_weights(0.1, mask=None) tf.update(None, 0, weights) diff --git a/optax/transforms/_clipping_test.py b/optax/transforms/_clipping_test.py index 573df1511..11fba679f 100644 --- a/optax/transforms/_clipping_test.py +++ b/optax/transforms/_clipping_test.py @@ -48,7 +48,8 @@ def test_clip(self): ) def test_clip_by_block_rms(self): - rmf_fn = lambda t: jnp.sqrt(jnp.mean(t**2)) + def rmf_fn(t): + return jnp.sqrt(jnp.mean(t**2)) updates = self.per_step_updates for i in range(1, STEPS + 1): clipper = _clipping.clip_by_block_rms(1.0 / i) diff --git a/optax/transforms/_masking.py b/optax/transforms/_masking.py index bde1c3384..2e1ff4d3b 100644 --- a/optax/transforms/_masking.py +++ b/optax/transforms/_masking.py @@ -41,7 +41,7 @@ def _mask_callable( mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]], ): callable_leaves = jax.tree.leaves(jax.tree.map(callable, mask)) - return (len(callable_leaves) > 0) and all(callable_leaves) # pylint:disable=g-explicit-length-test + return (len(callable_leaves) > 0) and all(callable_leaves) def masked( @@ -103,8 +103,7 @@ def _maybe_mask(pytree): jax.tree.structure(pytree) == base_structure ): return mask_pytree(pytree, mask_tree) - else: - return pytree + return pytree return {k: _maybe_mask(v) for k, v in pytree_dict.items()} diff --git a/optax/transforms/_masking_test.py b/optax/transforms/_masking_test.py index 92594c975..4e687de10 100644 --- a/optax/transforms/_masking_test.py +++ b/optax/transforms/_masking_test.py @@ -172,8 +172,9 @@ def increment_dim_1(v): ) def test_masked(self, opt_builder, use_fn): mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} - mask_arg = lambda _: mask if use_fn else mask - params = {'a': 1.0, 'b': [2.0, 3.0], 'c': {'d': 4.0, 'e': (5.0, 6.0)}} + def mask_arg(_): + return mask if use_fn else mask + params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} params = jax.tree.map(jnp.asarray, params) input_updates = jax.tree.map(lambda x: x / 10.0, params) @@ -280,8 +281,9 @@ def test_empty(self, container): ) def test_tree_mismatch_fails(self, extra_key_in_mask, use_fn): mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} - mask_arg = lambda _: mask if use_fn else mask - params = {'a': 1.0, 'b': [2.0, 3.0], 'c': {'d': 4.0, 'e': (5.0, 6.0)}} + def mask_arg(_): + return mask if use_fn else mask + params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} params = jax.tree.map(jnp.asarray, params) if extra_key_in_mask: @@ -296,7 +298,8 @@ def test_tree_mismatch_fails(self, extra_key_in_mask, use_fn): @chex.all_variants def test_mask_fn(self): params = {'a': jnp.ones((1, 2)), 'b': (jnp.ones((1,)), np.ones((1, 2, 3)))} - mask_fn = lambda p: jax.tree.map(lambda x: x.ndim > 1, p) + def mask_fn(p): + return jax.tree.map(lambda x: x.ndim > 1, p) init_fn, update_fn = _masking.masked( transform.add_decayed_weights(0.1), mask_fn ) @@ -324,7 +327,8 @@ def test_nested_mask(self, opt_builder): 'linear_3': {'w': jnp.zeros((2, 3)), 'b': jnp.zeros(3)}, } - outer_mask = lambda p: jax.tree.map(lambda x: x.ndim > 1, p) + def outer_mask(p): + return jax.tree.map(lambda x: x.ndim > 1, p) inner_mask = jax.tree.map(lambda _: True, params) inner_mask['linear_2'] = False diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index 130b9d909..6c067486c 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -14,8 +14,6 @@ # ============================================================================== """The tree_utils sub-package.""" -# pylint: disable=g-importing-member - from optax.tree_utils._casting import tree_cast from optax.tree_utils._casting import tree_dtype from optax.tree_utils._random import tree_random_like diff --git a/optax/tree_utils/_casting.py b/optax/tree_utils/_casting.py index 99cb07e28..d34e2be04 100644 --- a/optax/tree_utils/_casting.py +++ b/optax/tree_utils/_casting.py @@ -44,8 +44,7 @@ def tree_cast( """ if dtype is not None: return jax.tree.map(lambda t: t.astype(dtype), tree) - else: - return tree + return tree def tree_dtype( @@ -118,26 +117,25 @@ def tree_dtype( dtype = jnp.asarray(leaves[0]).dtype _tree_assert_all_dtypes_equal(tree, dtype) return dtype - elif mixed_dtype_handler == 'promote': + if mixed_dtype_handler == 'promote': promoted_dtype = functools.reduce( jnp.promote_types, [jnp.asarray(x).dtype for x in leaves] ) return promoted_dtype - elif mixed_dtype_handler == 'highest': + if mixed_dtype_handler == 'highest': highest_dtype = functools.reduce( _higher_dtype, [jnp.asarray(x).dtype for x in leaves] ) return highest_dtype - elif mixed_dtype_handler == 'lowest': + if mixed_dtype_handler == 'lowest': lowest_dtype = functools.reduce( _lower_dtype, [jnp.asarray(x).dtype for x in leaves] ) return lowest_dtype - else: - raise ValueError( - f'Invalid value for {mixed_dtype_handler=}, possible values are: None,' - ' "promote", "highest", "lowest".' - ) + raise ValueError( + f'Invalid value for {mixed_dtype_handler=}, possible values are: None,' + ' "promote", "highest", "lowest".' + ) def _tree_assert_all_dtypes_equal( @@ -158,6 +156,7 @@ def _assert_dtypes_equal(path, x): if x_dtype != dtype: err_msg = f'Expected {dtype=} for {path} but got {x_dtype}.' return err_msg + return None err_msgs = jax.tree.leaves( jax.tree_util.tree_map_with_path(_assert_dtypes_equal, tree) @@ -184,13 +183,12 @@ def _lower_dtype( """ if jnp.promote_types(dtype1, dtype2) == dtype1: return dtype2 - elif jnp.promote_types(dtype1, dtype2) == dtype2: + if jnp.promote_types(dtype1, dtype2) == dtype2: return dtype1 - else: - raise ValueError( - f'Cannot compare dtype of {dtype1=} and {dtype2=}.' - f' Neither {dtype1} nor {dtype2} can be promoted to the other.' - ) + raise ValueError( + f'Cannot compare dtype of {dtype1=} and {dtype2=}.' + f' Neither {dtype1} nor {dtype2} can be promoted to the other.' + ) def _higher_dtype( @@ -210,5 +208,4 @@ def _higher_dtype( """ if _lower_dtype(dtype1, dtype2) == dtype1: return dtype2 - else: - return dtype1 + return dtype1 diff --git a/optax/tree_utils/_casting_test.py b/optax/tree_utils/_casting_test.py index 1ca145af6..96358d444 100644 --- a/optax/tree_utils/_casting_test.py +++ b/optax/tree_utils/_casting_test.py @@ -89,10 +89,10 @@ def test_tree_dtype(self): self.assertRaises(ValueError, otu.tree_dtype, tree, 'highest') @parameterized.named_parameters( - dict(testcase_name='empty_dict', tree={}), - dict(testcase_name='empty_list', tree=[]), - dict(testcase_name='empty_tuple', tree=()), - dict(testcase_name='empty_none', tree=None), + {"testcase_name": 'empty_dict', "tree": {}}, + {"testcase_name": 'empty_list', "tree": []}, + {"testcase_name": 'empty_tuple', "tree": ()}, + {"testcase_name": 'empty_none', "tree": None}, ) def test_tree_dtype_utilities_with_empty_trees(self, tree): """Test tree data type utilities on empty trees.""" diff --git a/optax/tree_utils/_random_test.py b/optax/tree_utils/_random_test.py index 8995c5329..fa7bd0d74 100644 --- a/optax/tree_utils/_random_test.py +++ b/optax/tree_utils/_random_test.py @@ -28,11 +28,11 @@ # We consider samplers with varying input dtypes, we do not test all possible # samplers from `jax.random`. _SAMPLER_DTYPES = ( - dict(sampler=jrd.normal, dtype=None), - dict(sampler=jrd.normal, dtype='bfloat16'), - dict(sampler=jrd.normal, dtype='float32'), - dict(sampler=jrd.rademacher, dtype='int32'), - dict(sampler=jrd.bits, dtype='uint32'), + {"sampler": jrd.normal, "dtype": None}, + {"sampler": jrd.normal, "dtype": 'bfloat16'}, + {"sampler": jrd.normal, "dtype": 'float32'}, + {"sampler": jrd.rademacher, "dtype": 'int32'}, + {"sampler": jrd.bits, "dtype": 'uint32'}, ) @@ -45,6 +45,7 @@ def get_variable(type_var: str): if type_var == 'pytree': pytree = {'k1': 1.0, 'k2': (2.0, 3.0), 'k3': jnp.asarray([4.0, 5.0])} return jax.tree.map(jnp.asarray, pytree) + raise ValueError(f'invalid type_var {type_var}') class RandomTest(chex.TestCase): diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index 98528945c..de1c22f43 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -152,10 +152,9 @@ def tree_map_params( def map_params(maybe_placeholder_value, value): if isinstance(maybe_placeholder_value, _ParamsPlaceholder): return jax.tree.map(f, value, *rest, is_leaf=is_leaf) - elif transform_non_params is not None: + if transform_non_params is not None: return transform_non_params(value) - else: - return value + return value return jax.tree.map( map_params, @@ -377,10 +376,9 @@ def tree_get( ) if len(found_values_with_path) > 1: raise KeyError(f"Found multiple values for '{key}' in {tree}.") - elif not found_values_with_path: + if not found_values_with_path: return default - else: - return found_values_with_path[0][1] + return found_values_with_path[0][1] def tree_set( @@ -472,8 +470,7 @@ def tree_set( f"Found no values matching '{key}' given the filtering operation in" f" {tree}" ) - else: - raise KeyError(f"Found no values matching '{key}' in {tree}") + raise KeyError(f"Found no values matching '{key}' in {tree}") has_any_key = functools.partial(_node_has_keys, keys=tuple(kwargs.keys())) @@ -487,35 +484,30 @@ def _replace(path: _KeyPath, node: Any) -> Any: ): # The node itself is a named tuple we wanted to replace return kwargs[node.__class__.__name__] - else: - # The node contains one of the keys we want to replace - children_with_path = _get_children_with_path(path, node) - new_children_with_keys = {} - for child_path, child in children_with_path: - # Scan each child of that node - key = _get_key(child_path[-1]) - if key in kwargs and ( - filtering is None or filtering(child_path, child) - ): - # If the child matches a given key given the filtering operation - # replaces with the new value - new_children_with_keys.update({key: kwargs[key]}) + # The node contains one of the keys we want to replace + children_with_path = _get_children_with_path(path, node) + new_children_with_keys = {} + for child_path, child in children_with_path: + # Scan each child of that node + key = _get_key(child_path[-1]) + if key in kwargs and ( + filtering is None or filtering(child_path, child) + ): + # If the child matches a given key given the filtering operation + # replaces with the new value + new_children_with_keys.update({key: kwargs[key]}) + else: + if isinstance(child, (dict, list, tuple)): + # If the child is itself a pytree, further search in the child to + # replace the given value + new_children_with_keys.update({key: _replace(child_path, child)}) else: - if ( - isinstance(child, tuple) - or isinstance(child, dict) - or isinstance(child, list) - ): - # If the child is itself a pytree, further search in the child to - # replace the given value - new_children_with_keys.update({key: _replace(child_path, child)}) - else: - # If the child is just a leaf that does not contain the key or - # satisfies the filtering operation, just return the child. - new_children_with_keys.update({key: child}) - return _set_children(node, new_children_with_keys) - else: - return node + # If the child is just a leaf that does not contain the key or + # satisfies the filtering operation, just return the child. + new_children_with_keys.update({key: child}) + return _set_children(node, new_children_with_keys) + + return node # Mimics jax.tree_util.tree_map_with_path(_replace, tree, is_leaf) # except that the paths we consider can contain NamedTupleKeys @@ -667,12 +659,11 @@ def _node_has_keys(node: Any, keys: tuple[Any, ...]) -> bool: """ if _is_named_tuple(node) and any(key in node._fields for key in keys): return True - elif _is_named_tuple(node) and (node.__class__.__name__ in keys): + if _is_named_tuple(node) and (node.__class__.__name__ in keys): return True - elif isinstance(node, dict) and any(key in node for key in keys): + if isinstance(node, dict) and any(key in node for key in keys): return True - else: - return False + return False def _flatten_to_key( @@ -696,13 +687,11 @@ def _flatten_to_key( if _is_named_tuple(node): if key == node.__class__.__name__: return (path, node) - else: - path_to_key = (*path, NamedTupleKey(node.__class__.__name__, key)) - return (path_to_key, getattr(node, key)) - elif isinstance(node, dict) and key in node: + path_to_key = (*path, NamedTupleKey(node.__class__.__name__, key)) + return (path_to_key, getattr(node, key)) + if isinstance(node, dict) and key in node: return (*path, jax.tree_util.DictKey(key)), node[key] - else: - return path, node + return path, node def _get_children_with_path( @@ -732,15 +721,14 @@ def _get_children_with_path( ) for field in node._fields ] - elif isinstance(node, dict): + if isinstance(node, dict): return [ ((*path, jax.tree_util.DictKey(key)), value) for key, value in node.items() ] - else: - raise ValueError( - f"Subtree must be a dict or a NamedTuple. Got {type(node)}" - ) + raise ValueError( + f"Subtree must be a dict or a NamedTuple. Got {type(node)}" + ) def _set_children(node: Any, children_with_keys: dict[Any, Any]) -> Any: @@ -762,12 +750,11 @@ def _set_children(node: Any, children_with_keys: dict[Any, Any]) -> Any: """ if _is_named_tuple(node): return node._replace(**children_with_keys) - elif isinstance(node, dict): + if isinstance(node, dict): return children_with_keys - else: - raise ValueError( - f"Subtree must be a dict or a NamedTuple. Got {type(node)}" - ) + raise ValueError( + f"Subtree must be a dict or a NamedTuple. Got {type(node)}" + ) def _get_key(key: _KeyEntry) -> Union[int, str]: diff --git a/optax/tree_utils/_state_utils_test.py b/optax/tree_utils/_state_utils_test.py index 04e60f122..e80475fe0 100644 --- a/optax/tree_utils/_state_utils_test.py +++ b/optax/tree_utils/_state_utils_test.py @@ -307,16 +307,17 @@ def test_tree_get_all_with_path(self): self.assertEqual(found_values, expected_result) with self.subTest('Test with optional filtering'): - state = dict(hparams=dict(learning_rate=1.0), learning_rate='foo') + state = {"hparams": {"learning_rate": 1.0}, "learning_rate": 'foo'} # Without filtering two values are found found_values = _state_utils.tree_get_all_with_path(state, 'learning_rate') self.assertLen(found_values, 2) # With filtering only the float entry is returned - filtering = lambda _, value: isinstance(value, float) found_values = _state_utils.tree_get_all_with_path( - state, 'learning_rate', filtering=filtering + state, + 'learning_rate', + filtering=lambda _, value: isinstance(value, float), ) self.assertLen(found_values, 1) expected_result = [( @@ -327,10 +328,11 @@ def test_tree_get_all_with_path(self): with self.subTest('Test to get a subtree (here hyperparams_states)'): opt = _inject.inject_hyperparams(alias.sgd)(learning_rate=lambda x: x) - filtering = lambda _, value: isinstance(value, tuple) state = opt.init(params) found_values = _state_utils.tree_get_all_with_path( - state, 'learning_rate', filtering=filtering + state, + 'learning_rate', + filtering=lambda _, value: isinstance(value, tuple), ) expected_result = [( ( @@ -346,7 +348,7 @@ def test_tree_get_all_with_path(self): self.assertEqual(found_values, expected_result) with self.subTest('Test with nested tree containing a key'): - tree = dict(a=dict(a=1.0)) + tree = {"a": {"a": 1.0}} found_values = _state_utils.tree_get_all_with_path(tree, 'a') expected_result = [ ((jtu.DictKey(key='a'),), {'a': 1.0}), @@ -384,12 +386,13 @@ def test_tree_get(self): learning_rate=lambda x: 1 / (x + 1) ) state = opt.init(params) - filtering = lambda _, value: isinstance(value, jnp.ndarray) @jax.jit def get_learning_rate(state): return _state_utils.tree_get( - state, 'learning_rate', filtering=filtering + state, + 'learning_rate', + filtering=lambda _, value: isinstance(value, jnp.ndarray) ) for i in range(4): @@ -399,14 +402,17 @@ def get_learning_rate(state): self.assertEqual(lr, 1 / (i + 1)) with self.subTest('Test with optional filtering'): - state = dict(hparams=dict(learning_rate=1.0), learning_rate='foo') + state = {"hparams": {"learning_rate": 1.0}, "learning_rate": 'foo'} # Without filtering raises an error self.assertRaises(KeyError, _state_utils.tree_get, state, 'learning_rate') # With filtering, fetches the float entry - filtering = lambda path, value: isinstance(value, float) - lr = _state_utils.tree_get(state, 'learning_rate', filtering=filtering) + lr = _state_utils.tree_get( + state, + 'learning_rate', + filtering=lambda _, value: isinstance(value, float), + ) self.assertEqual(lr, 1.0) with self.subTest('Test filtering for specific state'): @@ -415,10 +421,9 @@ def get_learning_rate(state): ) state = opt.init(params) - filtering = ( - lambda path, _: isinstance(path[-1], _state_utils.NamedTupleKey) + def filtering(path, _): + return isinstance(path[-1], _state_utils.NamedTupleKey) \ and path[-1].tuple_name == 'ScaleByAdamState' - ) count = _state_utils.tree_get(state, 'count', filtering=filtering) self.assertEqual(count, jnp.asarray(0, dtype=jnp.dtype('int32'))) @@ -490,9 +495,12 @@ def set_learning_rate(state, lr): self.assertEqual(value, 2.0) with self.subTest('Test with optional filtering'): - state = dict(hparams=dict(learning_rate=1.0), learning_rate='foo') - filtering = lambda _, value: isinstance(value, float) - new_state = _state_utils.tree_set(state, filtering, learning_rate=0.5) + state = {"hparams": {"learning_rate": 1.0}, "learning_rate": 'foo'} + new_state = _state_utils.tree_set( + state, + lambda _, value: isinstance(value, float), + learning_rate=0.5, + ) found_values = _state_utils.tree_get_all_with_path( new_state, 'learning_rate' ) @@ -503,17 +511,23 @@ def set_learning_rate(state, lr): self.assertEqual(found_values, expected_result) with self.subTest('Test with nested trees and filtering'): - tree = dict(a=dict(a=1.0), b=dict(a=1)) - filtering = lambda _, value: isinstance(value, float) - new_tree = _state_utils.tree_set(tree, filtering, a=2.0) - expected_result = dict(a=dict(a=2.0), b=dict(a=1)) + tree = {"a": {"a": 1.0}, "b": {"a": 1}} + new_tree = _state_utils.tree_set( + tree, + lambda _, value: isinstance(value, float), + a=2.0, + ) + expected_result = {"a": {"a": 2.0}, "b": {"a": 1}} self.assertEqual(new_tree, expected_result) with self.subTest('Test setting a subtree'): - tree = dict(a=dict(a=1.0), b=dict(a=1)) - filtering = lambda _, value: isinstance(value, dict) - new_tree = _state_utils.tree_set(tree, filtering, a=dict(c=0.0)) - expected_result = dict(a=dict(c=0.0), b=dict(a=1)) + tree = {"a": {"a": 1.0}, "b": {"a": 1}} + new_tree = _state_utils.tree_set( + tree, + lambda _, value: isinstance(value, dict), + a={"c": 0.0}, + ) + expected_result = {"a": {"c": 0.0}, "b": {"a": 1}} self.assertEqual(new_tree, expected_result) with self.subTest('Test setting a specific state'): @@ -522,10 +536,9 @@ def set_learning_rate(state, lr): ) state = opt.init(params) - filtering = ( - lambda path, _: isinstance(path[-1], _state_utils.NamedTupleKey) + def filtering(path, _): + return isinstance(path[-1], _state_utils.NamedTupleKey) \ and path[-1].tuple_name == 'ScaleByAdamState' - ) new_state = _state_utils.tree_set(state, filtering, count=jnp.array(42)) expected_result = ( diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 37e0ea27c..bdcd97f82 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -201,8 +201,7 @@ def tree_l2_norm(tree: Any, squared: bool = False) -> chex.Numeric: sqnorm = tree_sum(squared_tree) if squared: return sqnorm - else: - return jnp.sqrt(sqnorm) + return jnp.sqrt(sqnorm) def tree_l1_norm(tree: Any) -> chex.Numeric: @@ -330,13 +329,13 @@ def tree_update_moment_per_elem_norm(updates, moments, decay, order): def orderth_norm(g): if jnp.isrealobj(g): - return g**order - else: - half_order = order / 2 - # JAX generates different HLO for int and float `order` - if half_order.is_integer(): - half_order = int(half_order) - return numerics.abs_sq(g) ** half_order + return g ** order + + half_order = order / 2 + # JAX generates different HLO for int and float `order` + if half_order.is_integer(): + half_order = int(half_order) + return numerics.abs_sq(g) ** half_order return jax.tree.map( lambda g, t: ( diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index c1908e620..b6082ea5c 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -44,14 +44,14 @@ def setUp(self): self.tree_a_dict_jax = jax.tree.map(jnp.array, self.tree_a_dict) self.tree_b_dict_jax = jax.tree.map(jnp.array, self.tree_b_dict) - self.data = dict( - tree_a=self.tree_a, - tree_b=self.tree_b, - tree_a_dict=self.tree_a_dict, - tree_b_dict=self.tree_b_dict, - array_a=self.array_a, - array_b=self.array_b, - ) + self.data = { + "tree_a": self.tree_a, + "tree_b": self.tree_b, + "tree_a_dict": self.tree_a_dict, + "tree_b_dict": self.tree_b_dict, + "array_a": self.array_a, + "array_b": self.array_b, + } def test_tree_add(self): expected = self.array_a + self.array_b @@ -268,24 +268,24 @@ def test_empty_tree_reduce(self): self.assertEqual(tu.tree_vdot(tree, tree), 0) @parameterized.named_parameters( - dict( - testcase_name='tree_add_scalar_mul', - operation=lambda m: tu.tree_add_scalar_mul(None, 1, m), - ), - dict( - testcase_name='tree_update_moment', - operation=lambda m: tu.tree_update_moment(None, m, 1, 1), - ), - dict( - testcase_name='tree_update_infinity_moment', - operation=lambda m: tu.tree_update_infinity_moment(None, m, 1, 1), - ), - dict( - testcase_name='tree_update_moment_per_elem_norm', - operation=lambda m: tu.tree_update_moment_per_elem_norm( + { + "testcase_name": 'tree_add_scalar_mul', + "operation": lambda m: tu.tree_add_scalar_mul(None, 1, m), + }, + { + "testcase_name": 'tree_update_moment', + "operation": lambda m: tu.tree_update_moment(None, m, 1, 1), + }, + { + "testcase_name": 'tree_update_infinity_moment', + "operation": lambda m: tu.tree_update_infinity_moment(None, m, 1, 1), + }, + { + "testcase_name": 'tree_update_moment_per_elem_norm', + "operation": lambda m: tu.tree_update_moment_per_elem_norm( None, m, 1, 1 ), - ), + }, ) def test_none_arguments(self, operation): m = jnp.array([1.0, 2.0, 3.0]) diff --git a/pyproject.toml b/pyproject.toml index 40ddd826b..4943090bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,39 +116,23 @@ ignore = [ [tool.pylint.messages_control] disable = [ "bad-indentation", - "invalid-name", "missing-class-docstring", "missing-function-docstring", - "unnecessary-lambda-assignment", "no-member", - "use-dict-literal", "too-many-locals", "too-many-arguments", "too-many-positional-arguments", - "unused-argument", - "unused-import", - "line-too-long", "no-value-for-parameter", - "wrong-import-order", - "no-else-return", "too-many-lines", "missing-module-docstring", "too-few-public-methods", - "abstract-method", - "used-before-assignment", + "too-many-public-methods", "too-many-statements", "protected-access", "too-many-instance-attributes", - "inconsistent-return-statements", - "trailing-newlines", - "consider-using-generator", - "no-else-raise", - "unknown-option-value", "not-callable", - "consider-merging-isinstance", "duplicate-code", - "import-error", - "overridden-final-method", + "unknown-option-value", ] [tool.pylint.similarities] From 70a924168a286fd572f0342a3e4dbf9a4337d9ea Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 20 Nov 2024 21:15:20 -0500 Subject: [PATCH 3/3] Restore original test.sh and .pylintrc. --- .flake8 | 14 - .pylintrc | 400 ++++++++++++++++++ optax/_src/alias_test.py | 128 +++--- optax/_src/float64_test.py | 34 +- optax/_src/linear_algebra_test.py | 8 +- optax/_src/linesearch_test.py | 44 +- optax/_src/update_test.py | 6 +- optax/_src/utils_test.py | 2 +- optax/contrib/_common_test.py | 104 ++--- optax/contrib/_sam_test.py | 8 +- optax/losses/_classification_test.py | 152 +++---- optax/losses/_regression_test.py | 10 +- optax/monte_carlo/control_variates_test.py | 90 ++-- .../stochastic_gradient_estimators_test.py | 4 +- optax/perturbations/_make_pert_test.py | 4 +- optax/projections/_projections_test.py | 12 +- optax/schedules/_inject_test.py | 4 +- optax/schedules/_schedule_test.py | 10 +- optax/transforms/_accumulation.py | 6 +- optax/transforms/_accumulation_test.py | 26 +- optax/tree_utils/_casting_test.py | 8 +- optax/tree_utils/_random_test.py | 10 +- optax/tree_utils/_state_utils_test.py | 18 +- optax/tree_utils/_tree_math_test.py | 28 +- test.sh | 82 ++-- 25 files changed, 803 insertions(+), 409 deletions(-) delete mode 100644 .flake8 create mode 100644 .pylintrc diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 7e9b93fd0..000000000 --- a/.flake8 +++ /dev/null @@ -1,14 +0,0 @@ -[flake8] -select = - E9, - F63, - F7, - F82, - E225, - E251, -show-source = true -statistics = true -exclude = - build, - dist, - test_venv diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 000000000..b26aeee4f --- /dev/null +++ b/.pylintrc @@ -0,0 +1,400 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MAIN] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=R, + abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat, + import-error, + import-self, + import-star-module-level, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=12 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google +# projects (like TensorFlow). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs +disable=unnecessary-lambda-assignment,no-value-for-parameter,use-dict-literal diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 22fedf59d..265d40073 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -47,50 +47,50 @@ _OPTIMIZERS_UNDER_TEST = ( - {"opt_name": 'sgd', "opt_kwargs": {"learning_rate": 1e-3, "momentum": 0.9}}, - {"opt_name": 'adadelta', "opt_kwargs": {"learning_rate": 0.1}}, - {"opt_name": 'adafactor', "opt_kwargs": {"learning_rate": 5e-3}}, - {"opt_name": 'adagrad', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'adam', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'adamw', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'adamax', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'adamaxw', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'adan', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'amsgrad', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'lars', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'lamb', "opt_kwargs": {"learning_rate": 1e-3}}, + {'opt_name': 'sgd', 'opt_kwargs': {'learning_rate': 1e-3, 'momentum': 0.9}}, + {'opt_name': 'adadelta', 'opt_kwargs': {'learning_rate': 0.1}}, + {'opt_name': 'adafactor', 'opt_kwargs': {'learning_rate': 5e-3}}, + {'opt_name': 'adagrad', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'adam', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'adamw', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'adamax', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'adamaxw', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'adan', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'amsgrad', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'lars', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'lamb', 'opt_kwargs': {'learning_rate': 1e-3}}, { - "opt_name": 'lion', - "opt_kwargs": {"learning_rate": 1e-2, "weight_decay": 1e-4}, + 'opt_name': 'lion', + 'opt_kwargs': {'learning_rate': 1e-2, 'weight_decay': 1e-4}, }, - {"opt_name": 'nadam', "opt_kwargs": {"learning_rate": 1e-2}}, - {"opt_name": 'nadamw', "opt_kwargs": {"learning_rate": 1e-2}}, + {'opt_name': 'nadam', 'opt_kwargs': {'learning_rate': 1e-2}}, + {'opt_name': 'nadamw', 'opt_kwargs': {'learning_rate': 1e-2}}, { - "opt_name": 'noisy_sgd', - "opt_kwargs": {"learning_rate": 1e-3, "eta": 1e-4}, + 'opt_name': 'noisy_sgd', + 'opt_kwargs': {'learning_rate': 1e-3, 'eta': 1e-4}, }, - {"opt_name": 'novograd', "opt_kwargs": {"learning_rate": 1e-3}}, + {'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1e-3}}, { - "opt_name": 'optimistic_gradient_descent', - "opt_kwargs": {"learning_rate": 2e-3, "alpha": 0.7, "beta": 0.1}, + 'opt_name': 'optimistic_gradient_descent', + 'opt_kwargs': {'learning_rate': 2e-3, 'alpha': 0.7, 'beta': 0.1}, }, { - "opt_name": 'optimistic_adam', - "opt_kwargs": {"learning_rate": 2e-3}, + 'opt_name': 'optimistic_adam', + 'opt_kwargs': {'learning_rate': 2e-3}, }, - {"opt_name": 'rmsprop', "opt_kwargs": {"learning_rate": 5e-3}}, + {'opt_name': 'rmsprop', 'opt_kwargs': {'learning_rate': 5e-3}}, { - "opt_name": 'rmsprop', - "opt_kwargs": {"learning_rate": 5e-3, "momentum": 0.9}, + 'opt_name': 'rmsprop', + 'opt_kwargs': {'learning_rate': 5e-3, 'momentum': 0.9}, }, - {"opt_name": 'sign_sgd', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'fromage', "opt_kwargs": {"learning_rate": 5e-3}}, - {"opt_name": 'adabelief', "opt_kwargs": {"learning_rate": 1e-2}}, - {"opt_name": 'radam', "opt_kwargs": {"learning_rate": 5e-3}}, - {"opt_name": 'rprop', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'sm3', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'yogi', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'polyak_sgd', "opt_kwargs": {"max_learning_rate": 1.0}}, + {'opt_name': 'sign_sgd', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'fromage', 'opt_kwargs': {'learning_rate': 5e-3}}, + {'opt_name': 'adabelief', 'opt_kwargs': {'learning_rate': 1e-2}}, + {'opt_name': 'radam', 'opt_kwargs': {'learning_rate': 5e-3}}, + {'opt_name': 'rprop', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'sm3', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'yogi', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'polyak_sgd', 'opt_kwargs': {'max_learning_rate': 1.0}}, ) @@ -529,41 +529,41 @@ def zakharov(x, xnp): return answer problems = { - "rosenbrock": { - "fun": lambda x: rosenbrock(x, jnp), - "numpy_fun": lambda x: rosenbrock(x, np), - "init": np.zeros(2), - "minimum": 0.0, - "minimizer": np.ones(2), + 'rosenbrock': { + 'fun': lambda x: rosenbrock(x, jnp), + 'numpy_fun': lambda x: rosenbrock(x, np), + 'init': np.zeros(2), + 'minimum': 0.0, + 'minimizer': np.ones(2), }, - "himmelblau": { - "fun": himmelblau, - "numpy_fun": himmelblau, - "init": np.ones(2), - "minimum": 0.0, + 'himmelblau': { + 'fun': himmelblau, + 'numpy_fun': himmelblau, + 'init': np.ones(2), + 'minimum': 0.0, # himmelblau has actually multiple minimizers, we simply consider one. - "minimizer": np.array([3.0, 2.0]), + 'minimizer': np.array([3.0, 2.0]), }, - "matyas": { - "fun": matyas, - "numpy_fun": matyas, - "init": np.ones(2) * 6.0, - "minimum": 0.0, - "minimizer": np.zeros(2), + 'matyas': { + 'fun': matyas, + 'numpy_fun': matyas, + 'init': np.ones(2) * 6.0, + 'minimum': 0.0, + 'minimizer': np.zeros(2), }, - "eggholder": { - "fun": lambda x: eggholder(x, jnp), - "numpy_fun": lambda x: eggholder(x, np), - "init": np.ones(2) * 6.0, - "minimum": -959.6407, - "minimizer": np.array([512.0, 404.22319]), + 'eggholder': { + 'fun': lambda x: eggholder(x, jnp), + 'numpy_fun': lambda x: eggholder(x, np), + 'init': np.ones(2) * 6.0, + 'minimum': -959.6407, + 'minimizer': np.array([512.0, 404.22319]), }, - "zakharov": { - "fun": lambda x: zakharov(x, jnp), - "numpy_fun": lambda x: zakharov(x, np), - "init": np.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e3]), - "minimum": 0.0, - "minimizer": np.zeros(6), + 'zakharov': { + 'fun': lambda x: zakharov(x, jnp), + 'numpy_fun': lambda x: zakharov(x, np), + 'init': np.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e3]), + 'minimum': 0.0, + 'minimizer': np.zeros(6), }, } return problems[name] diff --git a/optax/_src/float64_test.py b/optax/_src/float64_test.py index 3641014b4..1eadc1c09 100644 --- a/optax/_src/float64_test.py +++ b/optax/_src/float64_test.py @@ -28,37 +28,37 @@ ALL_MODULES = [ ('identity', base.identity, {}), - ('clip', clipping.clip, {"max_delta": 1.0}), - ('clip_by_global_norm', clipping.clip_by_global_norm, {"max_norm": 1.0}), - ('trace', transform.trace, {"decay": 0.5, "nesterov": False}), - ('trace_with_nesterov', transform.trace, {"decay": 0.5, "nesterov": True}), + ('clip', clipping.clip, {'max_delta': 1.0}), + ('clip_by_global_norm', clipping.clip_by_global_norm, {'max_norm': 1.0}), + ('trace', transform.trace, {'decay': 0.5, 'nesterov': False}), + ('trace_with_nesterov', transform.trace, {'decay': 0.5, 'nesterov': True}), ('scale_by_rss', transform.scale_by_rss, {}), ('scale_by_rms', transform.scale_by_rms, {}), ('scale_by_stddev', transform.scale_by_stddev, {}), ('adam', transform.scale_by_adam, {}), - ('scale', transform.scale, {"step_size": 3.0}), + ('scale', transform.scale, {'step_size': 3.0}), ( 'add_decayed_weights', transform.add_decayed_weights, - {"weight_decay": 0.1}, + {'weight_decay': 0.1}, ), ( 'scale_by_schedule', transform.scale_by_schedule, - {"step_size_fn": lambda x: x * 0.1}, + {'step_size_fn': lambda x: x * 0.1}, ), ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), - ('add_noise', transform.add_noise, {"eta": 1.0, "gamma": 0.1, "seed": 42}), + ('add_noise', transform.add_noise, {'eta': 1.0, 'gamma': 0.1, 'seed': 42}), ('apply_every_k', transform.apply_every, {}), - ('adagrad', alias.adagrad, {"learning_rate": 0.1}), - ('adam', alias.adam, {"learning_rate": 0.1}), - ('adamw', alias.adamw, {"learning_rate": 0.1}), - ('fromage', alias.fromage, {"learning_rate": 0.1}), - ('lamb', alias.lamb, {"learning_rate": 0.1}), - ('noisy_sgd', alias.noisy_sgd, {"learning_rate": 0.1}), - ('rmsprop', alias.rmsprop, {"learning_rate": 0.1}), - ('sgd', alias.sgd, {"learning_rate": 0.1}), - ('sign_sgd', alias.sgd, {"learning_rate": 0.1}), + ('adagrad', alias.adagrad, {'learning_rate': 0.1}), + ('adam', alias.adam, {'learning_rate': 0.1}), + ('adamw', alias.adamw, {'learning_rate': 0.1}), + ('fromage', alias.fromage, {'learning_rate': 0.1}), + ('lamb', alias.lamb, {'learning_rate': 0.1}), + ('noisy_sgd', alias.noisy_sgd, {'learning_rate': 0.1}), + ('rmsprop', alias.rmsprop, {'learning_rate': 0.1}), + ('sgd', alias.sgd, {'learning_rate': 0.1}), + ('sign_sgd', alias.sgd, {'learning_rate': 0.1}), ] diff --git a/optax/_src/linear_algebra_test.py b/optax/_src/linear_algebra_test.py index d9ee9cd97..bcb0821ce 100644 --- a/optax/_src/linear_algebra_test.py +++ b/optax/_src/linear_algebra_test.py @@ -48,8 +48,8 @@ class LinearAlgebraTest(chex.TestCase): def test_global_norm(self): flat_updates = jnp.array([2.0, 4.0, 3.0, 5.0], dtype=jnp.float32) nested_updates = { - "a": jnp.array([2.0, 4.0], dtype=jnp.float32), - "b": jnp.array([3.0, 5.0], dtype=jnp.float32), + 'a': jnp.array([2.0, 4.0], dtype=jnp.float32), + 'b': jnp.array([3.0, 5.0], dtype=jnp.float32), } np.testing.assert_array_equal( jnp.sqrt(jnp.sum(flat_updates**2)), @@ -79,8 +79,8 @@ def test_power_iteration_cond_fun(self, dim=6): @chex.all_variants @parameterized.parameters( - {"implicit": True}, - {"implicit": False}, + {'implicit': True}, + {'implicit': False}, ) def test_power_iteration(self, implicit, dim=6, tol=1e-3, num_iters=100): """Test power_iteration by comparing to numpy.linalg.eigh.""" diff --git a/optax/_src/linesearch_test.py b/optax/_src/linesearch_test.py index e315bc9fc..5f617466e 100644 --- a/optax/_src/linesearch_test.py +++ b/optax/_src/linesearch_test.py @@ -72,14 +72,14 @@ def zakharov(x): return sum1 + sum2**2 + sum2**4 problems = { - "polynomial": {"fn": polynomial, "input_shape": ()}, - "exponential": {"fn": exponential, "input_shape": ()}, - "sinusoidal": {"fn": sinusoidal, "input_shape": ()}, - "rosenbrock": {"fn": rosenbrock, "input_shape": (16,)}, - "himmelblau": {"fn": himmelblau, "input_shape": (2,)}, - "matyas": {"fn": matyas, "input_shape": (2,)}, - "eggholder": {"fn": eggholder, "input_shape": (2,)}, - "zakharov": {"fn": zakharov, "input_shape": (6,)}, + 'polynomial': {'fn': polynomial, 'input_shape': ()}, + 'exponential': {'fn': exponential, 'input_shape': ()}, + 'sinusoidal': {'fn': sinusoidal, 'input_shape': ()}, + 'rosenbrock': {'fn': rosenbrock, 'input_shape': (16,)}, + 'himmelblau': {'fn': himmelblau, 'input_shape': (2,)}, + 'matyas': {'fn': matyas, 'input_shape': (2,)}, + 'eggholder': {'fn': eggholder, 'input_shape': (2,)}, + 'zakharov': {'fn': zakharov, 'input_shape': (6,)}, } return problems[name] @@ -159,11 +159,11 @@ def test_linesearch( descent_dir = -jax.grad(fn)(init_params) opt_args = { - "max_backtracking_steps": 50, - "slope_rtol": slope_rtol, - "increase_factor": increase_factor, - "atol": atol, - "rtol": rtol, + 'max_backtracking_steps': 50, + 'slope_rtol': slope_rtol, + 'increase_factor': increase_factor, + 'atol': atol, + 'rtol': rtol, } solver = combine.chain( @@ -376,9 +376,9 @@ def _check_linesearch_conditions( slope_init = otu.tree_vdot(updates, grad_init) slope_final = otu.tree_vdot(updates, grad_final) default_opt_args = { - "slope_rtol": 1e-4, - "curv_rtol": 0.9, - "tol": 0.0, + 'slope_rtol': 1e-4, + 'curv_rtol': 0.9, + 'tol': 0.0, } opt_args = default_opt_args | opt_args slope_rtol, curv_rtol, tol = ( @@ -462,11 +462,11 @@ def test_linesearch(self, problem_name: str, seed: int): init_updates = -precond_vec * jax.grad(fn)(init_params) opt_args = { - "max_linesearch_steps": 30, - "slope_rtol": slope_rtol, - "curv_rtol": curv_rtol, - "tol": tol, - "max_learning_rate": None, + 'max_linesearch_steps': 30, + 'slope_rtol': slope_rtol, + 'curv_rtol': curv_rtol, + 'tol': tol, + 'max_learning_rate': None, } opt = _linesearch.scale_by_zoom_linesearch(**opt_args) @@ -606,7 +606,7 @@ def fn(x): ) s = otu.tree_get(final_state, 'learning_rate') self._check_linesearch_conditions( - fn, w, u, final_params, final_state, {"curv_rtol": curv_rtol} + fn, w, u, final_params, final_state, {'curv_rtol': curv_rtol} ) self.assertGreaterEqual(s, 30.0) diff --git a/optax/_src/update_test.py b/optax/_src/update_test.py index d9a2f270c..bc76f6648 100644 --- a/optax/_src/update_test.py +++ b/optax/_src/update_test.py @@ -77,10 +77,10 @@ def test_periodic_update(self): chex.assert_trees_all_close(params_2, new_params, atol=1e-10, rtol=1e-5) @parameterized.named_parameters( - {"testcase_name": 'apply_updates', "operation": update.apply_updates}, + {'testcase_name': 'apply_updates', 'operation': update.apply_updates}, { - "testcase_name": 'incremental_update', - "operation": lambda x, y: update.incremental_update(x, y, 1), + 'testcase_name': 'incremental_update', + 'operation': lambda x, y: update.incremental_update(x, y, 1), }, ) def test_none_argument(self, operation): diff --git a/optax/_src/utils_test.py b/optax/_src/utils_test.py index 1c96e6784..218d8fc6b 100644 --- a/optax/_src/utils_test.py +++ b/optax/_src/utils_test.py @@ -43,7 +43,7 @@ def fn(inputs): outputs = jax.tree.map(lambda x: x**2, outputs) return sum(jax.tree.leaves(outputs)) - inputs = {"a": -1.0, "b": {"c": (2.0,), "d": 0.0}} + inputs = {'a': -1.0, 'b': {'c': (2.0,), 'd': 0.0}} grad = jax.grad(fn) grads = grad(inputs) diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index 720c77c98..ac2e5cab5 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -36,22 +36,22 @@ # Testing contributions coded as GradientTransformations _MAIN_OPTIMIZERS_UNDER_TEST = [ - {"opt_name": 'acprop', "opt_kwargs": {"learning_rate": 1e-3}}, - {"opt_name": 'cocob', "opt_kwargs": {}}, - {"opt_name": 'cocob', "opt_kwargs": {"weight_decay": 1e-2}}, - {"opt_name": 'dadapt_adamw', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'dog', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'dowg', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'momo', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'momo_adam', "opt_kwargs": {"learning_rate": 1e-1}}, - {"opt_name": 'prodigy', "opt_kwargs": {"learning_rate": 1e-1}}, + {'opt_name': 'acprop', 'opt_kwargs': {'learning_rate': 1e-3}}, + {'opt_name': 'cocob', 'opt_kwargs': {}}, + {'opt_name': 'cocob', 'opt_kwargs': {'weight_decay': 1e-2}}, + {'opt_name': 'dadapt_adamw', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'dog', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'dowg', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'momo', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'momo_adam', 'opt_kwargs': {'learning_rate': 1e-1}}, + {'opt_name': 'prodigy', 'opt_kwargs': {'learning_rate': 1e-1}}, { - "opt_name": 'schedule_free_sgd', - "opt_kwargs": {"learning_rate": 1e-2, "warmup_steps": 5000}, + 'opt_name': 'schedule_free_sgd', + 'opt_kwargs': {'learning_rate': 1e-2, 'warmup_steps': 5000}, }, { - "opt_name": 'schedule_free_adamw', - "opt_kwargs": {"learning_rate": 1e-2, "warmup_steps": 5000}, + 'opt_name': 'schedule_free_adamw', + 'opt_kwargs': {'learning_rate': 1e-2, 'warmup_steps': 5000}, }, ] for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST: @@ -62,57 +62,63 @@ # (just with sgd as we just want the behavior of the wrapper) _MAIN_OPTIMIZERS_UNDER_TEST += [ { - "opt_name": 'sgd', - "opt_kwargs": {"learning_rate": 1e-1}, - "wrapper_name": 'mechanize', - "wrapper_kwargs": {"weight_decay": 0.0}, + 'opt_name': 'sgd', + 'opt_kwargs': {'learning_rate': 1e-1}, + 'wrapper_name': 'mechanize', + 'wrapper_kwargs': {'weight_decay': 0.0}, }, { - "opt_name": 'sgd', - "opt_kwargs": {"learning_rate": 1e-2}, - "wrapper_name": 'schedule_free', - "wrapper_kwargs": {"learning_rate": 1e-2}, + 'opt_name': 'sgd', + 'opt_kwargs': {'learning_rate': 1e-2}, + 'wrapper_name': 'schedule_free', + 'wrapper_kwargs': {'learning_rate': 1e-2}, }, { - "opt_name": 'sgd', - "opt_kwargs": {"learning_rate": 1e-3}, - "wrapper_name": 'reduce_on_plateau', - "wrapper_kwargs": {}, + 'opt_name': 'sgd', + 'opt_kwargs': {'learning_rate': 1e-3}, + 'wrapper_name': 'reduce_on_plateau', + 'wrapper_kwargs': {}, }, ] # Adding here instantiations of wrappers with any base optimizer _BASE_OPTIMIZERS = [ - {"opt_name": 'sgd', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'sgd', "opt_kwargs": {"learning_rate": 1.0, "momentum": 0.9}}, - {"opt_name": 'adam', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'adamw', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'adamax', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'adamaxw', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'adan', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'amsgrad', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'lamb', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'lion', "opt_kwargs": {"learning_rate": 1.0, "b1": 0.99}}, - {"opt_name": 'noisy_sgd', "opt_kwargs": {"learning_rate": 1.0, "eta": 1e-4}}, - {"opt_name": 'novograd', "opt_kwargs": {"learning_rate": 1.0}}, + {'opt_name': 'sgd', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'sgd', 'opt_kwargs': {'learning_rate': 1.0, 'momentum': 0.9}}, + {'opt_name': 'adam', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'adamw', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'adamax', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'adamaxw', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'adan', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'amsgrad', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'lamb', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'lion', 'opt_kwargs': {'learning_rate': 1.0, 'b1': 0.99}}, { - "opt_name": 'optimistic_gradient_descent', - "opt_kwargs": {"learning_rate": 1.0, "alpha": 0.7, "beta": 0.1}, + 'opt_name': 'noisy_sgd', + 'opt_kwargs': {'learning_rate': 1.0, 'eta': 1e-4}, }, - {"opt_name": 'rmsprop', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'rmsprop', "opt_kwargs": {"learning_rate": 1.0, "momentum": 0.9}}, - {"opt_name": 'adabelief', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'radam', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'sm3', "opt_kwargs": {"learning_rate": 1.0}}, - {"opt_name": 'yogi', "opt_kwargs": {"learning_rate": 1.0, "b1": 0.99}}, + {'opt_name': 'novograd', 'opt_kwargs': {'learning_rate': 1.0}}, + { + 'opt_name': 'optimistic_gradient_descent', + 'opt_kwargs': {'learning_rate': 1.0, 'alpha': 0.7, 'beta': 0.1}, + }, + {'opt_name': 'rmsprop', 'opt_kwargs': {'learning_rate': 1.0}}, + { + 'opt_name': 'rmsprop', + 'opt_kwargs': {'learning_rate': 1.0, 'momentum': 0.9}, + }, + {'opt_name': 'adabelief', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'radam', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'sm3', 'opt_kwargs': {'learning_rate': 1.0}}, + {'opt_name': 'yogi', 'opt_kwargs': {'learning_rate': 1.0, 'b1': 0.99}}, ] # TODO(harshm): make LARS and Fromage work with mechanic. _OTHER_OPTIMIZERS_UNDER_TEST = [ { - "opt_name": base_opt['opt_name'], - "opt_kwargs": base_opt['opt_kwargs'], - "wrapper_name": 'mechanize', - "wrapper_kwargs": {"weight_decay": 0.0}, + 'opt_name': base_opt['opt_name'], + 'opt_kwargs': base_opt['opt_kwargs'], + 'wrapper_name': 'mechanize', + 'wrapper_kwargs': {'weight_decay': 0.0}, } for base_opt in _BASE_OPTIMIZERS ] diff --git a/optax/contrib/_sam_test.py b/optax/contrib/_sam_test.py index 04167fe26..40bfc603c 100644 --- a/optax/contrib/_sam_test.py +++ b/optax/contrib/_sam_test.py @@ -27,11 +27,11 @@ from optax.tree_utils import _state_utils _BASE_OPTIMIZERS_UNDER_TEST = [ - {"base_opt_name": 'sgd', "base_opt_kwargs": {"learning_rate": 1e-3}}, + {'base_opt_name': 'sgd', 'base_opt_kwargs': {'learning_rate': 1e-3}}, ] _ADVERSARIAL_OPTIMIZERS_UNDER_TEST = [ - {"adv_opt_name": 'sgd', "adv_opt_kwargs": {"learning_rate": 1e-5}}, - {"adv_opt_name": 'adam', "adv_opt_kwargs": {"learning_rate": 1e-4}}, + {'adv_opt_name': 'sgd', 'adv_opt_kwargs': {'learning_rate': 1e-5}}, + {'adv_opt_name': 'adam', 'adv_opt_kwargs': {'learning_rate': 1e-4}}, ] @@ -79,7 +79,7 @@ def test_optimization( initial_params, final_params, get_updates = target(dtype) if opaque_mode: - update_kwargs = {"grad_fn": lambda p, _: get_updates(p)} + update_kwargs = {'grad_fn': lambda p, _: get_updates(p)} else: update_kwargs = {} diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index e44dc4fd5..448847b03 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -81,7 +81,7 @@ def test_gradient(self): order=1, ) - @parameterized.parameters({"size": 5}, {"size": 10}) + @parameterized.parameters({'size': 5}, {'size': 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.random.dirichlet(np.ones(size)) @@ -91,9 +91,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - {"axis": 0, "shape": [4, 5, 6]}, - {"axis": 1, "shape": [4, 5, 6]}, - {"axis": 2, "shape": [4, 5, 6]}, + {'axis': 0, 'shape': [4, 5, 6]}, + {'axis': 1, 'shape': [4, 5, 6]}, + {'axis': 2, 'shape': [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -230,9 +230,9 @@ def test_gradient(self): ) @parameterized.parameters( - {"axis": 0, "shape": [4, 5, 6]}, - {"axis": 1, "shape": [4, 5, 6]}, - {"axis": 2, "shape": [4, 5, 6]}, + {'axis': 0, 'shape': [4, 5, 6]}, + {'axis': 1, 'shape': [4, 5, 6]}, + {'axis': 2, 'shape': [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -252,54 +252,54 @@ class SigmoidCrossEntropyTest(parameterized.TestCase): @parameterized.parameters( { - "preds": np.array([-1e09, -1e-09]), - "labels": np.array([1.0, 0.0]), - "expected": 5e08, + 'preds': np.array([-1e09, -1e-09]), + 'labels': np.array([1.0, 0.0]), + 'expected': 5e08, }, { - "preds": np.array([-1e09, -1e-09]), - "labels": np.array([0.0, 1.0]), - "expected": 0.3465736, + 'preds': np.array([-1e09, -1e-09]), + 'labels': np.array([0.0, 1.0]), + 'expected': 0.3465736, }, { - "preds": np.array([1e09, 1e-09]), - "labels": np.array([1.0, 0.0]), - "expected": 0.3465736, + 'preds': np.array([1e09, 1e-09]), + 'labels': np.array([1.0, 0.0]), + 'expected': 0.3465736, }, { - "preds": np.array([1e09, 1e-09]), - "labels": np.array([0.0, 1.0]), - "expected": 5e08, + 'preds': np.array([1e09, 1e-09]), + 'labels': np.array([0.0, 1.0]), + 'expected': 5e08, }, { - "preds": np.array([-1e09, 1e-09]), - "labels": np.array([1.0, 0.0]), - "expected": 5e08, + 'preds': np.array([-1e09, 1e-09]), + 'labels': np.array([1.0, 0.0]), + 'expected': 5e08, }, { - "preds": np.array([-1e09, 1e-09]), - "labels": np.array([0.0, 1.0]), - "expected": 0.3465736, + 'preds': np.array([-1e09, 1e-09]), + 'labels': np.array([0.0, 1.0]), + 'expected': 0.3465736, }, { - "preds": np.array([1e09, -1e-09]), - "labels": np.array([1.0, 0.0]), - "expected": 0.3465736, + 'preds': np.array([1e09, -1e-09]), + 'labels': np.array([1.0, 0.0]), + 'expected': 0.3465736, }, { - "preds": np.array([1e09, -1e-09]), - "labels": np.array([0.0, 1.0]), - "expected": 5e08, + 'preds': np.array([1e09, -1e-09]), + 'labels': np.array([0.0, 1.0]), + 'expected': 5e08, }, { - "preds": np.array([0.0, 0.0]), - "labels": np.array([1.0, 0.0]), - "expected": 0.6931472, + 'preds': np.array([0.0, 0.0]), + 'labels': np.array([1.0, 0.0]), + 'expected': 0.6931472, }, { - "preds": np.array([0.0, 0.0]), - "labels": np.array([0.0, 1.0]), - "expected": 0.6931472, + 'preds': np.array([0.0, 0.0]), + 'labels': np.array([0.0, 1.0]), + 'expected': 0.6931472, }, ) def test_sigmoid_cross_entropy(self, preds, labels, expected): @@ -323,14 +323,14 @@ def setUp(self): @chex.all_variants @parameterized.parameters( - {"eps": 2, "expected": 4.5317}, - {"eps": 1, "expected": 3.7153}, - {"eps": -1, "expected": 2.0827}, - {"eps": 0, "expected": 2.8990}, - {"eps": -0.5, "expected": 2.4908}, - {"eps": 1.15, "expected": 3.8378}, - {"eps": 1.214, "expected": 3.8900}, - {"eps": 5.45, "expected": 7.3480}, + {'eps': 2, 'expected': 4.5317}, + {'eps': 1, 'expected': 3.7153}, + {'eps': -1, 'expected': 2.0827}, + {'eps': 0, 'expected': 2.8990}, + {'eps': -0.5, 'expected': 2.4908}, + {'eps': 1.15, 'expected': 3.8378}, + {'eps': 1.214, 'expected': 3.8900}, + {'eps': 5.45, 'expected': 7.3480}, ) def test_scalar(self, eps, expected): np.testing.assert_allclose( @@ -343,13 +343,13 @@ def test_scalar(self, eps, expected): @chex.all_variants @parameterized.parameters( - {"eps": 2, "expected": np.array([0.4823, 1.2567])}, - {"eps": 1, "expected": np.array([0.3261, 1.0407])}, - {"eps": 0, "expected": np.array([0.1698, 0.8247])}, - {"eps": -0.5, "expected": np.array([0.0917, 0.7168])}, - {"eps": 1.15, "expected": np.array([0.3495, 1.0731])}, - {"eps": 1.214, "expected": np.array([0.3595, 1.0870])}, - {"eps": 5.45, "expected": np.array([1.0211, 2.0018])}, + {'eps': 2, 'expected': np.array([0.4823, 1.2567])}, + {'eps': 1, 'expected': np.array([0.3261, 1.0407])}, + {'eps': 0, 'expected': np.array([0.1698, 0.8247])}, + {'eps': -0.5, 'expected': np.array([0.0917, 0.7168])}, + {'eps': 1.15, 'expected': np.array([0.3495, 1.0731])}, + {'eps': 1.214, 'expected': np.array([0.3595, 1.0870])}, + {'eps': 5.45, 'expected': np.array([1.0211, 2.0018])}, ) def test_batched(self, eps, expected): np.testing.assert_allclose( @@ -363,27 +363,27 @@ def test_batched(self, eps, expected): @chex.all_variants @parameterized.parameters( { - "logits": np.array( + 'logits': np.array( [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0], [0.134, 1.234, 3.235]] ), - "labels": np.array( + 'labels': np.array( [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2], [0.34, 0.33, 0.33]] ), }, { - "logits": np.array([[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]), - "labels": np.array([[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]), + 'logits': np.array([[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]), + 'labels': np.array([[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]), }, { - "logits": np.array( + 'logits': np.array( [[4.0, 2.0, 1.0, 0.134, 1.3515], [0.0, 5.0, 1.0, 0.5215, 5.616]] ), - "labels": np.array( + 'labels': np.array( [[0.5, 0.0, 0.0, 0.0, 0.5], [0.0, 0.12, 0.2, 0.56, 0.12]] ), }, - {"logits": np.array([1.89, 2.39]), "labels": np.array([0.34, 0.66])}, - {"logits": np.array([0.314]), "labels": np.array([1.0])}, + {'logits': np.array([1.89, 2.39]), 'labels': np.array([0.34, 0.66])}, + {'logits': np.array([0.314]), 'labels': np.array([1.0])}, ) def test_equals_to_cross_entropy_when_eps0(self, logits, labels): np.testing.assert_allclose( @@ -394,7 +394,7 @@ def test_equals_to_cross_entropy_when_eps0(self, logits, labels): atol=1e-4, ) - @parameterized.parameters({"size": 5}, {"size": 10}) + @parameterized.parameters({'size': 5}, {'size': 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.random.dirichlet(np.ones(size)) @@ -404,9 +404,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - {"axis": 0, "shape": [4, 5, 6]}, - {"axis": 1, "shape": [4, 5, 6]}, - {"axis": 2, "shape": [4, 5, 6]}, + {'axis': 0, 'shape': [4, 5, 6]}, + {'axis': 1, 'shape': [4, 5, 6]}, + {'axis': 2, 'shape': [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -567,7 +567,7 @@ def test_batched(self): atol=1e-4, ) - @parameterized.parameters({"size": 5}, {"size": 10}) + @parameterized.parameters({'size': 5}, {'size': 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.random.dirichlet(np.ones(size)) @@ -577,9 +577,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - {"axis": 0, "shape": [4, 5, 6]}, - {"axis": 1, "shape": [4, 5, 6]}, - {"axis": 2, "shape": [4, 5, 6]}, + {'axis': 0, 'shape': [4, 5, 6]}, + {'axis': 1, 'shape': [4, 5, 6]}, + {'axis': 2, 'shape': [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -673,7 +673,7 @@ def test_batched(self): atol=1e-4, ) - @parameterized.parameters({"size": 5}, {"size": 10}) + @parameterized.parameters({'size': 5}, {'size': 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.random.dirichlet(np.ones(size)) @@ -683,9 +683,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - {"axis": 0, "shape": [4, 5, 6]}, - {"axis": 1, "shape": [4, 5, 6]}, - {"axis": 2, "shape": [4, 5, 6]}, + {'axis': 0, 'shape': [4, 5, 6]}, + {'axis': 1, 'shape': [4, 5, 6]}, + {'axis': 2, 'shape': [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) @@ -733,7 +733,7 @@ def test_batched(self): atol=1e-4, ) - @parameterized.parameters({"size": 5}, {"size": 10}) + @parameterized.parameters({'size': 5}, {'size': 10}) def test_mask(self, size): preds = np.random.normal(size=size) targets = np.log(np.random.dirichlet(np.ones(size))) @@ -744,9 +744,9 @@ def test_mask(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - {"axis": 0, "shape": [4, 5, 6]}, - {"axis": 1, "shape": [4, 5, 6]}, - {"axis": 2, "shape": [4, 5, 6]}, + {'axis': 0, 'shape': [4, 5, 6]}, + {'axis': 1, 'shape': [4, 5, 6]}, + {'axis': 2, 'shape': [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) diff --git a/optax/losses/_regression_test.py b/optax/losses/_regression_test.py index 4bd8b0415..2c14c16f8 100644 --- a/optax/losses/_regression_test.py +++ b/optax/losses/_regression_test.py @@ -185,7 +185,7 @@ def test_batched_similarity(self): atol=1e-4, ) - @parameterized.parameters({"size": 5}, {"size": 10}) + @parameterized.parameters({'size': 5}, {'size': 10}) def test_mask_distance(self, size): preds = np.random.normal(size=size) targets = np.random.normal(size=size) @@ -194,7 +194,7 @@ def test_mask_distance(self, size): y = _regression.cosine_distance(preds, targets, where=mask) np.testing.assert_allclose(x, y, atol=1e-4) - @parameterized.parameters({"size": 5}, {"size": 10}) + @parameterized.parameters({'size': 5}, {'size': 10}) def test_mask_similarity(self, size): preds = np.random.normal(size=size) targets = np.random.normal(size=size) @@ -204,9 +204,9 @@ def test_mask_similarity(self, size): np.testing.assert_allclose(x, y, atol=1e-4) @parameterized.parameters( - {"axis": 0, "shape": [4, 5, 6]}, - {"axis": 1, "shape": [4, 5, 6]}, - {"axis": 2, "shape": [4, 5, 6]}, + {'axis': 0, 'shape': [4, 5, 6]}, + {'axis': 1, 'shape': [4, 5, 6]}, + {'axis': 2, 'shape': [4, 5, 6]}, ) def test_axis(self, shape, axis): preds = np.random.normal(size=shape) diff --git a/optax/monte_carlo/control_variates_test.py b/optax/monte_carlo/control_variates_test.py index e59aec84b..2cd8d4c65 100644 --- a/optax/monte_carlo/control_variates_test.py +++ b/optax/monte_carlo/control_variates_test.py @@ -37,7 +37,7 @@ def _assert_equal(actual, expected, rtol=1e-2, atol=1e-2): # Scalar. if not actual.shape: np.testing.assert_allclose( - np.asarray(actual), np.asarray(expected), rtol, atol + np.asarray(actual), np.asarray(expected), rtol, atol ) return @@ -81,14 +81,14 @@ def test_quadratic_function(self, effective_mean, effective_log_scale): mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32 + shape=(data_dims), dtype=jnp.float32 ) params = [mean, log_scale] dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) def function(x): - return jnp.sum(x**2) + return jnp.sum(x**2) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) @@ -107,7 +107,7 @@ def test_polynomial_function(self, effective_mean, effective_log_scale): mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32 + shape=(data_dims), dtype=jnp.float32 ) params = [mean, log_scale] @@ -115,7 +115,7 @@ def test_polynomial_function(self, effective_mean, effective_log_scale): rng = jax.random.PRNGKey(1) dist_samples = dist.sample((num_samples,), rng) def function(x): - return jnp.sum(x**5) + return jnp.sum(x**5) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) @@ -137,7 +137,7 @@ def test_non_polynomial_function(self): dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) def function(x): - return jnp.sum(jnp.log(x**2)) + return jnp.sum(jnp.log(x**2)) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) @@ -149,7 +149,7 @@ def function(x): # Second order expansion is log(\mu**2) + 1/2 * \sigma**2 (-2 / \mu**2) expected_cv_val = -np.exp(1.0) ** 2 * data_dims _assert_equal( - expected_cv(params, None), expected_cv_val, rtol=1e-1, atol=1e-3 + expected_cv(params, None), expected_cv_val, rtol=1e-1, atol=1e-3 ) @@ -164,46 +164,46 @@ def test_linear_function(self, effective_mean, effective_log_scale, decay): mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32 + shape=(data_dims), dtype=jnp.float32 ) params = [mean, log_scale] def function(x): - return jnp.sum(weights * x) + return jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) cv, expected_cv, update_state = control_variates.moving_avg_baseline( - function, - decay=decay, - zero_debias=False, - use_decay_early_training_heuristic=False, + function, + decay=decay, + zero_debias=False, + use_decay_early_training_heuristic=False, ) state_1 = jnp.array(1.0) avg_cv = jnp.mean( - _map_variant(self.variant)(cv, params, dist_samples, (state_1, 0)) + _map_variant(self.variant)(cv, params, dist_samples, (state_1, 0)) ) _assert_equal(avg_cv, state_1) _assert_equal(expected_cv(params, (state_1, 0)), state_1) state_2 = jnp.array(2.0) avg_cv = jnp.mean( - _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0)) + _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0)) ) _assert_equal(avg_cv, state_2) _assert_equal(expected_cv(params, (state_2, 0)), state_2) update_state_1 = update_state(params, dist_samples, (state_1, 0))[0] _assert_equal( - update_state_1, decay * state_1 + (1 - decay) * function(mean) + update_state_1, decay * state_1 + (1 - decay) * function(mean) ) update_state_2 = update_state(params, dist_samples, (state_2, 0))[0] _assert_equal( - update_state_2, decay * state_2 + (1 - decay) * function(mean) + update_state_2, decay * state_2 + (1 - decay) * function(mean) ) @chex.all_variants @@ -217,34 +217,34 @@ def test_linear_function_with_heuristic( mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32 + shape=(data_dims), dtype=jnp.float32 ) params = [mean, log_scale] def function(x): - return jnp.sum(weights * x) + return jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) cv, expected_cv, update_state = control_variates.moving_avg_baseline( - function, - decay=decay, - zero_debias=False, - use_decay_early_training_heuristic=True, + function, + decay=decay, + zero_debias=False, + use_decay_early_training_heuristic=True, ) state_1 = jnp.array(1.0) avg_cv = jnp.mean( - _map_variant(self.variant)(cv, params, dist_samples, (state_1, 0)) + _map_variant(self.variant)(cv, params, dist_samples, (state_1, 0)) ) _assert_equal(avg_cv, state_1) _assert_equal(expected_cv(params, (state_1, 0)), state_1) state_2 = jnp.array(2.0) avg_cv = jnp.mean( - _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0)) + _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0)) ) _assert_equal(avg_cv, state_2) _assert_equal(expected_cv(params, (state_2, 0)), state_2) @@ -252,15 +252,15 @@ def function(x): first_step_decay = 0.1 update_state_1 = update_state(params, dist_samples, (state_1, 0))[0] _assert_equal( - update_state_1, - first_step_decay * state_1 + (1 - first_step_decay) * function(mean), + update_state_1, + first_step_decay * state_1 + (1 - first_step_decay) * function(mean), ) second_step_decay = 2.0 / 11 update_state_2 = update_state(params, dist_samples, (state_2, 1))[0] _assert_equal( - update_state_2, - second_step_decay * state_2 + (1 - second_step_decay) * function(mean), + update_state_2, + second_step_decay * state_2 + (1 - second_step_decay) * function(mean), ) @parameterized.parameters([(1.0, 0.5, 0.9), (1.0, 0.5, 0.99)]) @@ -273,36 +273,36 @@ def test_linear_function_zero_debias( mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32 + shape=(data_dims), dtype=jnp.float32 ) params = [mean, log_scale] def function(x): - return jnp.sum(weights * x) + return jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) update_state = control_variates.moving_avg_baseline( - function, - decay=decay, - zero_debias=False, - use_decay_early_training_heuristic=False, + function, + decay=decay, + zero_debias=False, + use_decay_early_training_heuristic=False, )[-1] update_state_zero_debias = control_variates.moving_avg_baseline( - function, - decay=decay, - zero_debias=True, - use_decay_early_training_heuristic=False, + function, + decay=decay, + zero_debias=True, + use_decay_early_training_heuristic=False, )[-1] updated_state = update_state(params, dist_samples, (jnp.array(0.0), 0))[0] _assert_equal(updated_state, (1 - decay) * function(mean)) updated_state_zero_debias = update_state_zero_debias( - params, dist_samples, (jnp.array(0.0), 0) + params, dist_samples, (jnp.array(0.0), 0) )[0] _assert_equal(updated_state_zero_debias, function(mean)) @@ -339,7 +339,7 @@ def test_quadratic_function( params = [mean, log_scale] def function(x): - return jnp.sum(x**2) + return jnp.sum(x**2) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( @@ -410,7 +410,7 @@ def test_cubic_function( params = [mean, log_scale] def function(x): - return jnp.sum(x**3) + return jnp.sum(x**3) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( @@ -485,7 +485,7 @@ def test_forth_power_function( params = [mean, log_scale] def function(x): - return jnp.sum(x**4) + return jnp.sum(x**4) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( @@ -576,7 +576,7 @@ def test_weighted_linear_function( params = [mean, log_scale] def function(x): - return jnp.sum(weights * x) + return jnp.sum(weights * x) rng = jax.random.PRNGKey(1) cv_rng, ge_rng = jax.random.split(rng) @@ -653,7 +653,7 @@ def test_non_polynomial_function( params = [mean, log_scale] def function(x): - return jnp.log(jnp.sum(x**2)) + return jnp.log(jnp.sum(x**2)) rng = jax.random.PRNGKey(1) cv_rng, ge_rng = jax.random.split(rng) diff --git a/optax/monte_carlo/stochastic_gradient_estimators_test.py b/optax/monte_carlo/stochastic_gradient_estimators_test.py index 0fb756731..da95b8be7 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators_test.py +++ b/optax/monte_carlo/stochastic_gradient_estimators_test.py @@ -133,7 +133,9 @@ def test_constant_function(self, estimator, constant): ], named=True), ) - def test_linear_function(self, estimator, effective_mean, effective_log_scale): + def test_linear_function( + self, estimator, effective_mean, effective_log_scale + ): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] rng = jax.random.PRNGKey(1) diff --git a/optax/perturbations/_make_pert_test.py b/optax/perturbations/_make_pert_test.py index 911cc650a..857110316 100644 --- a/optax/perturbations/_make_pert_test.py +++ b/optax/perturbations/_make_pert_test.py @@ -68,8 +68,8 @@ def setUp(self): for i in range(2): example_tree.append({ - "weights": jnp.ones(weight_shapes[i]), - "biases": jnp.ones(biases_shapes[i]), + 'weights': jnp.ones(weight_shapes[i]), + 'biases': jnp.ones(biases_shapes[i]), }) self.example_tree = example_tree diff --git a/optax/projections/_projections_test.py b/optax/projections/_projections_test.py index 2c79a08fc..efe13fd0e 100644 --- a/optax/projections/_projections_test.py +++ b/optax/projections/_projections_test.py @@ -40,14 +40,14 @@ def setUp(self): array_2d = jnp.array([[0.5, 2.1, -3.5], [1.0, 2.0, 3.0]]) tree = (array_1d, array_1d) self.data = { - "array_1d": array_1d, - "array_2d": array_2d, - "tree": tree, + 'array_1d': array_1d, + 'array_2d': array_2d, + 'tree': tree, } self.fns = { - "l1": (proj.projection_l1_ball, otu.tree_l1_norm), - "l2": (proj.projection_l2_ball, otu.tree_l2_norm), - "linf": (proj.projection_linf_ball, otu.tree_linf_norm), + 'l1': (proj.projection_l1_ball, otu.tree_l1_norm), + 'l2': (proj.projection_l2_ball, otu.tree_l2_norm), + 'linf': (proj.projection_linf_ball, otu.tree_linf_norm), } def test_projection_non_negative(self): diff --git a/optax/schedules/_inject_test.py b/optax/schedules/_inject_test.py index 08d6993e5..a2add6c9b 100644 --- a/optax/schedules/_inject_test.py +++ b/optax/schedules/_inject_test.py @@ -222,7 +222,7 @@ def test_wrap_stateless_schedule(self): my_schedule(count), my_wrapped_schedule(state), atol=0.0 ) count = count + 1 - extra_args = {"loss": jnp.ones([], dtype=jnp.float32)} + extra_args = {'loss': jnp.ones([], dtype=jnp.float32)} state = my_wrapped_schedule.update(state, **extra_args) np.testing.assert_allclose(count, state, atol=0.0) @@ -240,7 +240,7 @@ def test_inject_stateful_hyperparams(self): ) state = self.variant(tx.init)(params) - extra_args = {"addendum": 0.3 * jnp.ones((), dtype=jnp.float32)} + extra_args = {'addendum': 0.3 * jnp.ones((), dtype=jnp.float32)} _, state = self.variant(tx.update)( grads, state, params=params, **extra_args ) diff --git a/optax/schedules/_schedule_test.py b/optax/schedules/_schedule_test.py index baa46e73d..3e0ed90b1 100644 --- a/optax/schedules/_schedule_test.py +++ b/optax/schedules/_schedule_test.py @@ -553,11 +553,11 @@ def test_limits(self, lr0, lr1, lr2): lr_kwargs = [] for step, lr in zip([2e3, 3e3, 5e3], [lr0, lr1, lr2]): lr_kwargs += [{ - "decay_steps": int(step), - "peak_value": lr, - "init_value": 0, - "end_value": 0.0, - "warmup_steps": 500, + 'decay_steps': int(step), + 'peak_value': lr, + 'init_value': 0, + 'end_value': 0.0, + 'warmup_steps': 500, }] schedule_fn = self.variant(_schedule.sgdr_schedule(lr_kwargs)) np.testing.assert_allclose(lr0, schedule_fn(500)) diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py index 72c2982bd..b2646e26c 100644 --- a/optax/transforms/_accumulation.py +++ b/optax/transforms/_accumulation.py @@ -178,8 +178,8 @@ def skip_not_finite( num_not_finite = jnp.sum(jnp.array(all_is_finite)) should_skip = num_not_finite > 0 return should_skip, { - "should_skip": should_skip, - "num_not_finite": num_not_finite, + 'should_skip': should_skip, + 'num_not_finite': num_not_finite, } @@ -210,7 +210,7 @@ def skip_large_updates( ) # This will also return True if `norm_sq` is NaN. should_skip = jnp.logical_not(norm_sq < max_squared_norm) - return should_skip, {"should_skip": should_skip, "norm_squared": norm_sq} + return should_skip, {'should_skip': should_skip, 'norm_squared': norm_sq} class MultiStepsState(NamedTuple): diff --git a/optax/transforms/_accumulation_test.py b/optax/transforms/_accumulation_test.py index 93cceb7a9..2eab10bf0 100644 --- a/optax/transforms/_accumulation_test.py +++ b/optax/transforms/_accumulation_test.py @@ -215,9 +215,9 @@ def test_multi_steps_every_k_schedule(self): alias.sgd(1e-4), lambda grad_step: jnp.where(grad_step < 2, 1, 3) ) opt_init, opt_update = ms_opt.gradient_transformation() - params = {"a": jnp.zeros([])} + params = {'a': jnp.zeros([])} opt_state = opt_init(params) - grad = {"a": jnp.zeros([])} + grad = {'a': jnp.zeros([])} self.assertFalse(ms_opt.has_updated(opt_state)) # First two steps have 1 mini-step per update. for _ in range(2): @@ -239,9 +239,9 @@ def test_multi_steps_zero_nans(self): every_k_schedule=2, ) opt_init, opt_update = ms_opt.gradient_transformation() - params = {"a": jnp.zeros([])} + params = {'a': jnp.zeros([])} opt_state = opt_init(params) - grad = {"a": jnp.zeros([])} + grad = {'a': jnp.zeros([])} opt_update(grad, opt_state, params) def test_multi_steps_computes_mean(self): @@ -250,9 +250,9 @@ def test_multi_steps_computes_mean(self): transform.scale(1.0), k_steps, use_grad_mean=True ) opt_init, opt_update = ms_opt.gradient_transformation() - params = {"a": jnp.zeros([])} + params = {'a': jnp.zeros([])} opt_state = opt_init(params) - grads = [{"a": jnp.ones([]) * i} for i in [1, 2, 3, 4]] + grads = [{'a': jnp.ones([]) * i} for i in [1, 2, 3, 4]] self.assertFalse(ms_opt.has_updated(opt_state)) # First 3 steps don't update. @@ -275,37 +275,37 @@ def test_multi_steps_skip_not_finite(self): opt_init, opt_update = ms_opt.gradient_transformation() opt_init = jax.jit(opt_init) opt_update = jax.jit(opt_update) - params = {"a": jnp.zeros([])} + params = {'a': jnp.zeros([])} opt_state = opt_init(params) with self.subTest('test_good_updates'): - updates, opt_state = opt_update({"a": jnp.ones([])}, opt_state, params) + updates, opt_state = opt_update({'a': jnp.ones([])}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 1) params = update.apply_updates(params, updates) - updates, opt_state = opt_update({"a": jnp.ones([])}, opt_state, params) + updates, opt_state = opt_update({'a': jnp.ones([])}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) params = update.apply_updates(params, updates) np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_inf_updates'): updates, opt_state = opt_update( - {"a": jnp.array(float('inf'))}, opt_state, params) + {'a': jnp.array(float('inf'))}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step params = update.apply_updates(params, updates) np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_nan_updates'): updates, opt_state = opt_update( - {"a": jnp.full([], float('nan'))}, opt_state, params) + {'a': jnp.full([], float('nan'))}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step params = update.apply_updates(params, updates) np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_final_good_updates'): - updates, opt_state = opt_update({"a": jnp.ones([])}, opt_state, params) + updates, opt_state = opt_update({'a': jnp.ones([])}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 1) params = update.apply_updates(params, updates) - updates, opt_state = opt_update({"a": jnp.ones([])}, opt_state, params) + updates, opt_state = opt_update({'a': jnp.ones([])}, opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) params = update.apply_updates(params, updates) np.testing.assert_array_equal( diff --git a/optax/tree_utils/_casting_test.py b/optax/tree_utils/_casting_test.py index 96358d444..48b4aace3 100644 --- a/optax/tree_utils/_casting_test.py +++ b/optax/tree_utils/_casting_test.py @@ -89,10 +89,10 @@ def test_tree_dtype(self): self.assertRaises(ValueError, otu.tree_dtype, tree, 'highest') @parameterized.named_parameters( - {"testcase_name": 'empty_dict', "tree": {}}, - {"testcase_name": 'empty_list', "tree": []}, - {"testcase_name": 'empty_tuple', "tree": ()}, - {"testcase_name": 'empty_none', "tree": None}, + {'testcase_name': 'empty_dict', 'tree': {}}, + {'testcase_name': 'empty_list', 'tree': []}, + {'testcase_name': 'empty_tuple', 'tree': ()}, + {'testcase_name': 'empty_none', 'tree': None}, ) def test_tree_dtype_utilities_with_empty_trees(self, tree): """Test tree data type utilities on empty trees.""" diff --git a/optax/tree_utils/_random_test.py b/optax/tree_utils/_random_test.py index fa7bd0d74..190206fb2 100644 --- a/optax/tree_utils/_random_test.py +++ b/optax/tree_utils/_random_test.py @@ -28,11 +28,11 @@ # We consider samplers with varying input dtypes, we do not test all possible # samplers from `jax.random`. _SAMPLER_DTYPES = ( - {"sampler": jrd.normal, "dtype": None}, - {"sampler": jrd.normal, "dtype": 'bfloat16'}, - {"sampler": jrd.normal, "dtype": 'float32'}, - {"sampler": jrd.rademacher, "dtype": 'int32'}, - {"sampler": jrd.bits, "dtype": 'uint32'}, + {'sampler': jrd.normal, 'dtype': None}, + {'sampler': jrd.normal, 'dtype': 'bfloat16'}, + {'sampler': jrd.normal, 'dtype': 'float32'}, + {'sampler': jrd.rademacher, 'dtype': 'int32'}, + {'sampler': jrd.bits, 'dtype': 'uint32'}, ) diff --git a/optax/tree_utils/_state_utils_test.py b/optax/tree_utils/_state_utils_test.py index e80475fe0..5ffd906aa 100644 --- a/optax/tree_utils/_state_utils_test.py +++ b/optax/tree_utils/_state_utils_test.py @@ -307,7 +307,7 @@ def test_tree_get_all_with_path(self): self.assertEqual(found_values, expected_result) with self.subTest('Test with optional filtering'): - state = {"hparams": {"learning_rate": 1.0}, "learning_rate": 'foo'} + state = {'hparams': {'learning_rate': 1.0}, 'learning_rate': 'foo'} # Without filtering two values are found found_values = _state_utils.tree_get_all_with_path(state, 'learning_rate') @@ -348,7 +348,7 @@ def test_tree_get_all_with_path(self): self.assertEqual(found_values, expected_result) with self.subTest('Test with nested tree containing a key'): - tree = {"a": {"a": 1.0}} + tree = {'a': {'a': 1.0}} found_values = _state_utils.tree_get_all_with_path(tree, 'a') expected_result = [ ((jtu.DictKey(key='a'),), {'a': 1.0}), @@ -402,7 +402,7 @@ def get_learning_rate(state): self.assertEqual(lr, 1 / (i + 1)) with self.subTest('Test with optional filtering'): - state = {"hparams": {"learning_rate": 1.0}, "learning_rate": 'foo'} + state = {'hparams': {'learning_rate': 1.0}, 'learning_rate': 'foo'} # Without filtering raises an error self.assertRaises(KeyError, _state_utils.tree_get, state, 'learning_rate') @@ -495,7 +495,7 @@ def set_learning_rate(state, lr): self.assertEqual(value, 2.0) with self.subTest('Test with optional filtering'): - state = {"hparams": {"learning_rate": 1.0}, "learning_rate": 'foo'} + state = {'hparams': {'learning_rate': 1.0}, 'learning_rate': 'foo'} new_state = _state_utils.tree_set( state, lambda _, value: isinstance(value, float), @@ -511,23 +511,23 @@ def set_learning_rate(state, lr): self.assertEqual(found_values, expected_result) with self.subTest('Test with nested trees and filtering'): - tree = {"a": {"a": 1.0}, "b": {"a": 1}} + tree = {'a': {'a': 1.0}, 'b': {'a': 1}} new_tree = _state_utils.tree_set( tree, lambda _, value: isinstance(value, float), a=2.0, ) - expected_result = {"a": {"a": 2.0}, "b": {"a": 1}} + expected_result = {'a': {'a': 2.0}, 'b': {'a': 1}} self.assertEqual(new_tree, expected_result) with self.subTest('Test setting a subtree'): - tree = {"a": {"a": 1.0}, "b": {"a": 1}} + tree = {'a': {'a': 1.0}, 'b': {'a': 1}} new_tree = _state_utils.tree_set( tree, lambda _, value: isinstance(value, dict), - a={"c": 0.0}, + a={'c': 0.0}, ) - expected_result = {"a": {"c": 0.0}, "b": {"a": 1}} + expected_result = {'a': {'c': 0.0}, 'b': {'a': 1}} self.assertEqual(new_tree, expected_result) with self.subTest('Test setting a specific state'): diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index b6082ea5c..7c753cf7c 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -45,12 +45,12 @@ def setUp(self): self.tree_b_dict_jax = jax.tree.map(jnp.array, self.tree_b_dict) self.data = { - "tree_a": self.tree_a, - "tree_b": self.tree_b, - "tree_a_dict": self.tree_a_dict, - "tree_b_dict": self.tree_b_dict, - "array_a": self.array_a, - "array_b": self.array_b, + 'tree_a': self.tree_a, + 'tree_b': self.tree_b, + 'tree_a_dict': self.tree_a_dict, + 'tree_b_dict': self.tree_b_dict, + 'array_a': self.array_a, + 'array_b': self.array_b, } def test_tree_add(self): @@ -269,20 +269,20 @@ def test_empty_tree_reduce(self): @parameterized.named_parameters( { - "testcase_name": 'tree_add_scalar_mul', - "operation": lambda m: tu.tree_add_scalar_mul(None, 1, m), + 'testcase_name': 'tree_add_scalar_mul', + 'operation': lambda m: tu.tree_add_scalar_mul(None, 1, m), }, { - "testcase_name": 'tree_update_moment', - "operation": lambda m: tu.tree_update_moment(None, m, 1, 1), + 'testcase_name': 'tree_update_moment', + 'operation': lambda m: tu.tree_update_moment(None, m, 1, 1), }, { - "testcase_name": 'tree_update_infinity_moment', - "operation": lambda m: tu.tree_update_infinity_moment(None, m, 1, 1), + 'testcase_name': 'tree_update_infinity_moment', + 'operation': lambda m: tu.tree_update_infinity_moment(None, m, 1, 1), }, { - "testcase_name": 'tree_update_moment_per_elem_norm', - "operation": lambda m: tu.tree_update_moment_per_elem_norm( + 'testcase_name': 'tree_update_moment_per_elem_norm', + 'operation': lambda m: tu.tree_update_moment_per_elem_norm( None, m, 1, 1 ), }, diff --git a/test.sh b/test.sh index 1134ada3a..e6d258535 100755 --- a/test.sh +++ b/test.sh @@ -13,78 +13,79 @@ # limitations under the License. # ============================================================================== -set -o errexit -set -o nounset -set -o pipefail - function cleanup { deactivate + rm -r "${TEMP_DIR}" } trap cleanup EXIT -echo "Deleting test environment (if it exists)" -rm -rf test_venv +REPO_DIR=$(pwd) +TEMP_DIR=$(mktemp --directory) -echo "Creating test environment" -python3 -m venv test_venv +set -o errexit +set -o nounset +set -o pipefail -echo "Activating test environment" -source test_venv/bin/activate +# Install deps in a virtual env. +python3 -m venv "${TEMP_DIR}/test_venv" +source "${TEMP_DIR}/test_venv/bin/activate" -for requirement in "pip" "setuptools" "wheel" "build" ".[test]" ".[examples]" ".[dev]" ".[docs]" -do - echo "Installing" $requirement - python3 -m pip install -qU $requirement -done +# Install dependencies. +python3 -m pip install --quiet --upgrade pip setuptools wheel +python3 -m pip install --quiet --upgrade flake8 pytest-xdist pylint pylint-exit +python3 -m pip install --quiet --editable ".[test, examples]" # Dp-accounting specifies exact minor versions as requirements which sometimes # become incompatible with other libraries optax needs. We therefore install # dependencies for dp-accounting manually. # TODO(b/239416992): Remove this workaround if dp-accounting switches to minimum # version requirements. -echo "Installing .[dp-accounting]" -python3 -m pip install -qU --editable ".[dp-accounting]" -python3 -m pip install -qU --no-deps "dp-accounting>=0.1.1" +python3 -m pip install --quiet --editable ".[dp-accounting]" +python3 -m pip install --quiet --no-deps "dp-accounting>=0.1.1" -echo "Installing requested JAX version" # Install the requested JAX version if [ -z "${JAX_VERSION-}" ]; then : # use version installed in requirements above elif [ "$JAX_VERSION" = "newest" ]; then - python3 -m pip install -qU jax jaxlib + python3 -m pip install --quiet --upgrade jax jaxlib else - python3 -m pip install -qU "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" + python3 -m pip install --quiet "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" fi # Ensure optax was not installed by one of the dependencies above, # since if it is, the tests below will be run against that version instead of # the branch build. -echo "Uninstalling optax (if already installed)" -python3 -m pip uninstall -q --yes optax +python3 -m pip uninstall --quiet --yes optax -echo "Linting with flake8" -flake8 +# Lint with flake8. +python3 -m flake8 --select=E9,F63,F7,F82,E225,E251 --show-source --statistics -echo "Linting with pylint" -pylint . --rcfile pyproject.toml +# Lint with pylint. +PYLINT_ARGS="-efail -wfail -cfail -rfail" +# Append specific config lines. +# Lint modules and tests separately. +python3 -m pylint --rcfile=.pylintrc $(find optax -name '*.py' | grep -v 'test.py' | xargs) -d E1102 || pylint-exit $PYLINT_ARGS $? +# Disable protected-access warnings for tests. +python3 -m pylint --rcfile=.pylintrc $(find optax -name '*_test.py' | xargs) -d W0212,E1102 || pylint-exit $PYLINT_ARGS $? -echo "Building the package" +# Build the package. +python3 -m pip install --quiet build python3 -m build +python3 -m pip wheel --no-deps dist/optax-*.tar.gz --wheel-dir "${TEMP_DIR}" +python3 -m pip install --quiet "${TEMP_DIR}/optax-"*.whl -echo "Building wheel" -python3 -m pip wheel --no-deps dist/optax-*.tar.gz - -echo "Installing the wheel" -python3 -m pip install -qU optax-*.whl - -echo "Checking types with pytype" -pytype +# Check types with pytype. +python3 -m pip install --quiet pytype +pytype "optax" --keep-going --disable import-error -echo "Running tests with pytest" -pytest +# Run tests using pytest. +# Change directory to avoid importing the package from repo root. +cd "${TEMP_DIR}" +python3 -m pytest --numprocesses auto --pyargs optax +cd "${REPO_DIR}" -echo "Building sphinx docs" # Build Sphinx docs. +python3 -m pip install --quiet --editable ".[docs]" # NOTE(vroulet) We have dependencies issues: # tensorflow > 2.13.1 requires ml-dtypes <= 0.3.2 # but jax requires ml-dtypes >= 0.4.0 @@ -94,10 +95,9 @@ echo "Building sphinx docs" # bug (which issues conflict warnings but runs fine). # A long term solution is probably to fully remove tensorflow from our # dependencies. +python3 -m pip install --upgrade --verbose typing_extensions cd docs -echo "make html" make html -echo "make doctest" make doctest # run doctests cd ..