-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deepinterpolation revived: training, transfer, and inference with more flexibility #1804
Conversation
@khl02007 FYI ;) |
…+ correct usage of total_samples
I updated the
Would you have time to review it? It's ready to merge on my side :) |
total_num_samples = np.sum(r.get_total_samples() for r in recordings) | ||
# In case of multiple recordings and/or multiple segments, we calculate the frame periods to be excluded (borders) | ||
exclude_intervals = [] | ||
pre_extended = pre_frame + pre_post_omission |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the SequentialGenerator might have some of this logic already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to exclude borders? I couldn't find it!
) | ||
|
||
# this is overridden to exclude samples from borders | ||
def _calculate_list_samples(self, total_frame_per_movie, exclude_intervals=[]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is needed given that you are inheriting from the SequentialGenerator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this case it's needed because of the borders exclusion
model_folder: str | Path, | ||
model_name: str, | ||
desired_shape: tuple[int, int], | ||
train_start_s: Optional[float] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To what extent using seconds is helpful here? I am always worried about edge conditions created by wrapping code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that using second here is much more convenient. E.g. I'd rather use 600 to indicate 10 minutes, than 18000000 (easy to forget, add a 0 here)
nb_gpus: int = 1, | ||
steps_per_epoch: int = 10, | ||
period_save: int = 100, | ||
apply_learning_decay: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would consider only exposing the parameters you care about, not all of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, which parameters will in practice never change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you tell me :) In my opinion there's no harm in exposing everything with defaults!
src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py
Show resolved
Hide resolved
@samuelgarcia we can merge this |
Enormous and amazing work camarade. |
Thanks! The multiple files are useful because I can do lazy imports without conditional definitions ;) |
Refactored and extended the
deepinterpolation
submodule.The existing code was only compatible with a pre-trained model from Neuropixels1,0 with interleaved zeros (so that the final shape was 384,2).
Training
With this PR, one can train a DI model from ANY SpikeInterface object using the
train_deepinterpolation
function. Thedesired_shape
parameter controls how the input should be reshaped to make the input image.For example, a NP1 probe could be treated as follows:
desired_shape=(192,2)
: the 4 staggered columns are "aligned" into two columns of 192 electrodesdesired_shape=(96, 4)
: each "real" column is a different image column (in this shape the channels would need to be reordered first).To use the previously built model
(384, 2)
, one can use the zero-channel pad to "manually" interleave zeros in the original recording.Transfer
The
existing_model_path
allows one to do transfer learning from an existing modelInference
Same as before, but the derived shape is inferred from the network:
NOTE: it is advised to zscore the recording prior to using DeepInterpolation, but this is not currently enforced.