diff --git a/dask_ml/_partial.py b/dask_ml/_partial.py index 943237336..3f9fb2cf9 100644 --- a/dask_ml/_partial.py +++ b/dask_ml/_partial.py @@ -95,7 +95,7 @@ def fit( if not hasattr(model, "partial_fit"): msg = "The class '{}' does not implement 'partial_fit'." - raise ValueError(msg.format(type(model))) + raise AttributeError(msg.format(type(model))) order = list(range(nblocks)) if shuffle_blocks: diff --git a/tests/test_incremental.py b/tests/test_incremental.py index 7867a3b51..2e67484d6 100644 --- a/tests/test_incremental.py +++ b/tests/test_incremental.py @@ -9,7 +9,7 @@ from dask.array.utils import assert_eq from scipy.sparse import csr_matrix from sklearn.base import clone -from sklearn.linear_model import SGDClassifier, SGDRegressor +from sklearn.linear_model import SGDClassifier, SGDRegressor, LinearRegression from sklearn.pipeline import make_pipeline import dask_ml.feature_extraction.text @@ -235,3 +235,25 @@ def test_incremental_sparse_inputs(): clf_output = clf.predict(X).astype(np.int64) assert_eq(clf_output, wrap_output, ignore_dtype=True) + + +def test_no_partial_fit(): + # Create data + n, d = 100, 10 + X_np = np.random.uniform(size=(n, d)) + y_np = np.random.uniform(size=n) + X_da = da.from_array(X_np, chunks=(n // 2, -1)) + y_da = da.from_array(y_np, chunks=n // 2) + + est = LinearRegression() + dask_est = Incremental(est) + + with pytest.raises(AttributeError, match="partial_fit"): + dask_est.fit(X_np, y_np) + with pytest.raises(AttributeError, match="partial_fit"): + dask_est.partial_fit(X_np, y_np) + + with pytest.raises(AttributeError, match="partial_fit"): + dask_est.fit(X_da, y_da) + with pytest.raises(AttributeError, match="partial_fit"): + dask_est.partial_fit(X_da, y_da) diff --git a/tests/test_partial.py b/tests/test_partial.py index 03e52ec84..4507ac28a 100644 --- a/tests/test_partial.py +++ b/tests/test_partial.py @@ -114,5 +114,5 @@ def test_bag(): def test_no_partial_fit_raises(): X, y = make_classification(chunks=50) - with pytest.raises(ValueError, match="RandomForestClassifier"): + with pytest.raises(AttributeError, match="does not implement 'partial_fit'"): fit(RandomForestClassifier(), X, y)