Skip to content

Commit

Permalink
update starfish-compatible spotiflow detector, added misc script
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Dec 18, 2024
1 parent 3561f09 commit 71ee355
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 81 deletions.
69 changes: 69 additions & 0 deletions extra/analyze_spot_clusters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Retrieve basic statistics of spot clusters in an image by running Spotiflow to detect individual spots and then aggreagating them according to a radius search.
Usage:
python analyze_spot_clusters.py --input /PATH/TO/IMG --model SPOITFLOW_MODEL --output ./out
"""
import argparse
from pathlib import Path

import networkx as nx
import numpy as np
import pandas as pd
from skimage import io
from sklearn.neighbors import radius_neighbors_graph
from spotiflow.model import Spotiflow
from spotiflow.utils import write_coords_csv


def analyze_clusters(spots: np.ndarray, max_distance: float = 11.0):
"""
Get information of clusters by building an r-radius graph.
"""
adj_matrix = radius_neighbors_graph(
spots, radius=max_distance, mode="distance", metric="euclidean"
)
graph = nx.from_scipy_sparse_array(adj_matrix)
conn_components = nx.connected_components(graph)
columns = ["cluster_id", "mean_y", "mean_x", "num_spots"]
if spots.shape[1] == 3:
columns.insert(1, "mean_z")
df = pd.DataFrame(columns=columns)
for i, component in enumerate(conn_components):
curr_spots = spots[list(component)]
center = np.mean(curr_spots, axis=0)
if center.shape[0] == 3:
mean_z, mean_y, mean_x = center
else:
mean_y, mean_x = center

component_data = {
"cluster_id": i,
"num_spots": len(component),
"mean_y": mean_y,
"mean_x": mean_x,
}
if center.shape[0] == 3:
component_data["mean_z"] = mean_z
curr_df = pd.DataFrame(component_data, index=[0])
df = pd.concat([df, curr_df], ignore_index=True)
return df


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=Path, help="Path to the image")
parser.add_argument("--model", type=Path, help="Path to the model")
parser.add_argument("--output", type=Path, help="Path to the output folder")
parser.add_argument("--max-distance", type=float, default=11.0, help="Max distance to consider two spots as part of the same cluster")
args = parser.parse_args()

args.output.mkdir(exist_ok=True, parents=True)

img = io.imread(args.input) # load the image
model = Spotiflow.from_folder(args.model) # load the clusters model
spots, _ = model.predict(img, normalizer="auto")
print("Analyzing clusters...")
clusters_df = analyze_clusters(spots, max_distance=args.max_distance)
clusters_df.to_csv(args.output / f"{args.input.stem}_clusters.csv", index=False)
print("Done!")
63 changes: 63 additions & 0 deletions extra/run_starfish_spotiflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
This script showcases how Spotiflow can be used to detect spots in an end-to-end starfish pipeline
through the spotiflow.starfish.SpotiflowDetector class.
Usage:
python run_starfish_spotiflow.py
"""
from starfish import data
from starfish import FieldOfView
from starfish.util.plot import imshow_plane
from starfish.types import Axes
from spotiflow.starfish import SpotiflowDetector
import matplotlib.pyplot as plt
import tifffile
import starfish
from starfish.types import TraceBuildingStrategies
import numpy as np


if __name__ == "__main__":
print("Loading data...")
experiment = data.STARmap(use_test_data=True)
stack = experiment['fov_000'].get_image('primary')
print("Projecting...")
projection = stack.reduce({Axes.CH, Axes.ZPLANE}, func="max")
reference_image = projection.sel({Axes.ROUND: 0})
print("Registering...")
ltt = starfish.image.LearnTransform.Translation(
reference_stack=reference_image,
axes=Axes.ROUND,
upsampling=1000,
)
transforms = ltt.run(projection)

warp = starfish.image.ApplyTransform.Warp()
stack = warp.run(
stack=stack,
transforms_list=transforms,
)

print("Detecting spots...")
bd = SpotiflowDetector(
model="smfish_3d",
min_distance=1,
is_volume=True,
probability_threshold=.4,
)

spots_spotiflow = bd.run(stack, n_processes=1)

print("Decoding...")
decoder = starfish.spots.DecodeSpots.PerRoundMaxChannel(
codebook=experiment.codebook,
anchor_round=0,
search_radius=10,
trace_building_strategy=TraceBuildingStrategies.NEAREST_NEIGHBOR
)

decoded_spotiflow = decoder.run(spots=spots_spotiflow)
decoded_spotiflow_df = decoded_spotiflow.to_features_dataframe()[["z", "y", "x", "target"]]
decoded_spotiflow_df_nonan = decoded_spotiflow_df[decoded_spotiflow_df["target"] != "nan"].reset_index(drop=True)
decoded_spotiflow_df.to_csv("decoded_starmap_spotiflow.csv", index=False)
print("Decoded results saved to decoded_starmap_spotiflow.csv")
Loading

0 comments on commit 71ee355

Please sign in to comment.