From 811dadf90733399b0a2a892da2dda8261e1b03e5 Mon Sep 17 00:00:00 2001 From: Alex Piet Date: Thu, 29 Aug 2024 14:10:03 -0700 Subject: [PATCH 1/6] adding censoring to event triggered response --- .../alignment.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/aind_dynamic_foraging_data_utils/alignment.py b/src/aind_dynamic_foraging_data_utils/alignment.py index 46a4c56..5a12cfc 100644 --- a/src/aind_dynamic_foraging_data_utils/alignment.py +++ b/src/aind_dynamic_foraging_data_utils/alignment.py @@ -200,6 +200,7 @@ def event_triggered_response( include_endpoint=True, output_format="tidy", interpolate=True, + censor=True, ): # NOQA E501 """ Slices a timeseries relative to a given set of event times @@ -272,6 +273,9 @@ def event_triggered_response( interpolate : Boolean if True (default), interpolates each response onto a common timebase if False, shifts each response to align indices to a common timebase + censor: Boolean + if True (default), censor observations that take place after the next event time + if False, do not censor Returns: -------- @@ -345,6 +349,8 @@ def event_triggered_response( # ensure that t_end is greater than t_start assert t_end > t_start, "must define t_end to be greater than t_start" + assert (not censor) or (output_format == "tidy"), "cannot censor data in wide output" + if output_sampling_rate is None: # if sampling rate is None, # set it to be the mean sampling rate of the input data @@ -452,4 +458,35 @@ def event_triggered_response( # drop the "variable" column, rename the "value" column tidy_etr = tidy_etr.drop(columns=["variable"]).rename(columns={"value": y}) # return the tidy event triggered responses + if censor: + tidy_etr = censor_event_triggered_response(tidy_etr, t_start, t_end, event_times) return tidy_etr + + +def censor_event_triggered_response(etr, t_start, t_end, event_times): + """ + censors the event triggered response by the immediately preceeding or + subsequent event times if that event time is within the (t_start, t_end) + time window + + censored timepoints are replaced with NaN, so all data points are still present + """ + + # Compute when we should censor + diff = np.diff(event_times) + diff_backward = np.concatenate([[np.inf], diff]) + diff_forward = np.concatenate([diff, [np.inf]]) + backward_time = [-np.min([np.abs(t_end), x]) for x in diff_backward] + forward_time = [np.min([t_end, x]) for x in diff_forward] + + # double check we have all events + assert len(event_times) == len(etr["event_number"].unique()), "event times missing" + + # Censor trials + for index, time in enumerate(event_times): + vec = (etr["event_number"] == index) & (etr["time"] < backward_time[index]) + etr.loc[vec, "data"] = np.nan + vec = (etr["event_number"] == index) & (etr["time"] > forward_time[index]) + etr.loc[vec, "data"] = np.nan + + return etr From e99aca386f96549f65b4deb7b8a392af4a2880ab Mon Sep 17 00:00:00 2001 From: Alex Piet Date: Thu, 29 Aug 2024 14:15:08 -0700 Subject: [PATCH 2/6] linting --- src/aind_dynamic_foraging_data_utils/alignment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aind_dynamic_foraging_data_utils/alignment.py b/src/aind_dynamic_foraging_data_utils/alignment.py index 5a12cfc..e40226c 100644 --- a/src/aind_dynamic_foraging_data_utils/alignment.py +++ b/src/aind_dynamic_foraging_data_utils/alignment.py @@ -458,9 +458,9 @@ def event_triggered_response( # drop the "variable" column, rename the "value" column tidy_etr = tidy_etr.drop(columns=["variable"]).rename(columns={"value": y}) # return the tidy event triggered responses - if censor: - tidy_etr = censor_event_triggered_response(tidy_etr, t_start, t_end, event_times) - return tidy_etr + if censor: + tidy_etr = censor_event_triggered_response(tidy_etr, t_start, t_end, event_times) + return tidy_etr def censor_event_triggered_response(etr, t_start, t_end, event_times): From 2341aa7af9aa386e99e055956cdc30ba12516581 Mon Sep 17 00:00:00 2001 From: Alex Piet Date: Thu, 29 Aug 2024 14:17:10 -0700 Subject: [PATCH 3/6] linting --- src/aind_dynamic_foraging_data_utils/alignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aind_dynamic_foraging_data_utils/alignment.py b/src/aind_dynamic_foraging_data_utils/alignment.py index e40226c..abf705e 100644 --- a/src/aind_dynamic_foraging_data_utils/alignment.py +++ b/src/aind_dynamic_foraging_data_utils/alignment.py @@ -187,7 +187,7 @@ def index_of_nearest_value(data_timestamps, event_timestamps): return event_indices -def event_triggered_response( +def event_triggered_response( # noqa C901 data, t, y, From e86ce4da83e3b7d5955619a09cab494d70bf04a1 Mon Sep 17 00:00:00 2001 From: Alex Piet Date: Tue, 3 Sep 2024 20:25:15 -0700 Subject: [PATCH 4/6] adding unit test, and naming bug fix --- .../alignment.py | 8 ++--- .../test_aind_dynamic_foraging_data_utils.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/aind_dynamic_foraging_data_utils/alignment.py b/src/aind_dynamic_foraging_data_utils/alignment.py index abf705e..eef0c64 100644 --- a/src/aind_dynamic_foraging_data_utils/alignment.py +++ b/src/aind_dynamic_foraging_data_utils/alignment.py @@ -459,11 +459,11 @@ def event_triggered_response( # noqa C901 tidy_etr = tidy_etr.drop(columns=["variable"]).rename(columns={"value": y}) # return the tidy event triggered responses if censor: - tidy_etr = censor_event_triggered_response(tidy_etr, t_start, t_end, event_times) + tidy_etr = censor_event_triggered_response(tidy_etr, y,t_start, t_end, event_times) return tidy_etr -def censor_event_triggered_response(etr, t_start, t_end, event_times): +def censor_event_triggered_response(etr, y,t_start, t_end, event_times): """ censors the event triggered response by the immediately preceeding or subsequent event times if that event time is within the (t_start, t_end) @@ -485,8 +485,8 @@ def censor_event_triggered_response(etr, t_start, t_end, event_times): # Censor trials for index, time in enumerate(event_times): vec = (etr["event_number"] == index) & (etr["time"] < backward_time[index]) - etr.loc[vec, "data"] = np.nan + etr.loc[vec, y] = np.nan vec = (etr["event_number"] == index) & (etr["time"] > forward_time[index]) - etr.loc[vec, "data"] = np.nan + etr.loc[vec, y] = np.nan return etr diff --git a/tests/test_aind_dynamic_foraging_data_utils.py b/tests/test_aind_dynamic_foraging_data_utils.py index e2e3908..f048cf6 100644 --- a/tests/test_aind_dynamic_foraging_data_utils.py +++ b/tests/test_aind_dynamic_foraging_data_utils.py @@ -139,6 +139,35 @@ def test_event_triggered_response(self): # Assert that the dataframe is unchanged pd.testing.assert_frame_equal(df, df_copy) + def test_event_triggered_response_censor(self): + ''' + tests the censoring property of the event_triggered_response function + ''' + # make a sample test set + t = np.arange(0, 10, 0.01) + y = np.array([10]*len(t)) + event_times=[2,3,4,5,6,7,8] + for e in event_times: + y[(t>e)&(t<(e+.25))]=1 + + df = pd.DataFrame({"time": t, "y": y}) + + # make etr + etr_censored = alignment.event_triggered_response( + data=df, + t='time', + y='y', + event_times=event_times, + t_before=2, + t_after=2, + output_sampling_rate=100, + censor=True + ) + + # assert properties of etr + assert np.isclose(etr_censored.query('time > 1')['y'].mean(),10,rtol=0.01) + assert np.isclose(etr_censored.query('time < -1')['y'].mean(),10,rtol=0.01) + if __name__ == "__main__": unittest.main() From 77284d6677dbe267721ce95ff705d5bae65868c4 Mon Sep 17 00:00:00 2001 From: Alex Piet Date: Tue, 3 Sep 2024 20:26:42 -0700 Subject: [PATCH 5/6] linting --- .../test_aind_dynamic_foraging_data_utils.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_aind_dynamic_foraging_data_utils.py b/tests/test_aind_dynamic_foraging_data_utils.py index f048cf6..49590c2 100644 --- a/tests/test_aind_dynamic_foraging_data_utils.py +++ b/tests/test_aind_dynamic_foraging_data_utils.py @@ -140,33 +140,33 @@ def test_event_triggered_response(self): pd.testing.assert_frame_equal(df, df_copy) def test_event_triggered_response_censor(self): - ''' + """ tests the censoring property of the event_triggered_response function - ''' + """ # make a sample test set t = np.arange(0, 10, 0.01) - y = np.array([10]*len(t)) - event_times=[2,3,4,5,6,7,8] + y = np.array([10] * len(t)) + event_times = [2, 3, 4, 5, 6, 7, 8] for e in event_times: - y[(t>e)&(t<(e+.25))]=1 + y[(t > e) & (t < (e + 0.25))] = 1 df = pd.DataFrame({"time": t, "y": y}) # make etr etr_censored = alignment.event_triggered_response( data=df, - t='time', - y='y', + t="time", + y="y", event_times=event_times, t_before=2, t_after=2, output_sampling_rate=100, - censor=True + censor=True, ) # assert properties of etr - assert np.isclose(etr_censored.query('time > 1')['y'].mean(),10,rtol=0.01) - assert np.isclose(etr_censored.query('time < -1')['y'].mean(),10,rtol=0.01) + assert np.isclose(etr_censored.query("time > 1")["y"].mean(), 10, rtol=0.01) + assert np.isclose(etr_censored.query("time < -1")["y"].mean(), 10, rtol=0.01) if __name__ == "__main__": From 4a47c3e8bee47986636af356d0852e9f8cd2fe0c Mon Sep 17 00:00:00 2001 From: Alex Piet Date: Tue, 3 Sep 2024 20:27:52 -0700 Subject: [PATCH 6/6] linting --- src/aind_dynamic_foraging_data_utils/alignment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aind_dynamic_foraging_data_utils/alignment.py b/src/aind_dynamic_foraging_data_utils/alignment.py index eef0c64..98bbaf6 100644 --- a/src/aind_dynamic_foraging_data_utils/alignment.py +++ b/src/aind_dynamic_foraging_data_utils/alignment.py @@ -459,11 +459,11 @@ def event_triggered_response( # noqa C901 tidy_etr = tidy_etr.drop(columns=["variable"]).rename(columns={"value": y}) # return the tidy event triggered responses if censor: - tidy_etr = censor_event_triggered_response(tidy_etr, y,t_start, t_end, event_times) + tidy_etr = censor_event_triggered_response(tidy_etr, y, t_start, t_end, event_times) return tidy_etr -def censor_event_triggered_response(etr, y,t_start, t_end, event_times): +def censor_event_triggered_response(etr, y, t_start, t_end, event_times): """ censors the event triggered response by the immediately preceeding or subsequent event times if that event time is within the (t_start, t_end)