Skip to content

Commit

Permalink
build: bump up precommit hooks (#837)
Browse files Browse the repository at this point in the history
* build: bump up precommit hooks

* fix: add toml to filetypes that ruff fixes

Ruff started to introduce support for pyproject.toml linting

* fix: pyright

* restore: old state of _conv + remove a bunch of ignores

* fix: typos, while on it

* fix: two more typos
  • Loading branch information
knyazer authored Sep 9, 2024
1 parent dae889d commit 97ac55a
Show file tree
Hide file tree
Showing 30 changed files with 67 additions and 80 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.6.4
hooks:
- id: ruff # linter
types_or: [python, pyi, jupyter]
types_or: [python, pyi, jupyter, toml]
args: [--fix]
- id: ruff-format # formatter
types_or: [python, pyi, jupyter]
types_or: [python, pyi, jupyter, toml]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.351
rev: v1.1.379
hooks:
- id: pyright
additional_dependencies:
Expand Down
2 changes: 1 addition & 1 deletion docs/all-of-equinox.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Finally, Equinox offers a number of more advanced goodies, like serialisation, d

**Equinox integrates smoothly with JAX**

Equinox introduces a powerful yet straightforward way to build neural networks, without introducing lots of new notions or tieing you into a framework. Indeed Equinox is a *library*, not a *framework* -- this means that anything you write in Equinox is fully compatible with anything else in the JAX ecosystem.
Equinox introduces a powerful yet straightforward way to build neural networks, without introducing lots of new notions or tying you into a framework. Indeed Equinox is a *library*, not a *framework* -- this means that anything you write in Equinox is fully compatible with anything else in the JAX ecosystem.

Equinox is all just regular JAX: PyTrees and transformations. Together, these two pieces allow us to specify complex models in JAX-friendly ways.

Expand Down
2 changes: 1 addition & 1 deletion docs/api/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

These offer an alternate (easier to use) API for JAX transformations.

For example, JAX uses `jax.jit(..., static_argnums=...)` to manually indicate which arguments should be treated dynamically/statically. Meanwhile `equinox.filter_jit` automatically treats all JAX/NumPy arrays dynamically, and everything else statically. Moreover, this is done at the level of individual PyTree leaves, so that unlike `jax.jit`, one argment can have both dynamic (array-valued) and static leaves.
For example, JAX uses `jax.jit(..., static_argnums=...)` to manually indicate which arguments should be treated dynamically/statically. Meanwhile `equinox.filter_jit` automatically treats all JAX/NumPy arrays dynamically, and everything else statically. Moreover, this is done at the level of individual PyTree leaves, so that unlike `jax.jit`, one argument can have both dynamic (array-valued) and static leaves.

Most users find that this is a simpler API when working with complicated PyTrees, such as are produced when using Equinox modules. But you can also still use Equinox with normal `jax.jit` etc. if you so prefer.

Expand Down
12 changes: 6 additions & 6 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Module(eqx.Module):
```
as this is used to accomplish something different: this creates two separate layers, that are initialised with the same values for their parameters. After making some gradient updates, you'll find that `self.linear1` and `self.linear2` are now different.

The reason for this is that in Equinox+JAX, models are Py*Trees*, not DAGs. (Directed acyclic graphs.) JAX follows a functional-programming-like style, in which the *identity* of an object (whether tha be a layer, a weight, or whatever) doesn't matter. Only its *value* matters. (This is known as referential transparency.)
The reason for this is that in Equinox+JAX, models are Py*Trees*, not DAGs. (Directed acyclic graphs.) JAX follows a functional-programming-like style, in which the *identity* of an object (whether that be a layer, a weight, or whatever) doesn't matter. Only its *value* matters. (This is known as referential transparency.)

See also the [`equinox.tree_check`][] function, which can be ran on a model to check if you have duplicate nodes.

Expand Down Expand Up @@ -153,20 +153,20 @@ def rollout(mlp, xs):
val = mlp(x)
carry = mlp
return carry, [val]

_, scan_out = jax.lax.scan(
step,
[mlp],
xs
)

return scan_out

key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))
```

will error. To fix this, you can explicitly capture the static elements via
will error. To fix this, you can explicitly capture the static elements via

```python
def rollout(mlp, xs):
Expand All @@ -176,7 +176,7 @@ def rollout(mlp, xs):
val = mlp(x)
carry, _ = eqx.partition(mlp, eqx.is_array)
return carry, [val]

_, scan_out = jax.lax.scan(
step,
arr,
Expand All @@ -196,7 +196,7 @@ Use [`equinox.debug.assert_max_traces`][], for example
def your_function(x, y, z):
...
```
will raise an error if it is compiled more than once, and tell you which argment caused the recompilation. (A function will be recompiled every time the shape or dtype of one of its array-valued inputs change, or if any of its static (non-array) inputs change (as measured by `__eq__`).)
will raise an error if it is compiled more than once, and tell you which argument caused the recompilation. (A function will be recompiled every time the shape or dtype of one of its array-valued inputs change, or if any of its static (non-array) inputs change (as measured by `__eq__`).)

As an alternative, a quick check for announcing each time your function is compiled can be achieved with a print statement:
```python
Expand Down
4 changes: 2 additions & 2 deletions docs/pattern.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ In practice, we argue that's a good idea! This rule means that when you see code
def foo(interp: AbstractPolynomialInterpolation)
... = interp.degree()
```
you know that it is calling precisely `AbstractPolynomialInterpolation.degree`, and not an override in some subclass. This is excellent for code readability. Thus we get the rule that no method should be overriden. (And this rule will also be checked via the `strict=True` flag.)
you know that it is calling precisely `AbstractPolynomialInterpolation.degree`, and not an override in some subclass. This is excellent for code readability. Thus we get the rule that no method should be overridden. (And this rule will also be checked via the `strict=True` flag.)

If we assume this, then we now find ourselves arriving at a conclusion: concrete means final. That is, once we have a concrete class (every abstract method/attribute defined in our ABCs is now overriden with an implementation, so we can instantiate this class), then it is now final (we're not allowed to re-override things, so subclassing is pointless). This is how we arrive at the abstract-or-final rule itself!
If we assume this, then we now find ourselves arriving at a conclusion: concrete means final. That is, once we have a concrete class (every abstract method/attribute defined in our ABCs is now overridden with an implementation, so we can instantiate this class), then it is now final (we're not allowed to re-override things, so subclassing is pointless). This is how we arrive at the abstract-or-final rule itself!

What about when you have an existing concrete class that you want to tweak just-a-little-bit? In this case, prefer composition over inheritance. Write a wrapper that forwards each method as appropriate. This is just as expressive, and means we keep these readable type-safe rules.

Expand Down
4 changes: 2 additions & 2 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def filter_closure_convert(fn: Callable[_P, _T], *args, **kwargs) -> Callable[_P
- `fn`: The function to call. Will be called as `fun(*args, **kwargs)`.
- `args`, `kwargs`: Example arguments at which to call the function. The function is
not actually evaluated on these arguments; all JAX arrays are subsituted for
not actually evaluated on these arguments; all JAX arrays are substituted for
tracers. Note that Python builtins (`bool`, `int`, `float`, `complex`) are
not substituted for tracers and are passed through as-is.
Expand Down Expand Up @@ -668,7 +668,7 @@ def f(x, y):
else:
fn = cast(Callable[_P, _T], fn)
closed_jaxpr, out_dynamic_struct, out_static = filter_make_jaxpr(fn)(
*args, # pyright: ignore
*args,
**kwargs,
)
jaxpr = closed_jaxpr.jaxpr
Expand Down
4 changes: 2 additions & 2 deletions equinox/_better_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ConcreteX(AbstractX):
!!! Info
`AbstractVar` does not create a dataclass field. This affects the order of
`__init__` argments. E.g.
`__init__` arguments. E.g.
```python
class AbstractX(Module):
attr1: AbstractVar[bool]
Expand Down Expand Up @@ -109,7 +109,7 @@ class ConcreteX(AbstractX):
!!! Info
`AbstractClassVar` does not create a dataclass field. This affects the order
of `__init__` argments. E.g.
of `__init__` arguments. E.g.
```python
class AbstractX(Module):
attr1: AbstractClassVar[bool]
Expand Down
8 changes: 4 additions & 4 deletions equinox/_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def _is_struct(x):
def filter_pure_callback(
callback,
*args,
result_shape_dtypes, # pyright: ignore
vectorized=False, # pyright: ignore
result_shape_dtypes,
vectorized=False,
**kwargs,
):
"""Calls a Python function inside a JIT region. As `jax.pure_callback` but accepts
Expand All @@ -38,13 +38,13 @@ def filter_pure_callback(

def _callback(_dynamic):
_args, _kwargs = combine(_dynamic, static)
_out = callback(*_args, **_kwargs) # pyright: ignore
_out = callback(*_args, **_kwargs)
_dynamic_out, _static_out = partition(_out, is_array)
if not tree_equal(_static_out, static_struct):
raise ValueError("Callback did not return matching static elements")
return _dynamic_out

dynamic_out = jax.pure_callback( # pyright: ignore
dynamic_out = jax.pure_callback(
_callback, dynamic_struct, dynamic, vectorized=vectorized
)
return combine(dynamic_out, static_struct)
10 changes: 4 additions & 6 deletions equinox/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class EnumerationItem(Module):
def __init__(self, x):
pass

def __eq__(self, other) -> Bool[Array, ""]: # pyright: ignore
def __eq__(self, other) -> Bool[Array, ""]:
if isinstance(other, EnumerationItem):
if self._enumeration is other._enumeration:
with jax.ensure_compile_time_eval():
Expand Down Expand Up @@ -209,11 +209,9 @@ def __getitem__(cls, item) -> str: ...

def __len__(cls) -> int: ...

class Enumeration( # pyright: ignore
enum.Enum, EnumerationItem, metaclass=_Sequence
):
_name_to_item: ClassVar[dict[str, EnumerationItem]]
_index_to_message: ClassVar[list[str]]
class Enumeration(enum.Enum, EnumerationItem, metaclass=_Sequence): # pyright: ignore
_name_to_item: ClassVar[dict[str, EnumerationItem]] # pyright: ignore
_index_to_message: ClassVar[list[str]] # pyright: ignore
_base_offsets: ClassVar[dict["Enumeration", int]]

@classmethod
Expand Down
12 changes: 5 additions & 7 deletions equinox/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _nan_like(x: Union[Array, np.ndarray]) -> Union[Array, np.ndarray]:
_frames_msg = f"""
-------------------
Opening a breakpoint with {EQX_ON_ERROR_BREAKPOINT_FRAMES} frames. You can control this
Opening a breakpoint with {EQX_ON_ERROR_BREAKPOINT_FRAMES} frames. You can control this
value by setting the environment variable `EQX_ON_ERROR_BREAKPOINT_FRAMES=<some value>`.
(Note that setting large values of this number may lead to crashes at trace time; see
`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.)
Expand Down Expand Up @@ -109,11 +109,11 @@ def tpu_msg(_out, _index):
return jtu.tree_map(_nan_like, _out)

def handle_error(): # pyright: ignore
out = jax.pure_callback(raises, struct, index) # pyright: ignore
out = jax.pure_callback(raises, struct, index)
# If we make it this far then we're on the TPU, which squelches runtime
# errors and returns dummy values instead.
# Fortunately, we're able to outsmart it!
return jax.pure_callback(tpu_msg, struct, out, index) # pyright: ignore
return jax.pure_callback(tpu_msg, struct, out, index)

struct = jax.eval_shape(lambda: x)
return lax.cond(pred, handle_error, lambda: x)
Expand All @@ -131,7 +131,7 @@ def to_nan(_index):

def handle_error():
index_struct = jax.eval_shape(lambda: index)
_index = jax.pure_callback( # pyright: ignore
_index = jax.pure_callback(
display_msg, index_struct, index, vectorized=True
)
# Support JAX with and without DCE behaviour on breakpoints.
Expand All @@ -146,9 +146,7 @@ def handle_error():
if EQX_ON_ERROR_BREAKPOINT_FRAMES is not None:
breakpoint_kwargs["num_frames"] = EQX_ON_ERROR_BREAKPOINT_FRAMES
_index = jax.debug.breakpoint(**breakpoint_kwargs)
return jax.pure_callback( # pyright: ignore
to_nan, struct, _index, vectorized=True
)
return jax.pure_callback(to_nan, struct, _index, vectorized=True)

struct = jax.eval_shape(lambda: x)
return lax.cond(pred, handle_error, lambda: x)
Expand Down
2 changes: 1 addition & 1 deletion equinox/_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def filter(
`pytree`. Each of its leaves should either be:
- `True`, in which case the leaf or subtree is kept;
- `False`, in which case the leaf or subtree is replaced with `replace`;
- a callable `Leaf -> bool`, in which case this is evaluted on the leaf or
- a callable `Leaf -> bool`, in which case this is evaluated on the leaf or
mapped over the subtree, and the leaf kept or replaced as appropriate.
- `inverse` switches the truthy/falsey behaviour: falsey results are kept and
truthy results are replaced.
Expand Down
4 changes: 2 additions & 2 deletions equinox/_make_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _fn(*_dynamic_flat):
_out_dynamic, _out_static = partition(_out, is_array)
return _out_dynamic, Static(_out_static)

jaxpr, out_struct = jax.make_jaxpr(_fn, return_shape=True)(*dynamic_flat) # pyright: ignore
jaxpr, out_struct = jax.make_jaxpr(_fn, return_shape=True)(*dynamic_flat)
dynamic_out_struct, static_out = out_struct
static_out = static_out.value
return jaxpr, dynamic_out_struct, static_out
Expand Down Expand Up @@ -70,7 +70,7 @@ def filter_make_jaxpr(
The example arguments to be traced may be anything with `.shape` and `.dtype`
fields (typically JAX arrays, NumPy arrays, of `jax.ShapeDtypeStruct`s). All
other argments are treated statically. In particular, Python builtins (`bool`,
other arguments are treated statically. In particular, Python builtins (`bool`,
`int`, `float`, `complex`) are treated as static inputs; wrap them in JAX/NumPy
arrays if you would like them to be traced.
"""
Expand Down
10 changes: 5 additions & 5 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class AbstractFoo(eqx.Module, strict=True):

# Inherits from ABCMeta to support `eqx.{AbstractVar, AbstractClassVar}` and
# `abc.abstractmethod`.
class _ActualModuleMeta(ABCMeta): # pyright: ignore
class _ActualModuleMeta(ABCMeta):
# This method is called whenever you definite a module: `class Foo(eqx.Module): ...`
def __new__(
mcs,
Expand Down Expand Up @@ -366,7 +366,7 @@ def __post_init__(self, *args, **kwargs):
if post_init is None:
init = cls.__init__

@ft.wraps(init) # pyright: ignore
@ft.wraps(init)
def __init__(self, *args, **kwargs):
__tracebackhide__ = True
init(self, *args, **kwargs)
Expand All @@ -377,7 +377,7 @@ def __init__(self, *args, **kwargs):

cls.__init__ = __init__

# Assign `__doc__` in case it has been manually overriden:
# Assign `__doc__` in case it has been manually overridden:
# ```
# class Foo(eqx.Module):
# x: int
Expand Down Expand Up @@ -803,9 +803,9 @@ def _make_initable(
if wraps:
field_names = _wrapper_field_names
else:
field_names = {field.name for field in dataclasses.fields(cls)} # pyright: ignore
field_names = {field.name for field in dataclasses.fields(cls)}

class _InitableModule(cls, _Initable): # pyright: ignore
class _InitableModule(cls, _Initable):
pass

def __setattr__(self, name, value):
Expand Down
2 changes: 1 addition & 1 deletion equinox/_serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def run(..., load_path=None):
`filter_spec` should typically be a function `(File, Any) -> Any`, which takes
a file handle and a leaf from `like`, and either returns the corresponding
loaded leaf, or retuns the leaf from `like` unchanged.
loaded leaf, or returns the leaf from `like` unchanged.
It can also be a PyTree of such functions, in which case the PyTree structure
should be a prefix of `pytree`, and each function will be mapped over the
Expand Down
2 changes: 1 addition & 1 deletion equinox/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def tree_equal(

def tree_flatten_one_level(
pytree: PyTree,
) -> tuple[list[PyTree], PyTreeDef]: # pyright: ignore
) -> tuple[list[PyTree], PyTreeDef]:
"""Returns the immediate subnodes of a PyTree node. If called on a leaf node then it
will return just that leaf.
Expand Down
10 changes: 5 additions & 5 deletions equinox/_unvmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _unvmap_all_impl(x):


def _unvmap_all_abstract_eval(x):
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype)


def _unvmap_all_batch(x, batch_axes):
Expand All @@ -33,7 +33,7 @@ def _unvmap_all_batch(x, batch_axes):

unvmap_all_p.def_impl(_unvmap_all_impl)
unvmap_all_p.def_abstract_eval(_unvmap_all_abstract_eval)
batching.primitive_batchers[unvmap_all_p] = _unvmap_all_batch # pyright: ignore
batching.primitive_batchers[unvmap_all_p] = _unvmap_all_batch
mlir.register_lowering(
unvmap_all_p,
mlir.lower_fun(_unvmap_all_impl, multiple_results=False),
Expand All @@ -54,7 +54,7 @@ def _unvmap_any_impl(x):


def _unvmap_any_abstract_eval(x):
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype)


def _unvmap_any_batch(x, batch_axes):
Expand All @@ -64,7 +64,7 @@ def _unvmap_any_batch(x, batch_axes):

unvmap_any_p.def_impl(_unvmap_any_impl)
unvmap_any_p.def_abstract_eval(_unvmap_any_abstract_eval)
batching.primitive_batchers[unvmap_any_p] = _unvmap_any_batch # pyright: ignore
batching.primitive_batchers[unvmap_any_p] = _unvmap_any_batch
mlir.register_lowering(
unvmap_any_p,
mlir.lower_fun(_unvmap_any_impl, multiple_results=False),
Expand Down Expand Up @@ -95,7 +95,7 @@ def _unvmap_max_batch(x, batch_axes):

unvmap_max_p.def_impl(_unvmap_max_impl)
unvmap_max_p.def_abstract_eval(_unvmap_max_abstract_eval)
batching.primitive_batchers[unvmap_max_p] = _unvmap_max_batch # pyright: ignore
batching.primitive_batchers[unvmap_max_p] = _unvmap_max_batch
mlir.register_lowering(
unvmap_max_p,
mlir.lower_fun(_unvmap_max_impl, multiple_results=False),
Expand Down
Loading

0 comments on commit 97ac55a

Please sign in to comment.