diff --git a/tests/conftest.py b/tests/conftest.py index 4aa6241b..9a191749 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,12 +24,13 @@ def dj_config(): "database.user": os.environ.get("DJ_USER") or dj.config["database.user"], } ) + os.environ["DATABASE_PREFIX"] = "test_" return @pytest.fixture(autouse=True, scope="session") def pipeline(): - import tutorial_pipeline as pipeline + from . import tutorial_pipeline as pipeline yield { "lab": pipeline.lab, @@ -37,11 +38,17 @@ def pipeline(): "session": pipeline.session, "probe": pipeline.probe, "ephys": pipeline.ephys, + "ephys_report": pipeline.ephys_report, "get_ephys_root_data_dir": pipeline.get_ephys_root_data_dir, } if _tear_down: - pipeline.subject.Subject.delete() + pipeline.ephys_report.schema.drop() + pipeline.ephys.schema.drop() + pipeline.probe.schema.drop() + pipeline.session.schema.drop() + pipeline.subject.schema.drop() + pipeline.lab.schema.drop() @pytest.fixture(scope="session") @@ -53,37 +60,46 @@ def insert_upstreams(pipeline): ephys = pipeline["ephys"] subject.Subject.insert1( - dict(subject="subject5", subject_birth_date="2023-01-01", sex="U") + dict(subject="subject5", subject_birth_date="2023-01-01", sex="U"), + skip_duplicates=True, ) session_key = dict(subject="subject5", session_datetime="2023-01-01 00:00:00") + session.Session.insert1(session_key, skip_duplicates=True) session_dir = "raw/subject5/session1" - session.SessionDirectory.insert1(dict(**session_key, session_dir=session_dir)) - probe.Probe.insert1(dict(probe="714000838", probe_type="neuropixels 1.0 - 3B")) + session.SessionDirectory.insert1( + dict(**session_key, session_dir=session_dir), skip_duplicates=True + ) + probe.Probe.insert1( + dict(probe="714000838", probe_type="neuropixels 1.0 - 3B"), skip_duplicates=True + ) ephys.ProbeInsertion.insert1( dict( - session_key, + **session_key, insertion_number=1, probe="714000838", - ) + ), + skip_duplicates=True, ) - yield - if _tear_down: - subject.Subject.delete() - probe.Probe.delete() + return @pytest.fixture(scope="session") -def populate_ephys_recording(pipeline, insert_upstream): +def populate_ephys_recording(pipeline, insert_upstreams): ephys = pipeline["ephys"] ephys.EphysRecording.populate() - yield + return - if _tear_down: - ephys.EphysRecording.delete() + +@pytest.fixture(scope="session") +def populate_lfp(pipeline, insert_upstreams): + ephys = pipeline["ephys"] + ephys.LFP.populate() + + return @pytest.fixture(scope="session") @@ -129,25 +145,20 @@ def insert_clustering_task(pipeline, populate_ephys_recording): paramset_idx=0, task_mode="load", # load or trigger clustering_output_dir="processed/subject5/session1/probe_1/kilosort2-5_1", - ) + ), + skip_duplicates=True, ) - yield - - if _tear_down: - ephys.ClusteringParamSet.delete() + return @pytest.fixture(scope="session") -def processing(pipeline, populate_ephys_recording): +def processing(pipeline, insert_clustering_task): ephys = pipeline["ephys"] + ephys.Clustering.populate() ephys.CuratedClustering.populate() - ephys.LFP.populate() ephys.WaveformSet.populate() + ephys.QualityMetrics.populate() - yield - - if _tear_down: - ephys.CuratedClustering.delete() - ephys.LFP.delete() + return diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 00000000..fe809ccc --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,132 @@ +import numpy as np +import pandas as pd +import datetime +from uuid import UUID + + +def test_generate_pipeline(pipeline): + subject = pipeline["subject"] + session = pipeline["session"] + ephys = pipeline["ephys"] + probe = pipeline["probe"] + + # test elements connection from lab, subject to Session + assert subject.Subject.full_table_name in session.Session.parents() + + # test elements connection from Session to probe, ephys, ephys_report + assert session.Session.full_table_name in ephys.ProbeInsertion.parents() + assert probe.Probe.full_table_name in ephys.ProbeInsertion.parents() + assert "spike_times" in (ephys.CuratedClustering.Unit.heading.secondary_attributes) + + +def test_insert_upstreams(pipeline, insert_upstreams): + """Check number of subjects inserted into the `subject.Subject` table""" + subject = pipeline["subject"] + session = pipeline["session"] + probe = pipeline["probe"] + ephys = pipeline["ephys"] + + assert len(subject.Subject()) == 1 + assert len(session.Session()) == 1 + assert len(probe.Probe()) == 1 + assert len(ephys.ProbeInsertion()) == 1 + + +def test_populate_ephys_recording(pipeline, populate_ephys_recording): + ephys = pipeline["ephys"] + + assert ephys.EphysRecording.fetch1() == { + "subject": "subject5", + "session_datetime": datetime.datetime(2023, 1, 1, 0, 0), + "insertion_number": 1, + "electrode_config_hash": UUID("8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee"), + "acq_software": "SpikeGLX", + "sampling_rate": 30000.0, + "recording_datetime": datetime.datetime(2018, 7, 3, 20, 32, 28), + "recording_duration": 338.666, + } + assert ( + ephys.EphysRecording.EphysFile.fetch1("file_path") + == "raw/subject5/session1/probe_1/npx_g0_t0.imec.ap.meta" + ) + + +def test_populate_lfp(pipeline, populate_lfp): + ephys = pipeline["ephys"] + + assert np.mean(ephys.LFP.fetch1("lfp_mean")) == -716.0220556825378 + assert len((ephys.LFP.Electrode).fetch("electrode")) == 43 + + +def test_insert_clustering_task(pipeline, insert_clustering_task): + ephys = pipeline["ephys"] + + assert ephys.ClusteringParamSet.fetch1("param_set_hash") == UUID( + "de78cee1-526f-319e-b6d5-8a2ba04963d8" + ) + + assert ephys.ClusteringTask.fetch1() == { + "subject": "subject5", + "session_datetime": datetime.datetime(2023, 1, 1, 0, 0), + "insertion_number": 1, + "paramset_idx": 0, + "clustering_output_dir": "processed/subject5/session1/probe_1/kilosort2-5_1", + "task_mode": "load", + } + + +def test_processing(pipeline, processing): + + ephys = pipeline["ephys"] + + # test ephys.CuratedClustering + assert len(ephys.CuratedClustering.Unit & 'cluster_quality_label = "good"') == 176 + assert np.sum(ephys.CuratedClustering.Unit.fetch("spike_count")) == 328167 + # test ephys.WaveformSet + waveforms = np.vstack( + (ephys.WaveformSet.PeakWaveform).fetch("peak_electrode_waveform") + ) + assert waveforms.shape == (227, 82) + + # test ephys.QualityMetrics + cluster_df = (ephys.QualityMetrics.Cluster).fetch(format="frame", order_by="unit") + waveform_df = (ephys.QualityMetrics.Waveform).fetch(format="frame", order_by="unit") + test_df = pd.concat([cluster_df, waveform_df], axis=1).reset_index() + test_value = test_df.select_dtypes(include=[np.number]).mean().values + + assert np.allclose( + test_value, + np.array( + [ + 1.00000000e00, + 0.00000000e00, + 1.13000000e02, + 4.26880089e00, + 1.24162431e00, + 7.17929515e-01, + 4.41633793e-01, + 3.08736082e-01, + 1.24039274e15, + 1.66763828e-02, + 4.33231948e00, + 7.12304747e-01, + 1.48995215e-02, + 7.73432472e-02, + 5.06451613e00, + 7.79528634e00, + 6.30182452e-01, + 1.19562726e02, + 7.90175419e-01, + np.nan, + 8.78436780e-01, + 1.08028193e-01, + -5.19418717e-02, + 2.36035242e02, + 7.48443665e-02, + 2.77550214e-02, + ] + ), + rtol=1e-03, + atol=1e-03, + equal_nan=True, + ) diff --git a/tests/tutorial_pipeline.py b/tests/tutorial_pipeline.py index 8f4ed7f7..fc92280b 100644 --- a/tests/tutorial_pipeline.py +++ b/tests/tutorial_pipeline.py @@ -3,7 +3,7 @@ import datajoint as dj from element_animal import subject from element_animal.subject import Subject -from element_array_ephys import probe, ephys_no_curation as ephys +from element_array_ephys import probe, ephys_no_curation as ephys, ephys_report from element_lab import lab from element_lab.lab import Lab, Location, Project, Protocol, Source, User from element_lab.lab import Device as Equipment