diff --git a/src/napatrackmater/Trackmate.py b/src/napatrackmater/Trackmate.py index 2109aa3..2d5df38 100644 --- a/src/napatrackmater/Trackmate.py +++ b/src/napatrackmater/Trackmate.py @@ -960,7 +960,7 @@ def _master_track_computer(self, track, track_id, t_start=None, t_end=None): all_source_ids, all_target_ids = self._generate_generations(track) root_root, root_splits, root_leaf = self._create_generations(all_source_ids) - + self._iterate_split_down(root_root, root_leaf, root_splits) # Determine if a track has divisions or none number_dividing = len(root_splits) if number_dividing > 0: diff --git a/src/napatrackmater/Trackvector.py b/src/napatrackmater/Trackvector.py index 79d49fc..c51fada 100644 --- a/src/napatrackmater/Trackvector.py +++ b/src/napatrackmater/Trackvector.py @@ -20,7 +20,9 @@ from statsmodels.tsa.stattools import acf, ccf from scipy.stats import norm, anderson from kapoorlabs_lightning.lightning_trainer import MitosisInception - +from natsort import natsorted +import seaborn as sns +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -521,11 +523,23 @@ def plot_mitosis_times(self, full_dataframe, save_path=""): counts = dividing_counts.values data = {"Time": times, "Count": counts} df = pd.DataFrame(data) - np.save(save_path + "_counts.npy", df.to_numpy()) + np.save(save_path + "counts.npy", df.to_numpy()) max_number_dividing = full_dataframe["Number_Dividing"].max() min_number_dividing = full_dataframe["Number_Dividing"].min() - excluded_keys = ["Track ID", "t", "z", "y", "x"] + excluded_keys = [ + "Track ID", + "t", + "z", + "y", + "x", + "Unnamed: 0", + "Unnamed", + "Track Duration", + "Generation ID", + "TrackMate Track ID", + "Tracklet Number ID", + ] for i in range( min_number_dividing.astype(int), max_number_dividing.astype(int) + 1 ): @@ -534,12 +548,13 @@ def plot_mitosis_times(self, full_dataframe, save_path=""): data = full_dataframe[column][ full_dataframe["Number_Dividing"].astype(int) == i ] + np.save( - f"{save_path}_{column}_Number_Dividing_{i}.npy", data.to_numpy() + f"{save_path}{column}Number_Dividing_{i}.npy", data.to_numpy() ) all_split_data = [] - for split_id in self.split_cell_ids: + for split_id in tqdm(self.split_cell_ids, desc="Cell split IDs"): spot_properties = self.unique_spot_properties[split_id] track_id = spot_properties[self.trackid_key] unique_id = spot_properties[self.uniqueid_key] @@ -592,7 +607,7 @@ def plot_mitosis_times(self, full_dataframe, save_path=""): all_split_data.append(data) - np.save(f"{save_path}_data_at_mitosis_time.npy", all_split_data) + np.save(f"{save_path}data_at_mitosis_time.npy", all_split_data) def get_shape_dynamic_feature_dataframe(self): @@ -3157,6 +3172,14 @@ def extract_number_from_string(string): return numbers +def extract_number_dividing(file_name): + number_dividing = -1 # Default value if the pattern isn't found + match = re.search(r"Number_Dividing_(\d+)\.npy", file_name) + if match: + number_dividing = int(match.group(1)) + return number_dividing + + def cross_correlation_class(tracks_dataframe, cell_type_label=None): if cell_type_label is not None: @@ -3374,3 +3397,125 @@ def cross_correlation_class(tracks_dataframe, cell_type_label=None): N_gen_dynamic_test_dict, N_gen_shape_test_dict, ) + + +def plot_at_mitosis_time(matrix_directory, save_dir, dataset_name, channel): + + files = os.listdir(matrix_directory) + + sorted_files = natsorted( + [ + file + for file in files + if file.endswith(".npy") and "data_at_mitosis_time" in file + ], + key=lambda x: ( + "_Dynamic Cluster" in x, + "_Shape Dynamic Cluster" in x, + "_Shape Cluster" in x, + x, + ), + ) + + excluded_keys = [ + "Track ID", + "t", + "z", + "y", + "x", + "Unnamed: 0", + "Unnamed", + "Track Duration", + "Generation ID", + "TrackMate Track ID", + "Tracklet Number ID", + "Tracklet_ID", + "Unique_ID", + "Track_ID", + ] + + for file_name in sorted_files: + + all_split_data = np.load( + os.path.join(matrix_directory, file_name), allow_pickle=True + ) + + grouped_data = {} + for data_dict in all_split_data: + number_times_divided = data_dict.get("Number Times Divided") + for key, value in data_dict.items(): + if key not in excluded_keys: + if key not in grouped_data: + grouped_data[key] = {number_times_divided: [value]} + else: + if number_times_divided not in grouped_data[key]: + grouped_data[key][number_times_divided] = [value] + else: + grouped_data[key][number_times_divided].append(value) + + for property_name, property_data in grouped_data.items(): + plt.figure(figsize=(10, 6)) + for number_times_divided, values in property_data.items(): + sns.histplot( + values, + label=f"Number Times Divided {number_times_divided}", + kde=True, + bins=20, + edgecolor="black", + alpha=0.5, + ) + plt.xlabel("Property Values") + plt.ylabel("Frequency") + plt.title(f"{property_name}") + + fig_name = ( + f"{dataset_name}_{channel}_{property_name}_at_mitosis_distribution.png" + ) + plt.savefig(os.path.join(save_dir, fig_name), dpi=300, bbox_inches="tight") + plt.show() + + +def plot_histograms_for_groups(matrix_directory, save_dir, dataset_name, channel): + + files = os.listdir(matrix_directory) + sorted_files = natsorted( + [ + file + for file in files + if file.endswith(".npy") + and "Number_Dividing" in file + and "Number_DividingNumber_Dividing" + and "DividingNumber_Dividing" not in file + ] + ) + groups = set() + print(sorted_files) + for file_name in sorted_files: + group_name = file_name.split("Number_Dividing")[0] + + groups.add(group_name) + + for group_name in groups: + plt.figure(figsize=(8, 6)) + group_files = [file for file in sorted_files if group_name in file] + + for file_name in group_files: + file_path = os.path.join(matrix_directory, file_name) + number_dividing = extract_number_dividing(file_name) + + data = np.load(file_path) + sns.histplot( + data, alpha=0.5, kde=True, label=f"Number_Dividing: {number_dividing}" + ) + + plt.xlabel("Value") + plt.ylabel("Counts") + simplified_group_name = re.search(r"__(.*?)_Number_Dividing", group_name) + simplified_group_name = ( + simplified_group_name.group(1) if simplified_group_name else group_name + ) + plt.title(f"{simplified_group_name}") + plt.legend() + fig_name = f"{dataset_name}_{channel}_{group_name}_all_distribution.png" + plt.savefig(os.path.join(save_dir, fig_name), dpi=300, bbox_inches="tight") + plt.show() diff --git a/src/napatrackmater/_version.py b/src/napatrackmater/_version.py index 2234b07..c252764 100644 --- a/src/napatrackmater/_version.py +++ b/src/napatrackmater/_version.py @@ -1,2 +1,2 @@ -__version__ = version = "5.3.2" -__version_tuple__ = version_tuple = (5, 3, 2) +__version__ = version = "5.3.3" +__version_tuple__ = version_tuple = (5, 3, 3)