diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 82f55483b4..3f47c505ad 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -41,7 +41,7 @@ class TemplateMetricsCalculator(BaseWaveformExtractorExtension): extension_name = "template_metrics" min_channels_for_multi_channel_warning = 10 - def __init__(self, waveform_extractor): + def __init__(self, waveform_extractor: WaveformExtractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) def _set_params( @@ -212,7 +212,6 @@ def get_extension_function(): ) -# TODO: add typing def compute_template_metrics( waveform_extractor, load_if_exists: bool = False,