Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzw committed Jun 10, 2024
1 parent 539a1d5 commit 0895f23
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 24 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ repos:
language: system
types: [python]
require_serial: true
exclude: ^(tests|api-coverage-tests)/
# prettier
- id: prettier
name: prettier
Expand Down
320 changes: 318 additions & 2 deletions ndonnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,324 @@
import importlib.metadata
import warnings

from ._array import Array, array, from_spox_var
from ._build import (
build,
)
from ._data_types import (
CastError,
CoreType,
Floating,
Integral,
Nullable,
NullableFloating,
NullableIntegral,
NullableNumerical,
Numerical,
from_numpy_dtype,
bool,
float32,
float64,
int8,
int16,
int32,
int64,
nbool,
nfloat32,
nfloat64,
nint8,
nint16,
nint32,
nint64,
nuint8,
nuint16,
nuint32,
nuint64,
nutf8,
uint8,
uint16,
uint32,
uint64,
utf8,
)
from ._funcs import (
arange,
asarray,
empty,
empty_like,
eye,
full,
full_like,
linspace,
ones,
ones_like,
tril,
triu,
zeros,
zeros_like,
astype,
broadcast_arrays,
broadcast_to,
can_cast,
finfo,
iinfo,
result_type,
abs,
acos,
acosh,
add,
asin,
asinh,
atan,
atan2,
atanh,
bitwise_and,
bitwise_left_shift,
bitwise_invert,
bitwise_or,
bitwise_right_shift,
bitwise_xor,
ceil,
cos,
cosh,
divide,
equal,
exp,
expm1,
floor,
floor_divide,
greater,
greater_equal,
isfinite,
isinf,
isnan,
less,
less_equal,
log,
log1p,
log2,
log10,
logaddexp,
logical_and,
logical_not,
logical_or,
logical_xor,
multiply,
negative,
not_equal,
positive,
pow,
remainder,
round,
sign,
sin,
sinh,
square,
sqrt,
subtract,
tan,
tanh,
trunc,
matmul,
matrix_transpose,
concat,
expand_dims,
flip,
permute_dims,
reshape,
roll,
squeeze,
stack,
argmax,
argmin,
nonzero,
searchsorted,
where,
unique_all,
unique_counts,
unique_inverse,
unique_values,
argsort,
sort,
cumulative_sum,
max,
mean,
min,
prod,
clip,
std,
sum,
var,
all,
any,
take,
)
from ._constants import (
e,
inf,
nan,
pi,
)

try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError as e: # pragma: no cover
warnings.warn(f"Could not determine version of {__name__}\n{e!s}", stacklevel=2)
except importlib.metadata.PackageNotFoundError as err: # pragma: no cover
warnings.warn(f"Could not determine version of {__name__}\n{err!s}", stacklevel=2)
__version__ = "unknown"


__all__ = [
"Array",
"array",
"from_spox_var",
"e",
"inf",
"nan",
"pi",
"arange",
"asarray",
"empty",
"empty_like",
"eye",
"full",
"full_like",
"linspace",
"ones",
"ones_like",
"tril",
"triu",
"zeros",
"zeros_like",
"astype",
"take",
"broadcast_arrays",
"broadcast_to",
"can_cast",
"finfo",
"iinfo",
"result_type",
"abs",
"acos",
"acosh",
"add",
"asin",
"asinh",
"atan",
"atan2",
"atanh",
"bitwise_and",
"bitwise_left_shift",
"bitwise_invert",
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
"ceil",
"cos",
"cosh",
"divide",
"equal",
"exp",
"expm1",
"floor",
"floor_divide",
"greater",
"greater_equal",
"isfinite",
"isinf",
"isnan",
"less",
"less_equal",
"log",
"log1p",
"log2",
"log10",
"logaddexp",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"multiply",
"negative",
"not_equal",
"positive",
"pow",
"remainder",
"round",
"sign",
"sin",
"sinh",
"square",
"sqrt",
"subtract",
"tan",
"tanh",
"trunc",
"matmul",
"matrix_transpose",
"concat",
"expand_dims",
"flip",
"permute_dims",
"reshape",
"roll",
"squeeze",
"stack",
"argmax",
"argmin",
"nonzero",
"searchsorted",
"where",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"argsort",
"sort",
"cumulative_sum",
"max",
"mean",
"min",
"prod",
"clip",
"std",
"sum",
"var",
"all",
"any",
"build",
"bool",
"utf8",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"nbool",
"nutf8",
"nfloat32",
"nfloat64",
"nint8",
"nint16",
"nint32",
"nint64",
"nuint8",
"nuint16",
"nuint32",
"nuint64",
"NullableNumerical",
"Numerical",
"NullableFloating",
"Floating",
"NullableIntegral",
"Nullable",
"Integral",
"CoreType",
"CastError",
"promote_nullable",
"from_numpy_dtype",
]
4 changes: 2 additions & 2 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def __ne__(self, other):
return ndx.not_equal(self, other)

@property
def mT(self) -> ndx.Array:
def mT(self) -> ndx.Array: # noqa: N802
"""Transpose of a matrix (or a stack of matrices).
If an array instance has fewer than two dimensions, an error should be raised.
Expand All @@ -496,7 +496,7 @@ def size(self) -> ndx.Array:
return ndx.prod(self.shape)

@property
def T(self) -> ndx.Array:
def T(self) -> ndx.Array: # noqa: N802
"""Transpose of the array.
The array instance must be two-dimensional.
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _get_dtype(dtype: str, version: int) -> dtypes.StructType | dtypes.CoreType:


def _extract_output_names(
output_dtypes: dict[str, dtypes.StructType | dtypes.CoreType]
output_dtypes: dict[str, dtypes.StructType | dtypes.CoreType],
) -> dict[str, ndx.CoreType]:
"""Given a dictionary mapping output names to their data types, extract the
underlying fully qualified names and their CoreTypes.
Expand Down
2 changes: 0 additions & 2 deletions ndonnx/_core/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


class CoreOperationsImpl(OperationsBlock):

# elementwise.py

def abs(self, x):
Expand Down Expand Up @@ -522,7 +521,6 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array) -> ndx.Array:

# set.py
def unique_all(self, x):

new_dtype = x.dtype
if isinstance(x.dtype, dtypes.Integral) or x.dtype in (
dtypes.bool,
Expand Down
1 change: 0 additions & 1 deletion ndonnx/_corearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __setitem__(
def _normalise_index(
self, index: IndexType
) -> _CoreArray | tuple[ScalarIndexType, ...]:

if isinstance(index, _CoreArray):
return index
else:
Expand Down
4 changes: 1 addition & 3 deletions ndonnx/_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def collect_lazy_arguments(obj):
else:
if obj.dtype in (ndx.utf8, ndx.nutf8):
# Lazy variant due to onnxruntime bug
lazy = ndx.array(
shape=obj._static_shape, dtype=obj.dtype
) # type: ignore
lazy = ndx.array(shape=obj._static_shape, dtype=obj.dtype) # type: ignore

# disassemble the array into its core_arrays
_flatten(obj, lazy, flattened_inference_inputs)
Expand Down
Loading

0 comments on commit 0895f23

Please sign in to comment.