From 5f3fdab246eb306920c70e2fd63843cd494df10f Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Fri, 16 Dec 2022 11:31:46 +0100 Subject: [PATCH] Fix for new jax version. (#22) --- .pre-commit-config.yaml | 24 ++++++------ setup.cfg | 4 -- src/pybaum/equality.py | 8 +--- src/pybaum/registry_entries.py | 6 +-- src/pybaum/typecheck.py | 72 +++++++++++++++++++++++++++++++--- tests/test_typecheck.py | 4 +- 6 files changed, 84 insertions(+), 34 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 258fe23..c215738 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,17 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-merge-conflict - id: debug-statements - id: end-of-file-fixer - repo: https://github.com/asottile/reorder_python_imports - rev: v3.1.0 + rev: v3.9.0 hooks: - id: reorder-python-imports types: [python] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-added-large-files args: ['--maxkb=100'] @@ -45,12 +45,12 @@ repos: additional_dependencies: [black==22.3.0] types: [rst] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 22.12.0 hooks: - id: black - language_version: python3.9 + language_version: python3.10 - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 5.0.4 hooks: - id: flake8 types: [python] @@ -71,7 +71,7 @@ repos: Pygments, ] - repo: https://github.com/PyCQA/doc8 - rev: 0.11.2 + rev: v1.0.0 hooks: - id: doc8 - repo: meta @@ -80,15 +80,15 @@ repos: - id: check-useless-excludes # - id: identity # Prints all files passed to pre-commits. Debugging. - repo: https://github.com/mgedmin/check-manifest - rev: "0.48" + rev: "0.49" hooks: - id: check-manifest - repo: https://github.com/PyCQA/doc8 - rev: 0.11.2 + rev: v1.0.0 hooks: - id: doc8 - repo: https://github.com/asottile/setup-cfg-fmt - rev: v1.20.1 + rev: v2.2.0 hooks: - id: setup-cfg-fmt - repo: https://github.com/econchick/interrogate @@ -98,11 +98,11 @@ repos: args: [-v, --fail-under=20] exclude: ^(tests|docs|setup\.py) - repo: https://github.com/codespell-project/codespell - rev: v2.1.0 + rev: v2.2.2 hooks: - id: codespell - repo: https://github.com/asottile/pyupgrade - rev: v2.34.0 + rev: v3.3.1 hooks: - id: pyupgrade args: [--py37-plus] diff --git a/setup.cfg b/setup.cfg index 912498f..6edd072 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,10 +17,6 @@ classifiers = Operating System :: POSIX Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 Topic :: Scientific/Engineering Topic :: Utilities diff --git a/src/pybaum/equality.py b/src/pybaum/equality.py index dd6b146..9c36b3c 100644 --- a/src/pybaum/equality.py +++ b/src/pybaum/equality.py @@ -11,10 +11,6 @@ if IS_PANDAS_INSTALLED: import pandas as pd -if IS_JAX_INSTALLED: - import jaxlib - - EQUALITY_CHECKERS = {} @@ -28,6 +24,4 @@ if IS_JAX_INSTALLED: - EQUALITY_CHECKERS[jaxlib.xla_extension.DeviceArray] = lambda a, b: bool( - (a == b).all() - ) + EQUALITY_CHECKERS["jax.numpy.ndarray"] = lambda a, b: bool((a == b).all()) diff --git a/src/pybaum/registry_entries.py b/src/pybaum/registry_entries.py index 3fa0bd2..71f6bd4 100644 --- a/src/pybaum/registry_entries.py +++ b/src/pybaum/registry_entries.py @@ -1,5 +1,4 @@ import itertools -from collections import namedtuple from collections import OrderedDict from itertools import product @@ -15,7 +14,6 @@ if IS_JAX_INSTALLED: import jax - import jaxlib def _none(): @@ -69,7 +67,7 @@ def _tuple(): def _namedtuple(): """Create registry entry for namedtuple and NamedTuple.""" entry = { - namedtuple: { + "namedtuple": { "flatten": lambda tree: (list(tree), tree), "unflatten": _unflatten_namedtuple, "names": lambda tree: list(tree._fields), @@ -125,7 +123,7 @@ def _array_element_names(arr): def _jax_array(): if IS_JAX_INSTALLED: entry = { - jaxlib.xla_extension.DeviceArray: { + "jax.numpy.ndarray": { "flatten": lambda arr: (arr.flatten().tolist(), arr.shape), "unflatten": lambda aux_data, leaves: jax.numpy.array(leaves).reshape( aux_data diff --git a/src/pybaum/typecheck.py b/src/pybaum/typecheck.py index b02d86a..74f69d6 100644 --- a/src/pybaum/typecheck.py +++ b/src/pybaum/typecheck.py @@ -1,8 +1,37 @@ -from collections import namedtuple +from pybaum.config import IS_JAX_INSTALLED +from pybaum.config import IS_NUMPY_INSTALLED + +if IS_JAX_INSTALLED: + import jax.numpy as jnp + +if IS_NUMPY_INSTALLED: + import numpy as np def get_type(obj): - """namdetuple aware type check. + """Get type of candidate objects in a pytree. + + This function allows us to reliably identify namedtuples, NamedTuples and jax arrays + for which standard ``type`` function does not work. + + Args: + obj: The object to be checked + + Returns: + type or str: The type of the object or a string with the type name. + + """ + if _is_namedtuple(obj): + out = "namedtuple" + elif _is_jax_array(obj): + out = "jax.numpy.ndarray" + else: + out = type(obj) + return out + + +def _is_namedtuple(obj): + """Check if an object is a namedtuple. As in JAX we treat collections.namedtuple and typing.NamedTuple both as namedtuple but the exact type is preserved in the unflatten function. @@ -24,8 +53,41 @@ def get_type(obj): bool """ - if isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_replace"): - out = namedtuple + out = ( + isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_replace") + ) + return out + + +def _is_jax_array(obj): + """Check if an object is a jax array. + + The exact type of jax arrays has changed over time and is an implementation detail. + + Instead we rely on isinstance checks which will likely be more stable in the future. + However, the behavior of isinstance for jax arrays has also changed over time. For + jax versions before 0.2.21, standard numpy arrays were instances of jax arrays, + now they are not. + + Resources: + ---------- + + - https://github.com/google/jax/issues/2115 + - https://github.com/google/jax/issues/2014 + - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0221-sept-23-2021 + - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0318-sep-26-2022 + + Args: + obj: The object to be checked + + Returns: + bool + + """ + if not IS_JAX_INSTALLED: + out = False + elif IS_NUMPY_INSTALLED: + out = isinstance(obj, jnp.ndarray) and not isinstance(obj, np.ndarray) else: - out = type(obj) + out = isinstance(obj, jnp.ndarray) return out diff --git a/tests/test_typecheck.py b/tests/test_typecheck.py index 54445e4..1cbe0ad 100644 --- a/tests/test_typecheck.py +++ b/tests/test_typecheck.py @@ -6,7 +6,7 @@ def test_namedtuple_is_discovered(): bla = namedtuple("bla", ["a", "b"])(1, 2) - assert get_type(bla) == namedtuple + assert get_type(bla) == "namedtuple" def test_typed_namedtuple_is_discovered(): @@ -15,7 +15,7 @@ class Blubb(NamedTuple): b: int blubb = Blubb(1, 2) - assert get_type(blubb) == namedtuple + assert get_type(blubb) == "namedtuple" def test_standard_tuple_is_not_discovered():