Skip to content

Commit

Permalink
Add JAX array to registry (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg authored Jun 15, 2022
1 parent 6ff956b commit aa0e218
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 27 deletions.
28 changes: 27 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
os: ['ubuntu-latest', 'macos-latest']
python-version: ['3.7', '3.8', '3.9', '3.10']

steps:
Expand All @@ -46,6 +46,32 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}

run-tests-windows:

name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: ['windows-latest']
python-version: ['3.7', '3.8', '3.9', '3.10']

steps:
- uses: actions/checkout@v2
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
python-version: ${{ matrix.python-version }}

- name: Install core dependencies.
shell: bash -l {0}
run: conda install -c conda-forge tox-conda

- name: Run pytest.
shell: bash -l {0}
run: tox -e pytest-windows -- -m "not slow"

docs:

name: Run documentation.
Expand Down
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.3.0
hooks:
- id: check-merge-conflict
- id: debug-statements
Expand All @@ -11,7 +11,7 @@ repos:
- id: reorder-python-imports
types: [python]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.3.0
hooks:
- id: check-added-large-files
args: ['--maxkb=100']
Expand Down Expand Up @@ -42,13 +42,13 @@ repos:
rev: v1.12.1
hooks:
- id: blacken-docs
additional_dependencies: [black]
additional_dependencies: [black==22.3.0]
types: [rst]
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
types: [python]
language_version: python3.9
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
Expand All @@ -71,7 +71,7 @@ repos:
Pygments,
]
- repo: https://github.com/PyCQA/doc8
rev: 0.11.1
rev: 0.11.2
hooks:
- id: doc8
- repo: meta
Expand All @@ -84,7 +84,7 @@ repos:
hooks:
- id: check-manifest
- repo: https://github.com/PyCQA/doc8
rev: 0.11.1
rev: 0.11.2
hooks:
- id: doc8
- repo: https://github.com/asottile/setup-cfg-fmt
Expand All @@ -102,7 +102,7 @@ repos:
hooks:
- id: codespell
- repo: https://github.com/asottile/pyupgrade
rev: v2.32.1
rev: v2.34.0
hooks:
- id: pyupgrade
args: [--py37-plus]
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ dependencies:
- pdbpp
- numpy
- pandas
- jax
- jaxlib
9 changes: 9 additions & 0 deletions src/pybaum/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,12 @@
IS_PANDAS_INSTALLED = False
else:
IS_PANDAS_INSTALLED = True


try:
import jax # noqa: F401
import jaxlib # noqa: F401
except ImportError:
IS_JAX_INSTALLED = False
else:
IS_JAX_INSTALLED = True
10 changes: 10 additions & 0 deletions src/pybaum/equality.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions to check equality of pytree leaves."""
from pybaum.config import IS_JAX_INSTALLED
from pybaum.config import IS_NUMPY_INSTALLED
from pybaum.config import IS_PANDAS_INSTALLED

Expand All @@ -10,6 +11,9 @@
if IS_PANDAS_INSTALLED:
import pandas as pd

if IS_JAX_INSTALLED:
import jaxlib


EQUALITY_CHECKERS = {}

Expand All @@ -21,3 +25,9 @@
if IS_PANDAS_INSTALLED:
EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b)
EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b)


if IS_JAX_INSTALLED:
EQUALITY_CHECKERS[jaxlib.xla_extension.DeviceArray] = lambda a, b: bool(
(a == b).all()
)
1 change: 1 addition & 0 deletions src/pybaum/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def get_registry(types=None, include_defaults=True):
- :obj:`None`
- :class:`collections.OrderedDict`
- "numpy.ndarray"
- "jax.numpy.ndarray"
- "pandas.Series"
- "pandas.DataFrame"
include_defaults (bool): Whether the default pytree containers "tuple", "dict"
Expand Down
22 changes: 22 additions & 0 deletions src/pybaum/registry_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import OrderedDict
from itertools import product

from pybaum.config import IS_JAX_INSTALLED
from pybaum.config import IS_NUMPY_INSTALLED
from pybaum.config import IS_PANDAS_INSTALLED

Expand All @@ -12,6 +13,10 @@
if IS_PANDAS_INSTALLED:
import pandas as pd

if IS_JAX_INSTALLED:
import jax
import jaxlib


def _none():
"""Create registry entry for NoneType."""
Expand Down Expand Up @@ -117,6 +122,22 @@ def _array_element_names(arr):
return names


def _jax_array():
if IS_JAX_INSTALLED:
entry = {
jaxlib.xla_extension.DeviceArray: {
"flatten": lambda arr: (arr.flatten().tolist(), arr.shape),
"unflatten": lambda aux_data, leaves: jax.numpy.array(leaves).reshape(
aux_data
),
"names": _array_element_names,
},
}
else:
entry = {}
return entry


def _pandas_series():
"""Create registry entry for pandas.Series."""
if IS_PANDAS_INSTALLED:
Expand Down Expand Up @@ -186,6 +207,7 @@ def _index_element_to_string(element):
"tuple": _tuple,
"dict": _dict,
"numpy.ndarray": _numpy_array,
"jax.numpy.ndarray": _jax_array,
"pandas.Series": _pandas_series,
"pandas.DataFrame": _pandas_dataframe,
"None": _none,
Expand Down
10 changes: 3 additions & 7 deletions src/pybaum/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
- The treedef containing information to unflatten pytrees is implemented differently.
"""
import itertools

from pybaum.equality import EQUALITY_CHECKERS
from pybaum.registry import get_registry
from pybaum.typecheck import get_type
Expand Down Expand Up @@ -42,8 +40,8 @@ def tree_flatten(tree, is_leaf=None, registry=None):
is_leaf = _process_is_leaf(is_leaf)

flat = _tree_flatten(tree, is_leaf=is_leaf, registry=registry)
dummy_flat = ["*"] * len(flat)
treedef = tree_unflatten(tree, dummy_flat, is_leaf=is_leaf, registry=registry)
# unflatten the flat tree to make a copy
treedef = tree_unflatten(tree, flat, is_leaf=is_leaf, registry=registry)
return flat, treedef


Expand Down Expand Up @@ -124,9 +122,7 @@ def tree_yield(tree, is_leaf=None, registry=None):
is_leaf = _process_is_leaf(is_leaf)

flat = _tree_yield(tree, is_leaf=is_leaf, registry=registry)
dummy_flat = itertools.repeat("*")
treedef = tree_unflatten(tree, dummy_flat, is_leaf=is_leaf, registry=registry)
return flat, treedef
return flat, tree


def tree_just_yield(tree, is_leaf=None, registry=None):
Expand Down
17 changes: 5 additions & 12 deletions tests/test_tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,13 @@ def example_flat():


@pytest.fixture
def example_treedef():
return (["*", "*", {"a": "*", "b": "*"}], "*")
def example_treedef(example_tree):
return example_tree


@pytest.fixture
def extended_treedef():
return (
[
"*",
np.array(["*", "*"]),
{"a": pd.Series(["*", "*"], index=["c", "d"]), "b": "*"},
],
"*",
)
def extended_treedef(example_tree):
return example_tree


@pytest.fixture
Expand Down Expand Up @@ -195,7 +188,7 @@ def test_flatten_df_all_columns():
def test_tree_yield(example_tree, example_treedef, example_flat):
generator, treedef = tree_yield(example_tree)

assert treedef == example_treedef
assert tree_equal(treedef, example_treedef)
assert inspect.isgenerator(generator)
for a, b in zip(generator, example_flat):
if isinstance(a, (np.ndarray, pd.Series)):
Expand Down
45 changes: 45 additions & 0 deletions tests/test_with_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
from pybaum.config import IS_JAX_INSTALLED
from pybaum.registry import get_registry
from pybaum.tree_util import leaf_names
from pybaum.tree_util import tree_equal
from pybaum.tree_util import tree_flatten
from pybaum.tree_util import tree_just_flatten

if IS_JAX_INSTALLED:
import jax.numpy as jnp
else:
# run the tests with normal numpy instead
import numpy as jnp


@pytest.fixture
def tree():
return {"a": {"b": jnp.arange(4).reshape(2, 2)}, "c": jnp.ones(2)}


@pytest.fixture
def flat():
return [0, 1, 2, 3, 1, 1]


@pytest.fixture
def registry():
return get_registry(types=["jax.numpy.ndarray", "numpy.ndarray"])


def test_tree_just_flatten_with_jax(tree, registry, flat):
got = tree_just_flatten(tree, registry=registry)
assert got == flat


def test_tree_flatten_with_jax(tree, registry, flat):
got_flat, got_treedef = tree_flatten(tree, registry=registry)
assert got_flat == flat
assert tree_equal(got_treedef, tree)


def test_leaf_names_with_jax(tree, registry):
got = leaf_names(tree, registry=registry)
expected = ["a_b_0_0", "a_b_0_1", "a_b_1_0", "a_b_1_1", "c_0", "c_1"]
assert got == expected
18 changes: 18 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@ skip_missing_interpreters = True
basepython = python

[testenv:pytest]
setenv =
CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1
conda_channels =
conda-forge
defaults
conda_deps =
conda-build
numpy
pandas
pytest
pytest-cov
pytest-mock
pytest-xdist
jax
jaxlib
commands = pytest {posargs}

[testenv:pytest-windows]
setenv =
CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1
conda_channels =
Expand Down

0 comments on commit aa0e218

Please sign in to comment.