diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ace24ede..a37bf0047 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt -r requirements-jax.txt -r requirements-test.txt + pip install -r requirements.txt -r requirements-jax.txt -r requirements-flax.txt -r requirements-test.txt pip install . pip install pytest pytest-xdist - name: Print installed dependencies diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index fb4d8e8fc..02a64ec33 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -27,7 +27,7 @@ jobs: run: | sudo apt install -y pandoc python -m pip install --upgrade pip - pip install -r requirements.txt -r requirements-jax.txt -r docs/requirements.txt + pip install -r requirements.txt -r requirements-jax.txt -r requirements-flax.txt -r docs/requirements.txt pip install . - name: Print installed dependencies run: | diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index cf2fbb342..209ab4905 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt -r requirements-jax.txt -r requirements-test.txt + pip install -r requirements.txt -r requirements-jax.txt -r requirements-flax.txt -r requirements-test.txt pip install . pip install pytest pytest-xdist - name: Print installed dependencies diff --git a/MANIFEST.in b/MANIFEST.in index 4cc82c5dd..809c60dc2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,5 +2,6 @@ include README.md include LICENSE include requirements.txt include requirements-jax.txt +include requirements-flax.txt include requirements-test.txt include haiku/py.typed diff --git a/haiku/experimental/__init__.py b/haiku/experimental/__init__.py index 0e85e0809..de99adc68 100644 --- a/haiku/experimental/__init__.py +++ b/haiku/experimental/__init__.py @@ -52,9 +52,13 @@ from haiku._src.summarise import MethodInvocation from haiku._src.summarise import ModuleDetails from haiku._src.summarise import tabulate -from haiku.experimental import flax from haiku.experimental import jaxpr_info +try: + from haiku.experimental import flax # pylint: disable=g-import-not-at-top +except ImportError: + flax = None + # TODO(tomhennigan): Remove deprecated alias. ParamContext = GetterContext diff --git a/readthedocs.yml b/readthedocs.yml index 7673a0972..4a18649c4 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -17,6 +17,7 @@ python: install: - requirements: requirements.txt - requirements: requirements-jax.txt + - requirements: requirements-flax.txt - requirements: docs/requirements.txt # Additional formats of documentation to be built apart from HTML diff --git a/requirements-flax.txt b/requirements-flax.txt new file mode 100644 index 000000000..1d9051262 --- /dev/null +++ b/requirements-flax.txt @@ -0,0 +1 @@ +flax>=0.7.1 diff --git a/requirements.txt b/requirements.txt index 62588eb19..ed929b68f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,3 @@ absl-py>=0.7.1 jmp>=0.0.2 numpy>=1.18.0 tabulate>=0.8.9 -flax>=0.7.1 diff --git a/setup.py b/setup.py index 25b39b183..f98cb65e0 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,10 @@ def _parse_requirements(requirements_txt_path): # Contained modules and scripts. packages=find_namespace_packages(exclude=['*_test.py', 'examples']), install_requires=_parse_requirements('requirements.txt'), - extras_require={'jax': _parse_requirements('requirements-jax.txt')}, + extras_require={ + 'jax': _parse_requirements('requirements-jax.txt'), + 'flax': _parse_requirements('requirements-flax.txt'), + }, tests_require=_parse_requirements('requirements-test.txt'), requires_python='>=3.9', include_package_data=True, diff --git a/test.sh b/test.sh index 7af30f4f7..ed0cfeb08 100755 --- a/test.sh +++ b/test.sh @@ -33,7 +33,11 @@ python --version # Install dependencies. python -m pip install --upgrade pip setuptools -pip install -r requirements.txt -r requirements-jax.txt -r requirements-test.txt +pip install \ + -r requirements.txt \ + -r requirements-jax.txt \ + -r requirements-flax.txt \ + -r requirements-test.txt python -c 'import jax; print(jax.__version__)' # Run setup.py to install Haiku itself.