Skip to content

Commit

Permalink
Use array interface for testing numpy arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 21, 2023
1 parent 0080c97 commit 2e1d276
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2434,6 +2434,7 @@ def inplace_predict(
_is_cudf_df,
_is_cupy_array,
_is_list,
_is_numpy_array,
_is_pandas_df,
_is_pandas_series,
_is_tuple,
Expand Down Expand Up @@ -2463,7 +2464,7 @@ def inplace_predict(
f"got {data.shape[1]}"
)

if isinstance(data, np.ndarray):
if _is_numpy_array(data):
from .data import _ensure_np_dtype

data, _ = _ensure_np_dtype(data, data.dtype)
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _is_scipy_coo(data: DataType) -> bool:


def _is_numpy_array(data: DataType) -> bool:
return isinstance(data, (np.ndarray, np.matrix))
return hasattr(data, "__array_interface__")


def _ensure_np_dtype(
Expand Down

0 comments on commit 2e1d276

Please sign in to comment.