Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepinterpolation revived: training, transfer, and inference with more flexibility #1804

Merged
merged 32 commits into from
Oct 23, 2023

Conversation

alejoe91
Copy link
Member

Refactored and extended the deepinterpolation submodule.

The existing code was only compatible with a pre-trained model from Neuropixels1,0 with interleaved zeros (so that the final shape was 384,2).

Training

With this PR, one can train a DI model from ANY SpikeInterface object using the train_deepinterpolation function. The desired_shape parameter controls how the input should be reshaped to make the input image.

For example, a NP1 probe could be treated as follows:

  • desired_shape=(192,2): the 4 staggered columns are "aligned" into two columns of 192 electrodes
  • desired_shape=(96, 4): each "real" column is a different image column (in this shape the channels would need to be reordered first).

To use the previously built model (384, 2) , one can use the zero-channel pad to "manually" interleave zeros in the original recording.

Transfer

The existing_model_path allows one to do transfer learning from an existing model

fine_tuned_model = train_deepinterpolation(recording, ..., existing_model_path="path-to-existing")

Inference

Same as before, but the derived shape is inferred from the network:

recording_di = spre.deepinterpolate(recording, model_path="path-to-model", **kwargs)

NOTE: it is advised to zscore the recording prior to using DeepInterpolation, but this is not currently enforced.

@alejoe91 alejoe91 added the preprocessing Related to preprocessing module label Jul 10, 2023
@alejoe91 alejoe91 marked this pull request as ready for review July 10, 2023 14:20
@alejoe91
Copy link
Member Author

@khl02007 FYI ;)

@alejoe91
Copy link
Member Author

alejoe91 commented Sep 7, 2023

@jeromelecoq

I updated the train_deepinteprolation function as wollows:

  • You can now pass a single recording or a list of recordings as input. In case of a recording with multiple segments (e.g., several play/pause) or multiple recordings, I modified the _calculate_list_samples so that intervals at borders are exluded.
  • You can optionally pass test_recordings (a single recording or a recording list) which can be a different recording used for the set
  • Unified the correct usage of total_samples (and removed overridden functions)
  • Set mean_squared_error as default loss
  • Extended the tests

Would you have time to review it? It's ready to merge on my side :)

total_num_samples = np.sum(r.get_total_samples() for r in recordings)
# In case of multiple recordings and/or multiple segments, we calculate the frame periods to be excluded (borders)
exclude_intervals = []
pre_extended = pre_frame + pre_post_omission

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the SequentialGenerator might have some of this logic already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to exclude borders? I couldn't find it!

)

# this is overridden to exclude samples from borders
def _calculate_list_samples(self, total_frame_per_movie, exclude_intervals=[]):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is needed given that you are inheriting from the SequentialGenerator

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case it's needed because of the borders exclusion

model_folder: str | Path,
model_name: str,
desired_shape: tuple[int, int],
train_start_s: Optional[float] = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To what extent using seconds is helpful here? I am always worried about edge conditions created by wrapping code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that using second here is much more convenient. E.g. I'd rather use 600 to indicate 10 minutes, than 18000000 (easy to forget, add a 0 here)

nb_gpus: int = 1,
steps_per_epoch: int = 10,
period_save: int = 100,
apply_learning_decay: int = 0,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider only exposing the parameters you care about, not all of them.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, which parameters will in practice never change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you tell me :) In my opinion there's no harm in exposing everything with defaults!

@alejoe91
Copy link
Member Author

@samuelgarcia we can merge this

@samuelgarcia
Copy link
Member

Enormous and amazing work camarade.
I hope this will be usefull and usable.
As a general comments I think I prefer a unique file for this kind preprocessing even if the file is enormous.
I understand that it was more easy while develpping to have multiple and sub folder file.

@samuelgarcia samuelgarcia merged commit 1c6535a into SpikeInterface:main Oct 23, 2023
9 checks passed
@alejoe91
Copy link
Member Author

Enormous and amazing work camarade. I hope this will be usefull and usable. As a general comments I think I prefer a unique file for this kind preprocessing even if the file is enormous. I understand that it was more easy while develpping to have multiple and sub folder file.

Thanks! The multiple files are useful because I can do lazy imports without conditional definitions ;)

@alejoe91 alejoe91 deleted the deepinterp branch October 26, 2023 11:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
preprocessing Related to preprocessing module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants