Skip to content

Commit

Permalink
Merge pull request #3032 from h-mayorquin/support_numpy2.0
Browse files Browse the repository at this point in the history
Numpy 2.0 cap Fix most egregorious deprecated behavior and cap version
  • Loading branch information
alejoe91 authored Jun 18, 2024
2 parents 90a2474 + 14970e1 commit b597a83
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [


dependencies = [
"numpy",
"numpy>=1.26, <2.0", # 1.20 np.ptp, 1.26 for avoiding pickling errors when numpy >2.0
"threadpoolctl>=3.0.0",
"tqdm",
"zarr>=2.16,<2.18",
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def default(self, obj):
if isinstance(obj, np.generic):
return obj.item()

if np.issctype(obj): # Cast numpy datatypes to their names
# Standard numpy dtypes like np.dtype('int32") are transformed this way
if isinstance(obj, np.dtype):
return np.dtype(obj).name

# This will transform to a string canonical representation of the dtype (e.g. np.int32 -> 'int32')
if isinstance(obj, type) and issubclass(obj, np.generic):
return np.dtype(obj).name

if isinstance(obj, np.ndarray):
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/tests/test_jsonification.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def test_numpy_dtype_alises_encoding():
# People tend to use this a dtype instead of the proper classes
json.dumps(np.int32, cls=SIJsonEncoder)
json.dumps(np.float32, cls=SIJsonEncoder)
json.dumps(np.bool_, cls=SIJsonEncoder) # Note that np.bool was deperecated in numpy 1.20.0


def test_recording_encoding(numpy_generated_recording):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_fetch_templates_database_info():
def test_query_templates_from_database():
templates_info = fetch_templates_database_info()

templates_info = templates_info.iloc[::15]
templates_info = templates_info.iloc[[1, 3, 5]]
num_selected = len(templates_info)

templates = query_templates_from_database(templates_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@
import spikeinterface.extractors as se
from spikeinterface.core import generate_recording
import spikeinterface.widgets as sw

try:
import spikeglx
import neurodsp.voltage as voltage

HAVE_IBL_NPIX = True
except ImportError:
HAVE_IBL_NPIX = False
import importlib.util

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))

Expand All @@ -31,7 +24,10 @@
# ----------------------------------------------------------------------------------------------------------------------


@pytest.mark.skipif(not HAVE_IBL_NPIX or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install")
@pytest.mark.skipif(
importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB,
reason="Only local. Requires ibl-neuropixel install",
)
@pytest.mark.parametrize("lagc", [False, 1, 300])
def test_highpass_spatial_filter_real_data(lagc):
"""
Expand All @@ -56,6 +52,9 @@ def test_highpass_spatial_filter_real_data(lagc):
use DEBUG = true to visualise.
"""
import spikeglx
import neurodsp.voltage as voltage

options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None)
print(options)

Expand Down Expand Up @@ -146,6 +145,8 @@ def get_ibl_si_data():
"""
Set fixture to session to ensure origional data is not changed.
"""
import spikeglx

local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
ibl_recording = spikeglx.Reader(
local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,10 @@
import spikeinterface.preprocessing as spre
import spikeinterface.extractors as se
from spikeinterface.core.generate import generate_recording
import importlib.util

try:
import spikeglx
import neurodsp.voltage as voltage

HAVE_IBL_NPIX = True
except ImportError:
HAVE_IBL_NPIX = False

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))

DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt
Expand All @@ -30,7 +23,10 @@
# -------------------------------------------------------------------------------


@pytest.mark.skipif(not HAVE_IBL_NPIX or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install")
@pytest.mark.skipif(
importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB,
reason="Only local. Requires ibl-neuropixel install",
)
def test_compare_real_data_with_ibl():
"""
Test SI implementation of bad channel interpolation against native IBL.
Expand All @@ -43,6 +39,9 @@ def test_compare_real_data_with_ibl():
si_scaled_recordin.get_traces(0) is also close to 1e-2.
"""
# Download and load data
import spikeglx
import neurodsp.voltage as voltage

local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
ibl_recording = spikeglx.Reader(
Expand Down Expand Up @@ -80,7 +79,10 @@ def test_compare_real_data_with_ibl():
assert np.mean(is_close) > 0.999


@pytest.mark.skipif(not HAVE_IBL_NPIX, reason="Requires ibl-neuropixel install")
@pytest.mark.skipif(
importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") is not None,
reason="Requires ibl-neuropixel install",
)
@pytest.mark.parametrize("num_channels", [32, 64])
@pytest.mark.parametrize("sigma_um", [1.25, 40])
@pytest.mark.parametrize("p", [0, -0.5, 1, 5])
Expand All @@ -90,6 +92,8 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan
Perform an extended test across a range of function inputs to check
IBL and SI interpolation results match.
"""
import neurodsp.voltage as voltage

recording = generate_recording(num_channels=num_channels, durations=[1])

# distribute default probe locations across 4 shanks if set
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def compute(self, traces, peaks, waveforms):
wf = waveforms[idx][:, :, chan_inds]

if self.feature == "ptp":
wf_data = wf.ptp(axis=1)
wf_data = np.ptp(wf, axis=1)
elif self.feature == "mean":
wf_data = wf.mean(axis=1)
elif self.feature == "energy":
Expand Down Expand Up @@ -293,7 +293,7 @@ def compute(self, traces, peaks, waveforms):

wf = waveforms[i, :][:, chan_inds]
if self.feature == "ptp":
wf_data = wf.ptp(axis=0)
wf_data = np.ptp(wf, axis=0)
elif self.feature == "energy":
wf_data = np.linalg.norm(wf, axis=0)
elif self.feature == "peak_voltage":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_waveform_thresholder_ptp(
recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs
)

data = tresholded_waveforms.ptp(axis=1) / noise_levels
data = np.ptp(tresholded_waveforms, axis=1) / noise_levels
assert np.all(data[data != 0] > 3)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(

def compute(self, traces, peaks, waveforms):
if self.feature == "ptp":
wf_data = waveforms.ptp(axis=1) / self.noise_levels
wf_data = np.ptp(waveforms, axis=1) / self.noise_levels
elif self.feature == "mean":
wf_data = waveforms.mean(axis=1) / self.noise_levels
elif self.feature == "energy":
Expand Down

0 comments on commit b597a83

Please sign in to comment.