Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Suite2pSegmentationExtractor to support multi channel and multi plane outputs #242

Merged
merged 31 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d6be165
update with streams
weiglszonja Sep 14, 2023
47e9903
Merge branch 'main' into update_suite2psegmentationextractor
weiglszonja Sep 14, 2023
03fbc1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2023
8b04ae7
fix io test
weiglszonja Sep 14, 2023
1a35f8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2023
28dabeb
fix extractor test
weiglszonja Sep 14, 2023
d9474fa
pydocstyle
weiglszonja Sep 14, 2023
2874854
fix tests
weiglszonja Sep 16, 2023
fc07661
allow for incomplete input
weiglszonja Sep 18, 2023
802164d
Merge branch 'main' into update_suite2psegmentationextractor
weiglszonja Sep 28, 2023
a4b9df2
Merge branch 'main' into update_suite2psegmentationextractor
weiglszonja Oct 2, 2023
3ddda89
change from stream to channel+plane name
alessandratrapani Oct 2, 2023
9607007
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2023
798994c
add tests
weiglszonja Oct 3, 2023
7658e61
change to warn for multi channel multi plane
weiglszonja Oct 3, 2023
849a0f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2023
1155278
pydocstyle
weiglszonja Oct 3, 2023
c02ac97
Merge remote-tracking branch 'origin/update_suite2psegmentationextrac…
weiglszonja Oct 3, 2023
b442c87
refactor get_available_planes and get_available_channels
weiglszonja Oct 5, 2023
1107f86
fix io test for suite2p
weiglszonja Oct 5, 2023
bcda76b
format code
alessandratrapani Oct 5, 2023
c9d7628
Merge branch 'update_suite2psegmentationextractor' of https://github.…
alessandratrapani Oct 5, 2023
1529a41
add docstring for plane and channel name
weiglszonja Oct 16, 2023
84be179
Merge branch 'main' into update_suite2psegmentationextractor
weiglszonja Oct 16, 2023
980d738
add channel name as class attribute
weiglszonja Oct 17, 2023
7e198f5
add _image_masks to suite2p extractor
weiglszonja Oct 23, 2023
636265d
remove get_roi_image_masks override
weiglszonja Oct 23, 2023
2b86402
Merge branch 'main' into update_suite2psegmentationextractor
weiglszonja Oct 23, 2023
ade4b3f
Merge branch 'main' into update_suite2psegmentationextractor
weiglszonja Oct 30, 2023
55589ca
update CHANGELOG.md
weiglszonja Oct 31, 2023
c26c4ca
Merge remote-tracking branch 'origin/update_suite2psegmentationextrac…
weiglszonja Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Upcoming

### Features

* Updated `Suite2pSegmentationExtractor` to support multi channel and multi plane data. [PR #242](https://github.com/catalystneuro/roiextractors/pull/242)



# v0.5.4

### Features
Expand Down
246 changes: 173 additions & 73 deletions src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import shutil
from pathlib import Path
from typing import Optional

from warnings import warn
import os
import numpy as np

from ...extraction_tools import PathType, IntType
from ...extraction_tools import PathType
from ...extraction_tools import _image_mask_extractor
from ...multisegmentationextractor import MultiSegmentationExtractor
from ...segmentationextractor import SegmentationExtractor
Expand All @@ -23,134 +24,233 @@ class Suite2pSegmentationExtractor(SegmentationExtractor):
extractor_name = "Suite2pSegmentationExtractor"
installed = True # check at class level if installed or not
is_writable = False
mode = "file"
mode = "folder"
installation_mesg = "" # error message when not installed

@classmethod
def get_available_channels(cls, folder_path: PathType):
"""Get the available channel names from the folder paths produced by Suite2p.

Parameters
----------
file_path : PathType
Path to Suite2p output path.

Returns
-------
channel_names: list
List of channel names.
"""
plane_names = cls.get_available_planes(folder_path=folder_path)

channel_names = ["chan1"]
second_channel_paths = list((Path(folder_path) / plane_names[0]).glob("F_chan2.npy"))
if not second_channel_paths:
return channel_names
channel_names.append("chan2")

return channel_names

@classmethod
def get_available_planes(cls, folder_path: PathType):
"""Get the available plane names from the folder produced by Suite2p.

Parameters
----------
file_path : PathType
Path to Suite2p output path.

Returns
-------
plane_names: list
List of plane names.
"""
from natsort import natsorted

folder_path = Path(folder_path)
prefix = "plane"
plane_paths = natsorted(folder_path.glob(pattern=prefix + "*"))
assert len(plane_paths), f"No planes found in '{folder_path}'."
plane_names = [plane_path.stem for plane_path in plane_paths]
return plane_names

def __init__(
self,
folder_path: Optional[PathType] = None,
combined: bool = False,
plane_no: IntType = 0,
file_path: Optional[PathType] = None,
folder_path: PathType,
channel_name: Optional[str] = None,
plane_name: Optional[str] = None,
combined: Optional[bool] = None, # TODO: to be removed
plane_no: Optional[int] = None, # TODO: to be removed
):
"""Create SegmentationExtractor object out of suite 2p data type.

Parameters
----------
folder_path: str or Path
~/suite2p folder location on disk
combined: bool
if the plane is a combined plane as in the Suite2p pipeline
plane_no: int
the plane for which to extract segmentation for.
file_path: str or Path [Deprecated]
~/suite2p folder location on disk
The path to the 'suite2p' folder.
channel_name: str, optional
The name of the channel to load, to determine what channels are available use Suite2pSegmentationExtractor.get_available_channels(folder_path).
plane_name: str, optional
The name of the plane to load, to determine what planes are available use Suite2pSegmentationExtractor.get_available_planes(folder_path).

"""
from warnings import warn

if file_path is not None:
if combined:
warning_string = "Keyword argument 'combined' is deprecated and will be removed on or after Nov, 2023. "
warn(
message=warning_string,
category=DeprecationWarning,
)
if plane_no:
warning_string = (
"The keyword argument 'file_path' is being deprecated on or after August, 2022 in favor of 'folder_path'. "
"'folder_path' takes precence over 'file_path'."
"Keyword argument 'plane_no' is deprecated and will be removed on or after Nov, 2023 in favor of 'plane_name'."
"Specify which stream you wish to load with the 'plane_name' keyword argument."
)
warn(
message=warning_string,
category=DeprecationWarning,
)
folder_path = file_path if folder_path is None else folder_path

SegmentationExtractor.__init__(self)
self.combined = combined
self.plane_no = plane_no
channel_names = self.get_available_channels(folder_path=folder_path)
if channel_name is None:
if len(channel_names) > 1:
# For backward compatibility maybe it is better to warn first
warn(
"More than one channel is detected! Please specify which channel you wish to load with the `channel_name` argument. "
"To see what channels are available, call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`.",
UserWarning,
)
channel_name = channel_names[0]

self.channel_name = channel_name
if self.channel_name not in channel_names:
raise ValueError(
f"The selected channel '{channel_name}' is not a valid channel name. To see what channels are available, "
f"call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`."
)

plane_names = self.get_available_planes(folder_path=folder_path)
if plane_name is None:
if len(plane_names) > 1:
# For backward compatibility maybe it is better to warn first
warn(
"More than one plane is detected! Please specify which plane you wish to load with the `plane_name` argument. "
"To see what planes are available, call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`.",
UserWarning,
)
plane_name = plane_names[0]

if plane_name not in plane_names:
raise ValueError(
f"The selected plane '{plane_name}' is not a valid plane name. To see what planes are available, "
f"call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`."
)
self.plane_name = plane_name

super().__init__()

self.folder_path = Path(folder_path)

self.stat = self._load_npy("stat.npy")
self._roi_response_raw = self._load_npy("F.npy", mmap_mode="r").T
self._roi_response_neuropil = self._load_npy("Fneu.npy", mmap_mode="r").T
self._roi_response_deconvolved = self._load_npy("spks.npy", mmap_mode="r").T
options = self._load_npy(file_name="ops.npy")
self.options = options.item() if options is not None else options
self._sampling_frequency = self.options["fs"]
self._num_frames = self.options["nframes"]
self._image_size = (self.options["Ly"], self.options["Lx"])

self.stat = self._load_npy(file_name="stat.npy")

fluorescence_traces_file_name = "F.npy" if channel_name == "chan1" else "F_chan2.npy"
neuropil_traces_file_name = "Fneu.npy" if channel_name == "chan1" else "Fneu_chan2.npy"
self._roi_response_raw = self._load_npy(file_name=fluorescence_traces_file_name, mmap_mode="r", transpose=True)
self._roi_response_neuropil = self._load_npy(file_name=neuropil_traces_file_name, mmap_mode="r", transpose=True)
self._roi_response_deconvolved = (
self._load_npy(file_name="spks.npy", mmap_mode="r", transpose=True) if channel_name == "chan1" else None
)

self.iscell = self._load_npy("iscell.npy", mmap_mode="r")
self.ops = self._load_npy("ops.npy").item()

self._channel_names = [f"OpticalChannel{i}" for i in range(self.ops["nchannels"])]
self._sampling_frequency = self.ops["fs"] * [2 if self.combined else 1][0]
self._raw_movie_file_location = self.ops.get("filelist", [None])[0]
self._image_correlation = self._summary_image_read("Vcorr")
self._image_mean = self._summary_image_read("meanImg")
channel_name = "OpticalChannel" if len(channel_names) == 1 else channel_name.capitalize()
self._channel_names = [channel_name]

self._image_correlation = self._correlation_image_read()
image_mean_name = "meanImg" if channel_name == "chan1" else f"meanImg_chan2"
self._image_mean = self.options[image_mean_name] if image_mean_name in self.options else None
roi_indices = list(range(self.get_num_rois()))
self._image_masks = _image_mask_extractor(
self.get_roi_pixel_masks(),
roi_indices,
self.get_image_size(),
)

def _load_npy(self, filename, mmap_mode=None):
"""Load a .npy file with specified filename.
def _load_npy(self, file_name: str, mmap_mode=None, transpose: bool = False):
"""Load a .npy file with specified filename. Returns None if file is missing.

Parameters
----------
filename: str
file_name: str
The name of the .npy file to load.
mmap_mode: str
The mode to use for memory mapping. See numpy.load for details.
transpose: bool, optional
Whether to transpose the loaded array.

Returns
-------
The loaded .npy file.
The loaded .npy file.
"""
file_path = self.folder_path / f"plane{self.plane_no}" / filename
return np.load(file_path, mmap_mode=mmap_mode, allow_pickle=mmap_mode is None)
file_path = self.folder_path / self.plane_name / file_name
if not file_path.exists():
return

data = np.load(file_path, mmap_mode=mmap_mode, allow_pickle=mmap_mode is None)
if transpose:
return data.T

return data

def get_num_frames(self) -> int:
return self._num_frames

def get_accepted_list(self):
return list(np.where(self.iscell[:, 0] == 1)[0])

def get_rejected_list(self):
return list(np.where(self.iscell[:, 0] == 0)[0])

def _summary_image_read(self, bstr="meanImg"):
"""Read summary image from ops (settings) dict.

Parameters
----------
bstr: str
The name of the summary image to read.
def _correlation_image_read(self):
"""Read correlation image from ops (settings) dict.

Returns
-------
img : numpy.ndarray | None
The summary image if bstr is in ops, else None.
The correlation image.
"""
img = None
if bstr in self.ops:
if bstr == "Vcorr" or bstr == "max_proj":
img = np.zeros((self.ops["Ly"], self.ops["Lx"]), np.float32)
img[
(self.ops["Ly"] - self.ops["yrange"][-1]) : (self.ops["Ly"] - self.ops["yrange"][0]),
self.ops["xrange"][0] : self.ops["xrange"][-1],
] = self.ops[bstr]
else:
img = self.ops[bstr]
if "Vcorr" not in self.options:
return None

correlation_image = self.options["Vcorr"]
if (self.options["yrange"][-1], self.options["xrange"][-1]) == self._image_size:
return correlation_image

img = np.zeros(self._image_size, correlation_image.dtype)
img[
(self.options["Ly"] - self.options["yrange"][-1]) : (self.options["Ly"] - self.options["yrange"][0]),
self.options["xrange"][0] : self.options["xrange"][-1],
] = correlation_image

return img

@property
def roi_locations(self):
"""Returns the center locations (x, y) of each ROI."""
return np.array([j["med"] for j in self.stat]).T.astype(int)

def get_roi_image_masks(self, roi_ids=None):
if roi_ids is None:
roi_idx_ = range(self.get_num_rois())
else:
roi_idx = [np.where(np.array(i) == self.get_roi_ids())[0] for i in roi_ids]
ele = [i for i, j in enumerate(roi_idx) if j.size == 0]
roi_idx_ = [j[0] for i, j in enumerate(roi_idx) if i not in ele]
return _image_mask_extractor(
self.get_roi_pixel_masks(roi_ids=roi_idx_),
list(range(len(roi_idx_))),
self.get_image_size(),
)

def get_roi_pixel_masks(self, roi_ids=None):
pixel_mask = []
for i in range(self.get_num_rois()):
pixel_mask.append(
np.vstack(
[
self.ops["Ly"] - 1 - self.stat[i]["ypix"],
self.stat[i]["ypix"],
self.stat[i]["xpix"],
self.stat[i]["lam"],
]
Expand All @@ -165,7 +265,7 @@ def get_roi_pixel_masks(self, roi_ids=None):
return [pixel_mask[i] for i in roi_idx_]

def get_image_size(self):
return [self.ops["Ly"], self.ops["Lx"]]
return self._image_size

@staticmethod
def write_segmentation(segmentation_object: SegmentationExtractor, save_path: PathType, overwrite=True):
Expand Down Expand Up @@ -238,7 +338,7 @@ def write_segmentation(segmentation_object: SegmentationExtractor, save_path: Pa
for no, i in enumerate(stat):
stat[no] = {
"med": roi_locs[no, :].tolist(),
"ypix": segmentation_object.get_image_size()[0] - 1 - pixel_masks[no][:, 0],
"ypix": pixel_masks[no][:, 0],
"xpix": pixel_masks[no][:, 1],
"lam": pixel_masks[no][:, 2],
}
Expand Down
10 changes: 5 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def test_imaging_extractors_canonical_shape(self, extractor_class, extractor_kwa
),
param(
extractor_class=Suite2pSegmentationExtractor,
extractor_kwargs=dict(folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p")),
),
param(
extractor_class=Suite2pSegmentationExtractor,
extractor_kwargs=dict(file_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p")),
extractor_kwargs=dict(
folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p"),
channel_name="chan1",
plane_name="plane0",
),
),
]

Expand Down
Loading
Loading