Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reparametrize raisedcosine #50

Merged
merged 36 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1f1e95b
added tests
BalzaniEdoardo Oct 13, 2023
ccb1c5e
linted
BalzaniEdoardo Oct 14, 2023
6bd116a
run on mac
BalzaniEdoardo Oct 14, 2023
5e4122a
test on linux and mac
BalzaniEdoardo Oct 14, 2023
df69255
clip -1
BalzaniEdoardo Nov 21, 2023
47cfaee
changed default
BalzaniEdoardo Nov 21, 2023
f13661d
merged
BalzaniEdoardo Nov 30, 2023
56c51db
reparametrization
BalzaniEdoardo Nov 30, 2023
f2b270e
linted
BalzaniEdoardo Nov 30, 2023
8c9b152
updated param names
BalzaniEdoardo Nov 30, 2023
b23eebd
updated default
BalzaniEdoardo Nov 30, 2023
9936ca3
improved parametrization
BalzaniEdoardo Nov 30, 2023
16bd684
removed unused variable
BalzaniEdoardo Dec 7, 2023
29cd30b
meged conflicts in basis?
BalzaniEdoardo Dec 15, 2023
b9bc61b
added checks
BalzaniEdoardo Dec 15, 2023
e8f318c
linted
BalzaniEdoardo Dec 15, 2023
b4b2e2c
fixed tests
BalzaniEdoardo Dec 15, 2023
ff32160
fixed refs
BalzaniEdoardo Dec 15, 2023
06757d9
fixed note
BalzaniEdoardo Dec 15, 2023
664169d
improved docstring
BalzaniEdoardo Dec 15, 2023
4d58d32
reversed ci change to only ubuntu
BalzaniEdoardo Dec 19, 2023
1fcd446
Update src/nemos/basis.py
BalzaniEdoardo Dec 19, 2023
fd3d1d4
Update src/nemos/basis.py
BalzaniEdoardo Dec 19, 2023
93a1b1e
Update src/nemos/basis.py
BalzaniEdoardo Dec 19, 2023
de1a895
made width a read only property, added comment on evaluate
BalzaniEdoardo Dec 19, 2023
cf4a60d
getter only for time scaling
BalzaniEdoardo Dec 19, 2023
120afe1
Update src/nemos/basis.py
BalzaniEdoardo Dec 19, 2023
b86760e
Update tests/test_basis.py
BalzaniEdoardo Dec 19, 2023
29d116d
Merge branch 'reparametrize_raisedcosine' of github.com:flatironinsti…
BalzaniEdoardo Dec 19, 2023
50fec49
linted
BalzaniEdoardo Dec 19, 2023
5a1a350
modified enforcing basis to zero
BalzaniEdoardo Dec 19, 2023
753ed20
improved code linting
BalzaniEdoardo Dec 19, 2023
348f8af
linted
BalzaniEdoardo Dec 19, 2023
e168cc4
changed code to have the very last time point being 0
BalzaniEdoardo Dec 19, 2023
5785acc
added comment
BalzaniEdoardo Dec 19, 2023
14055b6
linted
BalzaniEdoardo Dec 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
tox:
strategy:
matrix:
os: [ubuntu-latest, macos-latest] #[ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest] #[ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.10'] #['3.8', '3.9', '3.10']
runs-on: ${{ matrix.os }}
steps:
Expand Down
88 changes: 50 additions & 38 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def _check_n_basis_min(self) -> None:


class MSplineBasis(SplineBasis):
"""M-spline[$^1$](references) 1-dimensional basis functions.
"""M-spline[$^1$](#references) 1-dimensional basis functions.

Parameters
----------
Expand Down Expand Up @@ -576,7 +576,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:

class BSplineBasis(SplineBasis):
"""
B-spline[$^1$](references) 1-dimensional basis functions.
B-spline[$^1$](#references) 1-dimensional basis functions.

Parameters
----------
Expand Down Expand Up @@ -772,7 +772,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
class RaisedCosineBasisLinear(Basis):
"""Represent linearly-spaced raised cosine basis functions.

This implementation is based on the cosine bumps used by Pillow et al.[$^1$](references)
This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references)
to uniformly tile the internal points of the domain.
billbrod marked this conversation as resolved.
Show resolved Hide resolved

Parameters
Expand All @@ -793,19 +793,14 @@ class RaisedCosineBasisLinear(Basis):
def __init__(self, n_basis_funcs: int, width: float = 2.0) -> None:
super().__init__(n_basis_funcs)
self._n_input_dimensionality = 1
self.width = width
self._check_width(width)
self._width = width

@property
def width(self):
"""Return width of the raised cosine."""
return self._width

@width.setter
def width(self, width: float):
"""Check and set width of the raised cosine."""
self._check_width(width)
self._width = width

@staticmethod
def _check_width(width: float):
"""Validate the width value.
Expand Down Expand Up @@ -852,8 +847,11 @@ def evaluate(self, sample_pts: ArrayLike) -> NDArray:
if any(sample_pts < 0) or any(sample_pts > 1):
raise ValueError("Sample points for RaisedCosine basis must lie in [0,1]!")

peaks = np.linspace(0, 1, self.n_basis_funcs)
peaks = self._compute_peaks()
delta = peaks[1] - peaks[0]
# generate a set of shifted cosines, and constrain them to be non-zero
# over a single period, then enforce the codomain to be [0,1], by adding 1
# and then multiply by 0.5
basis_funcs = 0.5 * (
np.cos(
np.clip(
billbrod marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -867,6 +865,16 @@ def evaluate(self, sample_pts: ArrayLike) -> NDArray:

return basis_funcs

def _compute_peaks(self):
"""
Compute the location of raised cosine peaks.

Returns
-------
Peak locations of each basis element.
"""
return np.linspace(0, 1, self.n_basis_funcs)

def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
"""Evaluate the basis set on a grid of equi-spaced sample points.

Expand Down Expand Up @@ -939,46 +947,67 @@ def __init__(
super().__init__(n_basis_funcs, width=width)
self.enforce_decay_to_zero = enforce_decay_to_zero
if time_scaling is None:
self.time_scaling = 50.0
self._time_scaling = 50.0
else:
self.time_scaling = time_scaling
self._check_time_scaling(time_scaling)
self._time_scaling = time_scaling

@property
def time_scaling(self):
return self._time_scaling

@time_scaling.setter
def time_scaling(self, time_scaling):
@staticmethod
def _check_time_scaling(time_scaling):
if time_scaling <= 0:
raise ValueError(
f"Only strictly positive time_scaling are allowed, {time_scaling} provided instead."
)
self._time_scaling = time_scaling

def _transform_samples(self, sample_pts: NDArray) -> NDArray:
"""
Map the sample domain to log-space.

Parameters
----------
sample_pts : NDArray
sample_pts :
Sample points used for evaluating the splines,
shape (n_samples, ).

Returns
-------
NDArray
Transformed version of the sample points that matches the Raised Cosine basis domain,
shape (n_samples, ).
"""
# This log-stretching of the sample axis has the following effect:
# - as the time_scaling tends to 0, the points will be linearly spaced.
# - as the time_scaling tends to inf, basis will be dense around 0.
# - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain.
# - as the time_scaling tends to inf, basis will be small and dense around 0 and
# progressively larger and less dense towards 1.
Comment on lines +983 to +984
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# - as the time_scaling tends to inf, basis will be small and dense around 0 and
# progressively larger and less dense towards 1.
# - as the time_scaling tends to inf, basis will be small and dense around 0 and
# progressively larger and less dense towards 1.

log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log(
self.time_scaling + 1
)
return log_spaced_pts

def _compute_peaks(self):
"""
Peak location of each log-spaced cosine basis element

Compute the peak location for the log-spaced raised cosine basis.
Enforcing that the last basis decays to zero is equivalent to
setting the last peak to a value smaller than 1.

Returns
-------
Peak locations of each basis element.

"""
if self.enforce_decay_to_zero:
# compute the last peak location such that the last
# basis element decays to zero at the last sample.
last_peak = 1 - self.width / (self.n_basis_funcs + self.width - 1)
else:
last_peak = 1
return np.linspace(0, last_peak, self.n_basis_funcs)

def evaluate(self, sample_pts: ArrayLike) -> NDArray:
"""Generate log-spaced raised cosine basis with given samples.

Expand All @@ -998,24 +1027,7 @@ def evaluate(self, sample_pts: ArrayLike) -> NDArray:
If the sample provided do not lie in [0,1].
"""
(sample_pts,) = self._check_evaluate_input(sample_pts)
if not self.enforce_decay_to_zero:
# keep the last basis, i.e. do not enforce decay to zero
# for the filter.
eval_basis = super().evaluate(self._transform_samples(sample_pts))
else:
# temporarily add n_trim basis element
# guaranteeing that the last basis element decays to 0 for any width.
n_trim = int(np.ceil(self.width))
self.n_basis_funcs += n_trim
# wrap the evaluation in a try -> finally to make sure that the original
# basis function number is preserved even if an exception is raised.
try:
eval_basis = super().evaluate(self._transform_samples(sample_pts))
eval_basis = eval_basis[..., :-n_trim]
finally:
self.n_basis_funcs -= n_trim

return eval_basis
return super().evaluate(self._transform_samples(sample_pts))


class OrthExponentialBasis(Basis):
Expand Down
14 changes: 2 additions & 12 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,8 @@ def test_decay_to_zero_basis_number_match(self, width):
@pytest.mark.parametrize(
"time_scaling ,expectation",
[
(
-1,
pytest.raises(
ValueError, match="Only strictly positive time_scaling are allowed"
),
),
(
0,
pytest.raises(
ValueError, match="Only strictly positive time_scaling are allowed"
),
),
(-1, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")),
(0, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")),
(0.1, does_not_raise()),
(10, does_not_raise()),
],
Expand Down