diff --git a/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py b/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py new file mode 100644 index 0000000..a19eb3d --- /dev/null +++ b/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py @@ -0,0 +1,61 @@ +""" + Tools for computing trial by trial metrics + df_trials = compute_all_trial_metrics(nwb) + +""" + +import numpy as np + +# TODO, we might want to make these parameters metric specific +WIN_DUR = 15 +MIN_EVENTS = 2 + + +def compute_all_trial_metrics(nwb): + """ + Computes all trial by trial metrics + + response_rate, fraction of trials with a response + gocue_reward_rate, fraction of trials with a reward + response_reward_rate, fraction of trials with a reward, + computed only on trials with a response + choose_right_rate, fraction of trials where chose right, + computed only on trials with a response + + """ + if not hasattr(nwb, "df_trials"): + print("You need to compute df_trials: nwb_utils.create_trials_df(nwb)") + return + + df = nwb.df_trials.copy() + + df["RESPONDED"] = [x in [0, 1] for x in df["animal_response"].values] + # Rolling fraction of goCues with a response + df["response_rate"] = ( + df["RESPONDED"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean() + ) + + # Rolling fraction of goCues with a response + df["gocue_reward_rate"] = ( + df["earned_reward"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean() + ) + + # Rolling fraction of responses with a response + df["RESPONSE_REWARD"] = [ + x[0] if x[1] else np.nan for x in zip(df["earned_reward"], df["RESPONDED"]) + ] + df["response_reward_rate"] = ( + df["RESPONSE_REWARD"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean() + ) + + # Rolling fraction of choosing right + df["WENT_RIGHT"] = [x if x in [0, 1] else np.nan for x in df["animal_response"]] + df["choose_right_rate"] = ( + df["WENT_RIGHT"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean() + ) + + # Clean up temp columns + drop_cols = ["RESPONDED", "RESPONSE_REWARD", "WENT_RIGHT"] + df = df.drop(columns=drop_cols) + + return df diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py index 9ad354b..7c4b0ec 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py @@ -15,7 +15,12 @@ def plot_session_scroller( # noqa: C901 pragma: no cover - nwb, ax=None, fig=None, plot_bouts=False, processing="bright" + nwb, + ax=None, + fig=None, + plot_bouts=False, + processing="bright", + metrics=["pR", "pL", "response_rate"], ): """ Creates an interactive plot of the session. @@ -56,10 +61,16 @@ def plot_session_scroller( # noqa: C901 pragma: no cover df_licks = nwb.df_licks else: df_licks = None + if not hasattr(nwb, "df_trials"): + print("computing df_trials") + nwb.df_trials = nu.create_df_trials(nwb) + df_trials = nwb.df_trials + else: + df_trials = nwb.df_trials if ax is None: if fip_df is None: - fig, ax = plt.subplots(figsize=(15, 3)) + fig, ax = plt.subplots(figsize=(15, 4)) else: fig, ax = plt.subplots(figsize=(15, 8)) @@ -78,30 +89,32 @@ def plot_session_scroller( # noqa: C901 pragma: no cover "right_reward_top": 0.75, "go_cue_bottom": 0, "go_cue_top": 1, - "G_1_dff-bright_bottom": 1, - "G_1_dff-bright_top": 2, - "G_2_dff-bright_bottom": 2, - "G_2_dff-bright_top": 3, - "R_1_dff-bright_bottom": 3, - "R_1_dff-bright_top": 4, - "R_2_dff-bright_bottom": 4, - "R_2_dff-bright_top": 5, - "G_1_dff-poly_bottom": 1, - "G_1_dff-poly_top": 2, - "G_2_dff-poly_bottom": 2, - "G_2_dff-poly_top": 3, - "R_1_dff-poly_bottom": 3, - "R_1_dff-poly_top": 4, - "R_2_dff-poly_bottom": 4, - "R_2_dff-poly_top": 5, - "G_1_dff-exp_bottom": 1, - "G_1_dff-exp_top": 2, - "G_2_dff-exp_bottom": 2, - "G_2_dff-exp_top": 3, - "R_1_dff-exp_bottom": 3, - "R_1_dff-exp_top": 4, - "R_2_dff-exp_bottom": 4, - "R_2_dff-exp_top": 5, + "metrics_bottom": 1, + "metrics_top": 2, + "G_1_dff-bright_bottom": 2, + "G_1_dff-bright_top": 3, + "G_2_dff-bright_bottom": 3, + "G_2_dff-bright_top": 4, + "R_1_dff-bright_bottom": 4, + "R_1_dff-bright_top": 5, + "R_2_dff-bright_bottom": 5, + "R_2_dff-bright_top": 6, + "G_1_dff-poly_bottom": 2, + "G_1_dff-poly_top": 3, + "G_2_dff-poly_bottom": 3, + "G_2_dff-poly_top": 4, + "R_1_dff-poly_bottom": 4, + "R_1_dff-poly_top": 5, + "R_2_dff-poly_bottom": 5, + "R_2_dff-poly_top": 6, + "G_1_dff-exp_bottom": 2, + "G_1_dff-exp_top": 3, + "G_2_dff-exp_bottom": 3, + "G_2_dff-exp_top": 4, + "R_1_dff-exp_bottom": 4, + "R_1_dff-exp_top": 5, + "R_2_dff-exp_bottom": 5, + "R_2_dff-exp_top": 6, } yticks = [ (params["left_lick_top"] - params["left_lick_bottom"]) / 2 + params["left_lick_bottom"], @@ -110,9 +123,22 @@ def plot_session_scroller( # noqa: C901 pragma: no cover + params["left_reward_bottom"], (params["right_reward_top"] - params["right_reward_bottom"]) / 2 + params["right_reward_bottom"], + (params["metrics_top"] - params["metrics_bottom"]) * 0.25 + params["metrics_bottom"], + (params["metrics_top"] - params["metrics_bottom"]) * 0.50 + params["metrics_bottom"], + (params["metrics_top"] - params["metrics_bottom"]) * 0.75 + params["metrics_bottom"], + params["metrics_top"], + ] + ylabels = [ + "left licks", + "right licks", + "left reward", + "right reward", + "0.25", + "0.50", + "0.75", + "metrics", ] - ylabels = ["left licks", "right licks", "left reward", "right reward"] - ycolors = ["k", "k", "r", "r"] + ycolors = ["k", "k", "r", "r", "darkgray", "darkgray", "darkgray", "k"] if fip_df is not None: fip_channels = [ @@ -260,6 +286,12 @@ def plot_session_scroller( # noqa: C901 pragma: no cover "bD", ) + # Plot baiting + bait_right = df_trials.query("bait_right")["goCue_start_time_in_session"].values + bait_left = df_trials.query("bait_left")["goCue_start_time_in_session"].values + ax.plot(bait_right, [params["go_cue_top"] - 0.05] * len(bait_right), "ms", label="baited") + ax.plot(bait_left, [params["go_cue_bottom"] + 0.05] * len(bait_left), "ms") + left_reward_deliverys = df_events.query('event == "left_reward_delivery_time"') left_times = left_reward_deliverys.timestamps.values ax.vlines( @@ -295,6 +327,26 @@ def plot_session_scroller( # noqa: C901 pragma: no cover label="go cue", ) + # plot metrics + ax.axhline(params["metrics_bottom"], color="k", linewidth=0.5, alpha=0.25) + go_cue_times_doubled = np.repeat(go_cue_times, 2)[1:] + if "pR" in metrics: + pR = params["metrics_bottom"] + df_trials["reward_probabilityR"] + pR = np.repeat(pR, 2)[:-1] + ax.plot(go_cue_times_doubled, pR, color="r", label="pR") + if "pL" in metrics: + pL = params["metrics_bottom"] + df_trials["reward_probabilityL"] + pL = np.repeat(pL, 2)[:-1] + ax.plot(go_cue_times_doubled, pL, color="b", label="pL") + + # plot metrics if they are available + for metric in metrics: + if metric in df_trials: + values = df_trials[metric] + params["metrics_bottom"] + ax.plot(go_cue_times, values, label=metric) + elif metric not in ["pL", "pR"]: + print('Metric "{}" not available in df_trials'.format(metric)) + # Clean up plot ax.legend(framealpha=1, loc="lower left", reverse=True) ax.set_yticks(yticks) @@ -304,9 +356,9 @@ def plot_session_scroller( # noqa: C901 pragma: no cover tick.set_color(color) ax.set_xlabel("time (s)", fontsize=STYLE["axis_fontsize"]) if fip_df is None: - ax.set_ylim(0, 1) + ax.set_ylim(0, 2) else: - ax.set_ylim(0, 5) + ax.set_ylim(0, 6) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if fip_df is not None: