Skip to content

Commit

Permalink
Merge pull request #1859 from samuelgarcia/refactor_widget_module
Browse files Browse the repository at this point in the history
Refactor widget module
  • Loading branch information
samuelgarcia authored Jul 21, 2023
2 parents fc1c0b8 + 370dc66 commit 3d14d1c
Show file tree
Hide file tree
Showing 96 changed files with 2,683 additions and 4,384 deletions.
3 changes: 2 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,14 @@ spikeinterface.widgets
.. autofunction:: plot_amplitudes
.. autofunction:: plot_autocorrelograms
.. autofunction:: plot_crosscorrelograms
.. autofunction:: plot_motion
.. autofunction:: plot_quality_metrics
.. autofunction:: plot_sorting_summary
.. autofunction:: plot_spike_locations
.. autofunction:: plot_spikes_on_traces
.. autofunction:: plot_template_metrics
.. autofunction:: plot_template_similarity
.. autofunction:: plot_timeseries
.. autofunction:: plot_traces
.. autofunction:: plot_unit_depths
.. autofunction:: plot_unit_locations
.. autofunction:: plot_unit_summary
Expand Down
10 changes: 5 additions & 5 deletions doc/how_to/analyse_neuropixels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ the ipywydgets interactive ploter
.. code:: python
%matplotlib widget
si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets')
si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets')
Note that using this ipywydgets make possible to explore diffrents
preprocessing chain wihtout to save the entire file to disk. Everything
Expand All @@ -276,9 +276,9 @@ is lazy, so you can change the previsous cell (parameters, step order,
# here we use static plot using matplotlib backend
fig, axs = plt.subplots(ncols=3, figsize=(20, 10))
si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0])
si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1])
si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2])
si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0])
si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1])
si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2])
for i, label in enumerate(('filter', 'cmr', 'final')):
axs[i].set_title(label)
Expand All @@ -292,7 +292,7 @@ is lazy, so you can change the previsous cell (parameters, step order,
# plot some channels
fig, ax = plt.subplots(figsize=(20, 10))
some_chans = rec.channel_ids[[100, 150, 200, ]]
si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans)
si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans)
Expand Down
2 changes: 1 addition & 1 deletion doc/how_to/get_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ and the raster plots.

.. code:: ipython3
w_ts = sw.plot_timeseries(recording, time_range=(0, 5))
w_ts = sw.plot_traces(recording, time_range=(0, 5))
w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5))
Expand Down
8 changes: 4 additions & 4 deletions doc/modules/widgets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ The :code:`plot_*(..., backend="matplotlib")` functions come with the following
.. code-block:: python
# matplotlib backend
w = plot_timeseries(recording, backend="matplotlib")
w = plot_traces(recording, backend="matplotlib")
**Output:**

Expand All @@ -146,9 +146,9 @@ Each function has the following additional arguments:
from spikeinterface.preprocessing import common_reference
# ipywidgets backend also supports multiple "layers" for plot_timeseries
# ipywidgets backend also supports multiple "layers" for plot_traces
rec_dict = dict(filt=recording, cmr=common_reference(recording))
w = sw.plot_timeseries(rec_dict, backend="ipywidgets")
w = sw.plot_traces(rec_dict, backend="ipywidgets")
**Output:**

Expand All @@ -171,7 +171,7 @@ The functions have the following additional arguments:
.. code-block:: python
# sortingview backend
w_ts = sw.plot_timeseries(recording, backend="ipywidgets")
w_ts = sw.plot_traces(recording, backend="ipywidgets")
w_ss = sw.plot_sorting_summary(recording, backend="sortingview")
Expand Down
10 changes: 5 additions & 5 deletions examples/how_to/analyse_neuropixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
#
# ```python
# # %matplotlib widget
# si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets')
# si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets')
# ```
#
# Note that using this ipywidgets make possible to explore different preprocessing chains without saving the entire file to disk.
Expand All @@ -94,17 +94,17 @@
# here we use a static plot using matplotlib backend
fig, axs = plt.subplots(ncols=3, figsize=(20, 10))

si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0])
si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1])
si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2])
si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0])
si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1])
si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2])
for i, label in enumerate(('filter', 'cmr', 'final')):
axs[i].set_title(label)
# -

# plot some channels
fig, ax = plt.subplots(figsize=(20, 10))
some_chans = rec.channel_ids[[100, 150, 200, ]]
si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans)
si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans)


# ### Should we save the preprocessed data to a binary file?
Expand Down
2 changes: 1 addition & 1 deletion examples/how_to/get_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
#
# Let's use the `spikeinterface.widgets` module to visualize the traces and the raster plots.

w_ts = sw.plot_timeseries(recording, time_range=(0, 5))
w_ts = sw.plot_traces(recording, time_range=(0, 5))
w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5))

# This is how you retrieve info from a `BaseRecording`...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@

import spikeinterface.widgets as sw

w_ts = sw.plot_timeseries(recording, time_range=(0, 5))
w_ts = sw.plot_traces(recording, time_range=(0, 5))
w_rs = sw.plot_rasters(sorting, time_range=(0, 5))

plt.show()
10 changes: 5 additions & 5 deletions examples/modules_gallery/widgets/plot_1_rec_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@
recording, sorting = se.toy_example(duration=10, num_channels=4, seed=0, num_segments=1)

##############################################################################
# plot_timeseries()
# plot_traces()
# ~~~~~~~~~~~~~~~~~

w_ts = sw.plot_timeseries(recording)
w_ts = sw.plot_traces(recording)

##############################################################################
# We can select time range

w_ts1 = sw.plot_timeseries(recording, time_range=(5, 8))
w_ts1 = sw.plot_traces(recording, time_range=(5, 8))

##############################################################################
# We can color with groups

recording2 = recording.clone()
recording2.set_channel_groups(channel_ids=recording.get_channel_ids(), groups=[0, 0, 1, 1])
w_ts2 = sw.plot_timeseries(recording2, time_range=(5, 8), color_groups=True)
w_ts2 = sw.plot_traces(recording2, time_range=(5, 8), color_groups=True)

##############################################################################
# **Note**: each function returns a widget object, which allows to access the figure and axis.
Expand All @@ -41,7 +41,7 @@
##############################################################################
# We can also use the 'map' mode useful for high channel count

w_ts = sw.plot_timeseries(recording, mode='map', time_range=(5, 8),
w_ts = sw.plot_traces(recording, mode='map', time_range=(5, 8),
show_channel_ids=True, order_channel_by_depth=True)

##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CompressedBinaryIblExtractorTest(RecordingCommonTestSuite, unittest.TestCa
# ~ import matplotlib.pyplot as plt
# ~ import spikeinterface.widgets as sw
# ~ from probeinterface.plotting import plot_probe
# ~ sw.plot_timeseries(rec)
# ~ sw.plot_traces(rec)
# ~ plot_probe(rec.get_probe())
# ~ plt.show()

Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/preprocessing/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def test_filter_opencl():
# rec2_cached0 = rec2.save(chunk_size=1000,verbose=False, progress_bar=True, n_jobs=4)

# import matplotlib.pyplot as plt
# from spikeinterface.widgets import plot_timeseries
# plot_timeseries(rec, segment_index=0)
# plot_timeseries(rec_filtered, segment_index=0)
# plot_timeseries(rec2_cached0, segment_index=0)
# from spikeinterface.widgets import plot_traces
# plot_traces(rec, segment_index=0)
# plot_traces(rec_filtered, segment_index=0)
# plot_traces(rec2_cached0, segment_index=0)
# plt.show()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_normalize_by_quantile():
rec2.save(verbose=False)

# import matplotlib.pyplot as plt
# from spikeinterface.widgets import plot_timeseries
# from spikeinterface.widgets import plot_traces
# fig, ax = plt.subplots()
# ax.plot(rec.get_traces(segment_index=0)[:, 0], color='g')
# ax.plot(rec2.get_traces(segment_index=0)[:, 0], color='r')
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/tests/test_phase_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def test_phase_shift():

# ~ import matplotlib.pyplot as plt
# ~ import spikeinterface.full as si
# ~ si.plot_timeseries(rec, segment_index=0, time_range=[0, 10])
# ~ si.plot_timeseries(rec2, segment_index=0, time_range=[0, 10])
# ~ si.plot_timeseries(rec3, segment_index=0, time_range=[0, 10])
# ~ si.plot_traces(rec, segment_index=0, time_range=[0, 10])
# ~ si.plot_traces(rec2, segment_index=0, time_range=[0, 10])
# ~ si.plot_traces(rec3, segment_index=0, time_range=[0, 10])
# ~ plt.show()


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/tests/test_rectify.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_rectify():
assert traces.shape[1] == 1

# import matplotlib.pyplot as plt
# from spikeinterface.widgets import plot_timeseries
# from spikeinterface.widgets import plot_traces
# fig, ax = plt.subplots()
# ax.plot(rec.get_traces(segment_index=0)[:, 0], color='g')
# ax.plot(rec2.get_traces(segment_index=0)[:, 0], color='r')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def plot_figure_1(benchmark, mode="average", cell_ind="auto"):
)

print(benchmark.recording)
# si.plot_timeseries(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1])
# si.plot_traces(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1])
# axs[0, 1].set_ylabel('Neurons')

# si.plot_spikes_on_traces(benchmark.waveforms, unit_ids=[unit_id], time_range=(times[0]-0.01, times[0] + 0.1), unit_colors={unit_id : 'r'}, ax=axs[0, 1],
Expand Down
34 changes: 0 additions & 34 deletions src/spikeinterface/widgets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,3 @@
# check if backend are available
try:
import matplotlib

HAVE_MPL = True
except:
HAVE_MPL = False

try:
import sortingview

HAVE_SV = True
except:
HAVE_SV = False

try:
import ipywidgets

HAVE_IPYW = True
except:
HAVE_IPYW = False


# theses import make the Widget.resgister() at import time
if HAVE_MPL:
import spikeinterface.widgets.matplotlib

if HAVE_SV:
import spikeinterface.widgets.sortingview

if HAVE_IPYW:
import spikeinterface.widgets.ipywidgets

# when importing widget list backend are already registered
from .widget_list import *

# general functions
Expand Down
14 changes: 1 addition & 13 deletions src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
# basics
# from .timeseries import plot_timeseries, TimeseriesWidget
# from .timeseries import plot_timeseries, TracesWidget
from .rasters import plot_rasters, RasterWidget
from .probemap import plot_probe_map, ProbeMapWidget

# isi/ccg/acg
from .isidistribution import plot_isi_distribution, ISIDistributionWidget

# from .correlograms import (plot_crosscorrelograms, CrossCorrelogramsWidget,
# plot_autocorrelograms, AutoCorrelogramsWidget)

# peak activity
from .activity import plot_peak_activity_map, PeakActivityMapWidget

# waveform/PC related
# from .unitwaveforms import plot_unit_waveforms, plot_unit_templates
# from .unitwaveformdensitymap import plot_unit_waveform_density_map, UnitWaveformDensityMapWidget
# from .amplitudes import plot_amplitudes_distribution
from .principalcomponent import plot_principal_component

# from .unitlocalization import plot_unit_localization, UnitLocalizationWidget

# units on probe
from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget

# from .depthamplitude import plot_units_depth_vs_amplitude

# comparison related
from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget
from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget
Expand Down Expand Up @@ -77,8 +67,6 @@
ComparisonPerformancesByTemplateSimilarity,
)

# unit summary
# from .unitsummary import plot_unit_summary, UnitSummaryWidget

# unit presence
from .presence import plot_presence, PresenceWidget
Expand Down
Loading

0 comments on commit 3d14d1c

Please sign in to comment.