Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add motion correction for Widefield Imaging #16

Merged
merged 7 commits into from
Nov 27, 2023
1 change: 1 addition & 0 deletions src/pinto_lab_to_nwb/widefield/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .motion_correction import load_motion_correction_data
92 changes: 92 additions & 0 deletions src/pinto_lab_to_nwb/widefield/utils/motion_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import List

import numpy as np
import pymatreader
from neuroconv.tools import get_module
from pynwb import NWBFile, TimeSeries
from roiextractors.extraction_tools import DtypeType


def add_motion_correction(
nwbfile: NWBFile,
motion_correction_series: np.ndarray,
one_photon_series_name: str,
convert_to_dtype: DtypeType = None,
) -> None:
"""Add motion correction data to the NWBFile.

The x, y shifts for the imaging data (identified by 'one_photon_series_name' are added to the NWBFile as a TimeSeries.
The series is added to the 'ophys' processing module.

Parameters
----------
nwbfile: NWBFile
The NWBFile where the motion correction time series will be added to.
motion_correction_series: numpy.ndarray
The x, y shifts for the imaging data.
one_photon_series_name: str
The name of the one photon series in the NWBFile.
convert_to_dtype: DtypeType, optional
The dtype to convert the motion correction series to.
"""
convert_to_dtype = convert_to_dtype or np.uint16

assert (
one_photon_series_name in nwbfile.acquisition
), f"The one photon series '{one_photon_series_name}' does not exist in the NWBFile."
name_suffix = one_photon_series_name.replace("OnePhotonSeries", "")
motion_correction_time_series_name = "MotionCorrectionSeries" + name_suffix
ophys = get_module(nwbfile, "ophys")
if motion_correction_time_series_name in ophys.data_interfaces:
raise ValueError(
f"The motion correction time series '{motion_correction_time_series_name}' already exists in the NWBFile."
)

one_photon_series = nwbfile.acquisition[one_photon_series_name]
num_frames = one_photon_series.data.maxshape[0]
assert (
num_frames == motion_correction_series.shape[0]
), f"The number of frames for motion correction ({motion_correction_series.shape[0]}) does not match the number of frames ({num_frames}) from the {one_photon_series_name} imaging data."
xy_translation = TimeSeries(
name="MotionCorrectionSeries" + name_suffix,
description=f"The x, y shifts for the {one_photon_series_name} imaging data.",
data=motion_correction_series.astype(dtype=convert_to_dtype),
unit="px",
timestamps=one_photon_series.timestamps,
)
ophys.add(xy_translation)


def load_motion_correction_data(file_paths: List[str]) -> np.ndarray:
"""Load motion correction data from mat files.

Parameters
----------
file_paths: List[str]
A list of paths to the motion correction files.

Returns
-------
motion_correction_data: numpy.ndarray
The concatenated xy shifts from all the files.
"""
motion_correction_data = np.concatenate([get_xy_shifts(file_path=str(file)) for file in file_paths], axis=0)
return motion_correction_data


def get_xy_shifts(file_path: str) -> np.ndarray:
"""Get the x, y (column, row) shifts from the motion correction file.

Parameters
----------
file_path: str
The path to the motion correction file.

Returns
-------
motion_correction: numpy.ndarray
The first column is the x shifts. The second column is the y shifts.
"""
motion_correction_data = pymatreader.read_mat(file_path)
motion_correction = np.column_stack((motion_correction_data["xShifts"], motion_correction_data["yShifts"]))
return motion_correction
46 changes: 46 additions & 0 deletions src/pinto_lab_to_nwb/widefield/widefieldnwbconverter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
"""Primary NWBConverter class for this dataset."""
from pathlib import Path
from typing import Optional, Dict

import numpy as np
from natsort import natsorted
from neuroconv import NWBConverter
from pynwb import NWBFile

from pinto_lab_to_nwb.widefield.interfaces import WidefieldImagingInterface, WidefieldProcessedImagingInterface
from pinto_lab_to_nwb.widefield.utils import load_motion_correction_data
from pinto_lab_to_nwb.widefield.utils.motion_correction import add_motion_correction
from pinto_lab_to_nwb.widefield.interfaces import (
WidefieldImagingInterface,
WidefieldProcessedImagingInterface,
Expand All @@ -22,3 +31,40 @@ class WideFieldNWBConverter(NWBConverter):
SummaryImagesBlue=WidefieldSegmentationImagesBlueInterface,
SummaryImagesViolet=WidefieldSegmentationImagesVioletInterface,
)

def __init__(self, source_data: Dict[str, dict], verbose: bool = True):
super().__init__(source_data, verbose)

# Load motion correction data
imaging_interface = self.data_interface_objects["ImagingBlue"]
imaging_folder_path = imaging_interface.source_data["folder_path"]
imaging_folder_name = Path(imaging_folder_path).stem
motion_correction_mat_files = natsorted(Path(imaging_folder_path).glob(f"{imaging_folder_name}*mcorr_1.mat"))
assert motion_correction_mat_files, f"No motion correction files found in {imaging_folder_path}."
self._motion_correction_data = load_motion_correction_data(file_paths=motion_correction_mat_files)

def add_to_nwbfile(self, nwbfile: NWBFile, metadata, conversion_options: Optional[dict] = None) -> None:
super().add_to_nwbfile(nwbfile=nwbfile, metadata=metadata, conversion_options=conversion_options)

# Add motion correction for blue and violet frames
imaging_interface_names = ["ImagingBlue", "ImagingViolet"]
for interface_name in imaging_interface_names:
photon_series_index = conversion_options[interface_name]["photon_series_index"]
one_photon_series_name = metadata["Ophys"]["OnePhotonSeries"][photon_series_index]["name"]

imaging_interface = self.data_interface_objects[interface_name]
frame_indices = imaging_interface.imaging_extractor.frame_indices
# filter motion correction for blue/violet frames
motion_correction = self._motion_correction_data[frame_indices, :]
if interface_name in conversion_options:
if "stub_test" in conversion_options[interface_name]:
if conversion_options[interface_name]["stub_test"]:
num_frames = 100
motion_correction = motion_correction[:num_frames, :]

add_motion_correction(
nwbfile=nwbfile,
motion_correction_series=motion_correction,
one_photon_series_name=one_photon_series_name,
convert_to_dtype=np.uint16,
)