diff --git a/sklearn/manifold/_mds.py b/sklearn/manifold/_mds.py index e497c49a117be..aa641253d1c27 100644 --- a/sklearn/manifold/_mds.py +++ b/sklearn/manifold/_mds.py @@ -15,7 +15,7 @@ from ..isotonic import IsotonicRegression from ..metrics import euclidean_distances from ..utils import check_array, check_random_state, check_symmetric -from ..utils._param_validation import Hidden, Interval, StrOptions +from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params from ..utils.parallel import Parallel, delayed @@ -167,6 +167,27 @@ def _smacof_single( return X, stress, it + 1 +@validate_params( + { + "dissimilarities": ["array-like"], + "metric": ["boolean"], + "n_components": [Interval(Integral, 1, None, closed="left")], + "init": ["array-like", None], + "n_init": [Interval(Integral, 1, None, closed="left")], + "n_jobs": [Integral, None], + "max_iter": [Interval(Integral, 1, None, closed="left")], + "verbose": ["verbose"], + "eps": [Interval(Real, 0, None, closed="left")], + "random_state": ["random_state"], + "return_n_iter": ["boolean"], + "normalized_stress": [ + "boolean", + StrOptions({"auto"}), + Hidden(StrOptions({"warn"})), + ], + }, + prefer_skip_nested_validation=True, +) def smacof( dissimilarities, *, @@ -204,7 +225,7 @@ def smacof( Parameters ---------- - dissimilarities : ndarray of shape (n_samples, n_samples) + dissimilarities : array-like of shape (n_samples, n_samples) Pairwise dissimilarities between the points. Must be symmetric. metric : bool, default=True @@ -218,7 +239,7 @@ def smacof( ``init`` is used to determine the dimensionality of the embedding space. - init : ndarray of shape (n_samples, n_components), default=None + init : array-like of shape (n_samples, n_components), default=None Starting configuration of the embedding to initialize the algorithm. By default, the algorithm is initialized with a randomly chosen array. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index a5f57b581c8e8..313acb607b266 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -202,6 +202,7 @@ def _check_function_param_validation( "sklearn.linear_model.orthogonal_mp_gram", "sklearn.linear_model.ridge_regression", "sklearn.metrics.accuracy_score", + "sklearn.manifold.smacof", "sklearn.metrics.auc", "sklearn.metrics.average_precision_score", "sklearn.metrics.balanced_accuracy_score",