Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 7, 2024
1 parent e259f59 commit 8943048
Show file tree
Hide file tree
Showing 17 changed files with 167 additions and 244 deletions.
36 changes: 19 additions & 17 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,11 @@ def wrapper(*args: Any, **kwargs: Any) -> R:

# Convert relevant arguments to cuNumeric ndarrays
args = tuple(
convert_to_cunumeric_ndarray(arg)
if idx in indices and arg is not None
else arg
(
convert_to_cunumeric_ndarray(arg)
if idx in indices and arg is not None
else arg
)
for (idx, arg) in enumerate(args)
)
for k, v in kwargs.items():
Expand Down Expand Up @@ -2290,16 +2292,14 @@ def clip(
min = (
min
if min is not None
else np.iinfo(self.dtype).min
if self.dtype.kind == "i"
else -np.inf
else (
np.iinfo(self.dtype).min if self.dtype.kind == "i" else -np.inf
)
)
max = (
max
if max is not None
else np.iinfo(self.dtype).max
if self.dtype.kind == "i"
else np.inf
else np.iinfo(self.dtype).max if self.dtype.kind == "i" else np.inf
)
args = (
np.array(min, dtype=self.dtype),
Expand Down Expand Up @@ -2552,9 +2552,7 @@ def _diag_helper(
res_dtype = (
dtype
if dtype is not None
else out.dtype
if out is not None
else a.dtype
else out.dtype if out is not None else a.dtype
)
a = a._maybe_convert(res_dtype, (a,))
if out is not None and out.shape != out_shape:
Expand Down Expand Up @@ -4294,11 +4292,15 @@ def _perform_unary_op(
else:
out = ndarray(
shape=out_shape,
dtype=src.dtype
if src.dtype.kind != "c"
else np.dtype(np.float32)
if src.dtype == np.dtype(np.complex64)
else np.dtype(np.float64),
dtype=(
src.dtype
if src.dtype.kind != "c"
else (
np.dtype(np.float32)
if src.dtype == np.dtype(np.complex64)
else np.dtype(np.float64)
)
),
inputs=(src,),
)

Expand Down
6 changes: 2 additions & 4 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,12 @@ class _CunumericSharedLib:
CUNUMERIC_ZIP: int

@abstractmethod
def cunumeric_has_curand(self) -> int:
...
def cunumeric_has_curand(self) -> int: ...

@abstractmethod
def cunumeric_register_reduction_op(
self, type_uid: int, elem_type_code: int
) -> None:
...
) -> None: ...


# Load the cuNumeric library first so we have a shard object that
Expand Down
3 changes: 1 addition & 2 deletions cunumeric/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def filter_namespace(


class AnyCallable(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...


@dataclass(frozen=True)
Expand Down
8 changes: 5 additions & 3 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ def wrapper(*args: Any, **kwargs: Any) -> R:
# Convert relevant arguments to DeferredArrays
self = args[0]
args = tuple(
self.runtime.to_deferred_array(arg)
if idx in indices and arg is not None
else arg
(
self.runtime.to_deferred_array(arg)
if idx in indices and arg is not None
else arg
)
for (idx, arg) in enumerate(args)
)
for k, v in kwargs.items():
Expand Down
40 changes: 22 additions & 18 deletions cunumeric/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,17 +1473,21 @@ def unary_op(
func(
rhs.array,
out=self.array,
where=where
if not isinstance(where, EagerArray)
else where.array,
where=(
where
if not isinstance(where, EagerArray)
else where.array
),
)
else:
func(
rhs.array,
out=(self.array, *(out.array for out in multiout)),
where=where
if not isinstance(where, EagerArray)
else where.array,
where=(
where
if not isinstance(where, EagerArray)
else where.array
),
)
elif op == UnaryOpCode.CLIP:
np.clip(rhs.array, out=self.array, a_min=args[0], a_max=args[1])
Expand Down Expand Up @@ -1542,9 +1546,9 @@ def unary_reduction(
out=self.array,
axis=orig_axis,
keepdims=keepdims,
where=where
if not isinstance(where, EagerArray)
else where.array,
where=(
where if not isinstance(where, EagerArray) else where.array
),
**kws,
)
elif op == UnaryRedCode.SUM_SQUARES:
Expand All @@ -1553,9 +1557,9 @@ def unary_reduction(
squared,
out=self.array,
axis=orig_axis,
where=where
if not isinstance(where, EagerArray)
else where.array,
where=(
where if not isinstance(where, EagerArray) else where.array
),
keepdims=keepdims,
)
elif op == UnaryRedCode.VARIANCE:
Expand All @@ -1565,9 +1569,9 @@ def unary_reduction(
np.sum(
squares,
axis=orig_axis,
where=where
if not isinstance(where, EagerArray)
else where.array,
where=(
where if not isinstance(where, EagerArray) else where.array
),
keepdims=keepdims,
out=self.array,
)
Expand Down Expand Up @@ -1607,9 +1611,9 @@ def binary_op(
rhs1.array,
rhs2.array,
out=self.array,
where=where
if not isinstance(where, EagerArray)
else where.array,
where=(
where if not isinstance(where, EagerArray) else where.array
),
)

def binary_reduction(
Expand Down
8 changes: 4 additions & 4 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3086,17 +3086,17 @@ def flatnonzero(a: ndarray) -> ndarray:


@overload
def where(a: npt.ArrayLike | ndarray, x: None, y: None) -> tuple[ndarray, ...]:
...
def where(
a: npt.ArrayLike | ndarray, x: None, y: None
) -> tuple[ndarray, ...]: ...


@overload
def where(
a: npt.ArrayLike | ndarray,
x: npt.ArrayLike | ndarray,
y: npt.ArrayLike | ndarray,
) -> ndarray:
...
) -> ndarray: ...


# TODO(mpapadakis): @add_boilerplate should extend the types of array
Expand Down
3 changes: 1 addition & 2 deletions cunumeric/random/bitgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def __init__(
)

@abstractproperty
def generatorType(self) -> BitGeneratorType:
...
def generatorType(self) -> BitGeneratorType: ...

def __del__(self) -> None:
if self.handle != 0:
Expand Down
Loading

0 comments on commit 8943048

Please sign in to comment.