Skip to content

Commit

Permalink
Use array interface for testing numpy arrays. (dmlc#9602)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 8, 2023
1 parent 032bcc5 commit 88d3db0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
3 changes: 2 additions & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,7 @@ def inplace_predict(
_is_cudf_df,
_is_cupy_array,
_is_list,
_is_np_array_like,
_is_pandas_df,
_is_pandas_series,
_is_tuple,
Expand Down Expand Up @@ -2428,7 +2429,7 @@ def inplace_predict(
f"got {data.shape[1]}"
)

if isinstance(data, np.ndarray):
if _is_np_array_like(data):
from .data import _ensure_np_dtype

data, _ = _ensure_np_dtype(data, data.dtype)
Expand Down
12 changes: 6 additions & 6 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def _is_scipy_coo(data: DataType) -> bool:
return isinstance(data, scipy.sparse.coo_matrix)


def _is_numpy_array(data: DataType) -> bool:
return isinstance(data, (np.ndarray, np.matrix))
def _is_np_array_like(data: DataType) -> bool:
return hasattr(data, "__array_interface__")


def _ensure_np_dtype(
Expand Down Expand Up @@ -1051,7 +1051,7 @@ def dispatch_data_backend(
return _from_scipy_csr(
data.tocsr(), missing, threads, feature_names, feature_types
)
if _is_numpy_array(data):
if _is_np_array_like(data):
return _from_numpy_array(
data, missing, threads, feature_names, feature_types, data_split_mode
)
Expand Down Expand Up @@ -1194,7 +1194,7 @@ def dispatch_meta_backend(
if _is_tuple(data):
_meta_from_tuple(data, name, dtype, handle)
return
if _is_numpy_array(data):
if _is_np_array_like(data):
_meta_from_numpy(data, name, dtype, handle)
return
if _is_pandas_df(data):
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def _proxy_transform(
return _transform_dlpack(data), None, feature_names, feature_types
if _is_list(data) or _is_tuple(data):
data = np.array(data)
if _is_numpy_array(data):
if _is_np_array_like(data):
data, _ = _ensure_np_dtype(data, data.dtype)
return data, None, feature_names, feature_types
if _is_scipy_csr(data):
Expand Down Expand Up @@ -1331,7 +1331,7 @@ def dispatch_proxy_set_data(
if not allow_host:
raise err

if _is_numpy_array(data):
if _is_np_array_like(data):
_check_data_shape(data)
proxy._set_data_from_array(data) # pylint: disable=W0212
return
Expand Down

0 comments on commit 88d3db0

Please sign in to comment.