Skip to content

Commit

Permalink
Merge pull request #43 from AllenNeuralDynamics/nan
Browse files Browse the repository at this point in the history
Adds NaN policy to event_triggered_response
  • Loading branch information
alexpiet authored Sep 11, 2024
2 parents e85beb4 + adc7b57 commit 4c4fb26
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 6 deletions.
48 changes: 42 additions & 6 deletions src/aind_dynamic_foraging_data_utils/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def event_triggered_response( # noqa C901
interpolate=True,
censor=True,
censor_times=None,
nan_policy="error",
): # NOQA E501
"""
Slices a timeseries relative to a given set of event times
Expand Down Expand Up @@ -280,6 +281,10 @@ def event_triggered_response( # noqa C901
censor_times: list or array or None
if None, and censor is True, then use event_times as the censor times
if times are provided, then these are the times at which ETR is censored
nan_policy: How to handle NaNs in the input data
"error": raise an exception if NaNs are present in the time window of an ETR
"interpolate": interpolate over NaN values
"exclude": exclude any response with a NaN in the response window
Returns:
--------
Expand Down Expand Up @@ -355,6 +360,8 @@ def event_triggered_response( # noqa C901

assert (not censor) or (output_format == "tidy"), "cannot censor data in wide output"

assert nan_policy in ["error", "interpolate", "exclude"], "unrecognized nan_policy"

if censor:
event_times = np.sort(event_times)

Expand Down Expand Up @@ -395,13 +402,42 @@ def event_triggered_response( # noqa C901
}
)

# update our dictionary to have a new key defined as
# 'event_{EVENT NUMBER}_t={EVENT TIME}' and
# a value that includes an array that represents the
# sliced data around the current event, interpolated
# on the linearly spaced time array

elif np.any(np.isnan(data_slice)):
if nan_policy == "error":
# raise exception
raise Exception("NaN value in data slice, at event time {}".format(event_time))
elif nan_policy == "exclude":
# exclude this event
data_dict.update(
{
"event_{}_t={}".format(event_number, event_time): np.full(
len(t_array), np.nan
)
}
)
else:
# Interpolate over NaNs
x_data = data_slice[~data_slice.isnull()]
data_slice[:] = np.interp(
data_slice.index.values, x_data.index.values, x_data.values
)

# Add to data dict as normal
data_dict.update(
{
"event_{}_t={}".format(event_number, event_time): np.interp(
data_dict["time"],
data_slice.index - event_time,
data_slice.values,
)
}
)
else:
# update our dictionary to have a new key defined as
# 'event_{EVENT NUMBER}_t={EVENT TIME}' and
# a value that includes an array that represents the
# sliced data around the current event, interpolated
# on the linearly spaced time array
data_dict.update(
{
"event_{}_t={}".format(event_number, event_time): np.interp(
Expand Down
76 changes: 76 additions & 0 deletions tests/test_aind_dynamic_foraging_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,82 @@ def test_event_triggered_response_censor(self):
assert np.isclose(etr_censored.query("time > 1")["y"].mean(), 10, rtol=0.01)
assert np.isclose(etr_censored.query("time < 0")["y"].mean(), 10, rtol=0.01)

def test_event_triggered_response_nan_policy(self):
"""
tests the `test_event_triggered_response` function
"""
# make a time vector from -10 to 110
t = np.arange(-10, 110, 0.01)

# Make a dataframe with one column as time, and another
# column called 'sinusoid' defined as sin(2*pi*t)
# The sinusoid column will have a period of 1
df = pd.DataFrame({"time": t, "sinusoid": np.sin(2 * np.pi * t)})

# Make an event triggered response, NaN values are outside window
df.loc[0:100, "sinusoid"] = np.nan
etr = alignment.event_triggered_response(
data=df,
t="time",
y="sinusoid",
event_times=np.arange(100),
t_before=1,
t_after=1,
output_sampling_rate=100,
nan_policy="error",
)

# Raises an error
df.loc[3980:4050, "sinusoid"] = np.nan
with self.assertRaises(Exception):
alignment.event_triggered_response(
data=df,
t="time",
y="sinusoid",
event_times=np.arange(100),
t_before=1,
t_after=1,
output_sampling_rate=100,
nan_policy="error",
)

# outputs are NaNs around NaN data
etr = alignment.event_triggered_response(
data=df,
t="time",
y="sinusoid",
event_times=np.arange(100),
t_before=1,
t_after=1,
output_sampling_rate=100,
nan_policy="exclude",
)
assert np.isclose(
etr.query("event_number == 29")["sinusoid"].isnull().mean(), 1.0, rtol=0.01
)
assert np.isclose(
etr.query("event_number == 30")["sinusoid"].isnull().mean(), 1.0, rtol=0.01
)
assert np.isclose(
etr.query("event_number == 31")["sinusoid"].isnull().mean(), 1.0, rtol=0.01
)
assert np.isclose(
etr.query("event_number not in [29,30,31]")["sinusoid"].isnull().mean(), 0.0, rtol=0.01
)

# outputs are NaNs around NaN data
etr = alignment.event_triggered_response(
data=df,
t="time",
y="sinusoid",
event_times=np.arange(100),
t_before=1,
t_after=1,
output_sampling_rate=100,
nan_policy="interpolate",
)
assert np.isclose(etr["sinusoid"].isnull().mean(), 0.0, rtol=0.01)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4c4fb26

Please sign in to comment.