Skip to content

Commit

Permalink
Merge pull request #96 from catalystneuro/align_keypoints
Browse files Browse the repository at this point in the history
align keypoints with raw FP
  • Loading branch information
CodyCBakerPhD authored Apr 9, 2024
2 parents 46ac51e + cdfb50a commit 4e64a12
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,22 @@ def extract_reinforcement_photometry_metadata(data_path: str, example_uuids: str
return session_metadata, subject_metadata


def extract_keypoint_metadata(data_path: str):
def extract_keypoint_metadata():
keypoint_subjects = ["dls-dlight-9", "dls-dlight-10", "dls-dlight-11", "dls-dlight-12", "dls-dlight-13"]
keypoint_start_times = [
"2022-07-14T11:24:31-05:00",
"2022-07-13T11:49:49-05:00",
"2022-07-13T12:21:37-05:00",
"2022-07-13T17:03:55-05:00",
"2022-07-13T16:28:19-05:00",
]
session_metadata, subject_metadata = {}, {}
for subject in keypoint_subjects:
for subject, session_start_time in zip(keypoint_subjects, keypoint_start_times):
session_metadata[subject] = dict(
keypoint=True,
photometry=True,
session_description="keypoint session",
session_start_time="1901-01-01T00:00:00-05:00", # TODO: replace with real session start time
session_start_time=session_start_time,
reference_max=np.NaN,
signal_max=np.NaN,
signal_reference_corr=np.NaN,
Expand Down Expand Up @@ -649,12 +656,11 @@ def get_session_name(session_df):
(
reinforcement_photometry_session_metadata,
reinforcement_photometry_subject_metadata,
) = extract_reinforcement_photometry_metadata(data_path)
# velocity_session_metadata, velocity_subject_metadata = extract_velocity_modulation_metadata(
# data_path,
# )
# keypoint_session_metadata, keypoint_subject_metadata = extract_keypoint_metadata(data_path)

) = extract_reinforcement_photometry_metadata(data_path, example_uuids=reinforcement_photometry_examples)
velocity_session_metadata, velocity_subject_metadata = extract_velocity_modulation_metadata(
data_path,
)
keypoint_session_metadata, keypoint_subject_metadata = extract_keypoint_metadata()
path2metadata = {
# photometry_session_metadata_path: photometry_session_metadata,
# photometry_subject_metadata_path: photometry_subject_metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict) -> None:
raw_photometry = RoiResponseSeries(
name="RawPhotometry",
description="The raw acquisition with mixed signal from both the blue light excitation (470nm) and UV excitation (405nm).",
comments=("Note: Raw photometry data is not temporally aligned for keypoint sessions."),
data=H5DataIO(raw_photometry[ascending_timestamps_indices], compression=True),
unit="F",
timestamps=commanded_signal_series.timestamps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def session_to_nwb(
file_path=str(processed_path),
tdt_path=str(tdt_path),
tdt_metadata_path=str(tdt_metadata_path),
depth_timestamp_path="",
session_metadata_path=str(session_metadata_path),
subject_metadata_path=str(subject_metadata_path),
session_uuid=subject_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# NWB Ecosystem
from pynwb.file import NWBFile
from pynwb.ophys import RoiResponseSeries
from ..markowitz_gillis_nature_2023.rawfiberphotometryinterface import RawFiberPhotometryInterface
from ..markowitz_gillis_nature_2023.rawfiberphotometryinterface import RawFiberPhotometryInterface, load_tdt_data
from neuroconv.tools import nwb_helpers
from hdmf.backends.hdf5.h5_utils import H5DataIO

Expand All @@ -18,26 +18,47 @@ def __init__(
file_path: str,
tdt_path: str,
tdt_metadata_path: str,
depth_timestamp_path: str,
session_uuid: str,
session_id: str,
session_metadata_path: str,
subject_metadata_path: str,
alignment_path: str = None,
):
super().__init__(
file_path=file_path,
tdt_path=tdt_path,
tdt_metadata_path=tdt_metadata_path,
depth_timestamp_path=depth_timestamp_path,
session_uuid=session_uuid,
session_id=session_id,
session_metadata_path=session_metadata_path,
subject_metadata_path=subject_metadata_path,
alignment_path=alignment_path,
)

def get_original_timestamps(self, metadata) -> np.ndarray:
processed_photometry = joblib.load(self.source_data["file_path"])
timestamps = np.arange(processed_photometry["dlight"].shape[0]) / metadata["Constants"]["VIDEO_SAMPLING_RATE"]
return timestamps

def align_processed_timestamps(
self, metadata: dict
) -> np.ndarray: # TODO: align timestamps if we get alignment_df.parquet
timestamps = self.get_original_timestamps(metadata=metadata)
self.set_aligned_timestamps(aligned_timestamps=timestamps)
return self.aligned_timestamps

def align_raw_timestamps(self, metadata: dict) -> np.ndarray: # TODO: remove if we get alignment_df.parquet
photometry_dict = load_tdt_data(self.source_data["tdt_path"], fs=metadata["FiberPhotometry"]["raw_rate"])
timestamps = photometry_dict["tstep"]
self.set_aligned_timestamps(aligned_timestamps=timestamps)
return self.aligned_timestamps

def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict) -> None:
SAMPLING_RATE = 30
super().add_to_nwbfile(nwbfile, metadata)
processed_photometry = joblib.load(self.source_data["file_path"])
timestamps = np.arange(processed_photometry["dlight"].shape[0]) / SAMPLING_RATE
timestamps = self.align_processed_timestamps(metadata)
signal_series = RoiResponseSeries(
name="SignalF",
description=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@ def __init__(
session_id: str,
session_metadata_path: str,
subject_metadata_path: str,
alignment_path: str = None,
):
super().__init__(
data_path=Path(data_path),
session_uuid=session_uuid,
session_id=session_id,
session_metadata_path=session_metadata_path,
subject_metadata_path=subject_metadata_path,
alignment_path=alignment_path,
)

def get_original_timestamps(self, metadata) -> np.ndarray:
raise NotImplementedError # TODO: align timestamps if we get alignment_df.parquet

def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict) -> None:
SAMPLING_RATE = metadata["Constants"]["VIDEO_SAMPLING_RATE"]
matched_timestamp_path = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
session_metadata_path: str,
subject_metadata_path: str,
summary_image_path: str,
alignment_path: str = None,
):
super().__init__(
file_path=file_path,
Expand All @@ -29,6 +30,7 @@ def __init__(
session_id=session_id,
session_metadata_path=session_metadata_path,
subject_metadata_path=subject_metadata_path,
alignment_path=alignment_path,
)

def get_metadata_schema(self) -> dict:
Expand All @@ -41,14 +43,17 @@ def get_metadata_schema(self) -> dict:
}
return metadata_schema

def get_original_timestamps(self, metadata) -> np.ndarray:
raise NotImplementedError # TODO: align timestamps if we get alignment_df.parquet

def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict) -> None:
SAMPLING_RATE = metadata["Constants"]["VIDEO_SAMPLING_RATE"]
keypoint_dict = joblib.load(self.source_data["file_path"])
raw_keypoints = keypoint_dict["positions_median"]
timestamps = H5DataIO(np.arange(raw_keypoints.shape[0]) / SAMPLING_RATE, compression=True)

index_to_name = metadata["Keypoint"]["index_to_name"]
camera_names = ["bottom", "side1", "side2", "side3", "side4", "top"]
camera_names = ["bottom", "side1", "side2", "side3", "side4", "top"] # as confirmed by email with authors
keypoints = []
for camera in camera_names:
nwbfile.create_device(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def reproduce_figS3(nwbfile_paths, config_path, metadata):
timestamps = (
nwbfile.processing["behavior"]
.data_interfaces["keypoints"]
.pose_estimation_series["rostral spine"]
.pose_estimation_series["rostral_spine"]
.timestamps[:]
)
positions_median = np.zeros((len(timestamps), 15, 3))
Expand Down

0 comments on commit 4e64a12

Please sign in to comment.