Skip to content

Commit

Permalink
add distribution plot functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Jul 12, 2024
1 parent e22e8cc commit 71c36c3
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/napatrackmater/Trackmate.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def _master_track_computer(self, track, track_id, t_start=None, t_end=None):

all_source_ids, all_target_ids = self._generate_generations(track)
root_root, root_splits, root_leaf = self._create_generations(all_source_ids)

self._iterate_split_down(root_root, root_leaf, root_splits)
# Determine if a track has divisions or none
number_dividing = len(root_splits)
if number_dividing > 0:
Expand Down
157 changes: 151 additions & 6 deletions src/napatrackmater/Trackvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from statsmodels.tsa.stattools import acf, ccf
from scipy.stats import norm, anderson
from kapoorlabs_lightning.lightning_trainer import MitosisInception

from natsort import natsorted
import seaborn as sns
from tqdm import tqdm

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -521,11 +523,23 @@ def plot_mitosis_times(self, full_dataframe, save_path=""):
counts = dividing_counts.values
data = {"Time": times, "Count": counts}
df = pd.DataFrame(data)
np.save(save_path + "_counts.npy", df.to_numpy())
np.save(save_path + "counts.npy", df.to_numpy())

max_number_dividing = full_dataframe["Number_Dividing"].max()
min_number_dividing = full_dataframe["Number_Dividing"].min()
excluded_keys = ["Track ID", "t", "z", "y", "x"]
excluded_keys = [
"Track ID",
"t",
"z",
"y",
"x",
"Unnamed: 0",
"Unnamed",
"Track Duration",
"Generation ID",
"TrackMate Track ID",
"Tracklet Number ID",
]
for i in range(
min_number_dividing.astype(int), max_number_dividing.astype(int) + 1
):
Expand All @@ -534,12 +548,13 @@ def plot_mitosis_times(self, full_dataframe, save_path=""):
data = full_dataframe[column][
full_dataframe["Number_Dividing"].astype(int) == i
]

np.save(
f"{save_path}_{column}_Number_Dividing_{i}.npy", data.to_numpy()
f"{save_path}{column}Number_Dividing_{i}.npy", data.to_numpy()
)

all_split_data = []
for split_id in self.split_cell_ids:
for split_id in tqdm(self.split_cell_ids, desc="Cell split IDs"):
spot_properties = self.unique_spot_properties[split_id]
track_id = spot_properties[self.trackid_key]
unique_id = spot_properties[self.uniqueid_key]
Expand Down Expand Up @@ -592,7 +607,7 @@ def plot_mitosis_times(self, full_dataframe, save_path=""):

all_split_data.append(data)

np.save(f"{save_path}_data_at_mitosis_time.npy", all_split_data)
np.save(f"{save_path}data_at_mitosis_time.npy", all_split_data)

def get_shape_dynamic_feature_dataframe(self):

Expand Down Expand Up @@ -3157,6 +3172,14 @@ def extract_number_from_string(string):
return numbers


def extract_number_dividing(file_name):
number_dividing = -1 # Default value if the pattern isn't found
match = re.search(r"Number_Dividing_(\d+)\.npy", file_name)
if match:
number_dividing = int(match.group(1))
return number_dividing


def cross_correlation_class(tracks_dataframe, cell_type_label=None):

if cell_type_label is not None:
Expand Down Expand Up @@ -3374,3 +3397,125 @@ def cross_correlation_class(tracks_dataframe, cell_type_label=None):
N_gen_dynamic_test_dict,
N_gen_shape_test_dict,
)


def plot_at_mitosis_time(matrix_directory, save_dir, dataset_name, channel):

files = os.listdir(matrix_directory)

sorted_files = natsorted(
[
file
for file in files
if file.endswith(".npy") and "data_at_mitosis_time" in file
],
key=lambda x: (
"_Dynamic Cluster" in x,
"_Shape Dynamic Cluster" in x,
"_Shape Cluster" in x,
x,
),
)

excluded_keys = [
"Track ID",
"t",
"z",
"y",
"x",
"Unnamed: 0",
"Unnamed",
"Track Duration",
"Generation ID",
"TrackMate Track ID",
"Tracklet Number ID",
"Tracklet_ID",
"Unique_ID",
"Track_ID",
]

for file_name in sorted_files:

all_split_data = np.load(
os.path.join(matrix_directory, file_name), allow_pickle=True
)

grouped_data = {}
for data_dict in all_split_data:
number_times_divided = data_dict.get("Number Times Divided")
for key, value in data_dict.items():
if key not in excluded_keys:
if key not in grouped_data:
grouped_data[key] = {number_times_divided: [value]}
else:
if number_times_divided not in grouped_data[key]:
grouped_data[key][number_times_divided] = [value]
else:
grouped_data[key][number_times_divided].append(value)

for property_name, property_data in grouped_data.items():
plt.figure(figsize=(10, 6))
for number_times_divided, values in property_data.items():
sns.histplot(
values,
label=f"Number Times Divided {number_times_divided}",
kde=True,
bins=20,
edgecolor="black",
alpha=0.5,
)
plt.xlabel("Property Values")
plt.ylabel("Frequency")
plt.title(f"{property_name}")

fig_name = (
f"{dataset_name}_{channel}_{property_name}_at_mitosis_distribution.png"
)
plt.savefig(os.path.join(save_dir, fig_name), dpi=300, bbox_inches="tight")
plt.show()


def plot_histograms_for_groups(matrix_directory, save_dir, dataset_name, channel):

files = os.listdir(matrix_directory)
sorted_files = natsorted(
[
file
for file in files
if file.endswith(".npy")
and "Number_Dividing" in file
and "Number_DividingNumber_Dividing"
and "DividingNumber_Dividing" not in file
]
)
groups = set()
print(sorted_files)
for file_name in sorted_files:
group_name = file_name.split("Number_Dividing")[0]

groups.add(group_name)

for group_name in groups:
plt.figure(figsize=(8, 6))
group_files = [file for file in sorted_files if group_name in file]

for file_name in group_files:
file_path = os.path.join(matrix_directory, file_name)
number_dividing = extract_number_dividing(file_name)

data = np.load(file_path)
sns.histplot(
data, alpha=0.5, kde=True, label=f"Number_Dividing: {number_dividing}"
)

plt.xlabel("Value")
plt.ylabel("Counts")
simplified_group_name = re.search(r"__(.*?)_Number_Dividing", group_name)
simplified_group_name = (
simplified_group_name.group(1) if simplified_group_name else group_name
)
plt.title(f"{simplified_group_name}")
plt.legend()
fig_name = f"{dataset_name}_{channel}_{group_name}_all_distribution.png"
plt.savefig(os.path.join(save_dir, fig_name), dpi=300, bbox_inches="tight")
plt.show()
4 changes: 2 additions & 2 deletions src/napatrackmater/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = version = "5.3.2"
__version_tuple__ = version_tuple = (5, 3, 2)
__version__ = version = "5.3.3"
__version_tuple__ = version_tuple = (5, 3, 3)

0 comments on commit 71c36c3

Please sign in to comment.