Skip to content

Commit

Permalink
Add PyRight errors
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 24, 2024
1 parent b8b61f5 commit 0eefc81
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 166 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ repos:
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.7
hooks:
- id: ruff

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.379
rev: v1.1.381
hooks:
- id: pyright

Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ disable = [
'too-many-branches',
'too-many-instance-attributes',
'too-many-locals',
'too-many-positional-arguments',
'too-many-return-statements',
'too-many-statements',
'too-many-try-statements',
Expand Down Expand Up @@ -158,7 +159,11 @@ venv = '.venv'
enableTypeIgnoreComments = false
reportImportCycles = true
reportCallInDefaultInitializer = true
reportConstantRedefinition = true
reportDeprecated = true
reportDuplicateImport = true
reportImplicitOverride = true
reportImplicitStringConcatenation = false
reportIncompatibleMethodOverride = true
reportIncompatibleVariableOverride = true
reportInconsistentConstructor = true
Expand All @@ -168,6 +173,7 @@ reportMissingSuperCall = true
reportMissingTypeArgument = true
reportOverlappingOverload = true
reportPrivateImportUsage = true
reportPropertyTypeMismatch = true
reportShadowedImports = true
reportUninitializedInstanceVariable = true
reportUnknownArgumentType = false
Expand All @@ -184,6 +190,11 @@ reportUntypedBaseClass = true
reportUntypedClassDecorator = true
reportUntypedFunctionDecorator = true
reportUntypedNamedTuple = true
reportUnusedCallResult = false
reportUnusedClass = true
reportUnusedExpression = true
reportUnusedFunction = true
reportUnusedVariable = true

[tool.mypy]
files = ['tjax', 'tests']
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.fixture(autouse=True)
def _jax_enable64() -> Generator[None, None, None]:
def _jax_enable64() -> Generator[None]: # pyright: ignore
with enable_x64():
yield

Expand Down
6 changes: 3 additions & 3 deletions tjax/_src/dataclasses/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def f(x: type[Any], /) -> type[TDataclassInstance]:
dynamic_fields.append(field_info.name)
data_clz.static_fields = static_fields
data_clz.dynamic_fields = dynamic_fields
register_dataclass(data_clz, dynamic_fields, static_fields)
_ = register_dataclass(data_clz, dynamic_fields, static_fields)

# Register the dynamically-dispatched functions.
get_test_string.register(data_clz, get_dataclass_test_string)
get_relative_test_string.register(data_clz, get_relative_dataclass_test_string)
_ = get_test_string.register(data_clz, get_dataclass_test_string)
_ = get_relative_test_string.register(data_clz, get_relative_dataclass_test_string)
return data_clz


Expand Down
2 changes: 1 addition & 1 deletion tjax/_src/dataclasses/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init_subclass__(cls,
kw_only: bool = False,
**kwargs: Any) -> None:
super().__init_subclass__(**kwargs, experimental_pytree=True)
dataclass(init=init, repr=repr, eq=eq, order=order, kw_only=kw_only)(cls)
_ = dataclass(init=init, repr=repr, eq=eq, order=order, kw_only=kw_only)(cls)


class DataClassModule(_DataClassModule):
Expand Down
6 changes: 3 additions & 3 deletions tjax/_src/display/display_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _variables(value: Any) -> dict[str, Any]:
def _verify(value: Any,
seen: MutableSet[int],
key: str
) -> Generator[Tree | None, None, None]:
) -> Generator[Tree | None]:
if id(value) in seen:
yield _assemble(key, Text(f'<seen {id(value)}>', style=_seen_color))
return
Expand Down Expand Up @@ -362,7 +362,7 @@ def _show_array(tree: Tree, array: NumpyArray) -> None:
tree.children.append(display_generic(float(np.std(xarray)), seen=set(), key="deviation"))
return
if len(array.shape) == 0:
tree.add(_format_number(array[()]))
_ = tree.add(_format_number(array[()]))
return
table = Table(show_header=False,
show_edge=False,
Expand All @@ -376,4 +376,4 @@ def _show_array(tree: Tree, array: NumpyArray) -> None:
for j in range(array.shape[0]):
table.add_row(*(_format_number(array[j, i])
for i in range(array.shape[1])))
tree.add(table)
_ = tree.add(table)
10 changes: 5 additions & 5 deletions tjax/_src/display/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@

# Functions ----------------------------------------------------------------------------------------
def internal_print_generic(*args: Any,
raise_on_nan: bool = True,
console: Console | None = None,
**kwargs: Any) -> None:
raise_on_nan: bool = True,
console: Console | None = None,
**kwargs: Any) -> None:
if console is None:
console = global_console
found_nan = False
root = Tree("", hide_root=True)
for value in args:
root.add(display_generic(value, seen=set()))
_ = root.add(display_generic(value, seen=set()))
found_nan = found_nan or raise_on_nan and 'nan' in str(root)
for key, value in kwargs.items():
root.add(display_generic(value, seen=set(), key=key))
_ = root.add(display_generic(value, seen=set(), key=key))
found_nan = found_nan or raise_on_nan and 'nan' in str(root)
console.print(root)
if found_nan:
Expand Down
2 changes: 1 addition & 1 deletion tjax/_src/math_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def create_diagonal_array(m: T) -> T:
target_index = (*index, slice(None, None, n + 1))
source_values = m[*index, :] # type: ignore[arg-type]
if isinstance(retval, jax.Array):
retval.at[target_index].set(source_values)
retval = retval.at[target_index].set(source_values)
else:
retval[target_index] = source_values
return xp.reshape(retval, s)
14 changes: 7 additions & 7 deletions tjax/_src/shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def jit(func: F, **kwargs: Any) -> F:
This ensures that abstract methods stay abstract, method overrides remain overrides.
"""
retval = jax.jit(func, **kwargs)
update_wrapper(retval, func, all_wrapper_assignments)
_ = update_wrapper(retval, func, all_wrapper_assignments)
# Return type is fixed by https://github.com/NeilGirdhar/jax/tree/jit_annotation.
return retval # type: ignore[return-value] # pyright: ignore

Expand All @@ -47,7 +47,7 @@ def __init__(self,
super().__init__()
static_argnums = tuple(sorted(static_argnums))
self.vjp = jax.custom_vjp(func, nondiff_argnums=static_argnums)
update_wrapper(self, func, all_wrapper_assignments)
_ = update_wrapper(self, func, all_wrapper_assignments)

def defvjp(self,
fwd: Callable[P, tuple[R_co, Any]],
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self,
super().__init__()
static_argnums = tuple(sorted(static_argnums))
self.vjp = jax.custom_vjp(func, nondiff_argnums=static_argnums)
update_wrapper(self, func, all_wrapper_assignments)
_ = update_wrapper(self, func, all_wrapper_assignments)

def defvjp(self,
fwd: Callable[Concatenate[U, P], tuple[R_co, Any]],
Expand Down Expand Up @@ -124,15 +124,15 @@ def __init__(self,
super().__init__()
nondiff_argnums = tuple(sorted(nondiff_argnums))
self.jvp = jax.custom_jvp(func, nondiff_argnums=nondiff_argnums)
update_wrapper(self, func, all_wrapper_assignments)
_ = update_wrapper(self, func, all_wrapper_assignments)

def defjvp(self, jvp: Callable[..., tuple[R_co, R_co]]) -> None:
"""Implement the custom forward pass of the custom derivative.
Args:
jvp: The custom forward pass.
"""
self.jvp.defjvp(jvp)
_ = self.jvp.defjvp(jvp)

def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> R_co:
return self.jvp(*args, **kwargs)
Expand All @@ -155,15 +155,15 @@ def __init__(self,
super().__init__()
nondiff_argnums = tuple(sorted(nondiff_argnums))
self.jvp = jax.custom_jvp(func, nondiff_argnums=nondiff_argnums)
update_wrapper(self, func, all_wrapper_assignments)
_ = update_wrapper(self, func, all_wrapper_assignments)

def defjvp(self, jvp: Callable[Concatenate[U, P], tuple[R_co, R_co]]) -> None:
"""Implement the custom forward pass of the custom derivative.
Args:
jvp: The custom forward pass.
"""
self.jvp.defjvp(jvp)
_ = self.jvp.defjvp(jvp)

def __call__(self, u: U, /, *args: P.args, **kwargs: P.kwargs) -> R_co:
return self.jvp(u, *args, **kwargs)
Expand Down
3 changes: 1 addition & 2 deletions tjax/_src/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ def _(actual: list[Any] | tuple[Any], rtol: float, atol: float) -> str:
is_named_tuple = not is_list and type(actual).__name__ != 'tuple'
return ((type(actual).__name__ if is_named_tuple else "")
+ ("[" if is_list else "(")
+ ", ".join(get_test_string(sub_actual, rtol, atol)
for i, sub_actual in enumerate(actual))
+ ", ".join(get_test_string(sub_actual, rtol, atol) for sub_actual in actual)
+ (',' if len(actual) == 1 else '')
+ ("]" if is_list else ")"))

Expand Down
Loading

0 comments on commit 0eefc81

Please sign in to comment.