Skip to content

Commit

Permalink
Check if DI is installed for tests and CI
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 15, 2023
1 parent 6a68907 commit 9eb4e3b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
15 changes: 14 additions & 1 deletion .github/workflows/deepinterpolation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,28 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: '3.9'
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v35
- name: Deepinteprolation changes
id: modules-changed
run: |
for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
if [[ $file == *"/deepinterpolation/"* ]]; then
echo "DeepInterpolation changed"
echo "DEEPINTERPOLATION_CHANGED=true" >> $GITHUB_OUTPUT
fi
- name: Install dependencies
if: ${{ steps.modules-changed.outputs.DEEPINTERPOLATION_CHANGED == 'true' }}
run: |
python -m pip install -U pip # Official recommended way
# install deepinteprolation
pip install tensorflow==2.7.0
pip install deepinterpolation@git+https://github.com/AllenInstitute/deepinterpolation.git
pip install protobuf==3.20.*
pip install -e .[full,test_core]
- name: Test core with pytest
- name: Test DeepInterpolation with pytest
if: ${{ steps.modules-changed.outputs.DEEPINTERPOLATION_CHANGED == 'true' }}
run: |
pytest -v src/spikeinterface/preprocessing/deepinterpolation
shell: bash # Necessary for pipeline to work on windows
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
mark_names = ["core", "extractors", "preprocessing", "postprocessing",
"sorters_external", "sorters_internal", "sorters",
"qualitymetrics", "comparison", "curation",
"widgets", "exporters", "sortingcomponents", "deepinterpolation"]
"widgets", "exporters", "sortingcomponents"]


# define global test folder
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ markers = [
"exporters",
"widgets",
"sortingcomponents",
"deepinterpolation",
"streaming_extractors: extractors that require streaming such as ross and fsspec",
]
filterwarnings =[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
from spikeinterface.preprocessing.deepinterpolation import train_deepinterpolation, deepinterpolate


try:
import tensorflow
import deepinterpolation

HAVE_DEEPINTERPOLATION = True
except ImportError:
HAVE_DEEPINTERPOLATION = False


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "deepinterpolation"
else:
Expand All @@ -37,6 +46,7 @@ def recording_and_shape_fixture():
return recording_and_shape()


@pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation")
def test_deepinterpolation_generator_borders(recording_and_shape_fixture):
"""Test that the generator avoids borders in multi-segment and recording lists cases"""
from spikeinterface.preprocessing.deepinterpolation.generators import SpikeInterfaceRecordingGenerator
Expand All @@ -58,6 +68,7 @@ def test_deepinterpolation_generator_borders(recording_and_shape_fixture):
assert len(gen_multi_list.exclude_intervals) == 2 * len(recording_multi_list) + 2


@pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation")
def test_deepinterpolation_training(recording_and_shape_fixture):
recording, desired_shape = recording_and_shape_fixture

Expand All @@ -82,6 +93,7 @@ def test_deepinterpolation_training(recording_and_shape_fixture):
print(model_path)


@pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation")
@pytest.mark.dependency(depends=["test_deepinterpolation_training"])
def test_deepinterpolation_transfer(recording_and_shape_fixture, tmp_path):
recording, desired_shape = recording_and_shape_fixture
Expand Down Expand Up @@ -109,6 +121,7 @@ def test_deepinterpolation_transfer(recording_and_shape_fixture, tmp_path):
print(model_path)


@pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation")
@pytest.mark.dependency(depends=["test_deepinterpolation_training"])
def test_deepinterpolation_inference(recording_and_shape_fixture):
recording, desired_shape = recording_and_shape_fixture
Expand Down

0 comments on commit 9eb4e3b

Please sign in to comment.