Skip to content

Commit

Permalink
Merge pull request #26 from AllenNeuralDynamics/metrics
Browse files Browse the repository at this point in the history
Adding computation and plotting of trial by trial metrics
  • Loading branch information
alexpiet authored Nov 13, 2024
2 parents 5543aa8 + afbf1b0 commit dff9747
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 30 deletions.
61 changes: 61 additions & 0 deletions src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Tools for computing trial by trial metrics
df_trials = compute_all_trial_metrics(nwb)
"""

import numpy as np

# TODO, we might want to make these parameters metric specific
WIN_DUR = 15
MIN_EVENTS = 2


def compute_all_trial_metrics(nwb):
"""
Computes all trial by trial metrics
response_rate, fraction of trials with a response
gocue_reward_rate, fraction of trials with a reward
response_reward_rate, fraction of trials with a reward,
computed only on trials with a response
choose_right_rate, fraction of trials where chose right,
computed only on trials with a response
"""
if not hasattr(nwb, "df_trials"):
print("You need to compute df_trials: nwb_utils.create_trials_df(nwb)")
return

df = nwb.df_trials.copy()

df["RESPONDED"] = [x in [0, 1] for x in df["animal_response"].values]
# Rolling fraction of goCues with a response
df["response_rate"] = (
df["RESPONDED"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean()
)

# Rolling fraction of goCues with a response
df["gocue_reward_rate"] = (
df["earned_reward"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean()
)

# Rolling fraction of responses with a response
df["RESPONSE_REWARD"] = [
x[0] if x[1] else np.nan for x in zip(df["earned_reward"], df["RESPONDED"])
]
df["response_reward_rate"] = (
df["RESPONSE_REWARD"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean()
)

# Rolling fraction of choosing right
df["WENT_RIGHT"] = [x if x in [0, 1] else np.nan for x in df["animal_response"]]
df["choose_right_rate"] = (
df["WENT_RIGHT"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean()
)

# Clean up temp columns
drop_cols = ["RESPONDED", "RESPONSE_REWARD", "WENT_RIGHT"]
df = df.drop(columns=drop_cols)

return df
112 changes: 82 additions & 30 deletions src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@


def plot_session_scroller( # noqa: C901 pragma: no cover
nwb, ax=None, fig=None, plot_bouts=False, processing="bright"
nwb,
ax=None,
fig=None,
plot_bouts=False,
processing="bright",
metrics=["pR", "pL", "response_rate"],
):
"""
Creates an interactive plot of the session.
Expand Down Expand Up @@ -56,10 +61,16 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
df_licks = nwb.df_licks
else:
df_licks = None
if not hasattr(nwb, "df_trials"):
print("computing df_trials")
nwb.df_trials = nu.create_df_trials(nwb)
df_trials = nwb.df_trials
else:
df_trials = nwb.df_trials

if ax is None:
if fip_df is None:
fig, ax = plt.subplots(figsize=(15, 3))
fig, ax = plt.subplots(figsize=(15, 4))
else:
fig, ax = plt.subplots(figsize=(15, 8))

Expand All @@ -78,30 +89,32 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
"right_reward_top": 0.75,
"go_cue_bottom": 0,
"go_cue_top": 1,
"G_1_dff-bright_bottom": 1,
"G_1_dff-bright_top": 2,
"G_2_dff-bright_bottom": 2,
"G_2_dff-bright_top": 3,
"R_1_dff-bright_bottom": 3,
"R_1_dff-bright_top": 4,
"R_2_dff-bright_bottom": 4,
"R_2_dff-bright_top": 5,
"G_1_dff-poly_bottom": 1,
"G_1_dff-poly_top": 2,
"G_2_dff-poly_bottom": 2,
"G_2_dff-poly_top": 3,
"R_1_dff-poly_bottom": 3,
"R_1_dff-poly_top": 4,
"R_2_dff-poly_bottom": 4,
"R_2_dff-poly_top": 5,
"G_1_dff-exp_bottom": 1,
"G_1_dff-exp_top": 2,
"G_2_dff-exp_bottom": 2,
"G_2_dff-exp_top": 3,
"R_1_dff-exp_bottom": 3,
"R_1_dff-exp_top": 4,
"R_2_dff-exp_bottom": 4,
"R_2_dff-exp_top": 5,
"metrics_bottom": 1,
"metrics_top": 2,
"G_1_dff-bright_bottom": 2,
"G_1_dff-bright_top": 3,
"G_2_dff-bright_bottom": 3,
"G_2_dff-bright_top": 4,
"R_1_dff-bright_bottom": 4,
"R_1_dff-bright_top": 5,
"R_2_dff-bright_bottom": 5,
"R_2_dff-bright_top": 6,
"G_1_dff-poly_bottom": 2,
"G_1_dff-poly_top": 3,
"G_2_dff-poly_bottom": 3,
"G_2_dff-poly_top": 4,
"R_1_dff-poly_bottom": 4,
"R_1_dff-poly_top": 5,
"R_2_dff-poly_bottom": 5,
"R_2_dff-poly_top": 6,
"G_1_dff-exp_bottom": 2,
"G_1_dff-exp_top": 3,
"G_2_dff-exp_bottom": 3,
"G_2_dff-exp_top": 4,
"R_1_dff-exp_bottom": 4,
"R_1_dff-exp_top": 5,
"R_2_dff-exp_bottom": 5,
"R_2_dff-exp_top": 6,
}
yticks = [
(params["left_lick_top"] - params["left_lick_bottom"]) / 2 + params["left_lick_bottom"],
Expand All @@ -110,9 +123,22 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
+ params["left_reward_bottom"],
(params["right_reward_top"] - params["right_reward_bottom"]) / 2
+ params["right_reward_bottom"],
(params["metrics_top"] - params["metrics_bottom"]) * 0.25 + params["metrics_bottom"],
(params["metrics_top"] - params["metrics_bottom"]) * 0.50 + params["metrics_bottom"],
(params["metrics_top"] - params["metrics_bottom"]) * 0.75 + params["metrics_bottom"],
params["metrics_top"],
]
ylabels = [
"left licks",
"right licks",
"left reward",
"right reward",
"0.25",
"0.50",
"0.75",
"metrics",
]
ylabels = ["left licks", "right licks", "left reward", "right reward"]
ycolors = ["k", "k", "r", "r"]
ycolors = ["k", "k", "r", "r", "darkgray", "darkgray", "darkgray", "k"]

if fip_df is not None:
fip_channels = [
Expand Down Expand Up @@ -260,6 +286,12 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
"bD",
)

# Plot baiting
bait_right = df_trials.query("bait_right")["goCue_start_time_in_session"].values
bait_left = df_trials.query("bait_left")["goCue_start_time_in_session"].values
ax.plot(bait_right, [params["go_cue_top"] - 0.05] * len(bait_right), "ms", label="baited")
ax.plot(bait_left, [params["go_cue_bottom"] + 0.05] * len(bait_left), "ms")

left_reward_deliverys = df_events.query('event == "left_reward_delivery_time"')
left_times = left_reward_deliverys.timestamps.values
ax.vlines(
Expand Down Expand Up @@ -295,6 +327,26 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
label="go cue",
)

# plot metrics
ax.axhline(params["metrics_bottom"], color="k", linewidth=0.5, alpha=0.25)
go_cue_times_doubled = np.repeat(go_cue_times, 2)[1:]
if "pR" in metrics:
pR = params["metrics_bottom"] + df_trials["reward_probabilityR"]
pR = np.repeat(pR, 2)[:-1]
ax.plot(go_cue_times_doubled, pR, color="r", label="pR")
if "pL" in metrics:
pL = params["metrics_bottom"] + df_trials["reward_probabilityL"]
pL = np.repeat(pL, 2)[:-1]
ax.plot(go_cue_times_doubled, pL, color="b", label="pL")

# plot metrics if they are available
for metric in metrics:
if metric in df_trials:
values = df_trials[metric] + params["metrics_bottom"]
ax.plot(go_cue_times, values, label=metric)
elif metric not in ["pL", "pR"]:
print('Metric "{}" not available in df_trials'.format(metric))

# Clean up plot
ax.legend(framealpha=1, loc="lower left", reverse=True)
ax.set_yticks(yticks)
Expand All @@ -304,9 +356,9 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
tick.set_color(color)
ax.set_xlabel("time (s)", fontsize=STYLE["axis_fontsize"])
if fip_df is None:
ax.set_ylim(0, 1)
ax.set_ylim(0, 2)
else:
ax.set_ylim(0, 5)
ax.set_ylim(0, 6)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
if fip_df is not None:
Expand Down

0 comments on commit dff9747

Please sign in to comment.