diff --git a/README.rst b/README.rst index 3289c4c3..3d054766 100644 --- a/README.rst +++ b/README.rst @@ -18,10 +18,10 @@ Major components Tjax's major components are: - A dataclass decorator :python:`dataclasss` that facilitates defining JAX trees, and has a MyPy plugin. - (See `dataclass `_ and `mypy_plugin `_.) + (See `dataclass `_ and `mypy_plugin `_.) - A fixed point finding library heavily based on `fax `_. Our - library (`fixed_point `_): + library (`fixed_point `_): - supports stochastic iterated functions, - uses dataclasses instead of closures to avoid leaking JAX tracers, and @@ -34,33 +34,33 @@ Minor components Tjax also includes: - An object-oriented wrapper on top of `optax `_. (See - `gradient `_.) + `gradient `_.) - A pretty printer :python:`print_generic` for aggregate and vector types, including dataclasses. (See - `display `_.) + `display `_.) - Versions of :python:`custom_vjp` and :python:`custom_jvp` that support being used on methods. - (See `shims `_.) + (See `shims `_.) - Tools for working with cotangents. (See - `cotangent_tools `_.) + `cotangent_tools `_.) -- A random number generator class :python:`Generator`. (See `generator `_.) +- A random number generator class :python:`Generator`. (See `generator `_.) - JAX tree registration for `NetworkX `_ graph types. (See - `graph `_.) + `graph `_.) - Leaky integration :python:`leaky_integrate` and Ornstein-Uhlenbeck process iteration - :python:`diffused_leaky_integrate`. (See `leaky_integral `_.) + :python:`diffused_leaky_integrate`. (See `leaky_integral `_.) -- An improved version of :python:`jax.tree_util.Partial`. (See `partial `_.) +- An improved version of :python:`jax.tree_util.Partial`. (See `partial `_.) -- A Matplotlib trajectory plotter :python:`PlottableTrajectory`. (See `plottable_trajectory `_.) +- A Matplotlib trajectory plotter :python:`PlottableTrajectory`. (See `plottable_trajectory `_.) - A testing function :python:`assert_jax_allclose` that automatically produces testing code. And, a related - function :python:`jax_allclose`. (See `testing `_.) + function :python:`jax_allclose`. (See `testing `_.) -- Basic tools :python:`sum_tensors` and :python:`is_scalar`. (See `tools `_.) +- Basic tools :python:`sum_tensors` and :python:`is_scalar`. (See `tools `_.) Also, see the `documentation `_. diff --git a/tjax/__init__.py b/tjax/__init__.py index aab5b5d4..7f51b71b 100644 --- a/tjax/__init__.py +++ b/tjax/__init__.py @@ -2,33 +2,44 @@ This library implements a variety of tools for the differential programming library [JAX](https://github.com/google/jax). """ -from .annotations import * -from .color_stub import * -from .cotangent_tools import * -from .dataclass import * -from .dataclass_patch import * -from .display import * -from .dtypes import * -from .generator import * -from .graph import * -from .leaky_integral import * -from .partial import * -from .plottable_trajectory import * -from .shims import * -from .testing import * -from .tools import * +from . import dataclasses, fixed_point, gradient +from ._src.dataclasses import dataclass +from ._src.annotations import (Array, BoolArray, ComplexArray, IntegerArray, PyTree, RealArray, + Shape, ShapeLike, SliceLike, TapFunctionTransforms) +from ._src.cotangent_tools import (block_cotangent, copy_cotangent, print_cotangent, + replace_cotangent) +from ._src.display import display_generic, print_generic +from ._src.dtypes import (complex_dtype, default_atol, default_rtol, default_tols, int_dtype, + real_dtype) +from ._src.generator import Generator +from ._src.leaky_integral import (diffused_leaky_integrate, leaky_covariance, leaky_data_weight, + leaky_integrate, leaky_integrate_time_series) +from ._src.partial import Partial +from ._src.plottable_trajectory import PlottableTrajectory +from ._src.shims import custom_jvp, custom_vjp, jit +from ._src.testing import (assert_jax_allclose, get_relative_test_string, get_test_string, + jax_allclose) +from ._src.tools import abs_square, is_scalar, safe_divide, sum_tensors -__pdoc__ = {} -__pdoc__['real_dtype'] = False -__pdoc__['complex_dtype'] = False -__pdoc__['PyTreeLike'] = False -__pdoc__['Field'] = False -__pdoc__['InitVar'] = False -__pdoc__['FrozenInstanceError'] = False - -document_dataclass(__pdoc__, 'Generator') -document_dataclass(__pdoc__, 'Partial') -del document_dataclass - - -__all__ = list(locals()) +__all__ = ['Array', 'BoolArray', 'ComplexArray', 'Generator', 'IntegerArray', 'Partial', + 'PlottableTrajectory', 'PyTree', 'RealArray', 'Shape', 'ShapeLike', 'SliceLike', + 'TapFunctionTransforms', 'abs_square', 'assert_jax_allclose', 'block_cotangent', + 'complex_dtype', 'copy_cotangent', 'custom_jvp', 'custom_vjp', 'dataclass', + 'dataclasses', 'default_atol', 'default_rtol', 'default_tols', + 'diffused_leaky_integrate', 'display_generic', 'fixed_point', 'get_relative_test_string', + 'get_test_string', 'gradient', 'int_dtype', 'is_scalar', 'jax_allclose', 'jit', + 'leaky_covariance', 'leaky_data_weight', 'leaky_integrate', + 'leaky_integrate_time_series', 'print_cotangent', 'print_generic', 'real_dtype', + 'replace_cotangent', 'safe_divide', 'sum_tensors'] +# +# __pdoc__ = {} +# __pdoc__['real_dtype'] = False +# __pdoc__['complex_dtype'] = False +# __pdoc__['PyTreeLike'] = False +# __pdoc__['Field'] = False +# __pdoc__['InitVar'] = False +# __pdoc__['FrozenInstanceError'] = False +# +# document_dataclass(__pdoc__, 'Generator') +# document_dataclass(__pdoc__, 'Partial') +# del document_dataclass diff --git a/tjax/_src/__init__.py b/tjax/_src/__init__.py new file mode 100644 index 00000000..2431d3e9 --- /dev/null +++ b/tjax/_src/__init__.py @@ -0,0 +1,14 @@ +from .annotations import * +from .color_stub import * +from .cotangent_tools import * +from .dataclasses import * +from .display import * +from .dtypes import * +from .generator import * +from .graph import * +from .leaky_integral import * +from .partial import * +from .plottable_trajectory import * +from .shims import * +from .testing import * +from .tools import * diff --git a/tjax/annotations.py b/tjax/_src/annotations.py similarity index 100% rename from tjax/annotations.py rename to tjax/_src/annotations.py diff --git a/tjax/color_stub.py b/tjax/_src/color_stub.py similarity index 100% rename from tjax/color_stub.py rename to tjax/_src/color_stub.py diff --git a/tjax/cotangent_tools.py b/tjax/_src/cotangent_tools.py similarity index 100% rename from tjax/cotangent_tools.py rename to tjax/_src/cotangent_tools.py diff --git a/tjax/_src/dataclasses/__init__.py b/tjax/_src/dataclasses/__init__.py new file mode 100644 index 00000000..0b9238e3 --- /dev/null +++ b/tjax/_src/dataclasses/__init__.py @@ -0,0 +1,3 @@ +from .dataclass import * +from .helpers import * +from .patch import * diff --git a/tjax/dataclass.py b/tjax/_src/dataclasses/dataclass.py similarity index 63% rename from tjax/dataclass.py rename to tjax/_src/dataclasses/dataclass.py index af8b6b61..bf59d157 100644 --- a/tjax/dataclass.py +++ b/tjax/_src/dataclasses/dataclass.py @@ -1,23 +1,15 @@ import dataclasses -from dataclasses import MISSING, Field, FrozenInstanceError, InitVar, asdict, astuple -from dataclasses import fields as d_fields -from dataclasses import is_dataclass, replace +from dataclasses import MISSING, FrozenInstanceError, InitVar, replace from functools import partial -from typing import (Any, Callable, Hashable, Iterable, List, Mapping, MutableMapping, Optional, - Sequence, Tuple, Type, TypeVar, overload) +from typing import Any, Callable, Hashable, List, Optional, Sequence, Tuple, Type, TypeVar, overload from jax.tree_util import register_pytree_node -from .annotations import PyTree -from .display import display_class, display_generic, display_key_and_value -from .testing import get_relative_test_string, get_test_string, jax_allclose +from ..annotations import PyTree +from ..display import display_class, display_generic, display_key_and_value +from ..testing import get_relative_test_string, get_test_string, jax_allclose -__all__ = ['dataclass', 'field', 'Field', 'FrozenInstanceError', 'InitVar', 'MISSING', - # Helper functions. - 'fields', 'asdict', 'astuple', 'replace', 'is_dataclass', 'field_names', - 'field_names_and_values', 'field_names_values_metadata', 'field_values', - # New functions. - 'document_dataclass'] +__all__ = ['dataclass', 'InitVar', 'MISSING', 'FrozenInstanceError'] T = TypeVar('T', bound=Any) @@ -194,80 +186,3 @@ def get_relative_dataclass_test_string(actual: Any, if not jax_allclose(getattr(actual, fn), getattr(original, fn), rtol=rtol, atol=atol)) retval += ")" return retval - - -# NOTE: Actual return type is 'Field[T]', but we want to help type checkers -# to understand the magic that happens at runtime. -# pylint: disable=redefined-builtin -@overload # `default` and `default_factory` are optional and mutually exclusive. -def field(*, static: bool = False, default: T, init: bool = ..., repr: bool = ..., - hash: Optional[bool] = ..., compare: bool = ..., - metadata: Optional[Mapping[str, Any]] = ...) -> T: - ... - - -@overload -def field(*, static: bool = False, default_factory: Callable[[], T], init: bool = ..., - repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., - metadata: Optional[Mapping[str, Any]] = ...) -> T: - ... - - -@overload -def field(*, static: bool = False, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., - compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ...) -> Any: - ... - - -def field(*, static: bool = False, default: Any = MISSING, - default_factory: Callable[[], Any] = MISSING, init: bool = True, # type: ignore - repr: bool = True, hash: Optional[bool] = None, compare: bool = True, - metadata: Optional[Mapping[str, Any]] = None) -> Any: - """ - Args: - static: Indicates whether a field is a pytree or static. Pytree fields are - differentiated and traced. Static fields are hashed and compared. - """ - if metadata is None: - metadata = {} - return dataclasses.field(metadata={**metadata, 'static': static}, - default=default, default_factory=default_factory, init=init, repr=repr, - hash=hash, compare=compare) # type: ignore - - -def fields(d: Any, *, static: Optional[bool] = None) -> Iterable[Field[Any]]: - if static is None: - yield from d_fields(d) - for this_field in d_fields(d): - if this_field.metadata.get('static', False) == static: - yield this_field - - -def field_names(d: Any, *, static: Optional[bool] = None) -> Iterable[str]: - for this_field in fields(d, static=static): - yield this_field.name - - -def field_names_and_values(d: Any, *, static: Optional[bool] = None) -> Iterable[Tuple[str, Any]]: - for name in field_names(d, static=static): - yield name, getattr(d, name) - - -def field_values(d: Any, *, static: Optional[bool] = None) -> Iterable[Any]: - for name in field_names(d, static=static): - yield getattr(d, name) - - -def field_names_values_metadata(d: Any, *, static: Optional[bool] = None) -> ( - Iterable[Tuple[str, Any, Mapping[str, Any]]]): - for this_field in fields(d, static=static): - yield this_field.name, getattr(d, this_field.name), this_field.metadata - - -def document_dataclass(pdoc: MutableMapping[str, Any], name: str) -> None: - pdoc[f'{name}.static_fields'] = False - pdoc[f'{name}.nonstatic_fields'] = False - pdoc[f'{name}.tree_flatten'] = False - pdoc[f'{name}.tree_unflatten'] = False - pdoc[f'{name}.display'] = False - pdoc[f'{name}.replace'] = False diff --git a/tjax/_src/dataclasses/helpers.py b/tjax/_src/dataclasses/helpers.py new file mode 100644 index 00000000..94219749 --- /dev/null +++ b/tjax/_src/dataclasses/helpers.py @@ -0,0 +1,90 @@ +import dataclasses +from dataclasses import MISSING, Field, asdict, astuple +from dataclasses import fields as d_fields +from dataclasses import is_dataclass, replace +from typing import (Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple, TypeVar, + overload) + +__all__ = ['field', 'Field', 'fields', 'asdict', 'astuple', 'replace', 'is_dataclass', + 'field_names', 'field_names_and_values', 'field_names_values_metadata', 'field_values', + 'document_dataclass'] + + +T = TypeVar('T', bound=Any) + + +# NOTE: Actual return type is 'Field[T]', but we want to help type checkers +# to understand the magic that happens at runtime. +# pylint: disable=redefined-builtin +@overload # `default` and `default_factory` are optional and mutually exclusive. +def field(*, static: bool = False, default: T, init: bool = ..., repr: bool = ..., + hash: Optional[bool] = ..., compare: bool = ..., + metadata: Optional[Mapping[str, Any]] = ...) -> T: + ... + + +@overload +def field(*, static: bool = False, default_factory: Callable[[], T], init: bool = ..., + repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., + metadata: Optional[Mapping[str, Any]] = ...) -> T: + ... + + +@overload +def field(*, static: bool = False, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., + compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ...) -> Any: + ... + + +def field(*, static: bool = False, default: Any = MISSING, + default_factory: Callable[[], Any] = MISSING, init: bool = True, # type: ignore + repr: bool = True, hash: Optional[bool] = None, compare: bool = True, + metadata: Optional[Mapping[str, Any]] = None) -> Any: + """ + Args: + static: Indicates whether a field is a pytree or static. Pytree fields are + differentiated and traced. Static fields are hashed and compared. + """ + if metadata is None: + metadata = {} + return dataclasses.field(metadata={**metadata, 'static': static}, + default=default, default_factory=default_factory, init=init, repr=repr, + hash=hash, compare=compare) # type: ignore + + +def fields(d: Any, *, static: Optional[bool] = None) -> Iterable[Field[Any]]: + if static is None: + yield from d_fields(d) + for this_field in d_fields(d): + if this_field.metadata.get('static', False) == static: + yield this_field + + +def field_names(d: Any, *, static: Optional[bool] = None) -> Iterable[str]: + for this_field in fields(d, static=static): + yield this_field.name + + +def field_names_and_values(d: Any, *, static: Optional[bool] = None) -> Iterable[Tuple[str, Any]]: + for name in field_names(d, static=static): + yield name, getattr(d, name) + + +def field_values(d: Any, *, static: Optional[bool] = None) -> Iterable[Any]: + for name in field_names(d, static=static): + yield getattr(d, name) + + +def field_names_values_metadata(d: Any, *, static: Optional[bool] = None) -> ( + Iterable[Tuple[str, Any, Mapping[str, Any]]]): + for this_field in fields(d, static=static): + yield this_field.name, getattr(d, this_field.name), this_field.metadata + + +def document_dataclass(pdoc: MutableMapping[str, Any], name: str) -> None: + pdoc[f'{name}.static_fields'] = False + pdoc[f'{name}.nonstatic_fields'] = False + pdoc[f'{name}.tree_flatten'] = False + pdoc[f'{name}.tree_unflatten'] = False + pdoc[f'{name}.display'] = False + pdoc[f'{name}.replace'] = False diff --git a/tjax/dataclass_patch.py b/tjax/_src/dataclasses/patch.py similarity index 100% rename from tjax/dataclass_patch.py rename to tjax/_src/dataclasses/patch.py diff --git a/tjax/display.py b/tjax/_src/display.py similarity index 100% rename from tjax/display.py rename to tjax/_src/display.py diff --git a/tjax/dtypes.py b/tjax/_src/dtypes.py similarity index 100% rename from tjax/dtypes.py rename to tjax/_src/dtypes.py diff --git a/tjax/fixed_point/__init__.py b/tjax/_src/fixed_point/__init__.py similarity index 100% rename from tjax/fixed_point/__init__.py rename to tjax/_src/fixed_point/__init__.py diff --git a/tjax/fixed_point/augmented.py b/tjax/_src/fixed_point/augmented.py similarity index 88% rename from tjax/fixed_point/augmented.py rename to tjax/_src/fixed_point/augmented.py index 06e6f477..38aa2ed9 100644 --- a/tjax/fixed_point/augmented.py +++ b/tjax/_src/fixed_point/augmented.py @@ -3,7 +3,7 @@ from typing import Generic, TypeVar from ..annotations import PyTree -from ..dataclass import dataclass +from ..dataclasses import dataclass __all__ = ['AugmentedState'] diff --git a/tjax/fixed_point/combinator.py b/tjax/_src/fixed_point/combinator.py similarity index 99% rename from tjax/fixed_point/combinator.py rename to tjax/_src/fixed_point/combinator.py index ffc14803..8250d7ff 100644 --- a/tjax/fixed_point/combinator.py +++ b/tjax/_src/fixed_point/combinator.py @@ -7,7 +7,7 @@ from jax.tree_util import tree_multimap from ..annotations import PyTree -from ..dataclass import dataclass +from ..dataclasses import dataclass from ..shims import custom_vjp from .augmented import State from .comparing import ComparingIteratedFunction, ComparingState diff --git a/tjax/fixed_point/comparing.py b/tjax/_src/fixed_point/comparing.py similarity index 98% rename from tjax/fixed_point/comparing.py rename to tjax/_src/fixed_point/comparing.py index fc1ad27a..7d760d4b 100644 --- a/tjax/fixed_point/comparing.py +++ b/tjax/_src/fixed_point/comparing.py @@ -7,7 +7,7 @@ from chex import Array from jax.tree_util import tree_map, tree_multimap, tree_reduce -from ..dataclass import dataclass +from ..dataclasses import dataclass from ..tools import safe_divide from .augmented import AugmentedState, State from .iterated_function import Comparand, IteratedFunction, Parameters, Trajectory diff --git a/tjax/fixed_point/iterated_function.py b/tjax/_src/fixed_point/iterated_function.py similarity index 99% rename from tjax/fixed_point/iterated_function.py rename to tjax/_src/fixed_point/iterated_function.py index 61782090..e62fe636 100644 --- a/tjax/fixed_point/iterated_function.py +++ b/tjax/_src/fixed_point/iterated_function.py @@ -11,7 +11,7 @@ from jax.tree_util import tree_multimap from ..annotations import PyTree, TapFunctionTransforms -from ..dataclass import dataclass, field +from ..dataclasses import dataclass, field from ..dtypes import default_atol, default_rtol from .augmented import AugmentedState, State diff --git a/tjax/fixed_point/stochastic.py b/tjax/_src/fixed_point/stochastic.py similarity index 99% rename from tjax/fixed_point/stochastic.py rename to tjax/_src/fixed_point/stochastic.py index 1fbb6371..79a5e863 100644 --- a/tjax/fixed_point/stochastic.py +++ b/tjax/_src/fixed_point/stochastic.py @@ -7,7 +7,7 @@ from chex import Array from jax.tree_util import tree_map, tree_multimap, tree_reduce -from ..dataclass import dataclass +from ..dataclasses import dataclass from ..leaky_integral import leaky_data_weight, leaky_integrate from ..tools import abs_square, safe_divide from .augmented import AugmentedState, State diff --git a/tjax/generator.py b/tjax/_src/generator.py similarity index 98% rename from tjax/generator.py rename to tjax/_src/generator.py index 5f607073..bdf719ee 100644 --- a/tjax/generator.py +++ b/tjax/_src/generator.py @@ -8,7 +8,7 @@ from chex import Array from .annotations import RealArray, Shape, ShapeLike -from .dataclass import dataclass +from .dataclasses import dataclass __all__ = ['Generator'] @@ -23,7 +23,6 @@ class Generator: has no mutating methods. Instead, its generation methods return a new instance along with the generated tensor. """ - key: Array # Class methods -------------------------------------------------------------------------------- diff --git a/tjax/gradient/__init__.py b/tjax/_src/gradient/__init__.py similarity index 100% rename from tjax/gradient/__init__.py rename to tjax/_src/gradient/__init__.py diff --git a/tjax/gradient/aliases.py b/tjax/_src/gradient/aliases.py similarity index 100% rename from tjax/gradient/aliases.py rename to tjax/_src/gradient/aliases.py diff --git a/tjax/gradient/chain.py b/tjax/_src/gradient/chain.py similarity index 96% rename from tjax/gradient/chain.py rename to tjax/_src/gradient/chain.py index ef404a35..eb21bc99 100644 --- a/tjax/gradient/chain.py +++ b/tjax/_src/gradient/chain.py @@ -1,7 +1,7 @@ from typing import Any, Generic, List, Optional, Tuple, TypeVar from ..annotations import PyTree -from ..dataclass import dataclass +from ..dataclasses import dataclass from .transform import GradientTransformation, Weights __all__ = ['ChainedGradientTransformation'] diff --git a/tjax/gradient/smd.py b/tjax/_src/gradient/smd.py similarity index 98% rename from tjax/gradient/smd.py rename to tjax/_src/gradient/smd.py index 200f5f53..68ccaaf9 100644 --- a/tjax/gradient/smd.py +++ b/tjax/_src/gradient/smd.py @@ -5,7 +5,7 @@ from jax.tree_util import tree_map, tree_multimap from ..annotations import PyTree -from ..dataclass import dataclass +from ..dataclasses import dataclass from .transform import SecondOrderGradientTransformation __all__ = ['SMDState', 'SMDGradient'] diff --git a/tjax/gradient/transform.py b/tjax/_src/gradient/transform.py similarity index 99% rename from tjax/gradient/transform.py rename to tjax/_src/gradient/transform.py index 221fcecf..096573bb 100644 --- a/tjax/gradient/transform.py +++ b/tjax/_src/gradient/transform.py @@ -4,7 +4,7 @@ from jax.tree_util import tree_map, tree_multimap, tree_reduce from ..annotations import PyTree -from ..dataclass import dataclass +from ..dataclasses import dataclass from ..tools import abs_square __all__ = ['GradientState', 'GradientTransformation', 'SecondOrderGradientTransformation', diff --git a/tjax/gradient/transforms.py b/tjax/_src/gradient/transforms.py similarity index 98% rename from tjax/gradient/transforms.py rename to tjax/_src/gradient/transforms.py index 4ec3bed7..88ea4543 100644 --- a/tjax/gradient/transforms.py +++ b/tjax/_src/gradient/transforms.py @@ -3,7 +3,7 @@ from chex import Numeric from optax import ScaleByAdamState, ScaleState, additive_weight_decay, scale, scale_by_adam -from ..dataclass import dataclass +from ..dataclasses import dataclass from .transform import GradientTransformation, Weights __all__ = ['Scale', 'ScaleByAdam', 'AdditiveWeightDecay'] diff --git a/tjax/graph.py b/tjax/_src/graph.py similarity index 100% rename from tjax/graph.py rename to tjax/_src/graph.py diff --git a/tjax/leaky_integral.py b/tjax/_src/leaky_integral.py similarity index 99% rename from tjax/leaky_integral.py rename to tjax/_src/leaky_integral.py index dee3252e..55803867 100644 --- a/tjax/leaky_integral.py +++ b/tjax/_src/leaky_integral.py @@ -8,7 +8,7 @@ from chex import Array from jax.lax import scan -from .dataclass import dataclass +from .dataclasses import dataclass from .dtypes import real_dtype from .generator import Generator diff --git a/tjax/partial.py b/tjax/_src/partial.py similarity index 100% rename from tjax/partial.py rename to tjax/_src/partial.py diff --git a/tjax/plottable_trajectory.py b/tjax/_src/plottable_trajectory.py similarity index 98% rename from tjax/plottable_trajectory.py rename to tjax/_src/plottable_trajectory.py index d55e4a1e..1a731dc6 100644 --- a/tjax/plottable_trajectory.py +++ b/tjax/_src/plottable_trajectory.py @@ -8,7 +8,7 @@ from matplotlib.axes import Axes from .annotations import PyTree -from .dataclass import dataclass +from .dataclasses import dataclass from .leaky_integral import leaky_integrate_time_series __all__ = ['PlottableTrajectory'] diff --git a/tjax/shims.py b/tjax/_src/shims.py similarity index 100% rename from tjax/shims.py rename to tjax/_src/shims.py diff --git a/tjax/testing.py b/tjax/_src/testing.py similarity index 100% rename from tjax/testing.py rename to tjax/_src/testing.py diff --git a/tjax/tools.py b/tjax/_src/tools.py similarity index 100% rename from tjax/tools.py rename to tjax/_src/tools.py diff --git a/tjax/dataclasses.py b/tjax/dataclasses.py new file mode 100644 index 00000000..69e9a175 --- /dev/null +++ b/tjax/dataclasses.py @@ -0,0 +1,9 @@ +from ._src.dataclasses.dataclass import MISSING, FrozenInstanceError, InitVar, dataclass +from ._src.dataclasses.helpers import (Field, asdict, astuple, document_dataclass, field, + field_names, field_names_and_values, + field_names_values_metadata, field_values, fields, + is_dataclass, replace) + +__all__ = ['Field', 'FrozenInstanceError', 'InitVar', 'MISSING', 'asdict', 'astuple', 'dataclass', + 'document_dataclass', 'field', 'field_names', 'field_names_and_values', + 'field_names_values_metadata', 'field_values', 'fields', 'is_dataclass', 'replace'] diff --git a/tjax/fixed_point.py b/tjax/fixed_point.py new file mode 100644 index 00000000..60bc2ab2 --- /dev/null +++ b/tjax/fixed_point.py @@ -0,0 +1,13 @@ +from tjax._src.fixed_point.augmented import AugmentedState +from tjax._src.fixed_point.combinator import (ComparingIteratedFunctionWithCombinator, + IteratedFunctionWithCombinator) +from tjax._src.fixed_point.comparing import ComparingIteratedFunction, ComparingState +from tjax._src.fixed_point.iterated_function import IteratedFunction +from tjax._src.fixed_point.stochastic import (StochasticIteratedFunction, + StochasticIteratedFunctionWithCombinator, + StochasticState) + +__all__ = ['AugmentedState', 'ComparingIteratedFunction', 'ComparingIteratedFunctionWithCombinator', + 'ComparingState', 'IteratedFunction', 'IteratedFunctionWithCombinator', + 'StochasticIteratedFunction', 'StochasticIteratedFunctionWithCombinator', + 'StochasticState'] diff --git a/tjax/gradient.py b/tjax/gradient.py new file mode 100644 index 00000000..37a0b276 --- /dev/null +++ b/tjax/gradient.py @@ -0,0 +1,11 @@ +from tjax._src.gradient.aliases import adam, adamw +from tjax._src.gradient.chain import ChainedGradientTransformation +from tjax._src.gradient.smd import SMDGradient, SMDState +from tjax._src.gradient.transform import (GradientState, GradientTransformation, + SecondOrderGradientTransformation, + ThirdOrderGradientTransformation) +from tjax._src.gradient.transforms import AdditiveWeightDecay, Scale, ScaleByAdam + +__all__ = ['AdditiveWeightDecay', 'ChainedGradientTransformation', 'GradientState', + 'GradientTransformation', 'SMDGradient', 'SMDState', 'Scale', 'ScaleByAdam', + 'SecondOrderGradientTransformation', 'ThirdOrderGradientTransformation', 'adam', 'adamw'] diff --git a/tjax/mypy_plugin.py b/tjax/mypy_plugin.py index 145d4a9b..37cab0b0 100644 --- a/tjax/mypy_plugin.py +++ b/tjax/mypy_plugin.py @@ -31,11 +31,11 @@ def plugin(version: str) -> Any: # The set of decorators that generate dataclasses. dataclass_makers = { - 'tjax.dataclass.dataclass', + 'tjax._src.dataclasses.dataclass.dataclass', } # type: Final field_makers = { - 'tjax.dataclass.field', + 'tjax._src.dataclasses.helpers.field', } # type: Final SELF_TVAR_NAME = '_DT' # type: Final