Skip to content

Commit

Permalink
Train a model doc review response
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Nov 22, 2024
1 parent 41955d7 commit b2a3dba
Showing 1 changed file with 55 additions and 31 deletions.
86 changes: 55 additions & 31 deletions examples/tutorials/curation/plot_2_train_a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface.
"""


##############################################################################
# Step 1: Generate and label data
# -------------------------------
#
# First we will import our dependencies
import warnings
warnings.filterwarnings("ignore")
from pathlib import Path
Expand All @@ -15,17 +22,15 @@
import spikeinterface.curation as sc
import spikeinterface.widgets as sw


# Note, you can set the number of cores you use using e.g.
# si.set_global_job_kwargs(n_jobs = 8)

##############################################################################
# Step 1: Generate and label data
# -------------------------------
#
# For this tutorial, we will use simulated data to create ``recording`` ``sorting`` objects. We'll
# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so will contain good
# units; :code:`sorting_2` is uncoupled, so should produce noise. We'll combine the two into one sorting
# For this tutorial, we will use simulated data to create ``recording`` and ``sorting`` objects. We'll
# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so the spike times of the sorter will
# perfectly match the spikes in the recording. Hence this will contain good units. However, we've
# uncoupled :code:`sorting_2` to the recording and the spike times will not be matched with the spikes in the recording.
# Hence these units will mostly be random noise. We'll combine the "good" and "noise" sortings into one sorting
# object using :code:`si.aggregate_units`.
#
# (When making your own model, you should
Expand All @@ -38,45 +43,57 @@
both_sortings = si.aggregate_units([sorting_1, sorting_2])

##############################################################################
# The models are based on `quality metrics <https://spikeinterface.readthedocs.io/en/latest/modules/qualitymetrics.html>`_
# and `template metrics <https://spikeinterface.readthedocs.io/en/latest/modules/postprocessing.html#template-metrics>`_,
# which are computed using a :code:`sorting_analyzer`. So we'll now create a sorting
# analyzer and compute the extensions needed to get the metrics.
# To do some visualisation and postprocessing, we need to create a sorting analyzer, and
# compute some extensions:

analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording)
analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics'])
analyzer.compute(['noise_levels','random_spikes','waveforms','templates'])

##############################################################################
# Let's plot the templates for the first and fifth units. The first (unit id 0) belongs to
# Now we can plot the templates for the first and fifth units. The first (unit id 0) belongs to
# :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belongs to :code:`sorting_2`
# so should look like noise.

sw.plot_unit_templates(analyzer, unit_ids=[0,5])

##############################################################################
# This is as expected: great! Find out more about plotting using widgets `here <https://spikeinterface.readthedocs.io/en/latest/modules/widgets.html>`_. The labels
# for our units are then easy to put in a list:
# This is as expected: great! (Find out more about plotting using widgets `here <https://spikeinterface.readthedocs.io/en/latest/modules/widgets.html>`_.)
# We've set out system up so that the first five units are 'good' and the next five are 'bad'.
# So we can make a list of labels which contain this information. For real data, you could
# use a manual curation tool to make your own list.

labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad']

##############################################################################
# Step 2: Train our model
# -----------------------
#
# With the labelled data, we can train the model using the :code:`train_model` function.
# Here, the idea is that the trainer will try several classifiers, imputation strategies and
# scaling techniques then save the most accurate. To save time, we'll only try one classifier
# (Random Forest), imputation strategy (median) and scaling technique (standard scaler).
# We'll now train a model, based on our labelled data. The model will be trained using properties
# of the units, and then be applied to units from other sortings. The properties we use are the
# `quality metrics <https://spikeinterface.readthedocs.io/en/latest/modules/qualitymetrics.html>`_
# and `template metrics <https://spikeinterface.readthedocs.io/en/latest/modules/postprocessing.html#template-metrics>`_.
# Hence we need to compute these, using some `sorting_analyzer` extensions.

folder = "my_model"
analyzer.compute(['spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics'])

##############################################################################
# Now that we have metrics and labels, we're ready to train the model using the
# `train_model` function. The trainer will try several classifiers, imputation strategies and
# scaling techniques then save the most accurate. To save time in this tutorial,
# we'll only try one classifier (Random Forest), imputation strategy (median) and scaling
# technique (standard scaler).
#
# We will use a list of one analyzer here, so the model is trained on a single
# session. In reality, we would usually train a model using multiple analyzers from an
# experiment, which should make the model more robust. To do this, you can simply pass
# a list of analyzers and a list of manually curated labels for each
# of these analyzers. Then the model would use all of these data as input.

# We will use a list of one analyzer here, we would strongly advise using more than one to
# improve model performance
trainer = sc.train_model(
mode = "analyzers", # You can supply a labelled csv file instead of an analyzer
labels = [labels],
analyzers = [analyzer],
folder = folder, # Where to save the model and model_info.json file
folder = "my_folder", # Where to save the model and model_info.json file
metric_names = None, # Specify which metrics to use for training: by default uses those already calculted
imputation_strategies = ["median"], # Defaults to all
scaling_techniques = ["standard_scaler"], # Defaults to all
Expand All @@ -87,23 +104,25 @@
best_model = trainer.best_pipeline

##############################################################################
# The above code saves the model in ``my_model/model.skops``, some metadata in
# ``my_model/model_info.json`` and the model accuracies in ``model_accuracies.csv``
# in the specified ``folder``.
# The above code saves the model in ``model.skops``, some metadata in
# ``model_info.json`` and the model accuracies in ``model_accuracies.csv``
# in the specified ``folder`` (in this case ``my_folder``.
#
# ``skops`` is a file format: you can think of it as a more-secture pkl file. `Read more <https://skops.readthedocs.io/en/stable/index.html>`_.
# (``skops`` is a file format: you can think of it as a more-secure pkl file. `Read more <https://skops.readthedocs.io/en/stable/index.html>`_.)
#
# The ``model_accuracies.csv`` file contains the accuracy, precision and recall of the
# tested models. Let's take a look
# tested models. Let's take a look:

accuracies = pd.read_csv(Path(folder) / "model_accuracies.csv", index_col = 0)
accuracies = pd.read_csv(Path("my_folder") / "model_accuracies.csv", index_col = 0)
accuracies.head()

##############################################################################
# Our model is perfect!! This is because the task was *very* easy. We had 10 units; where
# half were pure noise and half were not.
#
# The model also contains some more information, such as which features are importantly.
# The model also contains some more information, such as which features are "important",
# as defined by sklearn (learn about feature importance of a Random Forest Classifier
# `here <https://scikit-learn.org/1.5/auto_examples/ensemble/plot_forest_importances.html>`_.)
# We can plot these:

# Plot feature importances
Expand All @@ -112,14 +131,19 @@
features = best_model.feature_names_in_
n_features = best_model.n_features_in_

plt.figure(figsize=(12, 6))
plt.figure(figsize=(12, 7))
plt.title("Feature Importances")
plt.bar(range(n_features), importances[indices], align="center")
plt.xticks(range(n_features), features, rotation=90)
plt.xlim([-1, n_features])
plt.show()

##############################################################################
# Roughly, this means the model isn't using metrics such as "l_ratio" and "d_prime"
# but is using "snr" and "firing_rate". Using this information, you could retrain another,
# simpler model using a subset of the metrics, by passing, e.g.,
# `metric_names = ['snr', 'firing_rate',...]` to the `train_model` function.
#
# Now that you have a model, you can `apply it to another sorting
# <https://spikeinterface.readthedocs.io/en/latest/tutorials/curation/plot_1_automated_curation.html>`_
# or upload it to `HuggingFaceHub <https://huggingface.co/>`_.

0 comments on commit b2a3dba

Please sign in to comment.