Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Duplicate sessions #92

Merged
merged 4 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,23 @@ def get_metadata_schema(self) -> dict:
}
return metadata_schema

def get_original_timestamps(self) -> np.ndarray:
def get_original_timestamps(self, metadata: dict, reinforcement: bool = False) -> np.ndarray:
if reinforcement:
filters = [
("uuid", "==", self.source_data["session_uuid"]),
("target_syllable", "==", metadata["Optogenetics"]["target_syllable"][0]),
]
else:
filters = [("uuid", "==", self.source_data["session_uuid"])]
session_df = pd.read_parquet(
self.source_data["file_path"],
columns=["timestamp", "uuid"],
filters=[("uuid", "==", self.source_data["session_uuid"])],
filters=filters,
)
return session_df["timestamp"].to_numpy()

def align_timestamps(self, metadata: dict) -> np.ndarray:
timestamps = self.get_original_timestamps()
def align_timestamps(self, metadata: dict, reinforcement: bool = False) -> np.ndarray:
timestamps = self.get_original_timestamps(metadata=metadata, reinforcement=reinforcement)
self.set_aligned_timestamps(aligned_timestamps=timestamps)
if self.source_data["alignment_path"] is not None:
aligned_starting_time = (
Expand All @@ -67,17 +74,26 @@ def align_timestamps(self, metadata: dict) -> np.ndarray:
self.set_aligned_starting_time(aligned_starting_time=aligned_starting_time)
return self.aligned_timestamps

def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict, velocity_modulation: bool = False) -> None:
def add_to_nwbfile(
self, nwbfile: NWBFile, metadata: dict, velocity_modulation: bool = False, reinforcement: bool = False
) -> None:
if velocity_modulation:
columns = ["uuid", "predicted_syllable"]
else:
columns = self.source_data["columns"]
if reinforcement:
filters = [
("uuid", "==", self.source_data["session_uuid"]),
("target_syllable", "==", metadata["Optogenetics"]["target_syllable"][0]),
]
else:
filters = [("uuid", "==", self.source_data["session_uuid"])]
session_df = pd.read_parquet(
self.source_data["file_path"],
columns=columns,
filters=[("uuid", "==", self.source_data["session_uuid"])],
filters=filters,
)
timestamps = self.align_timestamps(metadata=metadata)
timestamps = self.align_timestamps(metadata=metadata, reinforcement=reinforcement)
# Add Syllable Data
sorted_pseudoindex2name = metadata["BehavioralSyllable"]["sorted_pseudoindex2name"]
id2sorted_index = metadata["BehavioralSyllable"]["id2sorted_index"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def session_to_nwb(
raw_path: Union[str, Path],
output_dir_path: Union[str, Path],
experiment_type: Literal["reinforcement", "photometry", "reinforcement-photometry", "velocity-modulation"],
processed_only: bool = False,
stub_test: bool = False,
):
processed_path = Path(processed_path)
Expand Down Expand Up @@ -85,7 +86,7 @@ def session_to_nwb(
session_uuid=session_uuid,
session_id=session_id,
)
conversion_options["Optogenetic"] = {}
conversion_options["BehavioralSyllable"] = dict(reinforcement=True)
behavioral_syllable_path = optoda_path
if "photometry" in session_metadata.keys():
tdt_path = list(raw_path.glob("tdt_data*.dat"))[0]
Expand Down Expand Up @@ -122,6 +123,11 @@ def session_to_nwb(
if experiment_type == "velocity-modulation":
conversion_options["BehavioralSyllable"] = dict(velocity_modulation=True)
conversion_options["Optogenetic"] = dict(velocity_modulation=True)
if processed_only:
source_data.pop("MoseqExtract")
source_data.pop("DepthVideo")
conversion_options.pop("MoseqExtract")
conversion_options.pop("DepthVideo")

converter = DattaNWBConverter(source_data=source_data)
metadata = converter.get_metadata()
Expand Down Expand Up @@ -160,23 +166,31 @@ def session_to_nwb(
reinforcement_photometry_examples = [figure1d_example, pulsed_photometry_example, excitation_photometry_example]
raw_rp_example = "b814a426-7ec9-440e-baaa-105ba27a5fa6"
velocity_modulation_example = "c621e134-50ec-4e8b-8175-a8c023d92789"
duplicated_session_example = "1c5441a6-aee8-44ff-999d-6f0787ad4632"

experiment_type2example_sessions = {
"reinforcement-photometry": [raw_rp_example],
"velocity-modulation": [velocity_modulation_example],
"reinforcement": [duplicated_session_example],
}
experiment_type2raw_path = {
"reinforcement-photometry": raw_rp_path,
"velocity-modulation": raw_velocity_path,
"reinforcement": "",
}
for experiment_type, example_sessions in experiment_type2example_sessions.items():
if experiment_type == "reinforcement":
processed_only = True
else:
processed_only = False
for example_session in example_sessions:
session_to_nwb(
session_uuid=example_session,
processed_path=processed_path,
raw_path=experiment_type2raw_path[experiment_type],
output_dir_path=output_dir_path,
experiment_type=experiment_type,
processed_only=processed_only,
stub_test=stub_test,
)
with NWBHDF5IO(output_dir_path / f"reinforcement-photometry-{raw_rp_example}.nwb", "r") as io:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,24 @@ def get_metadata_schema(self) -> dict:
"stim_duration_s": {"type": "number"},
"power_watts": {"type": "number"},
"pulse_width_s": {"type": "number"},
"target_syllable": {"type": "number"},
"target_syllable": {"type": "array"},
},
}
return metadata_schema

def get_original_timestamps(self) -> np.ndarray:
def get_original_timestamps(self, metadata: dict) -> np.ndarray:
session_df = pd.read_parquet(
self.source_data["file_path"],
columns=["timestamp", "uuid"],
filters=[("uuid", "==", self.source_data["session_uuid"])],
filters=[
("uuid", "==", self.source_data["session_uuid"]),
("target_syllable", "==", metadata["Optogenetics"]["target_syllable"][0]),
],
)
return session_df["timestamp"].to_numpy()

def align_timestamps(self, metadata: dict, velocity_modulation: bool) -> np.ndarray:
timestamps = self.get_original_timestamps()
timestamps = self.get_original_timestamps(metadata=metadata)
self.set_aligned_timestamps(aligned_timestamps=timestamps)
if self.source_data["alignment_path"] is not None:
aligned_starting_time = (
Expand All @@ -106,7 +109,10 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict, velocity_modulation:
session_df = pd.read_parquet(
self.source_data["file_path"],
columns=self.source_data["columns"],
filters=[("uuid", "==", self.source_data["session_uuid"])],
filters=[
("uuid", "==", self.source_data["session_uuid"]),
("target_syllable", "==", metadata["Optogenetics"]["target_syllable"][0]),
],
)

device = nwbfile.create_device(
Expand All @@ -127,11 +133,11 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict, velocity_modulation:
else: # pulsed stim
data, timestamps = self.reconstruct_pulsed_stim(metadata, session_df, session_timestamps)
id2sorted_index = metadata["BehavioralSyllable"]["id2sorted_index"]
target_syllable = id2sorted_index[metadata["Optogenetics"]["target_syllable"]]
target_syllables = [id2sorted_index[syllable_id] for syllable_id in metadata["Optogenetics"]["target_syllable"]]
ogen_series = OptogeneticSeries(
name="OptogeneticSeries",
description="Onset of optogenetic stimulation is recorded as a 1, and offset is recorded as a 0.",
comments=f"target_syllable = {target_syllable}",
comments=f"target_syllable(s) = {target_syllables}",
site=ogen_site,
data=H5DataIO(data, compression=True),
timestamps=H5DataIO(timestamps, compression=True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@
def extract_photometry_metadata(
data_path: str,
example_uuids: str = None,
num_sessions: int = None,
reinforcement_photometry: bool = False,
) -> dict:
"""Extract metadata from photometry data.
Expand All @@ -212,8 +211,6 @@ def extract_photometry_metadata(
Path to data.
example_uuids : str, optional
UUID of example session to extract metadata from.
num_sessions : int, optional
Number of sessions to extract metadata from.
reinforcement_photometry : bool, optional
If True, extract metadata from reinforcement photometry sessions. If False, extract metadata from
non-reinforcement photometry sessions.
Expand Down Expand Up @@ -265,14 +262,10 @@ def extract_photometry_metadata(
del df
else:
uuids = set(example_uuids)
if num_sessions is None:
num_sessions = len(uuids)
session_metadata = {}
for i, uuid in enumerate(tqdm(uuids, desc="Extracting photometry session metadata")):
extract_session_metadata(session_columns, photometry_data_path, session_metadata, uuid)
session_metadata[uuid]["photometry"] = True
if i + 1 >= num_sessions:
break
subject_ids = set(session_metadata[uuid]["subject_id"] for uuid in session_metadata)
subject_metadata = {}
for mouse_id in tqdm(subject_ids, desc="Extracting photometry subject metadata"):
Expand All @@ -284,7 +277,7 @@ def extract_photometry_metadata(


def extract_reinforcement_metadata(
data_path: str, example_uuids: str = None, num_sessions: int = None, reinforcement_photometry: bool = False
data_path: str, example_uuids: str = None, reinforcement_photometry: bool = False
) -> dict:
"""Extract metadata from reinforcement data.

Expand All @@ -294,8 +287,6 @@ def extract_reinforcement_metadata(
Path to data.
example_uuids : str, optional
UUID of example session to extract metadata from.
num_sessions : int, optional
Number of sessions to extract metadata from.
reinforcement_photometry : bool, optional
If True, extract metadata from reinforcement photometry sessions. If False, extract metadata from
non-photometry reinforcement sessions.
Expand Down Expand Up @@ -346,18 +337,16 @@ def extract_reinforcement_metadata(
else:
uuids = set(example_uuids)
session_metadata, subject_metadata = {}, {}
if num_sessions is None:
num_sessions = len(uuids)
for i, uuid in enumerate(tqdm(uuids, desc="Extracting reinforcement session metadata")):
extract_session_metadata(session_columns, reinforcement_data_path, session_metadata, uuid)
session_df = extract_session_metadata(session_columns, reinforcement_data_path, session_metadata, uuid)
target_syllables = set(session_df.target_syllable[session_df.target_syllable.notnull()])
session_metadata[uuid]["target_syllable"] = list(target_syllables)
# add si units to names
session_metadata[uuid]["stim_duration_s"] = session_metadata[uuid].pop("stim_duration")
session_metadata[uuid]["stim_frequency_Hz"] = session_metadata[uuid].pop("stim_frequency")
session_metadata[uuid]["pulse_width_s"] = session_metadata[uuid].pop("pulse_width")
session_metadata[uuid]["power_watts"] = session_metadata[uuid].pop("power") / 1000
session_metadata[uuid]["reinforcement"] = True
if i + 1 >= num_sessions:
break
subject_ids = set(session_metadata[uuid]["subject_id"] for uuid in session_metadata)
for mouse_id in tqdm(subject_ids, desc="Extracting reinforcement subject metadata"):
extract_subject_metadata(subject_columns, reinforcement_data_path, subject_metadata, mouse_id)
Expand All @@ -370,7 +359,6 @@ def extract_reinforcement_metadata(
def extract_velocity_modulation_metadata(
data_path: str,
example_uuids: str = None,
num_sessions: int = None,
) -> dict:
"""Extract metadata from velocity modulation data.
Parameters
Expand All @@ -379,8 +367,6 @@ def extract_velocity_modulation_metadata(
Path to data.
example_uuids : str, optional
UUID of example session to extract metadata from.
num_sessions : int, optional
Number of sessions to extract metadata from.
Returns
-------
metadata : dict
Expand Down Expand Up @@ -412,19 +398,17 @@ def extract_velocity_modulation_metadata(
else:
uuids = set(example_uuids)
session_metadata, subject_metadata = {}, {}
if num_sessions is None:
num_sessions = len(uuids)
for i, uuid in enumerate(tqdm(uuids, desc="Extracting velocity-modulation session metadata")):
extract_session_metadata(session_columns, velocity_data_path, session_metadata, uuid)
session_df = extract_session_metadata(session_columns, velocity_data_path, session_metadata, uuid)
target_syllables = set(session_df.target_syllable[session_df.target_syllable.notnull()])
session_metadata[uuid]["target_syllable"] = list(target_syllables)
# add si units to names
session_metadata[uuid]["stim_duration_s"] = session_metadata[uuid].pop("stim_duration")
session_metadata[uuid]["stim_frequency_Hz"] = np.NaN
session_metadata[uuid]["pulse_width_s"] = np.NaN
session_metadata[uuid]["power_watts"] = 10 / 1000 # power = 10mW from paper
session_metadata[uuid]["reinforcement"] = True
session_metadata[uuid]["velocity_modulation"] = True
if i + 1 >= num_sessions:
break
subject_ids = set(session_metadata[uuid]["subject_id"] for uuid in session_metadata)
for mouse_id in tqdm(subject_ids, desc="Extracting reinforcement subject metadata"):
extract_subject_metadata(subject_columns, velocity_data_path, subject_metadata, mouse_id)
Expand All @@ -434,9 +418,7 @@ def extract_velocity_modulation_metadata(
return session_metadata, subject_metadata


def extract_reinforcement_photometry_metadata(
data_path: str, example_uuids: str = None, num_sessions: int = None
) -> dict:
def extract_reinforcement_photometry_metadata(data_path: str, example_uuids: str = None) -> dict:
"""Extract metadata from reinforcement photometry data.

Parameters
Expand All @@ -445,19 +427,17 @@ def extract_reinforcement_photometry_metadata(
Path to data.
example_uuids : str, optional
UUID of example session to extract metadata from.
num_sessions : int, optional
Number of sessions to extract metadata from.

Returns
-------
metadata : dict
Dictionary of metadata.
"""
photometry_session_metadata, photometry_subject_metadata = extract_photometry_metadata(
data_path, example_uuids, num_sessions, reinforcement_photometry=True
data_path, example_uuids, reinforcement_photometry=True
)
reinforcement_session_metadata, reinforcement_subject_metadata = extract_reinforcement_metadata(
data_path, example_uuids, num_sessions, reinforcement_photometry=True
data_path, example_uuids, reinforcement_photometry=True
)
photometry_uuids = set(photometry_session_metadata.keys())
reinforcement_uuids = set(reinforcement_session_metadata.keys())
Expand Down Expand Up @@ -542,6 +522,7 @@ def extract_session_metadata(columns, data_path, metadata, uuid):
date = timezone.localize(metadata[uuid].pop("date"))
metadata[uuid]["session_start_time"] = date.isoformat()
metadata[uuid]["subject_id"] = metadata[uuid].pop("mouse_id")
return session_df


def extract_subject_metadata(columns, data_path, metadata, subject_id):
Expand Down Expand Up @@ -627,7 +608,13 @@ def get_session_name(session_df):
reinforcement_example = "dcf0767a-b75d-4c79-a242-84dd5b5cdd00"
excitation_example = "380d4711-85a6-4672-ad48-76e91607c41f"
excitation_pulsed_example = "be01945e-c6d0-4bca-bd56-4d4466d9d832"
reinforcement_examples = [reinforcement_example, excitation_example, excitation_pulsed_example]
duplicated_session_example = "1c5441a6-aee8-44ff-999d-6f0787ad4632"
reinforcement_examples = [
reinforcement_example,
excitation_example,
excitation_pulsed_example,
duplicated_session_example,
]
figure1d_example = "2891f649-4fbd-4119-a807-b8ef507edfab"
pulsed_photometry_example = "b8360fcd-acfd-4414-9e67-ba0dc5c979a8"
excitation_photometry_example = "95bec433-2242-4276-b8a5-6d069afa3910"
Expand Down