Skip to content

Commit

Permalink
modified label
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Nov 5, 2024
1 parent c948dc2 commit a02c3bf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
6 changes: 5 additions & 1 deletion src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,11 @@ def __init__(
self._n_output_features = None
self._input_shape = None

self._label = str(label)
if label is None:
self._label = self.__class__.__name__
else:
self._label = str(label)

self.window_size = window_size
self.bounds = bounds

Expand Down
10 changes: 5 additions & 5 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def test_init_mode(self, mode, expectation):
@pytest.mark.parametrize("label", [None, "label"])
def test_init_label(self, label):
bas = self.cls(5, label=label)
assert bas.label == str(label)
assert bas.label == (str(label) if label is not None else self.cls.__name__)

@pytest.mark.parametrize(
"attribute, value",
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def test_init_mode(self, mode, expectation):
@pytest.mark.parametrize("label", [None, "label"])
def test_init_label(self, label):
bas = self.cls(5, label=label)
assert bas.label == str(label)
assert bas.label == (str(label) if label is not None else self.cls.__name__)

@pytest.mark.parametrize(
"attribute, value",
Expand Down Expand Up @@ -2240,7 +2240,7 @@ def test_init_mode(self, mode, expectation):
@pytest.mark.parametrize("label", [None, "label"])
def test_init_label(self, label):
bas = self.cls(5, label=label)
assert bas.label == str(label)
assert bas.label == (str(label) if label is not None else self.cls.__name__)

@pytest.mark.parametrize(
"attribute, value",
Expand Down Expand Up @@ -3152,7 +3152,7 @@ def test_init_mode(self, mode, expectation):
@pytest.mark.parametrize("label", [None, "label"])
def test_init_label(self, label):
bas = self.cls(5, label=label, decay_rates=np.arange(1, 6))
assert bas.label == str(label)
assert bas.label == (str(label) if label is not None else self.cls.__name__)

@pytest.mark.parametrize(
"attribute, value",
Expand Down Expand Up @@ -3925,7 +3925,7 @@ def test_init_mode(self, mode, expectation):
@pytest.mark.parametrize("label", [None, "label"])
def test_init_label(self, label):
bas = self.cls(5, label=label)
assert bas.label == str(label)
assert bas.label == (str(label) if label is not None else self.cls.__name__)

@pytest.mark.parametrize(
"attribute, value",
Expand Down

0 comments on commit a02c3bf

Please sign in to comment.