From 298ac3ac3b52f4d22e7c6d13f47f27a2c8bbd75f Mon Sep 17 00:00:00 2001 From: Weizheng Lu Date: Thu, 7 Sep 2023 10:23:22 +0800 Subject: [PATCH] fix style --- python/xorbits/numpy/__init__.py | 2 +- .../xorbits/numpy/mars_adapters/__init__.py | 1 + python/xorbits/numpy/mars_adapters/core.py | 4 ++ .../mars_adapters/tests/test_mars_adapters.py | 4 ++ python/xorbits/numpy/numpy_adapters/core.py | 2 + python/xorbits/numpy/special/__init__.py | 47 +++++++++++++++++++ 6 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 python/xorbits/numpy/special/__init__.py diff --git a/python/xorbits/numpy/__init__.py b/python/xorbits/numpy/__init__.py index dac6046bb..7c83e4e91 100644 --- a/python/xorbits/numpy/__init__.py +++ b/python/xorbits/numpy/__init__.py @@ -95,7 +95,7 @@ def _install(): del warnings -from . import fft, linalg, random +from . import fft, linalg, random, special from .core import ndarray diff --git a/python/xorbits/numpy/mars_adapters/__init__.py b/python/xorbits/numpy/mars_adapters/__init__.py index afced4ee2..bb5ae9dfe 100644 --- a/python/xorbits/numpy/mars_adapters/__init__.py +++ b/python/xorbits/numpy/mars_adapters/__init__.py @@ -21,6 +21,7 @@ MARS_TENSOR_MAGIC_METHODS, MARS_TENSOR_OBJECTS, MARS_TENSOR_RANDOM_CALLABLES, + MARS_TENSOR_SPECIAL_CALLABLES, ) diff --git a/python/xorbits/numpy/mars_adapters/core.py b/python/xorbits/numpy/mars_adapters/core.py index 8560d11aa..36ac5b6b7 100644 --- a/python/xorbits/numpy/mars_adapters/core.py +++ b/python/xorbits/numpy/mars_adapters/core.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List, Optional, Set import numpy +import scipy from ...core.adapter import MARS_TENSOR_TYPE, mars_tensor, wrap_mars_callable @@ -54,6 +55,9 @@ def _collect_module_callables( MARS_TENSOR_LINALG_CALLABLES: Dict[str, Callable] = _collect_module_callables( mars_tensor.linalg, numpy.linalg ) +MARS_TENSOR_SPECIAL_CALLABLES: Dict[str, Callable] = _collect_module_callables( + mars_tensor.special, scipy.special +) def _collect_tensor_objects(): diff --git a/python/xorbits/numpy/mars_adapters/tests/test_mars_adapters.py b/python/xorbits/numpy/mars_adapters/tests/test_mars_adapters.py index d5eea980d..6b04607f8 100644 --- a/python/xorbits/numpy/mars_adapters/tests/test_mars_adapters.py +++ b/python/xorbits/numpy/mars_adapters/tests/test_mars_adapters.py @@ -62,6 +62,10 @@ def test_random(): assert isinstance(np.random.standard_normal(10), DataRef) +def test_special(): + assert isinstance(np.special.erf(np.linspace(-3, 3)), DataRef) + + def test_objects(): assert isinstance(np.c_[np.array([1, 2, 3]), np.array([4, 5, 6])], DataRef) diff --git a/python/xorbits/numpy/numpy_adapters/core.py b/python/xorbits/numpy/numpy_adapters/core.py index ce38db554..2a05e5b4e 100644 --- a/python/xorbits/numpy/numpy_adapters/core.py +++ b/python/xorbits/numpy/numpy_adapters/core.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, Type import numpy as np +import scipy from ..._mars.core import Entity as MarsEntity from ...core import DataType @@ -145,3 +146,4 @@ def collect_numpy_module_members(np_mod: ModuleType) -> Dict[str, Any]: NUMPY_LINALG_MEMBERS = collect_numpy_module_members(np.linalg) NUMPY_FFT_MEMBERS = collect_numpy_module_members(np.fft) NUMPY_RANDOM_MEMBERS = collect_numpy_module_members(np.random) +NUMPY_SPECIAL_MEMBERS = collect_numpy_module_members(scipy.special) diff --git a/python/xorbits/numpy/special/__init__.py b/python/xorbits/numpy/special/__init__.py new file mode 100644 index 000000000..e25a9a5b9 --- /dev/null +++ b/python/xorbits/numpy/special/__init__.py @@ -0,0 +1,47 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +from ...core.utils.fallback import unimplemented_func + + +def __dir__(): + from ..mars_adapters import MARS_TENSOR_SPECIAL_CALLABLES + from ..numpy_adapters.core import NUMPY_SPECIAL_MEMBERS + + return list(MARS_TENSOR_SPECIAL_CALLABLES.keys()) + list( + NUMPY_SPECIAL_MEMBERS.keys() + ) + + +def __getattr__(name: str): + from ..mars_adapters import MARS_TENSOR_SPECIAL_CALLABLES + from ..numpy_adapters.core import NUMPY_SPECIAL_MEMBERS + + if name in MARS_TENSOR_SPECIAL_CALLABLES: + return MARS_TENSOR_SPECIAL_CALLABLES[name] + else: + import numpy + import scipy + + if not hasattr(numpy.linalg, name): + raise AttributeError(name) + elif name in NUMPY_SPECIAL_MEMBERS: + return NUMPY_SPECIAL_MEMBERS[name] + else: # pragma: no cover + if inspect.ismethod(getattr(scipy.special, name)): + return unimplemented_func + else: + raise AttributeError(name)