From a02c3bf637edf715823861dba6c06560400a6777 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 5 Nov 2024 09:10:48 -0500 Subject: [PATCH] modified label --- src/nemos/basis.py | 6 +++++- tests/test_basis.py | 10 +++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 44380bca..8751f273 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -560,7 +560,11 @@ def __init__( self._n_output_features = None self._input_shape = None - self._label = str(label) + if label is None: + self._label = self.__class__.__name__ + else: + self._label = str(label) + self.window_size = window_size self.bounds = bounds diff --git a/tests/test_basis.py b/tests/test_basis.py index 6fe59698..4e31b49a 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -599,7 +599,7 @@ def test_init_mode(self, mode, expectation): @pytest.mark.parametrize("label", [None, "label"]) def test_init_label(self, label): bas = self.cls(5, label=label) - assert bas.label == str(label) + assert bas.label == (str(label) if label is not None else self.cls.__name__) @pytest.mark.parametrize( "attribute, value", @@ -1445,7 +1445,7 @@ def test_init_mode(self, mode, expectation): @pytest.mark.parametrize("label", [None, "label"]) def test_init_label(self, label): bas = self.cls(5, label=label) - assert bas.label == str(label) + assert bas.label == (str(label) if label is not None else self.cls.__name__) @pytest.mark.parametrize( "attribute, value", @@ -2240,7 +2240,7 @@ def test_init_mode(self, mode, expectation): @pytest.mark.parametrize("label", [None, "label"]) def test_init_label(self, label): bas = self.cls(5, label=label) - assert bas.label == str(label) + assert bas.label == (str(label) if label is not None else self.cls.__name__) @pytest.mark.parametrize( "attribute, value", @@ -3152,7 +3152,7 @@ def test_init_mode(self, mode, expectation): @pytest.mark.parametrize("label", [None, "label"]) def test_init_label(self, label): bas = self.cls(5, label=label, decay_rates=np.arange(1, 6)) - assert bas.label == str(label) + assert bas.label == (str(label) if label is not None else self.cls.__name__) @pytest.mark.parametrize( "attribute, value", @@ -3925,7 +3925,7 @@ def test_init_mode(self, mode, expectation): @pytest.mark.parametrize("label", [None, "label"]) def test_init_label(self, label): bas = self.cls(5, label=label) - assert bas.label == str(label) + assert bas.label == (str(label) if label is not None else self.cls.__name__) @pytest.mark.parametrize( "attribute, value",