Skip to content

Commit

Permalink
FEAT: Add xorbits.sklearn module (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
JiaYaobo authored Oct 17, 2023
1 parent c15148d commit 251d1ca
Show file tree
Hide file tree
Showing 53 changed files with 2,097 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/xorbits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _install():
from .lightgbm import _install as _install_lightgbm
from .numpy import _install as _install_numpy
from .pandas import _install as _install_pandas
from .sklearn import _install as _install_sklearn
from .web import _install as _install_web
from .xgboost import _install as _install_xgboost

Expand All @@ -34,6 +35,7 @@ def _install():
_install_xgboost()
_install_datasets()
_install_experimental()
_install_sklearn()


_install()
Expand Down
31 changes: 31 additions & 0 deletions python/xorbits/sklearn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def _install():
"""Nothing required for installing sklearn."""


__all__ = [
"cluster",
"datasets",
"decomposition",
"ensemble",
"linear_model",
"metrics",
"model_selection",
"neighbors",
"preprocessing",
"semi_supervised",
]
49 changes: 49 additions & 0 deletions python/xorbits/sklearn/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...core.utils.fallback import unimplemented_func


def _install():
"""Nothing required for installing sklearn."""


def __dir__(): # pragma: no cover
try:
import sklearn
except ImportError:
raise AttributeError("sklearn is required but not installed.")
from .mars_adapters import MARS_SKLEARN_CLUSTER_CALLABLES

return list(MARS_SKLEARN_CLUSTER_CALLABLES.keys())


def __getattr__(name: str): # pragma: no cover
import inspect

try:
import sklearn.cluster as sk_cluster
except ImportError:
raise AttributeError("sklearn is required but not installed.")
from .mars_adapters import MARS_SKLEARN_CLUSTER_CALLABLES

if name in MARS_SKLEARN_CLUSTER_CALLABLES:
return MARS_SKLEARN_CLUSTER_CALLABLES[name]
else:
if not hasattr(sk_cluster, name):
raise AttributeError(name)
else:
if inspect.ismethod(getattr(sk_cluster, name)):
return unimplemented_func()
else:
raise AttributeError
14 changes: 14 additions & 0 deletions python/xorbits/sklearn/cluster/mars_adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .core import MARS_SKLEARN_CLUSTER_CALLABLES
35 changes: 35 additions & 0 deletions python/xorbits/sklearn/cluster/mars_adapters/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sklearn.cluster as sk_cluster

from ...._mars.learn import cluster as mars_cluster
from ...._mars.learn.cluster import KMeans as MarsKMeans
from ....core.utils.docstring import attach_module_callable_docstring
from ...utils import SKLearnBase, _collect_module_callables, _install_cls_members


class KMeans(SKLearnBase):
_marscls = MarsKMeans


SKLEARN_CLUSTER_CLS_MAP = {KMeans: MarsKMeans}

MARS_SKLEARN_CLUSTER_CALLABLES = _collect_module_callables(
mars_cluster, sk_cluster, skip_members=["register_op"]
)
_install_cls_members(
SKLEARN_CLUSTER_CLS_MAP, MARS_SKLEARN_CLUSTER_CALLABLES, sk_cluster
)
attach_module_callable_docstring(KMeans, sk_cluster, sk_cluster.KMeans)
13 changes: 13 additions & 0 deletions python/xorbits/sklearn/cluster/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
57 changes: 57 additions & 0 deletions python/xorbits/sklearn/cluster/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

try:
import sklearn
except ImportError: # pragma: no cover
sklearn = None

import numpy as np
import pytest

from .... import numpy as xnp
from .. import KMeans

n_rows = 1000
n_clusters = 8
n_columns = 10
chunk_size = 200
rs = xnp.random.RandomState(0)
X = rs.rand(n_rows, n_columns, chunk_size=chunk_size)
X_new = rs.rand(n_rows, n_columns, chunk_size=chunk_size)


@pytest.mark.skipif(sklearn is None, reason="scikit-learn not installed")
def test_doc():
docstring = KMeans.__doc__
assert docstring is not None and docstring.endswith(
"This docstring was copied from sklearn.cluster."
)

docstring = KMeans.fit.__doc__
assert docstring is not None and docstring.endswith(
"This docstring was copied from sklearn.cluster._kmeans.KMeans."
)


@pytest.mark.skipif(sklearn is None, reason="sci-kit-learn not installed")
def test_kmeans_cluster():
kms = KMeans(n_clusters=n_clusters, random_state=0)
kms.fit(X)
predict = kms.predict(X_new).fetch()

assert kms.n_clusters == n_clusters
assert np.shape(kms.labels_.fetch()) == (n_rows,)
assert np.shape(kms.cluster_centers_.fetch()) == (n_clusters, n_columns)
assert np.shape(predict) == (n_rows,)
48 changes: 48 additions & 0 deletions python/xorbits/sklearn/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def _install():
"""Nothing required for installing sklearn."""


def __dir__(): # pragma: no cover
try:
import sklearn
except ImportError:
raise AttributeError("sklearn is required but not installed.")
from .mars_adapters import MARS_SKLEARN_DATASETS_CALLABLES

return list(MARS_SKLEARN_DATASETS_CALLABLES.keys())


def __getattr__(name: str): # pragma: no cover
import inspect

try:
import sklearn.datasets as sk_datasets
except ImportError:
raise AttributeError("sklearn is required but not installed.")
from .mars_adapters import MARS_SKLEARN_DATASETS_CALLABLES

if name in MARS_SKLEARN_DATASETS_CALLABLES:
return MARS_SKLEARN_DATASETS_CALLABLES[name]
else:
if not hasattr(sk_datasets, name):
raise AttributeError(name)
else:
if inspect.ismethod(getattr(sk_datasets, name)):
raise NotImplementedError(f"This function is not implemented yet.")
else:
raise AttributeError
14 changes: 14 additions & 0 deletions python/xorbits/sklearn/datasets/mars_adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .core import MARS_SKLEARN_DATASETS_CALLABLES
22 changes: 22 additions & 0 deletions python/xorbits/sklearn/datasets/mars_adapters/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sklearn.datasets as sk_datasets

from ...._mars.learn import datasets as mars_datasets
from ...utils import _collect_module_callables

MARS_SKLEARN_DATASETS_CALLABLES = _collect_module_callables(
mars_datasets, sk_datasets, skip_members=["register_op"]
)
13 changes: 13 additions & 0 deletions python/xorbits/sklearn/datasets/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 251d1ca

Please sign in to comment.