diff --git a/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/behavioralsyllableinterface.py b/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/behavioralsyllableinterface.py index 655b35b..920922d 100644 --- a/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/behavioralsyllableinterface.py +++ b/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/behavioralsyllableinterface.py @@ -49,16 +49,23 @@ def get_metadata_schema(self) -> dict: } return metadata_schema - def get_original_timestamps(self) -> np.ndarray: + def get_original_timestamps(self, metadata: dict, reinforcement: bool = False) -> np.ndarray: + if reinforcement: + filters = [ + ("uuid", "==", self.source_data["session_uuid"]), + ("target_syllable", "==", metadata["Optogenetics"]["target_syllable"][0]), + ] + else: + filters = [("uuid", "==", self.source_data["session_uuid"])] session_df = pd.read_parquet( self.source_data["file_path"], columns=["timestamp", "uuid"], - filters=[("uuid", "==", self.source_data["session_uuid"])], + filters=filters, ) return session_df["timestamp"].to_numpy() - def align_timestamps(self, metadata: dict) -> np.ndarray: - timestamps = self.get_original_timestamps() + def align_timestamps(self, metadata: dict, reinforcement: bool = False) -> np.ndarray: + timestamps = self.get_original_timestamps(metadata=metadata, reinforcement=reinforcement) self.set_aligned_timestamps(aligned_timestamps=timestamps) if self.source_data["alignment_path"] is not None: aligned_starting_time = ( @@ -67,17 +74,26 @@ def align_timestamps(self, metadata: dict) -> np.ndarray: self.set_aligned_starting_time(aligned_starting_time=aligned_starting_time) return self.aligned_timestamps - def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict, velocity_modulation: bool = False) -> None: + def add_to_nwbfile( + self, nwbfile: NWBFile, metadata: dict, velocity_modulation: bool = False, reinforcement: bool = False + ) -> None: if velocity_modulation: columns = ["uuid", "predicted_syllable"] else: columns = self.source_data["columns"] + if reinforcement: + filters = [ + ("uuid", "==", self.source_data["session_uuid"]), + ("target_syllable", "==", metadata["Optogenetics"]["target_syllable"][0]), + ] + else: + filters = [("uuid", "==", self.source_data["session_uuid"])] session_df = pd.read_parquet( self.source_data["file_path"], columns=columns, - filters=[("uuid", "==", self.source_data["session_uuid"])], + filters=filters, ) - timestamps = self.align_timestamps(metadata=metadata) + timestamps = self.align_timestamps(metadata=metadata, reinforcement=reinforcement) # Add Syllable Data sorted_pseudoindex2name = metadata["BehavioralSyllable"]["sorted_pseudoindex2name"] id2sorted_index = metadata["BehavioralSyllable"]["id2sorted_index"] diff --git a/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/convert_session.py b/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/convert_session.py index 7bf1309..38b44cb 100644 --- a/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/convert_session.py +++ b/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/convert_session.py @@ -19,6 +19,7 @@ def session_to_nwb( raw_path: Union[str, Path], output_dir_path: Union[str, Path], experiment_type: Literal["reinforcement", "photometry", "reinforcement-photometry", "velocity-modulation"], + processed_only: bool = False, stub_test: bool = False, ): processed_path = Path(processed_path) @@ -85,7 +86,7 @@ def session_to_nwb( session_uuid=session_uuid, session_id=session_id, ) - conversion_options["Optogenetic"] = {} + conversion_options["BehavioralSyllable"] = dict(reinforcement=True) behavioral_syllable_path = optoda_path if "photometry" in session_metadata.keys(): tdt_path = list(raw_path.glob("tdt_data*.dat"))[0] @@ -122,6 +123,11 @@ def session_to_nwb( if experiment_type == "velocity-modulation": conversion_options["BehavioralSyllable"] = dict(velocity_modulation=True) conversion_options["Optogenetic"] = dict(velocity_modulation=True) + if processed_only: + source_data.pop("MoseqExtract") + source_data.pop("DepthVideo") + conversion_options.pop("MoseqExtract") + conversion_options.pop("DepthVideo") converter = DattaNWBConverter(source_data=source_data) metadata = converter.get_metadata() @@ -160,16 +166,23 @@ def session_to_nwb( reinforcement_photometry_examples = [figure1d_example, pulsed_photometry_example, excitation_photometry_example] raw_rp_example = "b814a426-7ec9-440e-baaa-105ba27a5fa6" velocity_modulation_example = "c621e134-50ec-4e8b-8175-a8c023d92789" + duplicated_session_example = "1c5441a6-aee8-44ff-999d-6f0787ad4632" experiment_type2example_sessions = { "reinforcement-photometry": [raw_rp_example], "velocity-modulation": [velocity_modulation_example], + "reinforcement": [duplicated_session_example], } experiment_type2raw_path = { "reinforcement-photometry": raw_rp_path, "velocity-modulation": raw_velocity_path, + "reinforcement": "", } for experiment_type, example_sessions in experiment_type2example_sessions.items(): + if experiment_type == "reinforcement": + processed_only = True + else: + processed_only = False for example_session in example_sessions: session_to_nwb( session_uuid=example_session, @@ -177,6 +190,7 @@ def session_to_nwb( raw_path=experiment_type2raw_path[experiment_type], output_dir_path=output_dir_path, experiment_type=experiment_type, + processed_only=processed_only, stub_test=stub_test, ) with NWBHDF5IO(output_dir_path / f"reinforcement-photometry-{raw_rp_example}.nwb", "r") as io: diff --git a/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/optogeneticinterface.py b/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/optogeneticinterface.py index 8e4f027..a26efda 100644 --- a/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/optogeneticinterface.py +++ b/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/optogeneticinterface.py @@ -75,21 +75,24 @@ def get_metadata_schema(self) -> dict: "stim_duration_s": {"type": "number"}, "power_watts": {"type": "number"}, "pulse_width_s": {"type": "number"}, - "target_syllable": {"type": "number"}, + "target_syllable": {"type": "array"}, }, } return metadata_schema - def get_original_timestamps(self) -> np.ndarray: + def get_original_timestamps(self, metadata: dict) -> np.ndarray: session_df = pd.read_parquet( self.source_data["file_path"], columns=["timestamp", "uuid"], - filters=[("uuid", "==", self.source_data["session_uuid"])], + filters=[ + ("uuid", "==", self.source_data["session_uuid"]), + ("target_syllable", "==", metadata["Optogenetics"]["target_syllable"][0]), + ], ) return session_df["timestamp"].to_numpy() def align_timestamps(self, metadata: dict, velocity_modulation: bool) -> np.ndarray: - timestamps = self.get_original_timestamps() + timestamps = self.get_original_timestamps(metadata=metadata) self.set_aligned_timestamps(aligned_timestamps=timestamps) if self.source_data["alignment_path"] is not None: aligned_starting_time = ( @@ -106,7 +109,10 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict, velocity_modulation: session_df = pd.read_parquet( self.source_data["file_path"], columns=self.source_data["columns"], - filters=[("uuid", "==", self.source_data["session_uuid"])], + filters=[ + ("uuid", "==", self.source_data["session_uuid"]), + ("target_syllable", "==", metadata["Optogenetics"]["target_syllable"][0]), + ], ) device = nwbfile.create_device( @@ -127,11 +133,11 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict, velocity_modulation: else: # pulsed stim data, timestamps = self.reconstruct_pulsed_stim(metadata, session_df, session_timestamps) id2sorted_index = metadata["BehavioralSyllable"]["id2sorted_index"] - target_syllable = id2sorted_index[metadata["Optogenetics"]["target_syllable"]] + target_syllables = [id2sorted_index[syllable_id] for syllable_id in metadata["Optogenetics"]["target_syllable"]] ogen_series = OptogeneticSeries( name="OptogeneticSeries", description="Onset of optogenetic stimulation is recorded as a 1, and offset is recorded as a 0.", - comments=f"target_syllable = {target_syllable}", + comments=f"target_syllable(s) = {target_syllables}", site=ogen_site, data=H5DataIO(data, compression=True), timestamps=H5DataIO(timestamps, compression=True), diff --git a/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/preconversion/extract_metadata.py b/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/preconversion/extract_metadata.py index 5399344..91212e0 100644 --- a/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/preconversion/extract_metadata.py +++ b/src/datta_lab_to_nwb/markowitz_gillis_nature_2023/preconversion/extract_metadata.py @@ -201,7 +201,6 @@ def extract_photometry_metadata( data_path: str, example_uuids: str = None, - num_sessions: int = None, reinforcement_photometry: bool = False, ) -> dict: """Extract metadata from photometry data. @@ -212,8 +211,6 @@ def extract_photometry_metadata( Path to data. example_uuids : str, optional UUID of example session to extract metadata from. - num_sessions : int, optional - Number of sessions to extract metadata from. reinforcement_photometry : bool, optional If True, extract metadata from reinforcement photometry sessions. If False, extract metadata from non-reinforcement photometry sessions. @@ -265,14 +262,10 @@ def extract_photometry_metadata( del df else: uuids = set(example_uuids) - if num_sessions is None: - num_sessions = len(uuids) session_metadata = {} for i, uuid in enumerate(tqdm(uuids, desc="Extracting photometry session metadata")): extract_session_metadata(session_columns, photometry_data_path, session_metadata, uuid) session_metadata[uuid]["photometry"] = True - if i + 1 >= num_sessions: - break subject_ids = set(session_metadata[uuid]["subject_id"] for uuid in session_metadata) subject_metadata = {} for mouse_id in tqdm(subject_ids, desc="Extracting photometry subject metadata"): @@ -284,7 +277,7 @@ def extract_photometry_metadata( def extract_reinforcement_metadata( - data_path: str, example_uuids: str = None, num_sessions: int = None, reinforcement_photometry: bool = False + data_path: str, example_uuids: str = None, reinforcement_photometry: bool = False ) -> dict: """Extract metadata from reinforcement data. @@ -294,8 +287,6 @@ def extract_reinforcement_metadata( Path to data. example_uuids : str, optional UUID of example session to extract metadata from. - num_sessions : int, optional - Number of sessions to extract metadata from. reinforcement_photometry : bool, optional If True, extract metadata from reinforcement photometry sessions. If False, extract metadata from non-photometry reinforcement sessions. @@ -346,18 +337,16 @@ def extract_reinforcement_metadata( else: uuids = set(example_uuids) session_metadata, subject_metadata = {}, {} - if num_sessions is None: - num_sessions = len(uuids) for i, uuid in enumerate(tqdm(uuids, desc="Extracting reinforcement session metadata")): - extract_session_metadata(session_columns, reinforcement_data_path, session_metadata, uuid) + session_df = extract_session_metadata(session_columns, reinforcement_data_path, session_metadata, uuid) + target_syllables = set(session_df.target_syllable[session_df.target_syllable.notnull()]) + session_metadata[uuid]["target_syllable"] = list(target_syllables) # add si units to names session_metadata[uuid]["stim_duration_s"] = session_metadata[uuid].pop("stim_duration") session_metadata[uuid]["stim_frequency_Hz"] = session_metadata[uuid].pop("stim_frequency") session_metadata[uuid]["pulse_width_s"] = session_metadata[uuid].pop("pulse_width") session_metadata[uuid]["power_watts"] = session_metadata[uuid].pop("power") / 1000 session_metadata[uuid]["reinforcement"] = True - if i + 1 >= num_sessions: - break subject_ids = set(session_metadata[uuid]["subject_id"] for uuid in session_metadata) for mouse_id in tqdm(subject_ids, desc="Extracting reinforcement subject metadata"): extract_subject_metadata(subject_columns, reinforcement_data_path, subject_metadata, mouse_id) @@ -370,7 +359,6 @@ def extract_reinforcement_metadata( def extract_velocity_modulation_metadata( data_path: str, example_uuids: str = None, - num_sessions: int = None, ) -> dict: """Extract metadata from velocity modulation data. Parameters @@ -379,8 +367,6 @@ def extract_velocity_modulation_metadata( Path to data. example_uuids : str, optional UUID of example session to extract metadata from. - num_sessions : int, optional - Number of sessions to extract metadata from. Returns ------- metadata : dict @@ -412,10 +398,10 @@ def extract_velocity_modulation_metadata( else: uuids = set(example_uuids) session_metadata, subject_metadata = {}, {} - if num_sessions is None: - num_sessions = len(uuids) for i, uuid in enumerate(tqdm(uuids, desc="Extracting velocity-modulation session metadata")): - extract_session_metadata(session_columns, velocity_data_path, session_metadata, uuid) + session_df = extract_session_metadata(session_columns, velocity_data_path, session_metadata, uuid) + target_syllables = set(session_df.target_syllable[session_df.target_syllable.notnull()]) + session_metadata[uuid]["target_syllable"] = list(target_syllables) # add si units to names session_metadata[uuid]["stim_duration_s"] = session_metadata[uuid].pop("stim_duration") session_metadata[uuid]["stim_frequency_Hz"] = np.NaN @@ -423,8 +409,6 @@ def extract_velocity_modulation_metadata( session_metadata[uuid]["power_watts"] = 10 / 1000 # power = 10mW from paper session_metadata[uuid]["reinforcement"] = True session_metadata[uuid]["velocity_modulation"] = True - if i + 1 >= num_sessions: - break subject_ids = set(session_metadata[uuid]["subject_id"] for uuid in session_metadata) for mouse_id in tqdm(subject_ids, desc="Extracting reinforcement subject metadata"): extract_subject_metadata(subject_columns, velocity_data_path, subject_metadata, mouse_id) @@ -434,9 +418,7 @@ def extract_velocity_modulation_metadata( return session_metadata, subject_metadata -def extract_reinforcement_photometry_metadata( - data_path: str, example_uuids: str = None, num_sessions: int = None -) -> dict: +def extract_reinforcement_photometry_metadata(data_path: str, example_uuids: str = None) -> dict: """Extract metadata from reinforcement photometry data. Parameters @@ -445,8 +427,6 @@ def extract_reinforcement_photometry_metadata( Path to data. example_uuids : str, optional UUID of example session to extract metadata from. - num_sessions : int, optional - Number of sessions to extract metadata from. Returns ------- @@ -454,10 +434,10 @@ def extract_reinforcement_photometry_metadata( Dictionary of metadata. """ photometry_session_metadata, photometry_subject_metadata = extract_photometry_metadata( - data_path, example_uuids, num_sessions, reinforcement_photometry=True + data_path, example_uuids, reinforcement_photometry=True ) reinforcement_session_metadata, reinforcement_subject_metadata = extract_reinforcement_metadata( - data_path, example_uuids, num_sessions, reinforcement_photometry=True + data_path, example_uuids, reinforcement_photometry=True ) photometry_uuids = set(photometry_session_metadata.keys()) reinforcement_uuids = set(reinforcement_session_metadata.keys()) @@ -542,6 +522,7 @@ def extract_session_metadata(columns, data_path, metadata, uuid): date = timezone.localize(metadata[uuid].pop("date")) metadata[uuid]["session_start_time"] = date.isoformat() metadata[uuid]["subject_id"] = metadata[uuid].pop("mouse_id") + return session_df def extract_subject_metadata(columns, data_path, metadata, subject_id): @@ -627,7 +608,13 @@ def get_session_name(session_df): reinforcement_example = "dcf0767a-b75d-4c79-a242-84dd5b5cdd00" excitation_example = "380d4711-85a6-4672-ad48-76e91607c41f" excitation_pulsed_example = "be01945e-c6d0-4bca-bd56-4d4466d9d832" - reinforcement_examples = [reinforcement_example, excitation_example, excitation_pulsed_example] + duplicated_session_example = "1c5441a6-aee8-44ff-999d-6f0787ad4632" + reinforcement_examples = [ + reinforcement_example, + excitation_example, + excitation_pulsed_example, + duplicated_session_example, + ] figure1d_example = "2891f649-4fbd-4119-a807-b8ef507edfab" pulsed_photometry_example = "b8360fcd-acfd-4414-9e67-ba0dc5c979a8" excitation_photometry_example = "95bec433-2242-4276-b8a5-6d069afa3910"