From 7b16aec7035beb853791ec90b31ca95a820a9a6d Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas Date: Fri, 17 Jun 2022 18:22:16 -0700 Subject: [PATCH] BUG: SINDyDerivative set_params() missed sklearn interface Was resulting in: ``` > self.kwargs {"kind": "spline", "s": .01, "kwargs":{"kind": finitedifference, "k": 1}} ``` This is because set_params has to mimic the exact functionality of __init__. See https://scikit-learn.org/stable/developers/develop.html#parameters-and-init --- pysindy/differentiation/sindy_derivative.py | 2 +- test/differentiation/test_differentiation_methods.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pysindy/differentiation/sindy_derivative.py b/pysindy/differentiation/sindy_derivative.py index 7ae11133b..17ace83eb 100644 --- a/pysindy/differentiation/sindy_derivative.py +++ b/pysindy/differentiation/sindy_derivative.py @@ -52,7 +52,7 @@ def set_params(self, **params): # Simple optimization to gain speed (inspect is slow) return self else: - self.kwargs.update(params) + self.kwargs.update(params["kwargs"]) return self diff --git a/test/differentiation/test_differentiation_methods.py b/test/differentiation/test_differentiation_methods.py index e24481aac..6bae3da42 100644 --- a/test/differentiation/test_differentiation_methods.py +++ b/test/differentiation/test_differentiation_methods.py @@ -262,6 +262,13 @@ def test_wrapper_equivalence_with_dxdt(data, derivative_kws): ) +def test_sindy_derivative_kwarg_update(): + method = SINDyDerivative(kind="spectral", foo=2) + method.set_params(kwargs={"kind": "spline", "foo": 1}) + assert method.kwargs["kind"] == "spline" + assert method.kwargs["foo"] == 1 + + @pytest.mark.parametrize( "data, derivative_kws", [