diff --git a/src/aind_dynamic_foraging_data_utils/nwb_utils.py b/src/aind_dynamic_foraging_data_utils/nwb_utils.py index fec54c6..d1f9c6a 100644 --- a/src/aind_dynamic_foraging_data_utils/nwb_utils.py +++ b/src/aind_dynamic_foraging_data_utils/nwb_utils.py @@ -12,12 +12,16 @@ import os import re +import warnings import numpy as np import pandas as pd from pynwb import NWBHDF5IO from hdmf_zarr import NWBZarrIO +# If we adjust time_in_session, adjust it to this +SESSION_ALIGNMENT = "goCue_start_time" + def load_nwb_from_filename(filename): """ @@ -46,6 +50,7 @@ def unpack_metadata(nwb): Unpacks metadata as a dictionary attribute, instead of a Dynamic table nested inside a dictionary """ + # TODO, this should be outdated once we fix the NWB files themselves nwb.metadata = nwb.scratch["metadata"].to_dataframe().iloc[0].to_dict() @@ -297,20 +302,25 @@ def create_single_df_session(nwb_filename): return df_session -def create_df_trials(nwb_filename): +def create_df_trials(nwb_filename, adjust_time=True): """ Process nwb and create df_trials for every single session + + ARGS: + nwb_filename (str or NWB object), the session to extract the trials from + adjust_time (bool) if true, adjust t0 to be the first gocue + + RETURNS: + A pandas dataframe containing the columns of nwb.trials plus: + "_in_trial" time alignments where time is relative to the go cue on that trial + "_in_session" time alignments where time is relative to the first go cue + of the session. + earned_reward, (bool) whether a reward was earned in that trial + extra_reward (bool) whether a manual reward was given in that trial """ - nwb = load_nwb_from_filename(nwb_filename) - key_from_acq = [ - "left_lick_time", - "right_lick_time", - "left_reward_delivery_time", - "right_reward_delivery_time", - "FIP_falling_time", - "FIP_rising_time", - ] + # If we are given a filename, load the NWB object itself + nwb = load_nwb_from_filename(nwb_filename) # Parse subject and session_date if nwb.session_id.startswith("behavior") or nwb.session_id.startswith("FIP"): @@ -321,101 +331,147 @@ def create_df_trials(nwb_filename): splits = nwb.session_id.split("_") subject_id = splits[0] session_date = splits[1] - ses_idx = subject_id + "_" + session_date - df_ses_trials = nwb.trials.to_dataframe().reset_index() - df_ses_trials = df_ses_trials.rename(columns={"id": "trial"}) - df_ses_trials["ses_idx"] = ses_idx - - # Adjust all times relative to start of the first trial - t0 = df_ses_trials.start_time[0] - skip_cols = ["right_valve_open_time", "left_valve_open_time"] - for col in df_ses_trials.columns: - if ("time" in col) and (col not in skip_cols): - df_ses_trials[col + "_absolute"] = df_ses_trials[col] - t0 + # Build dataframe + df = nwb.trials.to_dataframe().reset_index() + df = df.rename(columns={"id": "trial"}) + df["ses_idx"] = ses_idx # Adjust for gaps in trial start/stop, and use the last stop time - last_stop = df_ses_trials.iloc[-1]["stop_time_absolute"] - df_ses_trials["stop_time_absolute"] = df_ses_trials["start_time_absolute"].shift( - -1, fill_value=last_stop - ) + last_stop = df.iloc[-1]["stop_time"] + df["stop_time"] = df["start_time"].shift(-1, fill_value=last_stop) + + # We skip these columns because they are how long the valve is open + # not the times at which the valves were opened + skip_cols = ["right_valve_open_time", "left_valve_open_time"] - # Adjust times relative to go cue - for col in df_ses_trials.columns: - if ( - ("time" in col) - and ("time_absolute" not in col) - and (col != "goCue_start_time") - and (col not in skip_cols) - ): - df_ses_trials.loc[:, col] = ( - df_ses_trials[col].values - df_ses_trials["goCue_start_time"].values - ) - df_ses_trials["goCue_start_time"] = 0.0 - - # Adjust event times relative to trial - events_ses = {key: nwb.acquisition[key].timestamps[:] - t0 for key in key_from_acq} - for event in [ + # compute times relative to start of trial and start of session + t0 = nwb.trials[SESSION_ALIGNMENT][0] + drop_cols = [] + for col in df.columns: + if ("time" in col) and (col not in skip_cols): + # Adjust all times relative to start of the first go cue + if adjust_time: + df[col + "_in_session"] = df[col] - t0 + else: + df[col + "_in_session"] = df[col] + + # Adjust times relative to go cue on each trial + if ("time" in col) and (col not in skip_cols): + # Here we always align to goCue_start_time, not SESSION_ALIGNMENT + # since this aligns events relative to the trial go cue, not the start + # of the session + df[col + "_in_trial"] = df[col].values - df["goCue_start_time"].values + + # Clean up these column names that are not clear + drop_cols.append(col) + + # Add a column of raw time so users can map if they want + df[SESSION_ALIGNMENT + "_raw"] = df[SESSION_ALIGNMENT] + + # Get lick and reward times + key_from_acq = [ "left_lick_time", "right_lick_time", "left_reward_delivery_time", "right_reward_delivery_time", - ]: - event_times = events_ses[event] - df_ses_trials[event] = df_ses_trials.apply( - lambda x: np.round( - event_times[ - (event_times > (x["goCue_start_time"] + x["goCue_start_time_absolute"])) - & (event_times < (x["stop_time"] + x["goCue_start_time_absolute"])) - ] - - x["goCue_start_time_absolute"], - 4, - ), + ] + if adjust_time: + events = {key: nwb.acquisition[key].timestamps[:] - t0 for key in key_from_acq} + else: + events = {key: nwb.acquisition[key].timestamps[:] for key in key_from_acq} + + # Map events to trials + # Here we map an event to the most recent goCue + df["next_goCue_start_time_in_session"] = df["goCue_start_time_in_session"].shift( + -1, fill_value=np.inf + ) + drop_cols.append("next_goCue_start_time_in_session") + for event in key_from_acq: + event_times = events[event] + df[event] = df.apply( + lambda x: event_times[ + (event_times >= x["goCue_start_time_in_session"]) + & (event_times < x["next_goCue_start_time_in_session"]) + ], axis=1, ) # Compute time of reward for each trial - df_ses_trials["reward_time"] = df_ses_trials.apply( - lambda x: np.nanmin( - np.concatenate( - [ - [np.nan], - x["right_reward_delivery_time"], - x["left_reward_delivery_time"], - ] - ) - ), - axis=1, - ) - df_ses_trials["reward_time_absolute"] = ( - df_ses_trials["reward_time"] + df_ses_trials["goCue_start_time_absolute"] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="All-NaN slice encountered") + df["reward_time_in_session"] = df.apply( + lambda x: np.nanmin( + np.concatenate( + [ + [np.nan], + x["right_reward_delivery_time"], + x["left_reward_delivery_time"], + ] + ) + ), + axis=1, + ) + df["reward_time_in_trial"] = ( + df["reward_time_in_session"] - df[SESSION_ALIGNMENT + "_in_session"] ) # Compute time of choice for each trials - df_ses_trials["choice_time"] = df_ses_trials.apply( - lambda x: np.nanmin(np.concatenate([[np.nan], x["right_lick_time"], x["left_lick_time"]])), - axis=1, - ) - df_ses_trials["choice_time_absolute"] = ( - df_ses_trials["choice_time"] + df_ses_trials["goCue_start_time_absolute"] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="All-NaN slice encountered") + df["choice_time_in_session"] = df.apply( + lambda x: np.nanmin( + np.concatenate([[np.nan], x["right_lick_time"], x["left_lick_time"]]) + ), + axis=1, + ) + df["choice_time_in_trial"] = ( + df["choice_time_in_session"] - df[SESSION_ALIGNMENT + "_in_session"] ) + # Filtering out choices greater than response window + slow_choice = df["choice_time_in_trial"] > df["response_duration"] + df.loc[slow_choice, "choice_time_in_session"] = np.nan + df.loc[slow_choice, "choice_time_in_trial"] = np.nan + # Compute boolean of whether animal was rewarded - df_ses_trials["reward"] = df_ses_trials.rewarded_historyR.astype( - int - ) | df_ses_trials.rewarded_historyL.astype(int) + # AutoWater and manual water is not included in earned_reward + df["earned_reward"] = df.rewarded_historyR | df.rewarded_historyL + # TODO update this section once we have reliable labels for manual rewards + # See issue #54 + df["extra_reward"] = (~df["earned_reward"]) & df["reward_time_in_session"].notnull() + + # Sanity checks + rewarded_df = df.query("earned_reward") + assert ( + np.isnan(rewarded_df["reward_time_in_session"]).sum() == 0 + ), "Rewarded trials without reward time" + assert ( + np.isnan(rewarded_df["choice_time_in_session"]).sum() == 0 + ), "Rewarded trials without choice time" + # assert np.all( + # rewarded_df["choice_time_in_session"] <= rewarded_df["reward_time_in_session"] + # ), "Reward before choice time" + if not np.all(rewarded_df["choice_time_in_session"] <= rewarded_df["reward_time_in_session"]): + warnings.warn("Reward before choice time. This is likely due to manual rewards") + # TODO, auto water can be delievered before choice time + assert np.all( + rewarded_df["choice_time_in_trial"] >= 0 + ), "Rewarded trial with negative choice_time_in_trial" + assert np.all( + np.isnan(df.query("not earned_reward").query("not extra_reward")["reward_time_in_session"]) + ), "Unrewarded trials with reward time" # Drop columns - df_ses_trials = df_ses_trials.drop( - columns=[ - "left_lick_time", - "right_lick_time", - "left_reward_delivery_time", - "right_reward_delivery_time", - ] - ) - return df_ses_trials + drop_cols += key_from_acq + df = df.drop(columns=drop_cols) + + if adjust_time: + print( + "Timestamps are adjusted such that `_in_session` timestamps start at the first go cue" + ) + return df def create_events_df(nwb_filename, adjust_time=True): @@ -461,31 +517,35 @@ def create_events_df(nwb_filename, adjust_time=True): ) event_types -= ignore_types - # Determine time 0 - t0 = nwb.trials.start_time[0] + # Determine time 0 as first go Cue + if adjust_time: + t0 = nwb.trials[SESSION_ALIGNMENT][0] + else: + t0 = 0 # Iterate over event types and build a dataframe of each events = [] for e in event_types: # For each event, get timestamps, data, and label - stamps = nwb.acquisition[e].timestamps[:] + raw_stamps = nwb.acquisition[e].timestamps[:] data = nwb.acquisition[e].data[:] labels = [e] * len(data) - if adjust_time: - stamps = stamps - t0 - df = pd.DataFrame({"timestamps": stamps, "data": data, "event": labels}) + stamps = raw_stamps - t0 + df = pd.DataFrame( + {"timestamps": stamps, "data": data, "event": labels, "raw_timestamps": raw_stamps} + ) events.append(df) # Add keys from trials table - # I don't like hardcoding dynamic foraging specific things here. - # I think these keys should be added to the stimulus field of the nwb trial_events = ["goCue_start_time"] for e in trial_events: - stamps = nwb.trials[:][e].values - labels = [e] * len(stamps) - if adjust_time: - stamps = stamps - t0 - df = pd.DataFrame({"timestamps": stamps, "event": labels}) + raw_stamps = nwb.trials[:][e].values + labels = [e] * len(raw_stamps) + data = [1] * len(raw_stamps) + stamps = raw_stamps - t0 + df = pd.DataFrame( + {"timestamps": stamps, "data": data, "event": labels, "raw_timestamps": raw_stamps} + ) events.append(df) # Build dataframe by concatenating each event @@ -494,8 +554,8 @@ def create_events_df(nwb_filename, adjust_time=True): df = df.dropna(subset="timestamps").reset_index(drop=True) # Add trial index for each event - trial_starts = nwb.trials.start_time[:] - nwb.trials.start_time[0] - last_stop = nwb.trials.stop_time[-1] - nwb.trials.start_time[0] + trial_starts = nwb.trials.start_time[:] - t0 + last_stop = nwb.trials.stop_time[-1] - t0 trial_index = [] for index, e in df.iterrows(): starts = np.where(e.timestamps > trial_starts)[0] @@ -507,6 +567,16 @@ def create_events_df(nwb_filename, adjust_time=True): trial_index.append(starts[-1]) df["trial"] = trial_index + # Sanity check that the first go cue is time 0 + gocues = df.query("event == @SESSION_ALIGNMENT") + if (len(gocues) > 0) and (adjust_time): + assert np.isclose(gocues.iloc[0]["timestamps"], 0, rtol=0.01) + # TODO, need more checks here for time alignment on trial index. + + if adjust_time: + print( + "Timestamps are adjusted such that `_in_session` timestamps start at the first go cue" + ) return df @@ -559,19 +629,23 @@ def create_fib_df(nwb_filename, tidy=True, adjust_time=True): if len(event_types) == 0: return None - # Determine time 0 - t0 = nwb.trials.start_time[0] + # Determine time 0 as first go Cue + if adjust_time: + t0 = nwb.trials[SESSION_ALIGNMENT][0] + else: + t0 = 0 # Iterate over event types and build a dataframe of each events = [] for e in event_types: # For each event, get timestamps, data, and label - stamps = nwb.acquisition[e].timestamps[:] + raw_stamps = nwb.acquisition[e].timestamps[:] data = nwb.acquisition[e].data[:] labels = [e] * len(data) - if adjust_time: - stamps = stamps - t0 - df = pd.DataFrame({"timestamps": stamps, "data": data, "event": labels}) + stamps = raw_stamps - t0 + df = pd.DataFrame( + {"timestamps": stamps, "data": data, "event": labels, "raw_timestamps": raw_stamps} + ) events.append(df) # Build dataframe by concatenating each event @@ -591,6 +665,11 @@ def create_fib_df(nwb_filename, tidy=True, adjust_time=True): ses_idx = subject_id + "_" + session_date df["ses_idx"] = ses_idx + if adjust_time: + print( + "Timestamps are adjusted such that `_in_session` timestamps start at the first go cue" + ) + # pivot table based on timestamps if not tidy: df_pivoted = pd.pivot(df, index="timestamps", columns=["event"], values="data")