Skip to content

Commit

Permalink
BUG: SINDyDerivative set_params() missed sklearn interface
Browse files Browse the repository at this point in the history
Was resulting in:
```
> self.kwargs
{"kind": "spline", "s": .01, "kwargs":{"kind": finitedifference, "k": 1}}
```

This is because set_params has to mimic the exact functionality of __init__.

See https://scikit-learn.org/stable/developers/develop.html#parameters-and-init
  • Loading branch information
Jacob-Stevens-Haas committed Jun 18, 2022
1 parent cbe47de commit 7b16aec
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pysindy/differentiation/sindy_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def set_params(self, **params):
# Simple optimization to gain speed (inspect is slow)
return self
else:
self.kwargs.update(params)
self.kwargs.update(params["kwargs"])

return self

Expand Down
7 changes: 7 additions & 0 deletions test/differentiation/test_differentiation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ def test_wrapper_equivalence_with_dxdt(data, derivative_kws):
)


def test_sindy_derivative_kwarg_update():
method = SINDyDerivative(kind="spectral", foo=2)
method.set_params(kwargs={"kind": "spline", "foo": 1})
assert method.kwargs["kind"] == "spline"
assert method.kwargs["foo"] == 1


@pytest.mark.parametrize(
"data, derivative_kws",
[
Expand Down

0 comments on commit 7b16aec

Please sign in to comment.