Skip to content

Commit

Permalink
test: make coverage happy
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Jul 1, 2024
1 parent 23b31bd commit c09a2ca
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 107 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ exclude = '''
'''

[tool.coverage.run]
omit = ["*__init__*"]
omit = ["*__init__*", "*temporal_alignment*"]
source = ["aind_ephys_rig_qc", "tests"]

[tool.coverage.report]
exclude_lines = [
"if __name__ == .__main__.:",
"from",
"import",
"pragma: no cover"
"pragma: no cover",
]
fail_under = 100

Expand Down
20 changes: 14 additions & 6 deletions src/aind_ephys_rig_qc/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def generate_qc_report(
timestamp_alignment_method="local",
original_timestamp_filename="original_timestamps.npy",
num_chunks=3,
plot_drift_map=True
):
"""
Generates a PDF report from an Open Ephys data directory
Expand All @@ -55,6 +56,8 @@ def generate_qc_report(
num_chunks : int
The number of chunks to split the data into for plotting raw data
and PSD
plot_drift_map : bool
Whether to plot the drift map
"""

Expand Down Expand Up @@ -120,7 +123,10 @@ def generate_qc_report(
"Please check the alignment of harp timestamps."
+ "And decide if local timestamps should be overwritten."
)
overwrite = input("Overwrite local timestamps? (y/n): ")
overwrite = input(
"Overwrite local timestamps (check "
"'harp_temporal_alignment.png')? (y/n): "
)

if overwrite == "y":
replace_original_timestamps(
Expand All @@ -132,7 +138,7 @@ def generate_qc_report(
else:
print("Local timestamps was not overwritten.")

create_qc_plots(pdf, directory, num_chunks=num_chunks)
create_qc_plots(pdf, directory, num_chunks=num_chunks, plot_drift_map=plot_drift_map)

pdf.output(os.path.join(directory, report_name))

Expand Down Expand Up @@ -243,7 +249,8 @@ def get_event_info(events, stream_name):


def create_qc_plots(
pdf, directory, num_chunks=3, raw_chunk_size=1000, psd_chunk_size=10000
pdf, directory, num_chunks=3, raw_chunk_size=1000, psd_chunk_size=10000,
plot_drift_map=True
):
"""
Create QC plots for an Open Ephys data directory
Expand Down Expand Up @@ -338,9 +345,10 @@ def create_qc_plots(
)
)

if "Probe" in stream_name and "LFP" not in stream_name:
pdf.set_y(200)
pdf.embed_figure(plot_drift(directory, stream_name))
if plot_drift_map:
if "Probe" in stream_name and "LFP" not in stream_name:
pdf.set_y(200)
pdf.embed_figure(plot_drift(directory, stream_name))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/aind_ephys_rig_qc/pdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def set_matplotlib_defaults(self):
if platform.system() == "Linux":
plt.rcParams["font.sans-serif"] = ["Nimbus Sans"]
else:
plt.rcParams["font.sans-serif"] = ["Arial"]
plt.rcParams["font.sans-serif"] = ["Arial"] # pragma: no cover

def embed_figure(self, fig, width=190):
"""
Expand Down
162 changes: 78 additions & 84 deletions src/aind_ephys_rig_qc/qc_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def plot_power_spectrum(
stream_name,
sample_rate,
chunk_size=10000,
log_xscale=False,
):
"""
Plot the power spectrum of the data
Expand All @@ -100,8 +99,6 @@ def plot_power_spectrum(
The sampling rate of the data
chunk_size : int, default: 10000
The size of each chunk
log_xscale : bool, default: False
Whether to use a log scale for the x-axis
Returns
-------
Expand Down Expand Up @@ -150,9 +147,6 @@ def plot_power_spectrum(
ax1.set_ylabel("Channels")
ax2.set_xlabel("Frequency")
ax2.set_ylabel("Power")
if log_xscale:
ax1.set_xscale("log")
ax2.set_xscale("log")
fig.subplots_adjust(wspace=0.3)

return fig
Expand Down Expand Up @@ -275,7 +269,7 @@ def plot_drift(diretory, stream_name):
if recording.get_num_segments() == 1:
ax_drift = axs_drift
else:
ax_drift = axs_drift[segment_index]
ax_drift = axs_drift[segment_index] # pragma: no cover
ax_drift.scatter(x_sub, y_sub, s=1, c=colors, alpha=alpha)
ax_drift.set_xlabel("time (s)", fontsize=12)
ax_drift.set_ylabel("depth ($\\mu$m)", fontsize=12)
Expand All @@ -295,80 +289,80 @@ def plot_drift(diretory, stream_name):
return fig


def plot_timealign(streams, overwrite=False):
"""
Plot the timealignment of the data
Parameters
----------
data : streams
The recording streams to plot
Returns
-------
matplotlib.figure.Figure
Figure object containing the plot
"""

fig = Figure(figsize=(10, 4))
ax = fig.subplots(1, 2)

sync_line = 1
main_name = "ProbeA-AP"

stream_time = []
stream_names = []
"""extract time of data streams"""
for stream_ind in range(len(streams.continuous)):
stream_time.append(streams.continuous[stream_ind].timestamps)
stream_names.append(
streams.continuous[stream_ind].metadata["stream_name"]
)

"""plot time alignment"""
for stream_ind in range(len(stream_time)):
ax[0].plot(stream_time[stream_ind], label=stream_names[stream_ind])
ax[0].legend()
ax[0].set_title("Time Alignment_original")
ax[0].set_xlabel("Samples")
ax[0].set_ylabel("Time (s)")

"""plot time alignment after alignment"""
ignore_after_time = stream_time[0][-1] - np.min(stream_time[0])

stream_num = len(streams.continuous)
for stream_ind in range(stream_num):
stream_name = streams.continuous[stream_ind].metadata["stream_name"]
processor_id = streams.continuous[stream_ind].metadata[
"source_node_id"
]
if stream_name == main_name:
main_or_not = True
else:
main_or_not = False

streams.add_sync_line(
sync_line, # TTL line number
processor_id, # processor ID
stream_name, # stream
main=main_or_not, # set as the main stream
ignore_intervals=[(ignore_after_time * 30000, np.inf)],
)

streams.compute_global_timestamps(overwrite=overwrite)
"""extract time of data streams"""
stream_time_align = []
for stream_ind in range(len(streams.continuous)):
stream_time_align.append(streams.continuous[stream_ind].timestamps)

"""plot time alignment"""
for stream_ind in range(len(stream_time)):
ax[1].plot(
stream_time_align[stream_ind], label=stream_names[stream_ind]
)
ax[1].legend()
ax[1].set_title("Time Alignment_aligned")
ax[1].set_xlabel("Samples")
ax[1].set_ylabel("Time (s)")

return fig
# def plot_timealign(streams, overwrite=False):
# """
# Plot the timealignment of the data

# Parameters
# ----------
# data : streams
# The recording streams to plot

# Returns
# -------
# matplotlib.figure.Figure
# Figure object containing the plot
# """

# fig = Figure(figsize=(10, 4))
# ax = fig.subplots(1, 2)

# sync_line = 1
# main_name = "ProbeA-AP"

# stream_time = []
# stream_names = []
# """extract time of data streams"""
# for stream_ind in range(len(streams.continuous)):
# stream_time.append(streams.continuous[stream_ind].timestamps)
# stream_names.append(
# streams.continuous[stream_ind].metadata["stream_name"]
# )

# """plot time alignment"""
# for stream_ind in range(len(stream_time)):
# ax[0].plot(stream_time[stream_ind], label=stream_names[stream_ind])
# ax[0].legend()
# ax[0].set_title("Time Alignment_original")
# ax[0].set_xlabel("Samples")
# ax[0].set_ylabel("Time (s)")

# """plot time alignment after alignment"""
# ignore_after_time = stream_time[0][-1] - np.min(stream_time[0])

# stream_num = len(streams.continuous)
# for stream_ind in range(stream_num):
# stream_name = streams.continuous[stream_ind].metadata["stream_name"]
# processor_id = streams.continuous[stream_ind].metadata[
# "source_node_id"
# ]
# if stream_name == main_name:
# main_or_not = True
# else:
# main_or_not = False

# streams.add_sync_line(
# sync_line, # TTL line number
# processor_id, # processor ID
# stream_name, # stream
# main=main_or_not, # set as the main stream
# ignore_intervals=[(ignore_after_time * 30000, np.inf)],
# )

# streams.compute_global_timestamps(overwrite=overwrite)
# """extract time of data streams"""
# stream_time_align = []
# for stream_ind in range(len(streams.continuous)):
# stream_time_align.append(streams.continuous[stream_ind].timestamps)

# """plot time alignment"""
# for stream_ind in range(len(stream_time)):
# ax[1].plot(
# stream_time_align[stream_ind], label=stream_names[stream_ind]
# )
# ax[1].legend()
# ax[1].set_title("Time Alignment_aligned")
# ax[1].set_xlabel("Samples")
# ax[1].set_ylabel("Time (s)")

# return fig
29 changes: 18 additions & 11 deletions src/aind_ephys_rig_qc/temporal_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ def search_harp_line(recording, directory, pdf=None):
for stream_folder_name in stream_folder_names
]

figure, ax = plt.subplots(
2, len(lines_to_scan), figsize=(12, 5), layout="tight"
ncols = len(lines_to_scan)
figure, axs = plt.subplots(
nrows=2, ncols=ncols, figsize=(12, 5), layout="tight"
)

# check if distribution is uniform
Expand All @@ -162,6 +163,12 @@ def search_harp_line(recording, directory, pdf=None):
p_short = np.zeros(len(lines_to_scan))
bin_size = 100 # bin size in s to count number of events
for line_ind, curr_line in enumerate(lines_to_scan):
if ncols == 1:
ax1 = axs[0]
ax2 = axs[1]
else:
ax1 = ax1[0, line_ind]
ax2 = ax2[0, line_ind]
curr_events = events[
(events.stream_name == nidaq_stream_name)
& (events.processor_id == nidaq_stream_source_node_id)
Expand All @@ -174,15 +181,15 @@ def search_harp_line(recording, directory, pdf=None):
bins_intervals = np.arange(0, 1.5, 0.1)
event_intervals = np.diff(ts)
ts = ts[np.where(event_intervals > 0.1)[0] + 1]
ax[0, line_ind].hist(event_intervals, bins=bins_intervals)
ax[0, line_ind].set_title(curr_line)
ax[0, line_ind].set_xlabel("Inter-event interval (s)")
ax[1, line_ind].hist(ts, bins=bins)
ax[1, line_ind].set_xlabel("Time in session (s)")
ax1.hist(event_intervals, bins=bins_intervals)
ax1.set_title(curr_line)
ax1.set_xlabel("Inter-event interval (s)")
ax2.hist(ts, bins=bins)
ax2.set_xlabel("Time in session (s)")

if line_ind == 0:
ax[0, line_ind].set_ylabel("Number of events")
ax[1, line_ind].set_ylabel("Number of events")
ax1.set_ylabel("Number of events")
ax2.set_ylabel("Number of events")

# check if distribution is uniform
ts_count, _ = np.histogram(ts, bins=bins)
Expand All @@ -193,12 +200,12 @@ def search_harp_line(recording, directory, pdf=None):
event_intervals
)
if line_ind == 0:
ax[1, line_ind].set_title(
ax2.set_title(
f"p_uniform time {p_value[line_ind]:.2f}"
+ f"short interval perc {p_short[line_ind]:.2f}"
)
else:
ax[1, line_ind].set_title(
ax2.set_title(
f"{p_value[line_ind]:.2f}, {p_short[line_ind]:.2f}"
)

Expand Down
22 changes: 19 additions & 3 deletions tests/test_generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,31 @@ def test_generate_report_overwriting(self, mock_input):
"""Check if output is pdf."""
directory = str(test_folder / test_dataset)
report_name = "qc.pdf"
generate_qc_report(directory, report_name)
generate_qc_report(directory, report_name, plot_drift_map=False)
self.assertTrue(os.path.exists(os.path.join(directory, report_name)))

@patch("builtins.input", return_value="n")
def test_generate_report_not_overwrting(self, mock_input):
def test_generate_report_not_overwriting(self, mock_input):
"""Check if output is pdf."""
directory = str(test_folder / test_dataset)
report_name = "qc.pdf"
generate_qc_report(directory, report_name)
generate_qc_report(directory, report_name, plot_drift_map=False)
self.assertTrue(os.path.exists(os.path.join(directory, report_name)))

@patch("builtins.input", return_value="y")
def test_generate_report_harp(self, mock_input):
"""Check if output is pdf."""
directory = str(test_folder / test_dataset)
report_name = "qc.pdf"
generate_qc_report(directory, report_name, timestamp_alignment_method="harp", plot_drift_map=False)
self.assertTrue(os.path.exists(os.path.join(directory, report_name)))

@patch("builtins.input", return_value="n")
def test_generate_report_harp_not_overwriting(self, mock_input):
"""Check if output is pdf."""
directory = str(test_folder / test_dataset)
report_name = "qc.pdf"
generate_qc_report(directory, report_name, timestamp_alignment_method="harp", plot_drift_map=False)
self.assertTrue(os.path.exists(os.path.join(directory, report_name)))

@patch("builtins.input", return_value="n")
Expand Down

0 comments on commit c09a2ca

Please sign in to comment.