-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into bug/df_index_column_pruning_issue
- Loading branch information
Showing
53 changed files
with
2,097 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def _install(): | ||
"""Nothing required for installing sklearn.""" | ||
|
||
|
||
__all__ = [ | ||
"cluster", | ||
"datasets", | ||
"decomposition", | ||
"ensemble", | ||
"linear_model", | ||
"metrics", | ||
"model_selection", | ||
"neighbors", | ||
"preprocessing", | ||
"semi_supervised", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from ...core.utils.fallback import unimplemented_func | ||
|
||
|
||
def _install(): | ||
"""Nothing required for installing sklearn.""" | ||
|
||
|
||
def __dir__(): # pragma: no cover | ||
try: | ||
import sklearn | ||
except ImportError: | ||
raise AttributeError("sklearn is required but not installed.") | ||
from .mars_adapters import MARS_SKLEARN_CLUSTER_CALLABLES | ||
|
||
return list(MARS_SKLEARN_CLUSTER_CALLABLES.keys()) | ||
|
||
|
||
def __getattr__(name: str): # pragma: no cover | ||
import inspect | ||
|
||
try: | ||
import sklearn.cluster as sk_cluster | ||
except ImportError: | ||
raise AttributeError("sklearn is required but not installed.") | ||
from .mars_adapters import MARS_SKLEARN_CLUSTER_CALLABLES | ||
|
||
if name in MARS_SKLEARN_CLUSTER_CALLABLES: | ||
return MARS_SKLEARN_CLUSTER_CALLABLES[name] | ||
else: | ||
if not hasattr(sk_cluster, name): | ||
raise AttributeError(name) | ||
else: | ||
if inspect.ismethod(getattr(sk_cluster, name)): | ||
return unimplemented_func() | ||
else: | ||
raise AttributeError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .core import MARS_SKLEARN_CLUSTER_CALLABLES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import sklearn.cluster as sk_cluster | ||
|
||
from ...._mars.learn import cluster as mars_cluster | ||
from ...._mars.learn.cluster import KMeans as MarsKMeans | ||
from ....core.utils.docstring import attach_module_callable_docstring | ||
from ...utils import SKLearnBase, _collect_module_callables, _install_cls_members | ||
|
||
|
||
class KMeans(SKLearnBase): | ||
_marscls = MarsKMeans | ||
|
||
|
||
SKLEARN_CLUSTER_CLS_MAP = {KMeans: MarsKMeans} | ||
|
||
MARS_SKLEARN_CLUSTER_CALLABLES = _collect_module_callables( | ||
mars_cluster, sk_cluster, skip_members=["register_op"] | ||
) | ||
_install_cls_members( | ||
SKLEARN_CLUSTER_CLS_MAP, MARS_SKLEARN_CLUSTER_CALLABLES, sk_cluster | ||
) | ||
attach_module_callable_docstring(KMeans, sk_cluster, sk_cluster.KMeans) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
try: | ||
import sklearn | ||
except ImportError: # pragma: no cover | ||
sklearn = None | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from .... import numpy as xnp | ||
from .. import KMeans | ||
|
||
n_rows = 1000 | ||
n_clusters = 8 | ||
n_columns = 10 | ||
chunk_size = 200 | ||
rs = xnp.random.RandomState(0) | ||
X = rs.rand(n_rows, n_columns, chunk_size=chunk_size) | ||
X_new = rs.rand(n_rows, n_columns, chunk_size=chunk_size) | ||
|
||
|
||
@pytest.mark.skipif(sklearn is None, reason="scikit-learn not installed") | ||
def test_doc(): | ||
docstring = KMeans.__doc__ | ||
assert docstring is not None and docstring.endswith( | ||
"This docstring was copied from sklearn.cluster." | ||
) | ||
|
||
docstring = KMeans.fit.__doc__ | ||
assert docstring is not None and docstring.endswith( | ||
"This docstring was copied from sklearn.cluster._kmeans.KMeans." | ||
) | ||
|
||
|
||
@pytest.mark.skipif(sklearn is None, reason="sci-kit-learn not installed") | ||
def test_kmeans_cluster(): | ||
kms = KMeans(n_clusters=n_clusters, random_state=0) | ||
kms.fit(X) | ||
predict = kms.predict(X_new).fetch() | ||
|
||
assert kms.n_clusters == n_clusters | ||
assert np.shape(kms.labels_.fetch()) == (n_rows,) | ||
assert np.shape(kms.cluster_centers_.fetch()) == (n_clusters, n_columns) | ||
assert np.shape(predict) == (n_rows,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def _install(): | ||
"""Nothing required for installing sklearn.""" | ||
|
||
|
||
def __dir__(): # pragma: no cover | ||
try: | ||
import sklearn | ||
except ImportError: | ||
raise AttributeError("sklearn is required but not installed.") | ||
from .mars_adapters import MARS_SKLEARN_DATASETS_CALLABLES | ||
|
||
return list(MARS_SKLEARN_DATASETS_CALLABLES.keys()) | ||
|
||
|
||
def __getattr__(name: str): # pragma: no cover | ||
import inspect | ||
|
||
try: | ||
import sklearn.datasets as sk_datasets | ||
except ImportError: | ||
raise AttributeError("sklearn is required but not installed.") | ||
from .mars_adapters import MARS_SKLEARN_DATASETS_CALLABLES | ||
|
||
if name in MARS_SKLEARN_DATASETS_CALLABLES: | ||
return MARS_SKLEARN_DATASETS_CALLABLES[name] | ||
else: | ||
if not hasattr(sk_datasets, name): | ||
raise AttributeError(name) | ||
else: | ||
if inspect.ismethod(getattr(sk_datasets, name)): | ||
raise NotImplementedError(f"This function is not implemented yet.") | ||
else: | ||
raise AttributeError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .core import MARS_SKLEARN_DATASETS_CALLABLES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import sklearn.datasets as sk_datasets | ||
|
||
from ...._mars.learn import datasets as mars_datasets | ||
from ...utils import _collect_module_callables | ||
|
||
MARS_SKLEARN_DATASETS_CALLABLES = _collect_module_callables( | ||
mars_datasets, sk_datasets, skip_members=["register_op"] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2022-2023 XProbe Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Oops, something went wrong.