Skip to content

Commit

Permalink
Hide modules and packages in _src
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Apr 12, 2021
1 parent 0bfcf9d commit 488e475
Show file tree
Hide file tree
Showing 36 changed files with 213 additions and 148 deletions.
26 changes: 13 additions & 13 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ Major components
Tjax's major components are:

- A dataclass decorator :python:`dataclasss` that facilitates defining JAX trees, and has a MyPy plugin.
(See `dataclass <https://github.com/NeilGirdhar/tjax/blob/master/tjax/dataclass.py>`_ and `mypy_plugin <https://github.com/NeilGirdhar/tjax/blob/master/tjax/mypy_plugin.py>`_.)
(See `dataclass <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/dataclasses>`_ and `mypy_plugin <https://github.com/NeilGirdhar/tjax/blob/master/tjax/mypy_plugin.py>`_.)

- A fixed point finding library heavily based on `fax <https://github.com/gehring/fax>`_. Our
library (`fixed_point <https://github.com/NeilGirdhar/tjax/blob/master/tjax/fixed_point>`_):
library (`fixed_point <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/fixed_point>`_):

- supports stochastic iterated functions,
- uses dataclasses instead of closures to avoid leaking JAX tracers, and
Expand All @@ -34,33 +34,33 @@ Minor components
Tjax also includes:

- An object-oriented wrapper on top of `optax <https://github.com/deepmind/optax>`_. (See
`gradient <https://github.com/NeilGirdhar/tjax/blob/master/tjax/gradient>`_.)
`gradient <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/gradient>`_.)

- A pretty printer :python:`print_generic` for aggregate and vector types, including dataclasses. (See
`display <https://github.com/NeilGirdhar/tjax/blob/master/tjax/display.py>`_.)
`display <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/display.py>`_.)

- Versions of :python:`custom_vjp` and :python:`custom_jvp` that support being used on methods.
(See `shims <https://github.com/NeilGirdhar/tjax/blob/master/tjax/shims.py>`_.)
(See `shims <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/shims.py>`_.)

- Tools for working with cotangents. (See
`cotangent_tools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/cotangent_tools.py>`_.)
`cotangent_tools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/cotangent_tools.py>`_.)

- A random number generator class :python:`Generator`. (See `generator <https://github.com/NeilGirdhar/tjax/blob/master/tjax/generator.py>`_.)
- A random number generator class :python:`Generator`. (See `generator <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/generator.py>`_.)

- JAX tree registration for `NetworkX <https://networkx.github.io/>`_ graph types. (See
`graph <https://github.com/NeilGirdhar/tjax/blob/master/tjax/graph.py>`_.)
`graph <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/graph.py>`_.)

- Leaky integration :python:`leaky_integrate` and Ornstein-Uhlenbeck process iteration
:python:`diffused_leaky_integrate`. (See `leaky_integral <https://github.com/NeilGirdhar/tjax/blob/master/tjax/leaky_integral.py>`_.)
:python:`diffused_leaky_integrate`. (See `leaky_integral <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/leaky_integral.py>`_.)

- An improved version of :python:`jax.tree_util.Partial`. (See `partial <https://github.com/NeilGirdhar/tjax/blob/master/tjax/partial.py>`_.)
- An improved version of :python:`jax.tree_util.Partial`. (See `partial <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/partial.py>`_.)

- A Matplotlib trajectory plotter :python:`PlottableTrajectory`. (See `plottable_trajectory <https://github.com/NeilGirdhar/tjax/blob/master/tjax/plottable_trajectory.py>`_.)
- A Matplotlib trajectory plotter :python:`PlottableTrajectory`. (See `plottable_trajectory <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/plottable_trajectory.py>`_.)

- A testing function :python:`assert_jax_allclose` that automatically produces testing code. And, a related
function :python:`jax_allclose`. (See `testing <https://github.com/NeilGirdhar/tjax/blob/master/tjax/testing.py>`_.)
function :python:`jax_allclose`. (See `testing <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/testing.py>`_.)

- Basic tools :python:`sum_tensors` and :python:`is_scalar`. (See `tools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/tools.py>`_.)
- Basic tools :python:`sum_tensors` and :python:`is_scalar`. (See `tools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/tools.py>`_.)

Also, see the `documentation <https://neilgirdhar.github.io/tjax/tjax/index.html>`_.

Expand Down
69 changes: 40 additions & 29 deletions tjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,44 @@
This library implements a variety of tools for the differential programming library
[JAX](https://github.com/google/jax).
"""
from .annotations import *
from .color_stub import *
from .cotangent_tools import *
from .dataclass import *
from .dataclass_patch import *
from .display import *
from .dtypes import *
from .generator import *
from .graph import *
from .leaky_integral import *
from .partial import *
from .plottable_trajectory import *
from .shims import *
from .testing import *
from .tools import *
from . import dataclasses, fixed_point, gradient
from ._src.dataclasses import dataclass
from ._src.annotations import (Array, BoolArray, ComplexArray, IntegerArray, PyTree, RealArray,
Shape, ShapeLike, SliceLike, TapFunctionTransforms)
from ._src.cotangent_tools import (block_cotangent, copy_cotangent, print_cotangent,
replace_cotangent)
from ._src.display import display_generic, print_generic
from ._src.dtypes import (complex_dtype, default_atol, default_rtol, default_tols, int_dtype,
real_dtype)
from ._src.generator import Generator
from ._src.leaky_integral import (diffused_leaky_integrate, leaky_covariance, leaky_data_weight,
leaky_integrate, leaky_integrate_time_series)
from ._src.partial import Partial
from ._src.plottable_trajectory import PlottableTrajectory
from ._src.shims import custom_jvp, custom_vjp, jit
from ._src.testing import (assert_jax_allclose, get_relative_test_string, get_test_string,
jax_allclose)
from ._src.tools import abs_square, is_scalar, safe_divide, sum_tensors

__pdoc__ = {}
__pdoc__['real_dtype'] = False
__pdoc__['complex_dtype'] = False
__pdoc__['PyTreeLike'] = False
__pdoc__['Field'] = False
__pdoc__['InitVar'] = False
__pdoc__['FrozenInstanceError'] = False

document_dataclass(__pdoc__, 'Generator')
document_dataclass(__pdoc__, 'Partial')
del document_dataclass


__all__ = list(locals())
__all__ = ['Array', 'BoolArray', 'ComplexArray', 'Generator', 'IntegerArray', 'Partial',
'PlottableTrajectory', 'PyTree', 'RealArray', 'Shape', 'ShapeLike', 'SliceLike',
'TapFunctionTransforms', 'abs_square', 'assert_jax_allclose', 'block_cotangent',
'complex_dtype', 'copy_cotangent', 'custom_jvp', 'custom_vjp', 'dataclass',
'dataclasses', 'default_atol', 'default_rtol', 'default_tols',
'diffused_leaky_integrate', 'display_generic', 'fixed_point', 'get_relative_test_string',
'get_test_string', 'gradient', 'int_dtype', 'is_scalar', 'jax_allclose', 'jit',
'leaky_covariance', 'leaky_data_weight', 'leaky_integrate',
'leaky_integrate_time_series', 'print_cotangent', 'print_generic', 'real_dtype',
'replace_cotangent', 'safe_divide', 'sum_tensors']
#
# __pdoc__ = {}
# __pdoc__['real_dtype'] = False
# __pdoc__['complex_dtype'] = False
# __pdoc__['PyTreeLike'] = False
# __pdoc__['Field'] = False
# __pdoc__['InitVar'] = False
# __pdoc__['FrozenInstanceError'] = False
#
# document_dataclass(__pdoc__, 'Generator')
# document_dataclass(__pdoc__, 'Partial')
# del document_dataclass
14 changes: 14 additions & 0 deletions tjax/_src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .annotations import *
from .color_stub import *
from .cotangent_tools import *
from .dataclasses import *
from .display import *
from .dtypes import *
from .generator import *
from .graph import *
from .leaky_integral import *
from .partial import *
from .plottable_trajectory import *
from .shims import *
from .testing import *
from .tools import *
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 3 additions & 0 deletions tjax/_src/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dataclass import *
from .helpers import *
from .patch import *
97 changes: 6 additions & 91 deletions tjax/dataclass.py → tjax/_src/dataclasses/dataclass.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
import dataclasses
from dataclasses import MISSING, Field, FrozenInstanceError, InitVar, asdict, astuple
from dataclasses import fields as d_fields
from dataclasses import is_dataclass, replace
from dataclasses import MISSING, FrozenInstanceError, InitVar, replace
from functools import partial
from typing import (Any, Callable, Hashable, Iterable, List, Mapping, MutableMapping, Optional,
Sequence, Tuple, Type, TypeVar, overload)
from typing import Any, Callable, Hashable, List, Optional, Sequence, Tuple, Type, TypeVar, overload

from jax.tree_util import register_pytree_node

from .annotations import PyTree
from .display import display_class, display_generic, display_key_and_value
from .testing import get_relative_test_string, get_test_string, jax_allclose
from ..annotations import PyTree
from ..display import display_class, display_generic, display_key_and_value
from ..testing import get_relative_test_string, get_test_string, jax_allclose

__all__ = ['dataclass', 'field', 'Field', 'FrozenInstanceError', 'InitVar', 'MISSING',
# Helper functions.
'fields', 'asdict', 'astuple', 'replace', 'is_dataclass', 'field_names',
'field_names_and_values', 'field_names_values_metadata', 'field_values',
# New functions.
'document_dataclass']
__all__ = ['dataclass', 'InitVar', 'MISSING', 'FrozenInstanceError']


T = TypeVar('T', bound=Any)
Expand Down Expand Up @@ -194,80 +186,3 @@ def get_relative_dataclass_test_string(actual: Any,
if not jax_allclose(getattr(actual, fn), getattr(original, fn), rtol=rtol, atol=atol))
retval += ")"
return retval


# NOTE: Actual return type is 'Field[T]', but we want to help type checkers
# to understand the magic that happens at runtime.
# pylint: disable=redefined-builtin
@overload # `default` and `default_factory` are optional and mutually exclusive.
def field(*, static: bool = False, default: T, init: bool = ..., repr: bool = ...,
hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> T:
...


@overload
def field(*, static: bool = False, default_factory: Callable[[], T], init: bool = ...,
repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> T:
...


@overload
def field(*, static: bool = False, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ...,
compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ...) -> Any:
...


def field(*, static: bool = False, default: Any = MISSING,
default_factory: Callable[[], Any] = MISSING, init: bool = True, # type: ignore
repr: bool = True, hash: Optional[bool] = None, compare: bool = True,
metadata: Optional[Mapping[str, Any]] = None) -> Any:
"""
Args:
static: Indicates whether a field is a pytree or static. Pytree fields are
differentiated and traced. Static fields are hashed and compared.
"""
if metadata is None:
metadata = {}
return dataclasses.field(metadata={**metadata, 'static': static},
default=default, default_factory=default_factory, init=init, repr=repr,
hash=hash, compare=compare) # type: ignore


def fields(d: Any, *, static: Optional[bool] = None) -> Iterable[Field[Any]]:
if static is None:
yield from d_fields(d)
for this_field in d_fields(d):
if this_field.metadata.get('static', False) == static:
yield this_field


def field_names(d: Any, *, static: Optional[bool] = None) -> Iterable[str]:
for this_field in fields(d, static=static):
yield this_field.name


def field_names_and_values(d: Any, *, static: Optional[bool] = None) -> Iterable[Tuple[str, Any]]:
for name in field_names(d, static=static):
yield name, getattr(d, name)


def field_values(d: Any, *, static: Optional[bool] = None) -> Iterable[Any]:
for name in field_names(d, static=static):
yield getattr(d, name)


def field_names_values_metadata(d: Any, *, static: Optional[bool] = None) -> (
Iterable[Tuple[str, Any, Mapping[str, Any]]]):
for this_field in fields(d, static=static):
yield this_field.name, getattr(d, this_field.name), this_field.metadata


def document_dataclass(pdoc: MutableMapping[str, Any], name: str) -> None:
pdoc[f'{name}.static_fields'] = False
pdoc[f'{name}.nonstatic_fields'] = False
pdoc[f'{name}.tree_flatten'] = False
pdoc[f'{name}.tree_unflatten'] = False
pdoc[f'{name}.display'] = False
pdoc[f'{name}.replace'] = False
90 changes: 90 additions & 0 deletions tjax/_src/dataclasses/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import dataclasses
from dataclasses import MISSING, Field, asdict, astuple
from dataclasses import fields as d_fields
from dataclasses import is_dataclass, replace
from typing import (Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple, TypeVar,
overload)

__all__ = ['field', 'Field', 'fields', 'asdict', 'astuple', 'replace', 'is_dataclass',
'field_names', 'field_names_and_values', 'field_names_values_metadata', 'field_values',
'document_dataclass']


T = TypeVar('T', bound=Any)


# NOTE: Actual return type is 'Field[T]', but we want to help type checkers
# to understand the magic that happens at runtime.
# pylint: disable=redefined-builtin
@overload # `default` and `default_factory` are optional and mutually exclusive.
def field(*, static: bool = False, default: T, init: bool = ..., repr: bool = ...,
hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> T:
...


@overload
def field(*, static: bool = False, default_factory: Callable[[], T], init: bool = ...,
repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> T:
...


@overload
def field(*, static: bool = False, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ...,
compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ...) -> Any:
...


def field(*, static: bool = False, default: Any = MISSING,
default_factory: Callable[[], Any] = MISSING, init: bool = True, # type: ignore
repr: bool = True, hash: Optional[bool] = None, compare: bool = True,
metadata: Optional[Mapping[str, Any]] = None) -> Any:
"""
Args:
static: Indicates whether a field is a pytree or static. Pytree fields are
differentiated and traced. Static fields are hashed and compared.
"""
if metadata is None:
metadata = {}
return dataclasses.field(metadata={**metadata, 'static': static},
default=default, default_factory=default_factory, init=init, repr=repr,
hash=hash, compare=compare) # type: ignore


def fields(d: Any, *, static: Optional[bool] = None) -> Iterable[Field[Any]]:
if static is None:
yield from d_fields(d)
for this_field in d_fields(d):
if this_field.metadata.get('static', False) == static:
yield this_field


def field_names(d: Any, *, static: Optional[bool] = None) -> Iterable[str]:
for this_field in fields(d, static=static):
yield this_field.name


def field_names_and_values(d: Any, *, static: Optional[bool] = None) -> Iterable[Tuple[str, Any]]:
for name in field_names(d, static=static):
yield name, getattr(d, name)


def field_values(d: Any, *, static: Optional[bool] = None) -> Iterable[Any]:
for name in field_names(d, static=static):
yield getattr(d, name)


def field_names_values_metadata(d: Any, *, static: Optional[bool] = None) -> (
Iterable[Tuple[str, Any, Mapping[str, Any]]]):
for this_field in fields(d, static=static):
yield this_field.name, getattr(d, this_field.name), this_field.metadata


def document_dataclass(pdoc: MutableMapping[str, Any], name: str) -> None:
pdoc[f'{name}.static_fields'] = False
pdoc[f'{name}.nonstatic_fields'] = False
pdoc[f'{name}.tree_flatten'] = False
pdoc[f'{name}.tree_unflatten'] = False
pdoc[f'{name}.display'] = False
pdoc[f'{name}.replace'] = False
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Generic, TypeVar

from ..annotations import PyTree
from ..dataclass import dataclass
from ..dataclasses import dataclass

__all__ = ['AugmentedState']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax.tree_util import tree_multimap

from ..annotations import PyTree
from ..dataclass import dataclass
from ..dataclasses import dataclass
from ..shims import custom_vjp
from .augmented import State
from .comparing import ComparingIteratedFunction, ComparingState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from chex import Array
from jax.tree_util import tree_map, tree_multimap, tree_reduce

from ..dataclass import dataclass
from ..dataclasses import dataclass
from ..tools import safe_divide
from .augmented import AugmentedState, State
from .iterated_function import Comparand, IteratedFunction, Parameters, Trajectory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax.tree_util import tree_multimap

from ..annotations import PyTree, TapFunctionTransforms
from ..dataclass import dataclass, field
from ..dataclasses import dataclass, field
from ..dtypes import default_atol, default_rtol
from .augmented import AugmentedState, State

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from chex import Array
from jax.tree_util import tree_map, tree_multimap, tree_reduce

from ..dataclass import dataclass
from ..dataclasses import dataclass
from ..leaky_integral import leaky_data_weight, leaky_integrate
from ..tools import abs_square, safe_divide
from .augmented import AugmentedState, State
Expand Down
Loading

0 comments on commit 488e475

Please sign in to comment.