-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update starfish-compatible spotiflow detector, added misc script
- Loading branch information
1 parent
3561f09
commit 71ee355
Showing
3 changed files
with
262 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
Retrieve basic statistics of spot clusters in an image by running Spotiflow to detect individual spots and then aggreagating them according to a radius search. | ||
Usage: | ||
python analyze_spot_clusters.py --input /PATH/TO/IMG --model SPOITFLOW_MODEL --output ./out | ||
""" | ||
import argparse | ||
from pathlib import Path | ||
|
||
import networkx as nx | ||
import numpy as np | ||
import pandas as pd | ||
from skimage import io | ||
from sklearn.neighbors import radius_neighbors_graph | ||
from spotiflow.model import Spotiflow | ||
from spotiflow.utils import write_coords_csv | ||
|
||
|
||
def analyze_clusters(spots: np.ndarray, max_distance: float = 11.0): | ||
""" | ||
Get information of clusters by building an r-radius graph. | ||
""" | ||
adj_matrix = radius_neighbors_graph( | ||
spots, radius=max_distance, mode="distance", metric="euclidean" | ||
) | ||
graph = nx.from_scipy_sparse_array(adj_matrix) | ||
conn_components = nx.connected_components(graph) | ||
columns = ["cluster_id", "mean_y", "mean_x", "num_spots"] | ||
if spots.shape[1] == 3: | ||
columns.insert(1, "mean_z") | ||
df = pd.DataFrame(columns=columns) | ||
for i, component in enumerate(conn_components): | ||
curr_spots = spots[list(component)] | ||
center = np.mean(curr_spots, axis=0) | ||
if center.shape[0] == 3: | ||
mean_z, mean_y, mean_x = center | ||
else: | ||
mean_y, mean_x = center | ||
|
||
component_data = { | ||
"cluster_id": i, | ||
"num_spots": len(component), | ||
"mean_y": mean_y, | ||
"mean_x": mean_x, | ||
} | ||
if center.shape[0] == 3: | ||
component_data["mean_z"] = mean_z | ||
curr_df = pd.DataFrame(component_data, index=[0]) | ||
df = pd.concat([df, curr_df], ignore_index=True) | ||
return df | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input", type=Path, help="Path to the image") | ||
parser.add_argument("--model", type=Path, help="Path to the model") | ||
parser.add_argument("--output", type=Path, help="Path to the output folder") | ||
parser.add_argument("--max-distance", type=float, default=11.0, help="Max distance to consider two spots as part of the same cluster") | ||
args = parser.parse_args() | ||
|
||
args.output.mkdir(exist_ok=True, parents=True) | ||
|
||
img = io.imread(args.input) # load the image | ||
model = Spotiflow.from_folder(args.model) # load the clusters model | ||
spots, _ = model.predict(img, normalizer="auto") | ||
print("Analyzing clusters...") | ||
clusters_df = analyze_clusters(spots, max_distance=args.max_distance) | ||
clusters_df.to_csv(args.output / f"{args.input.stem}_clusters.csv", index=False) | ||
print("Done!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" | ||
This script showcases how Spotiflow can be used to detect spots in an end-to-end starfish pipeline | ||
through the spotiflow.starfish.SpotiflowDetector class. | ||
Usage: | ||
python run_starfish_spotiflow.py | ||
""" | ||
from starfish import data | ||
from starfish import FieldOfView | ||
from starfish.util.plot import imshow_plane | ||
from starfish.types import Axes | ||
from spotiflow.starfish import SpotiflowDetector | ||
import matplotlib.pyplot as plt | ||
import tifffile | ||
import starfish | ||
from starfish.types import TraceBuildingStrategies | ||
import numpy as np | ||
|
||
|
||
if __name__ == "__main__": | ||
print("Loading data...") | ||
experiment = data.STARmap(use_test_data=True) | ||
stack = experiment['fov_000'].get_image('primary') | ||
print("Projecting...") | ||
projection = stack.reduce({Axes.CH, Axes.ZPLANE}, func="max") | ||
reference_image = projection.sel({Axes.ROUND: 0}) | ||
print("Registering...") | ||
ltt = starfish.image.LearnTransform.Translation( | ||
reference_stack=reference_image, | ||
axes=Axes.ROUND, | ||
upsampling=1000, | ||
) | ||
transforms = ltt.run(projection) | ||
|
||
warp = starfish.image.ApplyTransform.Warp() | ||
stack = warp.run( | ||
stack=stack, | ||
transforms_list=transforms, | ||
) | ||
|
||
print("Detecting spots...") | ||
bd = SpotiflowDetector( | ||
model="smfish_3d", | ||
min_distance=1, | ||
is_volume=True, | ||
probability_threshold=.4, | ||
) | ||
|
||
spots_spotiflow = bd.run(stack, n_processes=1) | ||
|
||
print("Decoding...") | ||
decoder = starfish.spots.DecodeSpots.PerRoundMaxChannel( | ||
codebook=experiment.codebook, | ||
anchor_round=0, | ||
search_radius=10, | ||
trace_building_strategy=TraceBuildingStrategies.NEAREST_NEIGHBOR | ||
) | ||
|
||
decoded_spotiflow = decoder.run(spots=spots_spotiflow) | ||
decoded_spotiflow_df = decoded_spotiflow.to_features_dataframe()[["z", "y", "x", "target"]] | ||
decoded_spotiflow_df_nonan = decoded_spotiflow_df[decoded_spotiflow_df["target"] != "nan"].reset_index(drop=True) | ||
decoded_spotiflow_df.to_csv("decoded_starmap_spotiflow.csv", index=False) | ||
print("Decoded results saved to decoded_starmap_spotiflow.csv") |
Oops, something went wrong.