Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split github workflows for lower latency, add ruff #1156

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,71 @@ on:
- cron: '0 3 * * *'

jobs:
linting:
name: "Lint check with flake8 and pylint"
runs-on: "ubuntu-latest"
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: "pyproject.toml"
- name: Install linting dependencies
run: |
pip install -U pip setuptools wheel
pip install -U flake8 pytest-xdist pylint pylint-exit
- name: Lint with flake8
run: |
python3 -m flake8 --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
- name: Lint module files with pylint
run: |
PYLINT_ARGS="-efail -wfail -cfail -rfail"
python3 -m pylint --rcfile=.pylintrc $(find optax -name '*.py' | grep -v 'test.py' | xargs) -d E1102 || pylint-exit $PYLINT_ARGS $?
- name: Lint test files with pylint
run: |
PYLINT_ARGS="-efail -wfail -cfail -rfail"
python3 -m pylint --rcfile=.pylintrc $(find optax -name '*_test.py' | xargs) -d W0212,E1102 || pylint-exit $PYLINT_ARGS $?
ruff-lint:
name: "Lint check with ruff"
runs-on: "ubuntu-latest"
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: "pyproject.toml"
- name: Install ruff and lint check
run: |
pip install -U ruff
ruff check .
doctests:
needs: [linting, ruff-lint] # do not run doctests if linting fails
name: "Doctests on ${{ matrix.os }} with Python ${{ matrix.python-version }}"
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: ["3.11"] # only build docs with a somewhat latest python
os: [ubuntu-latest]
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
with:
python-version: "${{ matrix.python-version }}"
cache: "pip"
cache-dependency-path: 'pyproject.toml'
- name: Build docs and run doctests
run: |
python3 -m pip install --quiet --editable ".[docs]"
cd docs
make html
make doctest # run doctests
shell: bash
build-and-test:
needs: [linting, ruff-lint] # do not run tests if linting fails
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }} jax=${{ matrix.jax-version }}"
runs-on: "${{ matrix.os }}"

strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
Expand All @@ -22,7 +83,6 @@ jobs:
- python-version: "3.9"
os: "ubuntu-latest"
jax-version: "0.4.27" # Keep version in sync with pyproject.toml and copy.bara.sky!

steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
Expand All @@ -39,7 +99,6 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Check links
uses: gaurav-nelson/github-action-markdown-link-check@v1
with:
Expand Down
4 changes: 2 additions & 2 deletions examples/contrib/reduce_on_plateau.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@
"source": [
"opt = optax.chain(\n",
" optax.adam(LEARNING_RATE),\n",
" reduce_on_plateau(\n",
" contrib.reduce_on_plateau(\n",
" patience=PATIENCE,\n",
" cooldown=COOLDOWN,\n",
" factor=FACTOR,\n",
Expand Down Expand Up @@ -759,7 +759,7 @@
}
],
"source": [
"transform = reduce_on_plateau(\n",
"transform = contrib.reduce_on_plateau(\n",
" patience=PATIENCE,\n",
" cooldown=COOLDOWN,\n",
" factor=FACTOR,\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/linear_assignment_problem.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"outputs": [],
"source": [
"import networkx as nx\n",
"from jax import numpy as jnp, random\n",
"from jax import random\n",
"import optax\n",
"from matplotlib import pyplot as plt"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/nanolm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@
}
],
"source": [
"plt.title(f\"Convergence of adamw (train loss)\")\n",
"plt.title(\"Convergence of adamw (train loss)\")\n",
"plt.plot(all_train_losses, label=\"train\", lw=3)\n",
"plt.plot(\n",
" jnp.arange(0, len(all_eval_losses) * N_FREQ_EVAL, N_FREQ_EVAL),\n",
Expand Down
2 changes: 1 addition & 1 deletion optax/schedules/_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def _convert_floats(x, dtype):
"""Convert float-like inputs to dtype, rest pass through."""
if jax.dtypes.scalar_type_of(x) == float:
if jax.dtypes.scalar_type_of(x) is float:
return jnp.asarray(x, dtype=dtype)
return x

Expand Down
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,17 @@ dp-accounting = [
[tool.setuptools.packages.find]
include = ["README.md", "LICENSE"]
exclude = ["*_test.py"]

[tool.ruff.lint]
select = [
"F",
"E",
]
ignore = [
"E731", # lambdas are allowed
"E501", # don't check line lengths
"F401", # allow unused imports
"E402", # allow modules not at top of file
"E741", # allow "l" as a variable name
"E703", # allow semicolons (for jupyter notebooks)
]
3 changes: 3 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,7 @@ make html
make doctest # run doctests
cd ..

pip install -U ruff
ruff check .

echo "All tests passed. Congrats!"
Loading