Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to test.sh and pyproject.toml. #1110

Merged
merged 3 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ docs/sg_execution_times.rst
.idea
.vscode

# Running tests
test_venv/
optax-*.whl
26 changes: 11 additions & 15 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,27 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.

# pylint: disable=g-bad-import-order
# pylint: disable=g-import-not-at-top
# pylint: disable=invalid-name

import inspect
import os
import sys

from sphinxcontrib import katex
import optax


def _add_annotations_import(path):
"""Appends a future annotations import to the file at the given path."""
with open(path) as f:
with open(path, encoding='utf-8') as f:
contents = f.read()
if contents.startswith('from __future__ import annotations'):
# If we run sphinx multiple times then we will append the future import
# multiple times too.
return

assert contents.startswith('#'), (path, contents.split('\n')[0])
with open(path, 'w') as f:
with open(path, 'w', encoding='utf-8') as f:
# NOTE: This is subtle and not unit tested, we're prefixing the first line
# in each Python file with this future import. It is important to prefix
# not insert a newline such that source code locations are accurate (we link
Expand All @@ -64,9 +67,6 @@ def _recursive_add_annotations_import():
sys.path.insert(0, os.path.abspath('../'))
sys.path.append(os.path.abspath('ext'))

import optax
from sphinxcontrib import katex

# -- Project information -----------------------------------------------------

project = 'Optax'
Expand Down Expand Up @@ -237,8 +237,7 @@ def linkcode_resolve(domain, info):
obj = getattr(obj, attr)
except AttributeError:
return None
else:
obj = inspect.unwrap(obj)
obj = inspect.unwrap(obj)

try:
filename = inspect.getsourcefile(obj)
Expand All @@ -251,13 +250,10 @@ def linkcode_resolve(domain, info):
return None

# TODO(slebedev): support tags after we release an initial version.
path = os.path.relpath(filename, start=os.path.dirname(optax.__file__))
return (
'https://github.com/google-deepmind/optax/tree/main/optax/%s#L%d#L%d'
% (
os.path.relpath(filename, start=os.path.dirname(optax.__file__)),
lineno,
lineno + len(source) - 1,
)
'https://github.com/google-deepmind/optax/tree/main/optax/'
f'{path}#L{lineno}#L{lineno + len(source) - 1}'
)


Expand Down
12 changes: 9 additions & 3 deletions docs/ext/coverage_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import types
from typing import Any, Sequence, Tuple

import optax
from sphinx import application
from sphinx import builders
from sphinx import errors
import optax


def find_internal_python_modules(
Expand Down Expand Up @@ -65,7 +65,7 @@ class OptaxCoverageCheck(builders.Builder):
def get_outdated_docs(self) -> str:
return "coverage_check"

def write(self, *ignored: Any) -> None:
def write(self, *ignored: Any) -> None: # pylint: disable=overridden-final-method
pass

def finish(self) -> None:
Expand All @@ -78,7 +78,13 @@ def finish(self) -> None:
"forget to add an entry to `api.rst`?\n"
f"Undocumented symbols: {undocumented_objects}")

def get_target_uri(self, docname, typ=None):
raise NotImplementedError

def write_doc(self, docname, doctree):
raise NotImplementedError


def setup(app: application.Sphinx) -> Mapping[str, Any]:
app.add_builder(OptaxCoverageCheck)
return dict(version=optax.__version__, parallel_read_safe=True)
return {"version": optax.__version__, "parallel_read_safe": True}
23 changes: 9 additions & 14 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# ==============================================================================
"""Optax: composable gradient processing and optimization, in JAX."""

# pylint: disable=wrong-import-position
# pylint: disable=g-importing-member
import typing as _typing

from optax import assignment
from optax import contrib
Expand Down Expand Up @@ -185,6 +184,14 @@
from optax._src.wrappers import skip_large_updates
from optax._src.wrappers import skip_not_finite

# TODO(mtthss): remove contrib aliases from flat namespace once users updated.
# Deprecated modules
from optax.contrib import differentially_private_aggregate as \
_deprecated_differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState as \
_deprecated_DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd as _deprecated_dpsgd


# TODO(mtthss): remove tree_utils aliases after updates.
tree_map_params = tree_utils.tree_map_params
Expand Down Expand Up @@ -236,13 +243,6 @@
squared_error = losses.squared_error
sigmoid_focal_loss = losses.sigmoid_focal_loss

# pylint: disable=g-import-not-at-top
# TODO(mtthss): remove contrib aliases from flat namespace once users updated.
# Deprecated modules
from optax.contrib import differentially_private_aggregate as _deprecated_differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState as _deprecated_DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd as _deprecated_dpsgd

_deprecations = {
# Added Apr 2024
"differentially_private_aggregate": (
Expand All @@ -269,8 +269,6 @@
_deprecated_dpsgd,
),
}
# pylint: disable=g-bad-import-order
import typing as _typing

if _typing.TYPE_CHECKING:
# pylint: disable=reimported
Expand All @@ -285,9 +283,6 @@
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
# pylint: enable=g-bad-import-order
# pylint: enable=g-import-not-at-top
# pylint: enable=g-importing-member


__version__ = "0.2.4.dev"
Expand Down
Loading
Loading