Skip to content

Commit

Permalink
Make flax an optional dependency.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635381166
  • Loading branch information
tomhennigan authored and copybara-github committed May 20, 2024
1 parent bdaabf9 commit a4304f7
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt -r requirements-jax.txt -r requirements-test.txt
pip install -r requirements.txt -r requirements-jax.txt -r requirements-flax.txt -r requirements-test.txt
pip install .
pip install pytest pytest-xdist
- name: Print installed dependencies
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
run: |
sudo apt install -y pandoc
python -m pip install --upgrade pip
pip install -r requirements.txt -r requirements-jax.txt -r docs/requirements.txt
pip install -r requirements.txt -r requirements-jax.txt -r requirements-flax.txt -r docs/requirements.txt
pip install .
- name: Print installed dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt -r requirements-jax.txt -r requirements-test.txt
pip install -r requirements.txt -r requirements-jax.txt -r requirements-flax.txt -r requirements-test.txt
pip install .
pip install pytest pytest-xdist
- name: Print installed dependencies
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ include README.md
include LICENSE
include requirements.txt
include requirements-jax.txt
include requirements-flax.txt
include requirements-test.txt
include haiku/py.typed
6 changes: 5 additions & 1 deletion haiku/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@
from haiku._src.summarise import MethodInvocation
from haiku._src.summarise import ModuleDetails
from haiku._src.summarise import tabulate
from haiku.experimental import flax
from haiku.experimental import jaxpr_info

try:
from haiku.experimental import flax # pylint: disable=g-import-not-at-top
except ImportError:
flax = None

# TODO(tomhennigan): Remove deprecated alias.
ParamContext = GetterContext

Expand Down
1 change: 1 addition & 0 deletions readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ python:
install:
- requirements: requirements.txt
- requirements: requirements-jax.txt
- requirements: requirements-flax.txt
- requirements: docs/requirements.txt

# Additional formats of documentation to be built apart from HTML
Expand Down
1 change: 1 addition & 0 deletions requirements-flax.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
flax>=0.7.1
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ absl-py>=0.7.1
jmp>=0.0.2
numpy>=1.18.0
tabulate>=0.8.9
flax>=0.7.1
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def _parse_requirements(requirements_txt_path):
# Contained modules and scripts.
packages=find_namespace_packages(exclude=['*_test.py', 'examples']),
install_requires=_parse_requirements('requirements.txt'),
extras_require={'jax': _parse_requirements('requirements-jax.txt')},
extras_require={
'jax': _parse_requirements('requirements-jax.txt'),
'flax': _parse_requirements('requirements-flax.txt'),
},
tests_require=_parse_requirements('requirements-test.txt'),
requires_python='>=3.9',
include_package_data=True,
Expand Down
6 changes: 5 additions & 1 deletion test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ python --version

# Install dependencies.
python -m pip install --upgrade pip setuptools
pip install -r requirements.txt -r requirements-jax.txt -r requirements-test.txt
pip install \
-r requirements.txt \
-r requirements-jax.txt \
-r requirements-flax.txt \
-r requirements-test.txt
python -c 'import jax; print(jax.__version__)'

# Run setup.py to install Haiku itself.
Expand Down

0 comments on commit a4304f7

Please sign in to comment.