diff --git a/src/napatrackmater/Trackvector.py b/src/napatrackmater/Trackvector.py index 4974bd8..a3a9126 100644 --- a/src/napatrackmater/Trackvector.py +++ b/src/napatrackmater/Trackvector.py @@ -3389,7 +3389,9 @@ def plot_at_mitosis_time(matrix_directory, save_dir, dataset_name, channel): plt.show() -def plot_histograms_for_groups(matrix_directory, save_dir, dataset_name, channel, name = 'all'): +def plot_histograms_for_groups( + matrix_directory, save_dir, dataset_name, channel, name="all" +): files = os.listdir(matrix_directory) sorted_files = natsorted( @@ -3440,7 +3442,7 @@ def plot_histograms_for_groups(matrix_directory, save_dir, dataset_name, channel def plot_histograms_for_cell_type_groups( - matrix_directory, save_dir, dataset_name, channel, label_dict = None, name = 'all' + matrix_directory, save_dir, dataset_name, channel, label_dict=None, name="all" ): files = os.listdir(matrix_directory) @@ -3470,11 +3472,13 @@ def plot_histograms_for_cell_type_groups( file_path = os.path.join(matrix_directory, file_name) cell_type = extract_celltype(file_name) if label_dict is not None: - cell_type_name = label_dict[cell_type] + cell_type_name = label_dict[cell_type] else: - cell_type_name = cell_type + cell_type_name = cell_type data = np.load(file_path, allow_pickle=True) - sns.histplot(data, alpha=0.5, kde=True, label=f"Cell_Type: {cell_type_name}") + sns.histplot( + data, alpha=0.5, kde=True, label=f"Cell_Type: {cell_type_name}" + ) plt.xlabel("Value") plt.ylabel("Counts") @@ -4042,7 +4046,6 @@ def inception_model_prediction( class_map, dynamic_model=None, shape_model=None, - num_samples=10, device="cpu", ): sub_dataframe = dataframe[dataframe["Track ID"] == track_id] @@ -4051,13 +4054,13 @@ def inception_model_prediction( total_duration = sub_dataframe["Track Duration"].max() - def sample_subarrays(data, num_samples, tracklet_length, total_duration): + if sub_dataframe.shape[0] < tracklet_length: + return "UnClassified" + + def sample_subarrays(data, tracklet_length, total_duration): max_start_index = total_duration - tracklet_length - if max_start_index > num_samples: - start_indices = random.sample(range(max_start_index), num_samples) - else: - start_indices = [0] * num_samples + start_indices = random.sample(range(max_start_index), max_start_index) subarrays = [] for start_index in start_indices: @@ -4070,10 +4073,10 @@ def sample_subarrays(data, num_samples, tracklet_length, total_duration): return subarrays sub_arrays_shape = sample_subarrays( - sub_dataframe_shape, num_samples, tracklet_length, total_duration + sub_dataframe_shape, tracklet_length, total_duration ) sub_arrays_dynamic = sample_subarrays( - sub_dataframe_dynamic, num_samples, tracklet_length, total_duration + sub_dataframe_dynamic, tracklet_length, total_duration ) def make_prediction(input_data, model): @@ -4089,12 +4092,11 @@ def make_prediction(input_data, model): return predicted_class.item() def get_most_frequent_prediction(predictions): - if predictions: - prediction_counts = Counter(predictions) - most_common_prediction, count = prediction_counts.most_common(1)[0] + prediction_counts = Counter(predictions) + most_common_prediction, count = prediction_counts.most_common(1)[0] - return most_common_prediction + return most_common_prediction shape_predictions = [] if shape_model is not None: @@ -4109,16 +4111,11 @@ def get_most_frequent_prediction(predictions): dynamic_predictions.append(predicted_class) final_predictions = shape_predictions + dynamic_predictions - if len(final_predictions) > 0: - most_frequent_prediction = get_most_frequent_prediction(final_predictions) - if most_frequent_prediction is not None: - most_predicted_class = class_map[int(most_frequent_prediction)] - - return most_predicted_class + most_frequent_prediction = get_most_frequent_prediction(final_predictions) - else: + most_predicted_class = class_map[int(most_frequent_prediction)] - return "UnClassified" + return most_predicted_class def save_cell_type_predictions( diff --git a/src/napatrackmater/_version.py b/src/napatrackmater/_version.py index 26034aa..1e50c25 100644 --- a/src/napatrackmater/_version.py +++ b/src/napatrackmater/_version.py @@ -1,2 +1,2 @@ -__version__ = version = "5.4.9" -__version_tuple__ = version_tuple = (5, 4, 9) +__version__ = version = "5.5.0" +__version_tuple__ = version_tuple = (5, 5, 0)