diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c67649c..63104b6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,13 +10,18 @@ on: jobs: build-and-test: - name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" + name: "Python ${{ matrix.python-version }} on ${{ matrix.os }} jax=${{ matrix.jax-version}}" runs-on: "${{ matrix.os }}" strategy: matrix: python-version: ["3.9", "3.10", "3.11"] os: [ubuntu-latest] + jax-version: ["newest"] + include: + - python-version: "3.9" + os: "ubuntu-latest" + jax-version: "0.4.27" # Keep this in sync with version in pyproject.toml steps: - uses: "actions/checkout@v2" @@ -26,5 +31,5 @@ jobs: cache: "pip" cache-dependency-path: '**/requirements*.txt' - name: Run CI tests - run: bash test.sh + run: JAX_VERSION="${{ matrix.jax-version }}" bash test.sh shell: bash diff --git a/requirements/requirements.txt b/requirements/requirements.txt index cc86154..c46e256 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,7 @@ absl-py>=0.9.0 typing_extensions>=4.2.0 -jax>=0.4.16 -jaxlib>=0.1.37 +jax>=0.4.27 +jaxlib>=0.4.27 numpy>=1.24.1 setuptools;python_version>="3.12" toolz>=0.9.0 diff --git a/test.sh b/test.sh index 9761ae6..dd701a5 100755 --- a/test.sh +++ b/test.sh @@ -32,6 +32,15 @@ pip install flake8 pytest-xdist pylint pylint-exit pip install -r requirements/requirements.txt pip install -r requirements/requirements-test.txt +# Install the requested JAX version +if [ "$JAX_VERSION" = "" ]; then + : # use version installed in requirements above +elif [ "$JAX_VERSION" = "newest" ]; then + pip install -U jax jaxlib +else + pip install "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" +fi + # Lint with flake8. flake8 `find chex -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics