Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change absolute time reference to first go cue #50

Merged
merged 17 commits into from
Nov 9, 2024
271 changes: 168 additions & 103 deletions src/aind_dynamic_foraging_data_utils/nwb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@

import os
import re
import warnings

import numpy as np
import pandas as pd
from pynwb import NWBHDF5IO
from hdmf_zarr import NWBZarrIO

# If we adjust time_in_session, adjust it to this
SESSION_ALIGNMENT = "goCue_start_time"


def load_nwb_from_filename(filename):
"""
Expand Down Expand Up @@ -46,6 +50,7 @@ def unpack_metadata(nwb):
Unpacks metadata as a dictionary attribute, instead of a Dynamic
table nested inside a dictionary
"""
# TODO, this should be outdated once we fix the NWB files themselves
nwb.metadata = nwb.scratch["metadata"].to_dataframe().iloc[0].to_dict()


Expand Down Expand Up @@ -297,20 +302,25 @@ def create_single_df_session(nwb_filename):
return df_session


def create_df_trials(nwb_filename):
def create_df_trials(nwb_filename, adjust_time=True):
"""
Process nwb and create df_trials for every single session

ARGS:
nwb_filename (str or NWB object), the session to extract the trials from
adjust_time (bool) if true, adjust t0 to be the first gocue

RETURNS:
A pandas dataframe containing the columns of nwb.trials plus:
"_in_trial" time alignments where time is relative to the go cue on that trial
"_in_session" time alignments where time is relative to the first go cue
of the session.
earned_reward, (bool) whether a reward was earned in that trial
extra_reward (bool) whether a manual reward was given in that trial
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
"""
nwb = load_nwb_from_filename(nwb_filename)

key_from_acq = [
"left_lick_time",
"right_lick_time",
"left_reward_delivery_time",
"right_reward_delivery_time",
"FIP_falling_time",
"FIP_rising_time",
]
# If we are given a filename, load the NWB object itself
nwb = load_nwb_from_filename(nwb_filename)

# Parse subject and session_date
if nwb.session_id.startswith("behavior") or nwb.session_id.startswith("FIP"):
Expand All @@ -321,101 +331,140 @@ def create_df_trials(nwb_filename):
splits = nwb.session_id.split("_")
subject_id = splits[0]
session_date = splits[1]

ses_idx = subject_id + "_" + session_date

df_ses_trials = nwb.trials.to_dataframe().reset_index()
df_ses_trials = df_ses_trials.rename(columns={"id": "trial"})
df_ses_trials["ses_idx"] = ses_idx
# Build dataframe
df = nwb.trials.to_dataframe().reset_index()
df = df.rename(columns={"id": "trial"})
df["ses_idx"] = ses_idx

# Adjust for gaps in trial start/stop, and use the last stop time
last_stop = df.iloc[-1]["stop_time"]
df["stop_time"] = df["start_time"].shift(-1, fill_value=last_stop)

# Adjust all times relative to start of the first trial
t0 = df_ses_trials.start_time[0]
# We skip these columns because they are how long the valve is open
# not the times at which the valves were opened
skip_cols = ["right_valve_open_time", "left_valve_open_time"]
for col in df_ses_trials.columns:

# compute times relative to start of trial and start of session
t0 = nwb.trials[SESSION_ALIGNMENT][0]
drop_cols = []
for col in df.columns:
if ("time" in col) and (col not in skip_cols):
df_ses_trials[col + "_absolute"] = df_ses_trials[col] - t0
# Adjust all times relative to start of the first go cue
if adjust_time:
df[col + "_in_session"] = df[col] - t0
else:
df[col + "_in_session"] = df[col]

# Adjust for gaps in trial start/stop, and use the last stop time
last_stop = df_ses_trials.iloc[-1]["stop_time_absolute"]
df_ses_trials["stop_time_absolute"] = df_ses_trials["start_time_absolute"].shift(
-1, fill_value=last_stop
)
# Adjust times relative to go cue on each trial
if ("time" in col) and (col not in skip_cols):
df[col + "_in_trial"] = df[col].values - df["goCue_start_time"].values
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
alexpiet marked this conversation as resolved.
Show resolved Hide resolved

# Adjust times relative to go cue
for col in df_ses_trials.columns:
if (
("time" in col)
and ("time_absolute" not in col)
and (col != "goCue_start_time")
and (col not in skip_cols)
):
df_ses_trials.loc[:, col] = (
df_ses_trials[col].values - df_ses_trials["goCue_start_time"].values
)
df_ses_trials["goCue_start_time"] = 0.0

# Adjust event times relative to trial
events_ses = {key: nwb.acquisition[key].timestamps[:] - t0 for key in key_from_acq}
for event in [
# Clean up these column names that are not clear
drop_cols.append(col)

# Add a column of raw time so users can map if they want
df[SESSION_ALIGNMENT + "_raw"] = df[SESSION_ALIGNMENT]

# Get lick and reward times
key_from_acq = [
"left_lick_time",
"right_lick_time",
"left_reward_delivery_time",
"right_reward_delivery_time",
]:
event_times = events_ses[event]
df_ses_trials[event] = df_ses_trials.apply(
lambda x: np.round(
event_times[
(event_times > (x["goCue_start_time"] + x["goCue_start_time_absolute"]))
& (event_times < (x["stop_time"] + x["goCue_start_time_absolute"]))
]
- x["goCue_start_time_absolute"],
4,
),
]
if adjust_time:
events = {key: nwb.acquisition[key].timestamps[:] - t0 for key in key_from_acq}
else:
events = {key: nwb.acquisition[key].timestamps[:] for key in key_from_acq}

# Map events to trials
# Here we map an event to the most recent goCue
df["next_goCue_start_time_in_session"] = df["goCue_start_time_in_session"].shift(
-1, fill_value=np.inf
)
drop_cols.append("next_goCue_start_time_in_session")
for event in key_from_acq:
event_times = events[event]
df[event] = df.apply(
lambda x: event_times[
(event_times >= x["goCue_start_time_in_session"])
& (event_times < x["next_goCue_start_time_in_session"])
],
axis=1,
)

# Compute time of reward for each trial
df_ses_trials["reward_time"] = df_ses_trials.apply(
lambda x: np.nanmin(
np.concatenate(
[
[np.nan],
x["right_reward_delivery_time"],
x["left_reward_delivery_time"],
]
)
),
axis=1,
)
df_ses_trials["reward_time_absolute"] = (
df_ses_trials["reward_time"] + df_ses_trials["goCue_start_time_absolute"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
df["reward_time_in_session"] = df.apply(
lambda x: np.nanmin(
np.concatenate(
[
[np.nan],
x["right_reward_delivery_time"],
x["left_reward_delivery_time"],
]
)
),
axis=1,
)
df["reward_time_in_trial"] = (
df["reward_time_in_session"] - df[SESSION_ALIGNMENT + "_in_session"]
)

# Compute time of choice for each trials
df_ses_trials["choice_time"] = df_ses_trials.apply(
lambda x: np.nanmin(np.concatenate([[np.nan], x["right_lick_time"], x["left_lick_time"]])),
axis=1,
)
df_ses_trials["choice_time_absolute"] = (
df_ses_trials["choice_time"] + df_ses_trials["goCue_start_time_absolute"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
df["choice_time_in_session"] = df.apply(
lambda x: np.nanmin(
np.concatenate([[np.nan], x["right_lick_time"], x["left_lick_time"]])
),
axis=1,
)
df["choice_time_in_trial"] = (
df["choice_time_in_session"] - df[SESSION_ALIGNMENT + "_in_session"]
)

# Filtering out choices greater than response window
slow_choice = df["choice_time_in_trial"] > df["response_duration"]
df.loc[slow_choice, "choice_time_in_session"] = np.nan
df.loc[slow_choice, "choice_time_in_trial"] = np.nan

# Compute boolean of whether animal was rewarded
df_ses_trials["reward"] = df_ses_trials.rewarded_historyR.astype(
int
) | df_ses_trials.rewarded_historyL.astype(int)
# AutoWater and manual water is not included in earned_reward
df["earned_reward"] = df.rewarded_historyR | df.rewarded_historyL
df["extra_reward"] = (~df["earned_reward"]) & df["reward_time_in_session"].notnull()
alexpiet marked this conversation as resolved.
Show resolved Hide resolved

# Sanity checks
rewarded_df = df.query("earned_reward")
assert (
np.isnan(rewarded_df["reward_time_in_session"]).sum() == 0
), "Rewarded trials without reward time"
assert (
np.isnan(rewarded_df["choice_time_in_session"]).sum() == 0
), "Rewarded trials without choice time"
# assert np.all(
# rewarded_df["choice_time_in_session"] <= rewarded_df["reward_time_in_session"]
# ), "Reward before choice time"
if not np.all(rewarded_df["choice_time_in_session"] <= rewarded_df["reward_time_in_session"]):
warnings.warn("Reward before choice time. This is likely due to manual rewards")
# TODO, auto water can be delievered before choice time
assert np.all(
rewarded_df["choice_time_in_trial"] >= 0
), "Rewarded trial with negative choice_time_in_trial"
assert np.all(
np.isnan(df.query("not earned_reward").query("not extra_reward")["reward_time_in_session"])
), "Unrewarded trials with reward time"
alexpiet marked this conversation as resolved.
Show resolved Hide resolved

# Drop columns
df_ses_trials = df_ses_trials.drop(
columns=[
"left_lick_time",
"right_lick_time",
"left_reward_delivery_time",
"right_reward_delivery_time",
]
)
return df_ses_trials
drop_cols += key_from_acq
df = df.drop(columns=drop_cols)

if adjust_time:
print("Timestamps are adjusted so t(0) = first go cue")
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
return df


def create_events_df(nwb_filename, adjust_time=True):
Expand Down Expand Up @@ -461,31 +510,32 @@ def create_events_df(nwb_filename, adjust_time=True):
)
event_types -= ignore_types

# Determine time 0
t0 = nwb.trials.start_time[0]
# Determine time 0 as first go Cue
if adjust_time:
t0 = nwb.trials[SESSION_ALIGNMENT][0]
else:
t0 = 0

# Iterate over event types and build a dataframe of each
events = []
for e in event_types:
# For each event, get timestamps, data, and label
stamps = nwb.acquisition[e].timestamps[:]
raw_stamps = nwb.acquisition[e].timestamps[:]
data = nwb.acquisition[e].data[:]
labels = [e] * len(data)
if adjust_time:
stamps = stamps - t0
df = pd.DataFrame({"timestamps": stamps, "data": data, "event": labels})
stamps = raw_stamps - t0
df = pd.DataFrame(
{"timestamps": stamps, "data": data, "event": labels, "raw_timestamps": raw_stamps}
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
)
events.append(df)

# Add keys from trials table
# I don't like hardcoding dynamic foraging specific things here.
# I think these keys should be added to the stimulus field of the nwb
trial_events = ["goCue_start_time"]
for e in trial_events:
stamps = nwb.trials[:][e].values
labels = [e] * len(stamps)
if adjust_time:
stamps = stamps - t0
df = pd.DataFrame({"timestamps": stamps, "event": labels})
raw_stamps = nwb.trials[:][e].values
labels = [e] * len(raw_stamps)
stamps = raw_stamps - t0
df = pd.DataFrame({"timestamps": stamps, "event": labels, "raw_timestamps": raw_stamps})
events.append(df)

# Build dataframe by concatenating each event
Expand All @@ -494,8 +544,8 @@ def create_events_df(nwb_filename, adjust_time=True):
df = df.dropna(subset="timestamps").reset_index(drop=True)

# Add trial index for each event
trial_starts = nwb.trials.start_time[:] - nwb.trials.start_time[0]
last_stop = nwb.trials.stop_time[-1] - nwb.trials.start_time[0]
trial_starts = nwb.trials.start_time[:] - t0
last_stop = nwb.trials.stop_time[-1] - t0
trial_index = []
for index, e in df.iterrows():
starts = np.where(e.timestamps > trial_starts)[0]
Expand All @@ -507,6 +557,14 @@ def create_events_df(nwb_filename, adjust_time=True):
trial_index.append(starts[-1])
df["trial"] = trial_index

# Sanity check that the first go cue is time 0
gocues = df.query("event == @SESSION_ALIGNMENT")
if (len(gocues) > 0) and (adjust_time):
assert np.isclose(gocues.iloc[0]["timestamps"], 0, rtol=0.01)
# TODO, need more checks here for time alignment on trial index.

if adjust_time:
print("Timestamps are adjusted so t(0) = first go cue")
alexpiet marked this conversation as resolved.
Show resolved Hide resolved
return df


Expand Down Expand Up @@ -559,19 +617,23 @@ def create_fib_df(nwb_filename, tidy=True, adjust_time=True):
if len(event_types) == 0:
return None

# Determine time 0
t0 = nwb.trials.start_time[0]
# Determine time 0 as first go Cue
if adjust_time:
t0 = nwb.trials[SESSION_ALIGNMENT][0]
else:
t0 = 0

# Iterate over event types and build a dataframe of each
events = []
for e in event_types:
# For each event, get timestamps, data, and label
stamps = nwb.acquisition[e].timestamps[:]
raw_stamps = nwb.acquisition[e].timestamps[:]
data = nwb.acquisition[e].data[:]
labels = [e] * len(data)
if adjust_time:
stamps = stamps - t0
df = pd.DataFrame({"timestamps": stamps, "data": data, "event": labels})
stamps = raw_stamps - t0
df = pd.DataFrame(
{"timestamps": stamps, "data": data, "event": labels, "raw_timestamps": raw_stamps}
)
events.append(df)

# Build dataframe by concatenating each event
Expand All @@ -591,6 +653,9 @@ def create_fib_df(nwb_filename, tidy=True, adjust_time=True):
ses_idx = subject_id + "_" + session_date
df["ses_idx"] = ses_idx

if adjust_time:
print("Timestamps are adjusted so t(0) = first go cue")
alexpiet marked this conversation as resolved.
Show resolved Hide resolved

# pivot table based on timestamps
if not tidy:
df_pivoted = pd.pivot(df, index="timestamps", columns=["event"], values="data")
Expand Down
Loading