Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additional trial by trial metrics - bias and plotting updates #29

Merged
merged 29 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 143 additions & 4 deletions src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
"""
Tools for computing trial by trial metrics
df_trials = compute_all_trial_metrics(nwb)
df_trials = compute_trial_metrics(nwb)
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
df_trials = compute_bias(nwb)

"""

import pandas as pd
import numpy as np

import aind_dynamic_foraging_models.logistic_regression.model as model

# TODO, we might want to make these parameters metric specific
WIN_DUR = 15
MIN_EVENTS = 2


def compute_all_trial_metrics(nwb):
def compute_trial_metrics(nwb):
"""
Computes all trial by trial metrics
Computes trial by trial metrics

response_rate, fraction of trials with a response
gocue_reward_rate, fraction of trials with a reward
Expand Down Expand Up @@ -54,8 +58,143 @@ def compute_all_trial_metrics(nwb):
df["WENT_RIGHT"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean()
)

# Rolling reward probability for best option
df["IDEAL_OBSERVER_REWARD_PROB"] = df[["reward_probabilityR", "reward_probabilityL"]].max(
axis=1
)
df["ideal_observer_reward_rate"] = (
df["IDEAL_OBSERVER_REWARD_PROB"]
.rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True)
.mean()
)

# Rolling reward probability for best option with baiting
if "bait_left" in df:
df["IDEAL_OBSERVER_REWARD_PROB_WITH_BAITING"] = [
1 if (x[0] or x[1]) else x[2]
for x in zip(df["bait_left"], df["bait_right"], df["IDEAL_OBSERVER_REWARD_PROB"])
]
df["ideal_observer_reward_rate_with_baiting"] = (
df["IDEAL_OBSERVER_REWARD_PROB_WITH_BAITING"]
.rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True)
.mean()
)

# TODO, add from process_nwb
# trial duration (stop-time - start-time) (start/stop time, or gocue to gocue?)
# n_licks_left (# of left licks in response window)
# n_licks_left_total (# of left licks from goCue to next go cue)
# Same for Right, same for all
# intertrial choices (boolean)
# number of intertrial choices
# number of intertrial switches
# response switch or repeat

# Clean up temp columns
drop_cols = ["RESPONDED", "RESPONSE_REWARD", "WENT_RIGHT"]
drop_cols = [
"RESPONDED",
"RESPONSE_REWARD",
"WENT_RIGHT",
"IDEAL_OBSERVER_REWARD_PROB",
"IDEAL_OBSERVER_REWARD_PROB_WITH_BAITING",
]
df = df.drop(columns=drop_cols)

return df


def compute_bias(nwb):
"""
Computes side bias by fitting a logistic regression model
returns trials table with the following columns:
bias, the side bias
bias_ci_lower, the lower confidence interval on the bias
bias_ci_upper, the uppwer confidence interval on the bias
"""

# Parameters for computing bias
n_trials_back = 5
max_window = 200
cv = 1
compute_every = 10
BIAS_LIMIT = 10

# Make sure trials table has been computed
if not hasattr(nwb, "df_trials"):
print("You need to compute df_trials: nwb_utils.create_trials_df(nwb)")
return

# extract choice and reward
df = nwb.df_trials.copy()
df["choice"] = [np.nan if x == 2 else x for x in df["animal_response"]]
df["reward"] = [any(x) for x in zip(df["earned_reward"], df["extra_reward"])]

# Set up lists to store results
bias = []
ci_lower = []
ci_upper = []
C = []

# Iterate over trials and compute
compute_on = np.arange(compute_every, len(df), compute_every)
for i in compute_on:
# Determine interval to compute on
start = np.max([0, i - max_window])
end = i

# extract choice and reward
choice = df.loc[start:end]["choice"].values
reward = df.loc[start:end]["reward"].values

# Determine if we have valid data to fit model
unique = np.unique(choice[~np.isnan(choice)])
if len(unique) == 0:
# no choices, report bias confidence as (-inf, +inf)
bias.append(np.nan)
ci_lower.append(-BIAS_LIMIT)
ci_upper.append(BIAS_LIMIT)
C.append(np.nan)
elif len(unique) == 2:
# Fit model
out = model.fit_logistic_regression(
choice, reward, n_trial_back=n_trials_back, cv=cv, fit_exponential=False
)
bias.append(out["df_beta"].loc["bias"]["bootstrap_mean"].values[0])
ci_lower.append(out["df_beta"].loc["bias"]["bootstrap_CI_lower"].values[0])
ci_upper.append(out["df_beta"].loc["bias"]["bootstrap_CI_upper"].values[0])
C.append(out["C"])
elif unique[0] == 0:
# only left choices, report bias confidence as (-inf, 0)
bias.append(-1)
ci_lower.append(-BIAS_LIMIT)
ci_upper.append(0)
C.append(np.nan)
elif unique[0] == 1:
# only right choices, report bias confidence as (0, +inf)
bias.append(+1)
ci_lower.append(0)
ci_upper.append(BIAS_LIMIT)
C.append(np.nan)

# Pack results into a dataframe
df = pd.DataFrame()
df["trial"] = compute_on
df["bias"] = bias
df["bias_ci_lower"] = ci_lower
df["bias_ci_upper"] = ci_upper
df["bias_C"] = C

# merge onto trials dataframe
df_trials = pd.merge(
nwb.df_trials.drop(columns=["bias", "bias_ci_lower", "bias_ci_upper"], errors="ignore"),
df[["trial", "bias", "bias_ci_lower", "bias_ci_upper"]],
how="left",
on=["trial"],
)

# fill in bias on non-computed trials
df_trials["bias"] = df_trials["bias"].bfill().ffill()
df_trials["bias_ci_lower"] = df_trials["bias_ci_lower"].bfill().ffill()
df_trials["bias_ci_upper"] = df_trials["bias_ci_upper"].bfill().ffill()

return df_trials
107 changes: 87 additions & 20 deletions src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,54 @@ def moving_average(a, n=3):
return ret[(n - 1) :] / n # noqa: E203


def plot_foraging_session_nwb(nwb, **kwargs):
"""
Wrapper function that extracts fields
"""

if not hasattr(nwb, "df_trials"):
print("You need to compute df_trials: nwb_utils.create_trials_df(nwb)")
return

if "bias" not in nwb.df_trials:
fig, axes = plot_foraging_session(
[np.nan if x == 2 else x for x in nwb.df_trials["animal_response"].values],
nwb.df_trials["earned_reward"].values,
[nwb.df_trials["reward_probabilityL"], nwb.df_trials["reward_probabilityR"]],
**kwargs,
)
else:
fig, axes = plot_foraging_session(
[np.nan if x == 2 else x for x in nwb.df_trials["animal_response"].values],
nwb.df_trials["earned_reward"].values,
[nwb.df_trials["reward_probabilityL"], nwb.df_trials["reward_probabilityR"]],
bias=nwb.df_trials["bias"].values,
bias_lower=nwb.df_trials["bias_ci_lower"].values,
bias_upper=nwb.df_trials["bias_ci_upper"].values,
plot_list=["bias"],
**kwargs,
)

# Add some text info
# TODO, waiting for AIND metadata to get integrated before adding this info:
# {df_session.metadata.rig.iloc[0]}, {df_session.metadata.user_name.iloc[0]}\n'
# f'FORAGING finished {df_session.session_stats.finished_trials.iloc[0]} '
# f'ignored {df_session.session_stats.ignored_trials.iloc[0]} + '
# f'AUTOWATER collected {df_session.session_stats.autowater_collected.iloc[0]} '
# f'ignored {df_session.session_stats.autowater_ignored.iloc[0]}\n'
# f'FORAGING finished rate {df_session.session_stats.finished_rate.iloc[0]:.2%}, '
axes[0].text(
0,
1.05,
f"{nwb.session_id}\n"
f'Total trials {len(nwb.df_trials)}, ignored {np.sum(nwb.df_trials["animal_response"]==2)},'
f' left {np.sum(nwb.df_trials["animal_response"] == 0)},'
f' right {np.sum(nwb.df_trials["animal_response"] == 1)}',
fontsize=8,
transform=axes[0].transAxes,
)


def plot_foraging_session( # noqa: C901
choice_history: Union[List, np.ndarray],
reward_history: Union[List, np.ndarray],
Expand All @@ -34,6 +82,10 @@ def plot_foraging_session( # noqa: C901
base_color: str = "y",
ax: plt.Axes = None,
vertical: bool = False,
bias: Union[List, np.ndarray] = None,
bias_lower: Union[List, np.ndarray] = None,
bias_upper: Union[List, np.ndarray] = None,
plot_list: List = ["choice", "finished", "reward_prob"],
) -> Tuple[plt.Figure, List[plt.Axes]]:
"""Plot dynamic foraging session.

Expand Down Expand Up @@ -186,37 +238,51 @@ def plot_foraging_session( # noqa: C901
# Base probability
xx = np.arange(0, n_trials) + 1
yy = p_reward_fraction
ax_choice_reward.plot(
*(xx, yy) if not vertical else [*(yy, xx)],
color=base_color,
label="Base rew. prob.",
lw=1.5,
)
if "reward_prob" in plot_list:
ax_choice_reward.plot(
*(xx, yy) if not vertical else [*(yy, xx)],
color=base_color,
label="Base rew. prob.",
lw=1.5,
)

# Smoothed choice history
y = moving_average(choice_history, smooth_factor) / (
moving_average(~np.isnan(choice_history), smooth_factor) + 1e-6
)
y[y > 100] = np.nan
x = np.arange(0, len(y)) + int(smooth_factor / 2) + 1
ax_choice_reward.plot(
*(x, y) if not vertical else [*(y, x)],
linewidth=1.5,
color="black",
label="Choice (smooth = %g)" % smooth_factor,
)
if "choice" in plot_list:
ax_choice_reward.plot(
*(x, y) if not vertical else [*(y, x)],
linewidth=1.5,
color="black",
label="Choice (smooth = %g)" % smooth_factor,
)

# finished ratio
if np.sum(np.isnan(choice_history)):
x = np.arange(0, len(y)) + int(smooth_factor / 2) + 1
y = moving_average(~np.isnan(choice_history), smooth_factor)
ax_choice_reward.plot(
*(x, y) if not vertical else [*(y, x)],
linewidth=0.8,
color="m",
alpha=1,
label="Finished (smooth = %g)" % smooth_factor,
)
if "finished" in plot_list:
ax_choice_reward.plot(
*(x, y) if not vertical else [*(y, x)],
linewidth=0.8,
color="m",
alpha=1,
label="Finished (smooth = %g)" % smooth_factor,
)

# Bias
if ("bias" in plot_list) and (bias is not None):
bias = (np.array(bias) + 1) / (2)
bias_lower = (np.array(bias_lower) + 1) / (2)
bias_upper = (np.array(bias_upper) + 1) / (2)
bias_lower[bias_lower < 0] = 0
bias_upper[bias_upper > 1] = 1
ax_choice_reward.plot(xx, bias, color="g", lw=1.5, label="bias")
ax_choice_reward.fill_between(xx, bias_lower, bias_upper, color="g", alpha=0.25)
ax_choice_reward.plot(xx, [0.5] * len(xx), color="g", linestyle="--", alpha=0.2, lw=1)

# add valid ranage
if valid_range is not None:
Expand Down Expand Up @@ -269,7 +335,7 @@ def plot_foraging_session( # noqa: C901
if not vertical:
ax_choice_reward.set_yticks([0, 1])
ax_choice_reward.set_yticklabels(["Left", "Right"])
ax_choice_reward.legend(fontsize=6, loc="upper left", bbox_to_anchor=(0.6, 1.3), ncol=3)
ax_choice_reward.legend(fontsize=6, loc="upper left", bbox_to_anchor=(0.4, 1.3), ncol=3)

# sns.despine(trim=True, bottom=True, ax=ax_1)
ax_choice_reward.spines["top"].set_visible(False)
Expand Down Expand Up @@ -305,5 +371,6 @@ def plot_foraging_session( # noqa: C901
ax_reward_schedule.set(ylabel="Trial number")

ax.remove()
plt.tight_layout()

return ax_choice_reward.get_figure(), [ax_choice_reward, ax_reward_schedule]
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def plot_session_scroller( # noqa: C901 pragma: no cover

plot_bouts (bool), if True, plot licks colored by segmented lick bouts

processing (str) processing method for FIP data to plot

metrics (list of strings), list of metrics to plot. Must be either 'pR','pL' or
columns in nwb.df_trials

EXAMPLES:
plot_foraging_session.plot_session_scroller(nwb)
plot_foraging_session.plot_session_scroller(nwb, plot_bouts=True)
Expand Down Expand Up @@ -333,11 +338,11 @@ def plot_session_scroller( # noqa: C901 pragma: no cover
if "pR" in metrics:
pR = params["metrics_bottom"] + df_trials["reward_probabilityR"]
pR = np.repeat(pR, 2)[:-1]
ax.plot(go_cue_times_doubled, pR, color="r", label="pR")
ax.plot(go_cue_times_doubled, pR, color="b", label="pR")
if "pL" in metrics:
pL = params["metrics_bottom"] + df_trials["reward_probabilityL"]
pL = np.repeat(pL, 2)[:-1]
ax.plot(go_cue_times_doubled, pL, color="b", label="pL")
ax.plot(go_cue_times_doubled, pL, color="r", label="pL")

# plot metrics if they are available
for metric in metrics:
Expand Down
Binary file modified tests/data/test_plot_session.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/data/test_plot_session_vertical.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading