From a15b9cee4f455a995fd603c75db05122087b6ded Mon Sep 17 00:00:00 2001 From: bendichter Date: Wed, 14 Dec 2022 13:25:37 -0600 Subject: [PATCH] rename columns and use start times from next trial --- .../text/timeintervalsinterface.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/neuroconv/datainterfaces/text/timeintervalsinterface.py b/src/neuroconv/datainterfaces/text/timeintervalsinterface.py index ef12bef08..fee55aabb 100644 --- a/src/neuroconv/datainterfaces/text/timeintervalsinterface.py +++ b/src/neuroconv/datainterfaces/text/timeintervalsinterface.py @@ -1,22 +1,32 @@ from abc import abstractmethod +from typing import Dict, Optional + import numpy as np +import pandas as pd from pynwb import NWBFile from pynwb.epoch import TimeIntervals from ...basedatainterface import BaseDataInterface -from ...utils.types import FilePathType, Optional +from ...utils.types import FilePathType from ...tools.nwb_helpers import make_or_load_nwbfile -def convert_df_to_time_intervals(df, name, description=None): +def convert_df_to_time_intervals( + df: pd.DataFrame, + name: str, + description: Optional[str] = None, + column_name_mapping: Dict[str, str] = None, +): + if column_name_mapping is not None: + df.rename(columns=column_name_mapping, inplace=True) if description is None: description = name time_intervals = TimeIntervals(name, description) if "start_time" not in df: raise ValueError(f"df must contain a column named 'start_time'. Existing columns: {df.columns.to_list()}") if "stop_time" not in df: - df["stop_time"] = np.nan + df["stop_time"] = np.r_[df["start_time"][1:].to_numpy(), np.nan] for col in df: if col not in ("start_time", "stop_time"): time_intervals.add_column(col, col)