diff --git a/examples/tutorials/curation/plot_2_train_a_model.py b/examples/tutorials/curation/plot_2_train_a_model.py index 29388d5cfa..2b5a06f8a6 100644 --- a/examples/tutorials/curation/plot_2_train_a_model.py +++ b/examples/tutorials/curation/plot_2_train_a_model.py @@ -4,6 +4,13 @@ If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface. """ + + +############################################################################## +# Step 1: Generate and label data +# ------------------------------- +# +# First we will import our dependencies import warnings warnings.filterwarnings("ignore") from pathlib import Path @@ -15,17 +22,15 @@ import spikeinterface.curation as sc import spikeinterface.widgets as sw - # Note, you can set the number of cores you use using e.g. # si.set_global_job_kwargs(n_jobs = 8) ############################################################################## -# Step 1: Generate and label data -# ------------------------------- -# -# For this tutorial, we will use simulated data to create ``recording`` ``sorting`` objects. We'll -# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so will contain good -# units; :code:`sorting_2` is uncoupled, so should produce noise. We'll combine the two into one sorting +# For this tutorial, we will use simulated data to create ``recording`` and ``sorting`` objects. We'll +# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so the spike times of the sorter will +# perfectly match the spikes in the recording. Hence this will contain good units. However, we've +# uncoupled :code:`sorting_2` to the recording and the spike times will not be matched with the spikes in the recording. +# Hence these units will mostly be random noise. We'll combine the "good" and "noise" sortings into one sorting # object using :code:`si.aggregate_units`. # # (When making your own model, you should @@ -38,24 +43,24 @@ both_sortings = si.aggregate_units([sorting_1, sorting_2]) ############################################################################## -# The models are based on `quality metrics `_ -# and `template metrics `_, -# which are computed using a :code:`sorting_analyzer`. So we'll now create a sorting -# analyzer and compute the extensions needed to get the metrics. +# To do some visualisation and postprocessing, we need to create a sorting analyzer, and +# compute some extensions: analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording) -analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) +analyzer.compute(['noise_levels','random_spikes','waveforms','templates']) ############################################################################## -# Let's plot the templates for the first and fifth units. The first (unit id 0) belongs to +# Now we can plot the templates for the first and fifth units. The first (unit id 0) belongs to # :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belongs to :code:`sorting_2` # so should look like noise. sw.plot_unit_templates(analyzer, unit_ids=[0,5]) ############################################################################## -# This is as expected: great! Find out more about plotting using widgets `here `_. The labels -# for our units are then easy to put in a list: +# This is as expected: great! (Find out more about plotting using widgets `here `_.) +# We've set out system up so that the first five units are 'good' and the next five are 'bad'. +# So we can make a list of labels which contain this information. For real data, you could +# use a manual curation tool to make your own list. labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad'] @@ -63,20 +68,32 @@ # Step 2: Train our model # ----------------------- # -# With the labelled data, we can train the model using the :code:`train_model` function. -# Here, the idea is that the trainer will try several classifiers, imputation strategies and -# scaling techniques then save the most accurate. To save time, we'll only try one classifier -# (Random Forest), imputation strategy (median) and scaling technique (standard scaler). +# We'll now train a model, based on our labelled data. The model will be trained using properties +# of the units, and then be applied to units from other sortings. The properties we use are the +# `quality metrics `_ +# and `template metrics `_. +# Hence we need to compute these, using some `sorting_analyzer` extensions. -folder = "my_model" +analyzer.compute(['spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) + +############################################################################## +# Now that we have metrics and labels, we're ready to train the model using the +# `train_model` function. The trainer will try several classifiers, imputation strategies and +# scaling techniques then save the most accurate. To save time in this tutorial, +# we'll only try one classifier (Random Forest), imputation strategy (median) and scaling +# technique (standard scaler). +# +# We will use a list of one analyzer here, so the model is trained on a single +# session. In reality, we would usually train a model using multiple analyzers from an +# experiment, which should make the model more robust. To do this, you can simply pass +# a list of analyzers and a list of manually curated labels for each +# of these analyzers. Then the model would use all of these data as input. -# We will use a list of one analyzer here, we would strongly advise using more than one to -# improve model performance trainer = sc.train_model( mode = "analyzers", # You can supply a labelled csv file instead of an analyzer labels = [labels], analyzers = [analyzer], - folder = folder, # Where to save the model and model_info.json file + folder = "my_folder", # Where to save the model and model_info.json file metric_names = None, # Specify which metrics to use for training: by default uses those already calculted imputation_strategies = ["median"], # Defaults to all scaling_techniques = ["standard_scaler"], # Defaults to all @@ -87,23 +104,25 @@ best_model = trainer.best_pipeline ############################################################################## -# The above code saves the model in ``my_model/model.skops``, some metadata in -# ``my_model/model_info.json`` and the model accuracies in ``model_accuracies.csv`` -# in the specified ``folder``. +# The above code saves the model in ``model.skops``, some metadata in +# ``model_info.json`` and the model accuracies in ``model_accuracies.csv`` +# in the specified ``folder`` (in this case ``my_folder``. # -# ``skops`` is a file format: you can think of it as a more-secture pkl file. `Read more `_. +# (``skops`` is a file format: you can think of it as a more-secure pkl file. `Read more `_.) # # The ``model_accuracies.csv`` file contains the accuracy, precision and recall of the -# tested models. Let's take a look +# tested models. Let's take a look: -accuracies = pd.read_csv(Path(folder) / "model_accuracies.csv", index_col = 0) +accuracies = pd.read_csv(Path("my_folder") / "model_accuracies.csv", index_col = 0) accuracies.head() ############################################################################## # Our model is perfect!! This is because the task was *very* easy. We had 10 units; where # half were pure noise and half were not. # -# The model also contains some more information, such as which features are importantly. +# The model also contains some more information, such as which features are "important", +# as defined by sklearn (learn about feature importance of a Random Forest Classifier +# `here `_.) # We can plot these: # Plot feature importances @@ -112,7 +131,7 @@ features = best_model.feature_names_in_ n_features = best_model.n_features_in_ -plt.figure(figsize=(12, 6)) +plt.figure(figsize=(12, 7)) plt.title("Feature Importances") plt.bar(range(n_features), importances[indices], align="center") plt.xticks(range(n_features), features, rotation=90) @@ -120,6 +139,11 @@ plt.show() ############################################################################## +# Roughly, this means the model isn't using metrics such as "l_ratio" and "d_prime" +# but is using "snr" and "firing_rate". Using this information, you could retrain another, +# simpler model using a subset of the metrics, by passing, e.g., +# `metric_names = ['snr', 'firing_rate',...]` to the `train_model` function. +# # Now that you have a model, you can `apply it to another sorting # `_ # or upload it to `HuggingFaceHub `_.