Skip to content

Commit

Permalink
Merge pull request #39 from AllenNeuralDynamics/add_censor
Browse files Browse the repository at this point in the history
adding censoring to event triggered response
  • Loading branch information
alexpiet authored Sep 4, 2024
2 parents 04dc76f + 4a47c3e commit 5116f3c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
41 changes: 39 additions & 2 deletions src/aind_dynamic_foraging_data_utils/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def index_of_nearest_value(data_timestamps, event_timestamps):
return event_indices


def event_triggered_response(
def event_triggered_response( # noqa C901
data,
t,
y,
Expand All @@ -200,6 +200,7 @@ def event_triggered_response(
include_endpoint=True,
output_format="tidy",
interpolate=True,
censor=True,
): # NOQA E501
"""
Slices a timeseries relative to a given set of event times
Expand Down Expand Up @@ -272,6 +273,9 @@ def event_triggered_response(
interpolate : Boolean
if True (default), interpolates each response onto a common timebase
if False, shifts each response to align indices to a common timebase
censor: Boolean
if True (default), censor observations that take place after the next event time
if False, do not censor
Returns:
--------
Expand Down Expand Up @@ -345,6 +349,8 @@ def event_triggered_response(
# ensure that t_end is greater than t_start
assert t_end > t_start, "must define t_end to be greater than t_start"

assert (not censor) or (output_format == "tidy"), "cannot censor data in wide output"

if output_sampling_rate is None:
# if sampling rate is None,
# set it to be the mean sampling rate of the input data
Expand Down Expand Up @@ -452,4 +458,35 @@ def event_triggered_response(
# drop the "variable" column, rename the "value" column
tidy_etr = tidy_etr.drop(columns=["variable"]).rename(columns={"value": y})
# return the tidy event triggered responses
return tidy_etr
if censor:
tidy_etr = censor_event_triggered_response(tidy_etr, y, t_start, t_end, event_times)
return tidy_etr


def censor_event_triggered_response(etr, y, t_start, t_end, event_times):
"""
censors the event triggered response by the immediately preceeding or
subsequent event times if that event time is within the (t_start, t_end)
time window
censored timepoints are replaced with NaN, so all data points are still present
"""

# Compute when we should censor
diff = np.diff(event_times)
diff_backward = np.concatenate([[np.inf], diff])
diff_forward = np.concatenate([diff, [np.inf]])
backward_time = [-np.min([np.abs(t_end), x]) for x in diff_backward]
forward_time = [np.min([t_end, x]) for x in diff_forward]

# double check we have all events
assert len(event_times) == len(etr["event_number"].unique()), "event times missing"

# Censor trials
for index, time in enumerate(event_times):
vec = (etr["event_number"] == index) & (etr["time"] < backward_time[index])
etr.loc[vec, y] = np.nan
vec = (etr["event_number"] == index) & (etr["time"] > forward_time[index])
etr.loc[vec, y] = np.nan

return etr
29 changes: 29 additions & 0 deletions tests/test_aind_dynamic_foraging_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,35 @@ def test_event_triggered_response(self):
# Assert that the dataframe is unchanged
pd.testing.assert_frame_equal(df, df_copy)

def test_event_triggered_response_censor(self):
"""
tests the censoring property of the event_triggered_response function
"""
# make a sample test set
t = np.arange(0, 10, 0.01)
y = np.array([10] * len(t))
event_times = [2, 3, 4, 5, 6, 7, 8]
for e in event_times:
y[(t > e) & (t < (e + 0.25))] = 1

df = pd.DataFrame({"time": t, "y": y})

# make etr
etr_censored = alignment.event_triggered_response(
data=df,
t="time",
y="y",
event_times=event_times,
t_before=2,
t_after=2,
output_sampling_rate=100,
censor=True,
)

# assert properties of etr
assert np.isclose(etr_censored.query("time > 1")["y"].mean(), 10, rtol=0.01)
assert np.isclose(etr_censored.query("time < -1")["y"].mean(), 10, rtol=0.01)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5116f3c

Please sign in to comment.