Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change absolute time reference to first go cue #50

Merged
merged 17 commits into from
Nov 9, 2024
216 changes: 127 additions & 89 deletions src/aind_dynamic_foraging_data_utils/nwb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import os
import re
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -46,6 +47,7 @@ def unpack_metadata(nwb):
Unpacks metadata as a dictionary attribute, instead of a Dynamic
table nested inside a dictionary
"""
# TODO, this should be outdated once we fix the NWB files themselves
nwb.metadata = nwb.scratch["metadata"].to_dataframe().iloc[0].to_dict()


Expand Down Expand Up @@ -297,20 +299,15 @@ def create_single_df_session(nwb_filename):
return df_session


def create_df_trials(nwb_filename):
def create_df_trials(nwb_filename, adjust_time=True):
"""
Process nwb and create df_trials for every single session

adjust_time (bool) if true, adjust t0 to be the first gocue
"""
nwb = load_nwb_from_filename(nwb_filename)

key_from_acq = [
"left_lick_time",
"right_lick_time",
"left_reward_delivery_time",
"right_reward_delivery_time",
"FIP_falling_time",
"FIP_rising_time",
]
# If we are given a filename, load the NWB object itself
nwb = load_nwb_from_filename(nwb_filename)

# Parse subject and session_date
if nwb.session_id.startswith("behavior") or nwb.session_id.startswith("FIP"):
Expand All @@ -321,101 +318,132 @@ def create_df_trials(nwb_filename):
splits = nwb.session_id.split("_")
subject_id = splits[0]
session_date = splits[1]

ses_idx = subject_id + "_" + session_date

df_ses_trials = nwb.trials.to_dataframe().reset_index()
df_ses_trials = df_ses_trials.rename(columns={"id": "trial"})
df_ses_trials["ses_idx"] = ses_idx
# Build dataframe
df = nwb.trials.to_dataframe().reset_index()
df = df.rename(columns={"id": "trial"})
df["ses_idx"] = ses_idx

# Adjust all times relative to start of the first trial
t0 = df_ses_trials.start_time[0]
# Adjust for gaps in trial start/stop, and use the last stop time
last_stop = df.iloc[-1]["stop_time"]
df["stop_time"] = df["start_time"].shift(-1, fill_value=last_stop)

# We skip these columns because they are how long the valve is open
# not the times at which the valves were opened
skip_cols = ["right_valve_open_time", "left_valve_open_time"]
for col in df_ses_trials.columns:

# compute times relative to start of trial and start of session
t0 = nwb.trials.goCue_start_time[0]
drop_cols = []
for col in df.columns:
if ("time" in col) and (col not in skip_cols):
df_ses_trials[col + "_absolute"] = df_ses_trials[col] - t0
# Adjust all times relative to start of the first go cue
if adjust_time:
df[col + "_in_session"] = df[col] - t0
else:
df[col + "_in_session"] = df[col]

# Adjust for gaps in trial start/stop, and use the last stop time
last_stop = df_ses_trials.iloc[-1]["stop_time_absolute"]
df_ses_trials["stop_time_absolute"] = df_ses_trials["start_time_absolute"].shift(
-1, fill_value=last_stop
)
# Adjust times relative to go cue on each trial
if ("time" in col) and (col not in skip_cols):
df[col + "_in_trial"] = df[col].values - df["goCue_start_time"].values
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
alexpiet marked this conversation as resolved.
Show resolved Hide resolved

# Adjust times relative to go cue
for col in df_ses_trials.columns:
if (
("time" in col)
and ("time_absolute" not in col)
and (col != "goCue_start_time")
and (col not in skip_cols)
):
df_ses_trials.loc[:, col] = (
df_ses_trials[col].values - df_ses_trials["goCue_start_time"].values
)
df_ses_trials["goCue_start_time"] = 0.0

# Adjust event times relative to trial
events_ses = {key: nwb.acquisition[key].timestamps[:] - t0 for key in key_from_acq}
for event in [
# Clean up these column names that are not clear
drop_cols.append(col)

# Get lick and reward times
key_from_acq = [
"left_lick_time",
"right_lick_time",
"left_reward_delivery_time",
"right_reward_delivery_time",
]:
event_times = events_ses[event]
df_ses_trials[event] = df_ses_trials.apply(
lambda x: np.round(
event_times[
(event_times > (x["goCue_start_time"] + x["goCue_start_time_absolute"]))
& (event_times < (x["stop_time"] + x["goCue_start_time_absolute"]))
]
- x["goCue_start_time_absolute"],
4,
),
]
if adjust_time:
events = {key: nwb.acquisition[key].timestamps[:] - t0 for key in key_from_acq}
else:
events = {key: nwb.acquisition[key].timestamps[:] for key in key_from_acq}

# Map events to trials
# Here we map an event to the most recent goCue
df["next_goCue_start_time_in_session"] = df["goCue_start_time_in_session"].shift(
-1, fill_value=np.inf
)
drop_cols.append("next_goCue_start_time_in_session")
for event in key_from_acq:
event_times = events[event]
df[event] = df.apply(
lambda x: event_times[
(event_times >= x["goCue_start_time_in_session"])
& (event_times < x["next_goCue_start_time_in_session"])
],
axis=1,
)

# Compute time of reward for each trial
df_ses_trials["reward_time"] = df_ses_trials.apply(
lambda x: np.nanmin(
np.concatenate(
[
[np.nan],
x["right_reward_delivery_time"],
x["left_reward_delivery_time"],
]
)
),
axis=1,
)
df_ses_trials["reward_time_absolute"] = (
df_ses_trials["reward_time"] + df_ses_trials["goCue_start_time_absolute"]
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
df["reward_time_in_session"] = df.apply(
lambda x: np.nanmin(
np.concatenate(
[
[np.nan],
x["right_reward_delivery_time"],
x["left_reward_delivery_time"],
]
)
),
axis=1,
)
df["reward_time_in_trial"] = df["reward_time_in_session"] - df["goCue_start_time_in_session"]

# Compute time of choice for each trials
df_ses_trials["choice_time"] = df_ses_trials.apply(
lambda x: np.nanmin(np.concatenate([[np.nan], x["right_lick_time"], x["left_lick_time"]])),
axis=1,
)
df_ses_trials["choice_time_absolute"] = (
df_ses_trials["choice_time"] + df_ses_trials["goCue_start_time_absolute"]
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
df["choice_time_in_session"] = df.apply(
lambda x: np.nanmin(
np.concatenate([[np.nan], x["right_lick_time"], x["left_lick_time"]])
),
axis=1,
)
df["choice_time_in_trial"] = df["choice_time_in_session"] - df["goCue_start_time_in_session"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, same thing here-- yet another note to "goCue_start_time" that is buried in here


# Filtering out choices greater than response window
slow_choice = df["choice_time_in_trial"] > df["response_duration"]
df.loc[slow_choice, "choice_time_in_session"] = np.nan
df.loc[slow_choice, "choice_time_in_trial"] = np.nan

# Compute boolean of whether animal was rewarded
df_ses_trials["reward"] = df_ses_trials.rewarded_historyR.astype(
int
) | df_ses_trials.rewarded_historyL.astype(int)
df["earned_reward"] = df.rewarded_historyR.astype(int) | df.rewarded_historyL.astype(int)
df["extra_reward"] = (df["earned_reward"] == 0) & df["reward_time_in_session"].notnull()
alexpiet marked this conversation as resolved.
Show resolved Hide resolved

# Sanity checks
rewarded_df = df.query("earned_reward == 1")
assert (
np.isnan(rewarded_df["reward_time_in_session"]).sum() == 0
), "Rewarded trials without reward time"
assert (
np.isnan(rewarded_df["choice_time_in_session"]).sum() == 0
), "Rewarded trials without choice time"
assert np.all(
rewarded_df["choice_time_in_session"] <= rewarded_df["reward_time_in_session"]
), "Reward before choice time"
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
assert np.all(
rewarded_df["choice_time_in_trial"] >= 0
), "Rewarded trial with negative choice_time_in_trial"
assert np.all(
np.isnan(
df.query("earned_reward == 0").query("extra_reward == 0")["reward_time_in_session"]
)
), "Unrewarded trials with reward time"
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
# TODO, documentation of added columns

# Drop columns
df_ses_trials = df_ses_trials.drop(
columns=[
"left_lick_time",
"right_lick_time",
"left_reward_delivery_time",
"right_reward_delivery_time",
]
)
return df_ses_trials
drop_cols += key_from_acq
df = df.drop(columns=drop_cols)

if adjust_time:
print("Timestamps are adjusted so t(0) = first go cue")
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
return df


def create_events_df(nwb_filename, adjust_time=True):
Expand Down Expand Up @@ -461,8 +489,8 @@ def create_events_df(nwb_filename, adjust_time=True):
)
event_types -= ignore_types

# Determine time 0
t0 = nwb.trials.start_time[0]
# Determine time 0 as first go Cue
t0 = nwb.trials.goCue_start_time[0]

# Iterate over event types and build a dataframe of each
events = []
Expand Down Expand Up @@ -494,8 +522,8 @@ def create_events_df(nwb_filename, adjust_time=True):
df = df.dropna(subset="timestamps").reset_index(drop=True)

# Add trial index for each event
trial_starts = nwb.trials.start_time[:] - nwb.trials.start_time[0]
last_stop = nwb.trials.stop_time[-1] - nwb.trials.start_time[0]
trial_starts = nwb.trials.start_time[:] - nwb.trials.goCue_start_time[0]
last_stop = nwb.trials.stop_time[-1] - nwb.trials.goCue_start_time[0]
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
trial_index = []
for index, e in df.iterrows():
starts = np.where(e.timestamps > trial_starts)[0]
Expand All @@ -507,6 +535,13 @@ def create_events_df(nwb_filename, adjust_time=True):
trial_index.append(starts[-1])
df["trial"] = trial_index

# Sanity check that the first go cue is time 0
gocues = df.query('event == "goCue_start_time"')
if (len(gocues) > 0) and (adjust_time):
assert np.isclose(gocues.iloc[0]["timestamps"], 0, rtol=0.01)

if adjust_time:
print("Timestamps are adjusted so t(0) = first go cue")
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
return df


Expand Down Expand Up @@ -560,7 +595,7 @@ def create_fib_df(nwb_filename, tidy=True, adjust_time=True):
return None

# Determine time 0
t0 = nwb.trials.start_time[0]
t0 = nwb.trials.goCue_start_time[0]

# Iterate over event types and build a dataframe of each
events = []
Expand Down Expand Up @@ -591,6 +626,9 @@ def create_fib_df(nwb_filename, tidy=True, adjust_time=True):
ses_idx = subject_id + "_" + session_date
df["ses_idx"] = ses_idx

if adjust_time:
print("Timestamps are adjusted so t(0) = first go cue")
alexpiet marked this conversation as resolved.
Show resolved Hide resolved

# pivot table based on timestamps
if not tidy:
df_pivoted = pd.pivot(df, index="timestamps", columns=["event"], values="data")
Expand Down
Loading