Skip to content

Commit

Permalink
Merge branch 'main' into add_template_data_class
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 13, 2023
2 parents 107bdf9 + c94c8ae commit 1585f11
Show file tree
Hide file tree
Showing 12 changed files with 1,057 additions and 177 deletions.
21 changes: 21 additions & 0 deletions .github/actions/install-wine/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Install packages
description: This action installs the package and its dependencies for testing

inputs:
python-version:
description: 'Python version to set up'
required: false
os:
description: 'Operating system to set up'
required: false

runs:
using: "composite"
steps:
- name: Install wine (needed for Plexon2)
run: |
sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list
sudo dpkg --add-architecture i386
sudo apt-get update -qq
sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine
shell: bash
7 changes: 7 additions & 0 deletions .github/workflows/full-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ jobs:
echo "Extractors changed"
echo "EXTRACTORS_CHANGED=true" >> $GITHUB_OUTPUT
fi
if [[ $file == *"plexon2"* ]]; then
echo "Plexon2 changed"
echo "PLEXON2_CHANGED=true" >> $GITHUB_OUTPUT
fi
if [[ $file == *"/preprocessing/"* ]]; then
echo "Preprocessing changed"
echo "PREPROCESSING_CHANGED=true" >> $GITHUB_OUTPUT
Expand Down Expand Up @@ -122,6 +126,9 @@ jobs:
done
- name: Set execute permissions on run_tests.sh
run: chmod +x .github/run_tests.sh
- name: Install Wine (Plexon2)
if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }}
uses: ./.github/actions/install-wine
- name: Test core
run: ./.github/run_tests.sh core
- name: Test extractors
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black
files: ^src/
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ extractors = [
"ONE-api>=1.19.1",
"ibllib>=2.21.0",
"pymatreader>=0.0.32", # For cell explorer matlab files
"zugbruecke>=0.2; sys_platform!='win32'", # For plexon2
]

streaming_extractors = [
Expand Down
17 changes: 9 additions & 8 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math

import warnings
import numpy as np
from typing import Union, Optional, List, Literal

Expand Down Expand Up @@ -1037,13 +1037,14 @@ def __init__(
parent_recording: Union[BaseRecording, None] = None,
num_samples: Optional[List[int]] = None,
upsample_vector: Union[List[int], None] = None,
check_borbers: bool = True,
check_borders: bool = False,
) -> None:
templates = np.asarray(templates)
if check_borbers:
# TODO: this should be external to this class. It is not the responsability of this class to check the templates
if check_borders:
self._check_templates(templates)
# lets test this only once so force check_borbers=false for kwargs
check_borbers = False
# lets test this only once so force check_borders=False for kwargs
check_borders = False
self.templates = templates

channel_ids = parent_recording.channel_ids if parent_recording is not None else list(range(templates.shape[2]))
Expand Down Expand Up @@ -1131,7 +1132,7 @@ def __init__(
"nbefore": nbefore,
"amplitude_factor": amplitude_factor,
"upsample_vector": upsample_vector,
"check_borbers": check_borbers,
"check_borders": check_borders,
}
if parent_recording is None:
self._kwargs["num_samples"] = num_samples
Expand All @@ -1144,8 +1145,8 @@ def _check_templates(templates: np.ndarray):
threshold = 0.01 * max_value

if max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) > threshold:
raise Exception(
"Warning!\nYour templates do not go to 0 on the edges in InjectTemplatesRecording.__init__\nPlease make your window bigger."
warnings.warn(
"Warning! Your templates do not go to 0 on the edges in InjectTemplatesRecording. Please make your window bigger."
)


Expand Down
171 changes: 86 additions & 85 deletions src/spikeinterface/core/tests/test_waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from spikeinterface.core import generate_recording, generate_sorting
from spikeinterface.core.waveform_tools import (
extract_waveforms_to_buffers,
) # allocate_waveforms_buffers, distribute_waveforms_to_buffers
extract_waveforms_to_single_buffer,
split_waveforms_by_units,
)


if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -52,96 +54,95 @@ def test_waveform_tools():
unit_ids = sorting.unit_ids

some_job_kwargs = [
{},
{"n_jobs": 1, "chunk_size": 3000, "progress_bar": True},
{"n_jobs": 2, "chunk_size": 3000, "progress_bar": True},
]
some_modes = [
{"mode": "memmap"},
{"mode": "shared_memory"},
]
# if platform.system() != "Windows":
# # shared memory on windows is buggy...
# some_modes.append(
# {
# "mode": "shared_memory",
# }
# )

some_sparsity = [
dict(sparsity_mask=None),
dict(sparsity_mask=np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")),
]

# memmap mode
list_wfs = []
list_wfs_dense = []
list_wfs_sparse = []
for j, job_kwargs in enumerate(some_job_kwargs):
wf_folder = cache_folder / f"test_waveform_tools_{j}"
if wf_folder.is_dir():
shutil.rmtree(wf_folder)
wf_folder.mkdir(parents=True)
# wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype)
# distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs)
wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
unit_ids,
nbefore,
nafter,
mode="memmap",
return_scaled=False,
folder=wf_folder,
dtype=dtype,
sparsity_mask=None,
copy=False,
**job_kwargs,
)
for unit_ind, unit_id in enumerate(unit_ids):
wf = wfs_arrays[unit_id]
assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind)
list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids})
_check_all_wf_equal(list_wfs)

# memory
if platform.system() != "Windows":
# shared memory on windows is buggy...
list_wfs = []
for job_kwargs in some_job_kwargs:
# wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='shared_memory', folder=None, dtype=dtype)
# distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, mode='shared_memory', **job_kwargs)
wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
unit_ids,
nbefore,
nafter,
mode="shared_memory",
return_scaled=False,
folder=None,
dtype=dtype,
sparsity_mask=None,
copy=True,
**job_kwargs,
)
for unit_ind, unit_id in enumerate(unit_ids):
wf = wfs_arrays[unit_id]
assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind)
list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids})
# to avoid warning we need to first destroy arrays then sharedmemm object
# del wfs_arrays
# del wfs_arrays_info
_check_all_wf_equal(list_wfs)

# with sparsity
wf_folder = cache_folder / "test_waveform_tools_sparse"
if wf_folder.is_dir():
shutil.rmtree(wf_folder)
wf_folder.mkdir()

sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")
job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True}

# wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype, sparsity_mask=sparsity_mask)
# distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, sparsity_mask=sparsity_mask, **job_kwargs)

wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
unit_ids,
nbefore,
nafter,
mode="memmap",
return_scaled=False,
folder=wf_folder,
dtype=dtype,
sparsity_mask=sparsity_mask,
copy=False,
**job_kwargs,
)
for k, mode_kwargs in enumerate(some_modes):
for l, sparsity_kwargs in enumerate(some_sparsity):
# print()
# print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None)

if mode_kwargs["mode"] == "memmap":
wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}"
if wf_folder.is_dir():
shutil.rmtree(wf_folder)
wf_folder.mkdir(parents=True)
wf_file_path = wf_folder / "waveforms_all_units.npy"

mode_kwargs_ = dict(**mode_kwargs)
if mode_kwargs["mode"] == "memmap":
mode_kwargs_["folder"] = wf_folder

wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
unit_ids,
nbefore,
nafter,
return_scaled=False,
dtype=dtype,
copy=True,
**sparsity_kwargs,
**mode_kwargs_,
**job_kwargs,
)
for unit_ind, unit_id in enumerate(unit_ids):
wf = wfs_arrays[unit_id]
assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind)

if sparsity_kwargs["sparsity_mask"] is None:
list_wfs_dense.append(wfs_arrays)
else:
list_wfs_sparse.append(wfs_arrays)

mode_kwargs_ = dict(**mode_kwargs)
if mode_kwargs["mode"] == "memmap":
mode_kwargs_["file_path"] = wf_file_path

all_waveforms = extract_waveforms_to_single_buffer(
recording,
spikes,
unit_ids,
nbefore,
nafter,
return_scaled=False,
dtype=dtype,
copy=True,
**sparsity_kwargs,
**mode_kwargs_,
**job_kwargs,
)
wfs_arrays = split_waveforms_by_units(
unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs["sparsity_mask"]
)
if sparsity_kwargs["sparsity_mask"] is None:
list_wfs_dense.append(wfs_arrays)
else:
list_wfs_sparse.append(wfs_arrays)

_check_all_wf_equal(list_wfs_dense)
_check_all_wf_equal(list_wfs_sparse)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 1585f11

Please sign in to comment.