Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

require sample axis to be axis=0 #254

Merged
merged 15 commits into from
Oct 25, 2024
Merged

Conversation

BalzaniEdoardo
Copy link
Collaborator

This PR fixes a corner-case bug related to basis in mode="conv".

Issue

Before this PR, one could pass at basis initialization any kwargs from nemos.convolve.create_convolutional_predictor. This included axis meaning that one could specify what's the axis along which the convolution is applied.

What happened in basis.compute_features is that the axis is transposed to the first, and the rest of the axis are flattened to the second axis, so that the compute_features returns an 2darray. However, when the basis is composite, the we check that the number of samples matches for all input, and this check relied on the fact that the sample axis is axis=0.

This caused the following corner case:

>>> b = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=100, axis=1)

>>> # this works because there is one input (the check on the sample axis passes)
>>> b.compute_features(np.ones((2, 200)))
Array([[       nan,        nan,        nan, ...,        nan,        nan,
               nan],
       [       nan,        nan,        nan, ...,        nan,        nan,
               nan],
       [       nan,        nan,        nan, ...,        nan,        nan,
               nan],
       ...,
       [ 2.4915924,  2.4915924,  5.5825534, ..., 20.70655  , 39.878582 ,
        39.878582 ],
       [ 2.4915924,  2.4915924,  5.5825534, ..., 20.70655  , 39.878582 ,
        39.878582 ],
       [ 2.4915924,  2.4915924,  5.5825534, ..., 20.706553 , 39.878582 ,
        39.878586 ]], dtype=float32)

>>> # this fails because the check on the sample axis will raise an error since 3 != 2
>>> (b + b).compute_features(np.ones((2, 200)), np.ones((3, 200)))

Solution

Instead of complicating the check logic, or move it after the transposition happens for all 1D basis, I simplified our life and enforced that the sample axis is always the first one. This is always true for pynapple, and requires the user to transpose an array if the data are organized differently;

Now the initialization would raise an error if axis different from 0 is passed.

>>> b = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=100, axis=1)
ValueError: Invalid `axis=1` provided. Basis requires  requires the convolution to be applied along the first axis (`axis=0`).
Please transpose your input so that the desired axis for convolution is the first dimension (axis=0).

@BalzaniEdoardo BalzaniEdoardo marked this pull request as ready for review October 23, 2024 14:15
@codecov-commenter
Copy link

codecov-commenter commented Oct 23, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.15%. Comparing base (6b3a1e8) to head (d51f8d3).

Additional details and impacted files
@@               Coverage Diff               @@
##           development     #254      +/-   ##
===============================================
+ Coverage        97.09%   97.15%   +0.06%     
===============================================
  Files               20       20              
  Lines             1857     1863       +6     
===============================================
+ Hits              1803     1810       +7     
+ Misses              54       53       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@billbrod billbrod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote would be to have axis just not be allowed at all. Since the behavior cannot be modified, it feels a bit confusing to set it.

src/nemos/basis.py Outdated Show resolved Hide resolved
@BalzaniEdoardo BalzaniEdoardo merged commit 6645096 into development Oct 25, 2024
13 checks passed
@BalzaniEdoardo BalzaniEdoardo deleted the bug_fix_sample_axis_basis branch October 25, 2024 15:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants