diff --git a/README.md b/README.md index c756f1c..bb30906 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,85 @@ To develop the code, run pip install -e .[dev] ``` +## Usage +### Annotate licks +To create a dataframe of licks that has been annotated with licking bout starts/stops, cue responsive licks, reward triggered licks, and intertrial choices. +``` +import aind_dynamic_foraging_basic_analysis.licks.annotation as annotation +df_licks = annotation.annotate_licks(nwb) +``` + +You can then plot interlick interval analyses with: +``` +import aind_dynamic_foraging_basic_analysis.licks.plot_interlick_interval as pii + +#Plot interlick interval of all licks +pii.plot_interlick_interval(df_licks) + +#plot interlick interval for left and right licks separately +pii.plot_interlick_interval(df_licks, categories='event') +``` + +### Create lick analysis report +To create a figure with several licking pattern analyses: + +``` +import aind_dynamic_foraging_basic_analysis.licks.lick_analysis as lick_analysis +lick_analysis.plot_lick_analysis(nwb) +``` + +### Compute trial by trial metrics +To annotate the trials dataframe with trial by trial metrics: + +``` +import aind_dynamic_foraging_basic_analysis.metrics.trial_metrics as tm +df_trials = tm.compute_all_trial_metrics(nwb) +``` + +### Plot interactive session scroller +``` +import aind_dynamic_foraging_basic_analysis.plot.plot_session_scroller as pss +pss.plot_session_scroller(nwb) +``` + +To disable lick bout and other annotations: +``` +pss.plot_session_scroller(nwb,plot_bouts=False) +``` + +This function will automatically plot FIP data if available. To change the processing method plotted use: +``` +pss.plot_session_scroller(nwb, processing="bright") +``` + +To change which trial by trial metrics plotted: +``` +pss.plot_session_scroller(nwb, metrics=['response_rate']) +``` + +### Plot FIP PSTH +You can use the `plot_fip` module to compute and plot PSTHs for the FIP data. + +To compare one channel to multiple event types +``` +from aind_dynamic_foraging_basic_analysis.plot import plot_fip as pf +channel = 'G_1_dff-poly' +rewarded_go_cues = nwb.df_trials.query('earned_reward == 1')['goCue_start_time_in_session'].values +unrewarded_go_cues = nwb.df_trials.query('earned_reward == 0')['goCue_start_time_in_session'].values +pf.plot_fip_psth_compare_alignments( + nwb, + {'rewarded goCue':rewarded_go_cues,'unrewarded goCue':unrewarded_go_cues}, + channel, + censor=True + ) +``` + +To compare multiple channels to the same event type: +``` +pf.plot_fip_psth(nwb, 'goCue_start_time') +``` + + ## Contributing ### Linters and testing diff --git a/src/aind_dynamic_foraging_basic_analysis/__init__.py b/src/aind_dynamic_foraging_basic_analysis/__init__.py index f7dcc5f..74f4fca 100644 --- a/src/aind_dynamic_foraging_basic_analysis/__init__.py +++ b/src/aind_dynamic_foraging_basic_analysis/__init__.py @@ -1,6 +1,6 @@ """Init package""" -__version__ = "0.3.5" +__version__ = "0.3.7" from .foraging_efficiency import compute_foraging_efficiency # noqa: F401 from .plot.plot_foraging_session import plot_foraging_session # noqa: F401 diff --git a/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py b/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py index 2478640..ee29713 100644 --- a/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py +++ b/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py @@ -1,21 +1,24 @@ """ Tools for computing trial by trial metrics - df_trials = compute_all_trial_metrics(nwb) + df_trials = compute_trial_metrics(nwb) + df_trials = compute_bias(nwb) """ +import pandas as pd import numpy as np +import aind_dynamic_foraging_models.logistic_regression.model as model + # TODO, we might want to make these parameters metric specific WIN_DUR = 15 MIN_EVENTS = 2 LEFT, RIGHT, IGNORE = 0, 1, 2 - def compute_trial_metrics(nwb): """ - Computes all trial by trial metrics + Computes trial by trial metrics response_rate, fraction of trials with a response gocue_reward_rate, fraction of trials with a reward @@ -38,6 +41,7 @@ def compute_trial_metrics(nwb): df_trials = nwb.df_trials.copy() + # --- Add reward-related columns --- df_trials["reward"] = False df_trials.loc[ @@ -143,6 +147,7 @@ def compute_trial_metrics(nwb): df_trials.loc[i, "n_valid_licks_right"] = 0 df_trials.loc[i, "n_valid_licks_all"] = 0 + df_trials["RESPONDED"] = [x in [0, 1] for x in df_trials["animal_response"].values] # Rolling fraction of goCues with a response df_trials["response_rate"] = ( @@ -168,6 +173,16 @@ def compute_trial_metrics(nwb): df_trials["WENT_RIGHT"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean() ) + # TODO, add from process_nwb + # trial duration (stop-time - start-time) (start/stop time, or gocue to gocue?) + # n_licks_left (# of left licks in response window) + # n_licks_left_total (# of left licks from goCue to next go cue) + # Same for Right, same for all + # intertrial choices (boolean) + # number of intertrial choices + # number of intertrial switches + # response switch or repeat + # Clean up temp columns drop_cols = [ "RESPONDED", @@ -176,4 +191,105 @@ def compute_trial_metrics(nwb): ] df_trials = df_trials.drop(columns=drop_cols) + + return df_trials + + +def compute_bias(nwb): + """ + Computes side bias by fitting a logistic regression model + returns trials table with the following columns: + bias, the side bias + bias_ci_lower, the lower confidence interval on the bias + bias_ci_upper, the uppwer confidence interval on the bias + """ + + # Parameters for computing bias + n_trials_back = 5 + max_window = 200 + cv = 1 + compute_every = 10 + BIAS_LIMIT = 10 + + # Make sure trials table has been computed + if not hasattr(nwb, "df_trials"): + print("You need to compute df_trials: nwb_utils.create_trials_df(nwb)") + return + + # extract choice and reward + df_trials = nwb.df_trials.copy() + df_trials["choice"] = [np.nan if x == 2 else x for x in df_trials["animal_response"]] + df_trials["reward"] = [ + any(x) for x in zip(df_trials["earned_reward"], df_trials["extra_reward"]) + ] + + # Set up lists to store results + bias = [] + ci_lower = [] + ci_upper = [] + C = [] + + # Iterate over trials and compute + compute_on = np.arange(compute_every, len(df_trials), compute_every) + for i in compute_on: + # Determine interval to compute on + start = np.max([0, i - max_window]) + end = i + + # extract choice and reward + choice = df_trials.loc[start:end]["choice"].values + reward = df_trials.loc[start:end]["reward"].values + + # Determine if we have valid data to fit model + unique = np.unique(choice[~np.isnan(choice)]) + if len(unique) == 0: + # no choices, report bias confidence as (-inf, +inf) + bias.append(np.nan) + ci_lower.append(-BIAS_LIMIT) + ci_upper.append(BIAS_LIMIT) + C.append(np.nan) + elif len(unique) == 2: + # Fit model + out = model.fit_logistic_regression( + choice, reward, n_trial_back=n_trials_back, cv=cv, fit_exponential=False + ) + bias.append(out["df_beta"].loc["bias"]["bootstrap_mean"].values[0]) + ci_lower.append(out["df_beta"].loc["bias"]["bootstrap_CI_lower"].values[0]) + ci_upper.append(out["df_beta"].loc["bias"]["bootstrap_CI_upper"].values[0]) + C.append(out["C"]) + elif unique[0] == 0: + # only left choices, report bias confidence as (-inf, 0) + bias.append(-1) + ci_lower.append(-BIAS_LIMIT) + ci_upper.append(0) + C.append(np.nan) + elif unique[0] == 1: + # only right choices, report bias confidence as (0, +inf) + bias.append(+1) + ci_lower.append(0) + ci_upper.append(BIAS_LIMIT) + C.append(np.nan) + + # Pack results into a dataframe + df = pd.DataFrame() + df["trial"] = compute_on + df["bias"] = bias + df["bias_ci_lower"] = ci_lower + df["bias_ci_upper"] = ci_upper + df["bias_C"] = C + + # merge onto trials dataframe + df_trials = pd.merge( + nwb.df_trials.drop(columns=["bias", "bias_ci_lower", "bias_ci_upper"], errors="ignore"), + df[["trial", "bias", "bias_ci_lower", "bias_ci_upper"]], + how="left", + on=["trial"], + ) + + # fill in bias on non-computed trials + df_trials["bias"] = df_trials["bias"].bfill().ffill() + df_trials["bias_ci_lower"] = df_trials["bias_ci_lower"].bfill().ffill() + df_trials["bias_ci_upper"] = df_trials["bias_ci_upper"].bfill().ffill() + + return df_trials diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session.py index 35860ca..8e0bde3 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session.py @@ -22,6 +22,56 @@ def moving_average(a, n=3): return ret[(n - 1) :] / n # noqa: E203 +def plot_foraging_session_nwb(nwb, **kwargs): + """ + Wrapper function that extracts fields + """ + + if not hasattr(nwb, "df_trials"): + print("You need to compute df_trials: nwb_utils.create_trials_df(nwb)") + return + + if "bias" not in nwb.df_trials: + fig, axes = plot_foraging_session( + [np.nan if x == 2 else x for x in nwb.df_trials["animal_response"].values], + nwb.df_trials["earned_reward"].values, + [nwb.df_trials["reward_probabilityL"], nwb.df_trials["reward_probabilityR"]], + **kwargs, + ) + else: + if "plot_list" not in kwargs: + kwargs["plot_list"] = ["choice", "finished", "reward_prob", "bias"] + fig, axes = plot_foraging_session( + [np.nan if x == 2 else x for x in nwb.df_trials["animal_response"].values], + nwb.df_trials["earned_reward"].values, + [nwb.df_trials["reward_probabilityL"], nwb.df_trials["reward_probabilityR"]], + bias=nwb.df_trials["bias"].values, + bias_lower=nwb.df_trials["bias_ci_lower"].values, + bias_upper=nwb.df_trials["bias_ci_upper"].values, + autowater_offered=nwb.df_trials[["auto_waterL", "auto_waterR"]].any(axis=1), + **kwargs, + ) + + # Add some text info + # TODO, waiting for AIND metadata to get integrated before adding this info: + # {df_session.metadata.rig.iloc[0]}, {df_session.metadata.user_name.iloc[0]}\n' + # f'FORAGING finished {df_session.session_stats.finished_trials.iloc[0]} ' + # f'ignored {df_session.session_stats.ignored_trials.iloc[0]} + ' + # f'AUTOWATER collected {df_session.session_stats.autowater_collected.iloc[0]} ' + # f'ignored {df_session.session_stats.autowater_ignored.iloc[0]}\n' + # f'FORAGING finished rate {df_session.session_stats.finished_rate.iloc[0]:.2%}, ' + axes[0].text( + 0, + 1.05, + f"{nwb.session_id}\n" + f'Total trials {len(nwb.df_trials)}, ignored {np.sum(nwb.df_trials["animal_response"]==2)},' + f' left {np.sum(nwb.df_trials["animal_response"] == 0)},' + f' right {np.sum(nwb.df_trials["animal_response"] == 1)}', + fontsize=8, + transform=axes[0].transAxes, + ) + + def plot_foraging_session( # noqa: C901 choice_history: Union[List, np.ndarray], reward_history: Union[List, np.ndarray], @@ -34,6 +84,10 @@ def plot_foraging_session( # noqa: C901 base_color: str = "y", ax: plt.Axes = None, vertical: bool = False, + bias: Union[List, np.ndarray] = None, + bias_lower: Union[List, np.ndarray] = None, + bias_upper: Union[List, np.ndarray] = None, + plot_list: List = ["choice", "finished", "reward_prob"], ) -> Tuple[plt.Figure, List[plt.Axes]]: """Plot dynamic foraging session. @@ -124,30 +178,78 @@ def plot_foraging_session( # noqa: C901 # Rewarded trials (real foraging, autowater excluded) xx = np.nonzero(rewarded_excluding_autowater)[0] + 1 yy = 0.5 + (choice_history[rewarded_excluding_autowater] - 0.5) * 1.4 - ax_choice_reward.plot( - *(xx, yy) if not vertical else [*(yy, xx)], - "|" if not vertical else "_", - color="black", - markersize=10, - markeredgewidth=2, - label="Rewarded choices", - ) + yy_temp = choice_history[rewarded_excluding_autowater] + yy_right = yy_temp[yy_temp > 0.5] + 0.05 + xx_right = xx[yy_temp > 0.5] + yy_left = yy_temp[yy_temp < 0.5] - 0.05 + xx_left = xx[yy_temp < 0.5] + if not vertical: + ax_choice_reward.vlines( + xx_right, + yy_right, + yy_right + 0.1, + alpha=1, + linewidth=1, + color="black", + label="Rewarded choices", + ) + ax_choice_reward.vlines( + xx_left, + yy_left - 0.1, + yy_left, + alpha=1, + linewidth=1, + color="black", + ) + else: + ax_choice_reward.plot( + *(xx, yy) if not vertical else [*(yy, xx)], + "|" if not vertical else "_", + color="black", + markersize=10, + markeredgewidth=2, + label="Rewarded choices", + ) # Unrewarded trials (real foraging; not ignored or autowater trials) xx = np.nonzero(unrewarded_trials)[0] + 1 yy = 0.5 + (choice_history[unrewarded_trials] - 0.5) * 1.4 - ax_choice_reward.plot( - *(xx, yy) if not vertical else [*(yy, xx)], - "|" if not vertical else "_", - color="gray", - markersize=6, - markeredgewidth=1, - label="Unrewarded choices", - ) + yy_temp = choice_history[unrewarded_trials] + yy_right = yy_temp[yy_temp > 0.5] + xx_right = xx[yy_temp > 0.5] + yy_left = yy_temp[yy_temp < 0.5] + xx_left = xx[yy_temp < 0.5] + if not vertical: + ax_choice_reward.vlines( + xx_right, + yy_right + 0.05, + yy_right + 0.1, + alpha=1, + linewidth=1, + color="gray", + label="Unrewarded choices", + ) + ax_choice_reward.vlines( + xx_left, + yy_left - 0.1, + yy_left - 0.05, + alpha=1, + linewidth=1, + color="gray", + ) + else: + ax_choice_reward.plot( + *(xx, yy) if not vertical else [*(yy, xx)], + "|" if not vertical else "_", + color="gray", + markersize=6, + markeredgewidth=1, + label="Unrewarded choices", + ) # Ignored trials xx = np.nonzero(ignored & ~autowater_ignored)[0] + 1 - yy = [1.1] * sum(ignored & ~autowater_ignored) + yy = [1.2] * sum(ignored & ~autowater_ignored) ax_choice_reward.plot( *(xx, yy) if not vertical else [*(yy, xx)], "x", @@ -162,18 +264,44 @@ def plot_foraging_session( # noqa: C901 # Autowater offered and collected xx = np.nonzero(autowater_collected)[0] + 1 yy = 0.5 + (choice_history[autowater_collected] - 0.5) * 1.4 - ax_choice_reward.plot( - *(xx, yy) if not vertical else [*(yy, xx)], - "|" if not vertical else "_", - color="royalblue", - markersize=10, - markeredgewidth=2, - label="Autowater collected", - ) + + yy_temp = choice_history[autowater_collected] + yy_right = yy_temp[yy_temp > 0.5] + 0.05 + xx_right = xx[yy_temp > 0.5] + yy_left = yy_temp[yy_temp < 0.5] - 0.05 + xx_left = xx[yy_temp < 0.5] + + if not vertical: + ax_choice_reward.vlines( + xx_right, + yy_right, + yy_right + 0.1, + alpha=1, + linewidth=1, + color="royalblue", + label="Autowater collected", + ) + ax_choice_reward.vlines( + xx_left, + yy_left - 0.1, + yy_left, + alpha=1, + linewidth=1, + color="royalblue", + ) + else: + ax_choice_reward.plot( + *(xx, yy) if not vertical else [*(yy, xx)], + "|" if not vertical else "_", + color="royalblue", + markersize=10, + markeredgewidth=2, + label="Autowater collected", + ) # Also highlight the autowater offered but still ignored xx = np.nonzero(autowater_ignored)[0] + 1 - yy = [1.1] * sum(autowater_ignored) + yy = [1.2] * sum(autowater_ignored) ax_choice_reward.plot( *(xx, yy) if not vertical else [*(yy, xx)], "x", @@ -186,12 +314,13 @@ def plot_foraging_session( # noqa: C901 # Base probability xx = np.arange(0, n_trials) + 1 yy = p_reward_fraction - ax_choice_reward.plot( - *(xx, yy) if not vertical else [*(yy, xx)], - color=base_color, - label="Base rew. prob.", - lw=1.5, - ) + if "reward_prob" in plot_list: + ax_choice_reward.plot( + *(xx, yy) if not vertical else [*(yy, xx)], + color=base_color, + label="Base rew. prob.", + lw=1.5, + ) # Smoothed choice history y = moving_average(choice_history, smooth_factor) / ( @@ -199,24 +328,37 @@ def plot_foraging_session( # noqa: C901 ) y[y > 100] = np.nan x = np.arange(0, len(y)) + int(smooth_factor / 2) + 1 - ax_choice_reward.plot( - *(x, y) if not vertical else [*(y, x)], - linewidth=1.5, - color="black", - label="Choice (smooth = %g)" % smooth_factor, - ) + if "choice" in plot_list: + ax_choice_reward.plot( + *(x, y) if not vertical else [*(y, x)], + linewidth=1.5, + color="black", + label="Choice (smooth = %g)" % smooth_factor, + ) # finished ratio if np.sum(np.isnan(choice_history)): x = np.arange(0, len(y)) + int(smooth_factor / 2) + 1 y = moving_average(~np.isnan(choice_history), smooth_factor) - ax_choice_reward.plot( - *(x, y) if not vertical else [*(y, x)], - linewidth=0.8, - color="m", - alpha=1, - label="Finished (smooth = %g)" % smooth_factor, - ) + if "finished" in plot_list: + ax_choice_reward.plot( + *(x, y) if not vertical else [*(y, x)], + linewidth=0.8, + color="m", + alpha=1, + label="Finished (smooth = %g)" % smooth_factor, + ) + + # Bias + if ("bias" in plot_list) and (bias is not None): + bias = (np.array(bias) + 1) / (2) + bias_lower = (np.array(bias_lower) + 1) / (2) + bias_upper = (np.array(bias_upper) + 1) / (2) + bias_lower[bias_lower < 0] = 0 + bias_upper[bias_upper > 1] = 1 + ax_choice_reward.plot(xx, bias, color="g", lw=1.5, label="bias") + ax_choice_reward.fill_between(xx, bias_lower, bias_upper, color="g", alpha=0.25) + ax_choice_reward.plot(xx, [0.5] * len(xx), color="g", linestyle="--", alpha=0.2, lw=1) # add valid ranage if valid_range is not None: @@ -267,18 +409,17 @@ def plot_foraging_session( # noqa: C901 ax_reward_schedule.legend(fontsize=5, ncol=1, loc="upper left", bbox_to_anchor=(0, 1)) if not vertical: - ax_choice_reward.set_yticks([0, 1]) - ax_choice_reward.set_yticklabels(["Left", "Right"]) - ax_choice_reward.legend(fontsize=6, loc="upper left", bbox_to_anchor=(0.6, 1.3), ncol=3) + ax_choice_reward.set_yticks([0, 1, 1.2]) + ax_choice_reward.set_yticklabels(["Left", "Right", "Ignored"]) + ax_choice_reward.legend(fontsize=6, loc="upper left", bbox_to_anchor=(0.4, 1.3), ncol=3) - # sns.despine(trim=True, bottom=True, ax=ax_1) ax_choice_reward.spines["top"].set_visible(False) ax_choice_reward.spines["right"].set_visible(False) ax_choice_reward.spines["bottom"].set_visible(False) ax_choice_reward.tick_params(labelbottom=False) ax_choice_reward.xaxis.set_ticks_position("none") + ax_choice_reward.set_ylim([-0.15, 1.25]) - # sns.despine(trim=True, ax=ax_2) ax_reward_schedule.set_ylim([0, 1]) ax_reward_schedule.spines["top"].set_visible(False) ax_reward_schedule.spines["right"].set_visible(False) @@ -305,5 +446,6 @@ def plot_foraging_session( # noqa: C901 ax_reward_schedule.set(ylabel="Trial number") ax.remove() + plt.tight_layout() return ax_choice_reward.get_figure(), [ax_choice_reward, ax_reward_schedule] diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py index 7c4b0ec..1491dbd 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_session_scroller.py @@ -18,7 +18,7 @@ def plot_session_scroller( # noqa: C901 pragma: no cover nwb, ax=None, fig=None, - plot_bouts=False, + plot_bouts=True, processing="bright", metrics=["pR", "pL", "response_rate"], ): @@ -38,6 +38,11 @@ def plot_session_scroller( # noqa: C901 pragma: no cover plot_bouts (bool), if True, plot licks colored by segmented lick bouts + processing (str) processing method for FIP data to plot + + metrics (list of strings), list of metrics to plot. Must be either 'pR','pL' or + columns in nwb.df_trials + EXAMPLES: plot_foraging_session.plot_session_scroller(nwb) plot_foraging_session.plot_session_scroller(nwb, plot_bouts=True) @@ -333,11 +338,11 @@ def plot_session_scroller( # noqa: C901 pragma: no cover if "pR" in metrics: pR = params["metrics_bottom"] + df_trials["reward_probabilityR"] pR = np.repeat(pR, 2)[:-1] - ax.plot(go_cue_times_doubled, pR, color="r", label="pR") + ax.plot(go_cue_times_doubled, pR, color="b", label="pR") if "pL" in metrics: pL = params["metrics_bottom"] + df_trials["reward_probabilityL"] pL = np.repeat(pL, 2)[:-1] - ax.plot(go_cue_times_doubled, pL, color="b", label="pL") + ax.plot(go_cue_times_doubled, pL, color="r", label="pL") # plot metrics if they are available for metric in metrics: diff --git a/tests/data/test_plot_session.png b/tests/data/test_plot_session.png index 07082ca..eff8dac 100644 Binary files a/tests/data/test_plot_session.png and b/tests/data/test_plot_session.png differ diff --git a/tests/data/test_plot_session_vertical.png b/tests/data/test_plot_session_vertical.png index 2879d16..0d975d5 100644 Binary files a/tests/data/test_plot_session_vertical.png and b/tests/data/test_plot_session_vertical.png differ diff --git a/tests/test_plot_foraging_session.py b/tests/test_plot_foraging_session.py index 162a586..cf5f134 100644 --- a/tests/test_plot_foraging_session.py +++ b/tests/test_plot_foraging_session.py @@ -8,11 +8,21 @@ import unittest import numpy as np +import pandas as pd from aind_dynamic_foraging_basic_analysis import plot_foraging_session +import aind_dynamic_foraging_basic_analysis.plot.plot_foraging_session as pfs from tests.nwb_io import get_history_from_nwb +class EmptyNWB: + """ + Just an empty class for saving attributes to + """ + + pass + + class TestPlotSession(unittest.TestCase): """Test plot session""" @@ -29,6 +39,36 @@ def setUpClass(cls): _, ) = get_history_from_nwb(nwb_file) + def test_nwb_wrapper(self): + """ + Test wrapper function that plots foraging session from nwb file + """ + # Test we have df_trials + nwb = EmptyNWB() + pfs.plot_foraging_session_nwb(nwb) + + # Test without bias column + choices = np.array([0, 0, 1, 1, 2, 2]) + rewards = np.array([True, False, True, False, False, False]) + pL = [0.1] * 6 + pR = [0.8] * 6 + df = pd.DataFrame() + df["animal_response"] = choices + df["earned_reward"] = rewards + df["reward_probabilityL"] = pL + df["reward_probabilityR"] = pR + df["auto_waterL"] = [0] * 6 + df["auto_waterR"] = [0] * 6 + nwb.df_trials = df + nwb.session_id = "test" + pfs.plot_foraging_session_nwb(nwb) + + # Test with bias column + nwb.df_trials["bias"] = np.array([0, 0, 0.1, 0.1, 0.05, 0.05]) + nwb.df_trials["bias_ci_lower"] = np.array([0] * 6) + nwb.df_trials["bias_ci_upper"] = np.array([0.2] * 6) + pfs.plot_foraging_session_nwb(nwb) + def test_plot_session(self): """Test plot real session""" # Add some fake data for testing @@ -101,6 +141,20 @@ def test_plot_session_vertical(self): bbox_inches="tight", ) + fig, _ = plot_foraging_session( + choice_history=self.choice_history, + reward_history=self.reward_history, + p_reward=self.p_reward, + autowater_offered=np.array([0] * len(self.choice_history)), + fitted_data=None, + photostim=None, # trial, power, s_type + valid_range=None, + smooth_factor=5, + base_color="y", + ax=None, + vertical=True, + ) + def test_plot_session_wrong_format(self): """Some wrong input format""" with self.assertRaises(ValueError):