Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

align keypoints with raw FP #96

Merged
merged 8 commits into from
Apr 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,22 @@ def extract_reinforcement_photometry_metadata(data_path: str, example_uuids: str
return session_metadata, subject_metadata


def extract_keypoint_metadata(data_path: str):
def extract_keypoint_metadata():
keypoint_subjects = ["dls-dlight-9", "dls-dlight-10", "dls-dlight-11", "dls-dlight-12", "dls-dlight-13"]
keypoint_start_times = [
"2022-07-14T11:24:31-05:00",
"2022-07-13T11:49:49-05:00",
"2022-07-13T12:21:37-05:00",
"2022-07-13T17:03:55-05:00",
"2022-07-13T16:28:19-05:00",
]
session_metadata, subject_metadata = {}, {}
for subject in keypoint_subjects:
for subject, session_start_time in zip(keypoint_subjects, keypoint_start_times):
session_metadata[subject] = dict(
keypoint=True,
photometry=True,
session_description="keypoint session",
session_start_time="1901-01-01T00:00:00-05:00", # TODO: replace with real session start time
session_start_time=session_start_time,
reference_max=np.NaN,
signal_max=np.NaN,
signal_reference_corr=np.NaN,
Expand Down Expand Up @@ -649,12 +656,11 @@ def get_session_name(session_df):
(
reinforcement_photometry_session_metadata,
reinforcement_photometry_subject_metadata,
) = extract_reinforcement_photometry_metadata(data_path)
# velocity_session_metadata, velocity_subject_metadata = extract_velocity_modulation_metadata(
# data_path,
# )
# keypoint_session_metadata, keypoint_subject_metadata = extract_keypoint_metadata(data_path)

) = extract_reinforcement_photometry_metadata(data_path, example_uuids=reinforcement_photometry_examples)
velocity_session_metadata, velocity_subject_metadata = extract_velocity_modulation_metadata(
data_path,
)
keypoint_session_metadata, keypoint_subject_metadata = extract_keypoint_metadata()
path2metadata = {
# photometry_session_metadata_path: photometry_session_metadata,
# photometry_subject_metadata_path: photometry_subject_metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict) -> None:
raw_photometry = RoiResponseSeries(
name="RawPhotometry",
description="The raw acquisition with mixed signal from both the blue light excitation (470nm) and UV excitation (405nm).",
comments=("Note: Raw photometry data is not temporally aligned for keypoint sessions."),
data=H5DataIO(raw_photometry[ascending_timestamps_indices], compression=True),
unit="F",
timestamps=commanded_signal_series.timestamps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def session_to_nwb(
file_path=str(processed_path),
tdt_path=str(tdt_path),
tdt_metadata_path=str(tdt_metadata_path),
depth_timestamp_path="",
session_metadata_path=str(session_metadata_path),
subject_metadata_path=str(subject_metadata_path),
session_uuid=subject_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# NWB Ecosystem
from pynwb.file import NWBFile
from pynwb.ophys import RoiResponseSeries
from ..markowitz_gillis_nature_2023.rawfiberphotometryinterface import RawFiberPhotometryInterface
from ..markowitz_gillis_nature_2023.rawfiberphotometryinterface import RawFiberPhotometryInterface, load_tdt_data
from neuroconv.tools import nwb_helpers
from hdmf.backends.hdf5.h5_utils import H5DataIO

Expand All @@ -18,26 +18,47 @@ def __init__(
file_path: str,
tdt_path: str,
tdt_metadata_path: str,
depth_timestamp_path: str,
session_uuid: str,
session_id: str,
session_metadata_path: str,
subject_metadata_path: str,
alignment_path: str = None,
):
super().__init__(
file_path=file_path,
tdt_path=tdt_path,
tdt_metadata_path=tdt_metadata_path,
depth_timestamp_path=depth_timestamp_path,
session_uuid=session_uuid,
session_id=session_id,
session_metadata_path=session_metadata_path,
subject_metadata_path=subject_metadata_path,
alignment_path=alignment_path,
)

def get_original_timestamps(self, metadata) -> np.ndarray:
processed_photometry = joblib.load(self.source_data["file_path"])
timestamps = np.arange(processed_photometry["dlight"].shape[0]) / metadata["Constants"]["VIDEO_SAMPLING_RATE"]
return timestamps

def align_processed_timestamps(
self, metadata: dict
) -> np.ndarray: # TODO: align timestamps if we get alignment_df.parquet
timestamps = self.get_original_timestamps(metadata=metadata)
self.set_aligned_timestamps(aligned_timestamps=timestamps)
return self.aligned_timestamps

def align_raw_timestamps(self, metadata: dict) -> np.ndarray: # TODO: remove if we get alignment_df.parquet
photometry_dict = load_tdt_data(self.source_data["tdt_path"], fs=metadata["FiberPhotometry"]["raw_rate"])
timestamps = photometry_dict["tstep"]
self.set_aligned_timestamps(aligned_timestamps=timestamps)
return self.aligned_timestamps

def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict) -> None:
SAMPLING_RATE = 30
super().add_to_nwbfile(nwbfile, metadata)
processed_photometry = joblib.load(self.source_data["file_path"])
timestamps = np.arange(processed_photometry["dlight"].shape[0]) / SAMPLING_RATE
timestamps = self.align_processed_timestamps(metadata)
signal_series = RoiResponseSeries(
name="SignalF",
description=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@ def __init__(
session_id: str,
session_metadata_path: str,
subject_metadata_path: str,
alignment_path: str = None,
):
super().__init__(
data_path=Path(data_path),
session_uuid=session_uuid,
session_id=session_id,
session_metadata_path=session_metadata_path,
subject_metadata_path=subject_metadata_path,
alignment_path=alignment_path,
)

def get_original_timestamps(self, metadata) -> np.ndarray:
raise NotImplementedError # TODO: align timestamps if we get alignment_df.parquet

def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict) -> None:
SAMPLING_RATE = metadata["Constants"]["VIDEO_SAMPLING_RATE"]
matched_timestamp_path = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
session_metadata_path: str,
subject_metadata_path: str,
summary_image_path: str,
alignment_path: str = None,
):
super().__init__(
file_path=file_path,
Expand All @@ -29,6 +30,7 @@ def __init__(
session_id=session_id,
session_metadata_path=session_metadata_path,
subject_metadata_path=subject_metadata_path,
alignment_path=alignment_path,
)

def get_metadata_schema(self) -> dict:
Expand All @@ -41,14 +43,17 @@ def get_metadata_schema(self) -> dict:
}
return metadata_schema

def get_original_timestamps(self, metadata) -> np.ndarray:
raise NotImplementedError # TODO: align timestamps if we get alignment_df.parquet

def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict) -> None:
SAMPLING_RATE = metadata["Constants"]["VIDEO_SAMPLING_RATE"]
keypoint_dict = joblib.load(self.source_data["file_path"])
raw_keypoints = keypoint_dict["positions_median"]
timestamps = H5DataIO(np.arange(raw_keypoints.shape[0]) / SAMPLING_RATE, compression=True)

index_to_name = metadata["Keypoint"]["index_to_name"]
camera_names = ["bottom", "side1", "side2", "side3", "side4", "top"]
camera_names = ["bottom", "side1", "side2", "side3", "side4", "top"] # as confirmed by email with authors
keypoints = []
for camera in camera_names:
nwbfile.create_device(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def reproduce_figS3(nwbfile_paths, config_path, metadata):
timestamps = (
nwbfile.processing["behavior"]
.data_interfaces["keypoints"]
.pose_estimation_series["rostral spine"]
.pose_estimation_series["rostral_spine"]
.timestamps[:]
)
positions_median = np.zeros((len(timestamps), 15, 3))
Expand Down