Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
weiglszonja committed Oct 3, 2023
1 parent 9607007 commit 798994c
Showing 1 changed file with 80 additions and 19 deletions.
99 changes: 80 additions & 19 deletions tests/test_suite2psegmentationextractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import shutil
import tempfile
from pathlib import Path

import numpy as np
from hdmf.testing import TestCase
from numpy.testing import assert_array_equal

from roiextractors import Suite2pSegmentationExtractor
from tests.setup_paths import OPHYS_DATA_PATH
Expand All @@ -8,32 +14,87 @@ class TestSuite2pSegmentationExtractor(TestCase):
@classmethod
def setUpClass(cls):
folder_path = str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p")
cls.available_streams = dict(
channel_streams=["chan1", "chan2"],
plane_streams=dict(
chan1=["chan1_combined", "chan1_plane0", "chan1_plane1"], chan2=["chan2_plane0", "chan2_plane1"]
),
)
cls.channel_names = ["chan1", "chan2"]
cls.plane_names = ["plane0", "plane1"]

cls.folder_path = folder_path
cls.folder_path = Path(folder_path)

extractor = Suite2pSegmentationExtractor(folder_path=folder_path, stream_name="chan1_plane1")
extractor = Suite2pSegmentationExtractor(folder_path=folder_path, channel_name="chan1", plane_name="plane0")
cls.extractor = extractor

def test_stream_names(self):
self.assertEqual(Suite2pSegmentationExtractor.get_streams(folder_path=self.folder_path), self.available_streams)
cls.test_dir = Path(tempfile.mkdtemp())

def test_multi_stream_warns(self):
exc_msg = "More than one channel is detected! Please specify which stream you wish to load with the `stream_name` argument. To see what streams are available, call `Suite2pSegmentationExtractor.get_streams(folder_path=...)`."
with self.assertRaisesWith(exc_type=ValueError, exc_msg=exc_msg):
cls.first_channel_raw_traces = np.load(cls.folder_path / "plane0" / "F.npy").T
cls.second_channel_raw_traces = np.load(cls.folder_path / "plane0" / "F_chan2.npy").T

@classmethod
def tearDownClass(cls):
# remove the temporary directory and its contents
shutil.rmtree(cls.test_dir)

def test_channel_names(self):
self.assertEqual(Suite2pSegmentationExtractor.get_available_channels(folder_path=self.folder_path), self.channel_names)

def test_plane_names(self):
self.assertEqual(Suite2pSegmentationExtractor.get_available_planes(folder_path=self.folder_path), self.plane_names)

def test_multi_channel_warns(self):
exc_msg = "More than one channel is detected! Please specify which channel you wish to load with the `channel_name` argument. To see what channels are available, call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`."
with self.assertWarnsWith(warn_type=UserWarning, exc_msg=exc_msg):
Suite2pSegmentationExtractor(folder_path=self.folder_path)

def test_invalid_stream_raises(self):
exc_msg = "The selected stream 'plane0' is not a valid stream name. To see what streams are available, call `Suite2pSegmentationExtractor.get_streams(folder_path=...)`."
def test_multi_plane_warns(self):
exc_msg = "More than one plane is detected! Please specify which plane you wish to load with the `plane_name` argument. To see what planes are available, call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`."
with self.assertWarnsWith(warn_type=UserWarning, exc_msg=exc_msg):
Suite2pSegmentationExtractor(folder_path=self.folder_path, channel_name="chan2")

def test_incorrect_plane_name_raises(self):
exc_msg = "The selected plane 'plane2' is not a valid plane name. To see what planes are available, call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`."
with self.assertRaisesWith(exc_type=ValueError, exc_msg=exc_msg):
Suite2pSegmentationExtractor(folder_path=self.folder_path, stream_name="plane0")
Suite2pSegmentationExtractor(folder_path=self.folder_path, plane_name="plane2")

def test_incorrect_stream_raises(self):
exc_msg = "The selected stream 'chan1_plane2' is not in the available plane_streams '['chan1_combined', 'chan1_plane0', 'chan1_plane1']'!"
def test_incorrect_channel_name_raises(self):
exc_msg = "The selected channel 'test' is not a valid channel name. To see what channels are available, call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`."
with self.assertRaisesWith(exc_type=ValueError, exc_msg=exc_msg):
Suite2pSegmentationExtractor(folder_path=self.folder_path, stream_name="chan1_plane2")
Suite2pSegmentationExtractor(folder_path=self.folder_path, channel_name="test")

def test_incomplete_extractor_load(self):
"""Check extractor can be initialized when not all traces are available."""
# temporary directory for testing assertion when some of the files are missing
files_to_copy = ["stat.npy", "ops.npy", "iscell.npy", "Fneu.npy"]
(self.test_dir / "plane0").mkdir(exist_ok=True)
[shutil.copy(Path(self.folder_path) / "plane0" / file, self.test_dir / "plane0" / file) for file in files_to_copy]

extractor = Suite2pSegmentationExtractor(folder_path=self.test_dir)
traces_dict = extractor.get_traces_dict()
self.assertEqual(traces_dict["raw"], None)
self.assertEqual(traces_dict["dff"], None)
self.assertEqual(traces_dict["deconvolved"], None)

def test_image_size(self):
self.assertEqual(self.extractor.get_image_size(), (128, 128))

def test_num_frames(self):
self.assertEqual(self.extractor.get_num_frames(), 250)

def test_sampling_frequency(self):
self.assertEqual(self.extractor.get_sampling_frequency(), 10.0)

def test_channel_names(self):
self.assertEqual(self.extractor.get_channel_names(), ["Chan1"])

def test_num_channels(self):
self.assertEqual(self.extractor.get_num_channels(), 1)

def test_num_rois(self):
self.assertEqual(self.extractor.get_num_rois(), 15)

def test_extractor_first_channel_raw_traces(self):
assert_array_equal(self.extractor.get_traces(name="raw"), self.first_channel_raw_traces)

def test_extractor_second_channel(self):
extractor = Suite2pSegmentationExtractor(folder_path=self.folder_path, channel_name="chan2")
self.assertEqual(extractor.get_channel_names(), ["Chan2"])
traces = extractor.get_traces_dict()
self.assertEqual(traces["deconvolved"], None)
assert_array_equal(traces["raw"], self.second_channel_raw_traces)

0 comments on commit 798994c

Please sign in to comment.