diff --git a/src/aind_dynamic_foraging_data_utils/alignment.py b/src/aind_dynamic_foraging_data_utils/alignment.py index 46a4c56..98bbaf6 100644 --- a/src/aind_dynamic_foraging_data_utils/alignment.py +++ b/src/aind_dynamic_foraging_data_utils/alignment.py @@ -187,7 +187,7 @@ def index_of_nearest_value(data_timestamps, event_timestamps): return event_indices -def event_triggered_response( +def event_triggered_response( # noqa C901 data, t, y, @@ -200,6 +200,7 @@ def event_triggered_response( include_endpoint=True, output_format="tidy", interpolate=True, + censor=True, ): # NOQA E501 """ Slices a timeseries relative to a given set of event times @@ -272,6 +273,9 @@ def event_triggered_response( interpolate : Boolean if True (default), interpolates each response onto a common timebase if False, shifts each response to align indices to a common timebase + censor: Boolean + if True (default), censor observations that take place after the next event time + if False, do not censor Returns: -------- @@ -345,6 +349,8 @@ def event_triggered_response( # ensure that t_end is greater than t_start assert t_end > t_start, "must define t_end to be greater than t_start" + assert (not censor) or (output_format == "tidy"), "cannot censor data in wide output" + if output_sampling_rate is None: # if sampling rate is None, # set it to be the mean sampling rate of the input data @@ -452,4 +458,35 @@ def event_triggered_response( # drop the "variable" column, rename the "value" column tidy_etr = tidy_etr.drop(columns=["variable"]).rename(columns={"value": y}) # return the tidy event triggered responses - return tidy_etr + if censor: + tidy_etr = censor_event_triggered_response(tidy_etr, y, t_start, t_end, event_times) + return tidy_etr + + +def censor_event_triggered_response(etr, y, t_start, t_end, event_times): + """ + censors the event triggered response by the immediately preceeding or + subsequent event times if that event time is within the (t_start, t_end) + time window + + censored timepoints are replaced with NaN, so all data points are still present + """ + + # Compute when we should censor + diff = np.diff(event_times) + diff_backward = np.concatenate([[np.inf], diff]) + diff_forward = np.concatenate([diff, [np.inf]]) + backward_time = [-np.min([np.abs(t_end), x]) for x in diff_backward] + forward_time = [np.min([t_end, x]) for x in diff_forward] + + # double check we have all events + assert len(event_times) == len(etr["event_number"].unique()), "event times missing" + + # Censor trials + for index, time in enumerate(event_times): + vec = (etr["event_number"] == index) & (etr["time"] < backward_time[index]) + etr.loc[vec, y] = np.nan + vec = (etr["event_number"] == index) & (etr["time"] > forward_time[index]) + etr.loc[vec, y] = np.nan + + return etr diff --git a/tests/test_aind_dynamic_foraging_data_utils.py b/tests/test_aind_dynamic_foraging_data_utils.py index e2e3908..49590c2 100644 --- a/tests/test_aind_dynamic_foraging_data_utils.py +++ b/tests/test_aind_dynamic_foraging_data_utils.py @@ -139,6 +139,35 @@ def test_event_triggered_response(self): # Assert that the dataframe is unchanged pd.testing.assert_frame_equal(df, df_copy) + def test_event_triggered_response_censor(self): + """ + tests the censoring property of the event_triggered_response function + """ + # make a sample test set + t = np.arange(0, 10, 0.01) + y = np.array([10] * len(t)) + event_times = [2, 3, 4, 5, 6, 7, 8] + for e in event_times: + y[(t > e) & (t < (e + 0.25))] = 1 + + df = pd.DataFrame({"time": t, "y": y}) + + # make etr + etr_censored = alignment.event_triggered_response( + data=df, + t="time", + y="y", + event_times=event_times, + t_before=2, + t_after=2, + output_sampling_rate=100, + censor=True, + ) + + # assert properties of etr + assert np.isclose(etr_censored.query("time > 1")["y"].mean(), 10, rtol=0.01) + assert np.isclose(etr_censored.query("time < -1")["y"].mean(), 10, rtol=0.01) + if __name__ == "__main__": unittest.main()