Skip to content

Commit

Permalink
knn fit
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Oct 29, 2023
1 parent 42b0c21 commit 05f1d7c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 67 deletions.
143 changes: 78 additions & 65 deletions src/napatrackmater/Trackvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.neighbors import KNeighborsClassifier
from joblib import dump


class TrackVector(TrackMate):
def __init__(
self,
Expand Down Expand Up @@ -562,11 +563,9 @@ def create_gt_analysis_vectors_dict(global_shape_dynamic_dataframe: pd.DataFrame
gt_dataframe = track_data[
[
"Cluster",

]
]



full_dataframe = track_data[
[
"Track ID",
Expand All @@ -587,7 +586,7 @@ def create_gt_analysis_vectors_dict(global_shape_dynamic_dataframe: pd.DataFrame
"Distance_Cell_mask",
"Radial_Angle",
"Cell_Axis_Mask",
"Cluster"
"Cluster",
]
]

Expand All @@ -602,83 +601,94 @@ def create_gt_analysis_vectors_dict(global_shape_dynamic_dataframe: pd.DataFrame

return gt_analysis_vectors


def create_global_gt_dataframe(

full_dataframe,
ground_truth_csv_file,
calibration_z,
calibration_y,
calibration_x,
time_veto_threshold=0.0, space_veto_threshold=5.0, cell_type_key = "Celltype_label"
full_dataframe,
ground_truth_csv_file,
calibration_z,
calibration_y,
calibration_x,
time_veto_threshold=0.0,
space_veto_threshold=5.0,
cell_type_key="Celltype_label",
):

ground_truth_data_frame = pd.read_csv(ground_truth_csv_file)
ground_truth_data_frame.dropna(subset=[cell_type_key], inplace=True)

# Prepare ground truth tuples and labels
ground_truth_tuples = np.array([
ground_truth_data_frame['FRAME'].values,
ground_truth_data_frame['POSITION_Z'].values / calibration_z,
ground_truth_data_frame['POSITION_Y'].values / calibration_y,
ground_truth_data_frame['POSITION_X'].values / calibration_x,
]).T
ground_truth_labels = ground_truth_data_frame[cell_type_key].values
theory_tuples_spatial = full_dataframe[["t", "z", "y", "x"]].values

tree_spatial = cKDTree(theory_tuples_spatial)

# Initialize arrays to store the indices of the closest theory tuples and their corresponding ground truth labels
closest_theory_indices = []
corresponding_ground_truth_labels = []
closest_theory_tuples_found = []
closest_theory_track_ids = []
# Find the closest theory tuple for each ground truth tuple
for i, (ground_tuple, ground_label) in enumerate(zip(ground_truth_tuples, ground_truth_labels)):

ground_truth_data_frame = pd.read_csv(ground_truth_csv_file)
ground_truth_data_frame.dropna(subset=[cell_type_key], inplace=True)

# Prepare ground truth tuples and labels
ground_truth_tuples = np.array(
[
ground_truth_data_frame["FRAME"].values,
ground_truth_data_frame["POSITION_Z"].values / calibration_z,
ground_truth_data_frame["POSITION_Y"].values / calibration_y,
ground_truth_data_frame["POSITION_X"].values / calibration_x,
]
).T
ground_truth_labels = ground_truth_data_frame[cell_type_key].values
theory_tuples_spatial = full_dataframe[["t", "z", "y", "x"]].values

tree_spatial = cKDTree(theory_tuples_spatial)

# Initialize arrays to store the indices of the closest theory tuples and their corresponding ground truth labels
closest_theory_indices = []
corresponding_ground_truth_labels = []
closest_theory_tuples_found = []
closest_theory_track_ids = []
# Find the closest theory tuple for each ground truth tuple
for i, (ground_tuple, ground_label) in enumerate(
zip(ground_truth_tuples, ground_truth_labels)
):
# Use the KD-Tree for nearest-neighbor search
spatial_valid_indices = tree_spatial.query_ball_point(ground_tuple, space_veto_threshold, p=2)
spatial_valid_indices = tree_spatial.query_ball_point(
ground_tuple, space_veto_threshold, p=2
)

if spatial_valid_indices:
# Find the closest theory index within the common indices
closest_theory_index = spatial_valid_indices[np.argmin(tree_spatial.query(ground_tuple)[0])]
closest_theory_index = spatial_valid_indices[
np.argmin(tree_spatial.query(ground_tuple)[0])
]

# Get the closest theory tuple
closest_theory_tuple = full_dataframe.loc[closest_theory_index]

# Check if the index is valid, within the DataFrame's range, and satisfies the time veto
if (
0 <= closest_theory_index < len(full_dataframe)
and abs(closest_theory_tuple['t'] - ground_tuple[0]) <= time_veto_threshold
and abs(closest_theory_tuple["t"] - ground_tuple[0])
<= time_veto_threshold
):
closest_theory_indices.append(closest_theory_index)
corresponding_ground_truth_labels.append(ground_label)
closest_theory_tuples_found.append(closest_theory_tuple)
closest_theory_track_ids.append(closest_theory_tuple['Track ID'])
closest_theory_track_ids.append(closest_theory_tuple["Track ID"])

track_id_to_cluster = {
track_id: cluster_label
for track_id, cluster_label in zip(
closest_theory_track_ids, corresponding_ground_truth_labels
)
}
full_dataframe["Cluster"] = full_dataframe["Track ID"].map(track_id_to_cluster)

return full_dataframe

track_id_to_cluster = {
track_id: cluster_label
for track_id, cluster_label in zip(
closest_theory_track_ids, corresponding_ground_truth_labels
)
}
full_dataframe["Cluster"] = full_dataframe["Track ID"].map(track_id_to_cluster)

return full_dataframe

def supervised_clustering(
csv_file_name,
gt_analysis_vectors,
num_clusters,
csv_file_name,
gt_analysis_vectors,
num_clusters,
):
csv_file_name_original = csv_file_name + '_training_data'
csv_file_name_original = csv_file_name + "_training_data"
data_list = []


for track_id, (
shape_dynamic_dataframe_list,
gt_dataframe_list,
full_dataframe_list,
) in gt_analysis_vectors.items():

shape_dynamic_track_array = np.array(
[
[item for item in record.values()]
Expand All @@ -688,36 +698,39 @@ def supervised_clustering(

gt_track_array = np.array(
[[item for item in record.values()] for record in gt_dataframe_list]

)
if not np.isnan(gt_track_array[0]) and shape_dynamic_track_array.shape[0] > 1:
print(gt_track_array[0][0], shape_dynamic_track_array.shape)
(
shape_dynamic_covariance,
shape_dynamic_eigenvectors,
) = compute_covariance_matrix(shape_dynamic_track_array)

flattened_covariance = shape_dynamic_covariance.flatten()
data_list.append({
'Flattened_Covariance': flattened_covariance,
'gt_label': gt_track_array[0][0]
})
data_list.append(
{
"Flattened_Covariance": flattened_covariance,
"gt_label": gt_track_array[0][0],
}
)
result_dataframe = pd.DataFrame(data_list)
if os.path.exists(csv_file_name_original):
os.remove(csv_file_name_original)
os.remove(csv_file_name_original)
result_dataframe.to_csv(csv_file_name_original, index=False)
X = np.vstack(result_dataframe['Flattened_Covariance'].values)
y = result_dataframe['gt_label'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.01, random_state=42)
X = np.vstack(result_dataframe["Flattened_Covariance"].values)
y = result_dataframe["gt_label"].values
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.01, random_state=42
)
knn = KNeighborsClassifier(n_neighbors=num_clusters)
knn.fit(X_train, y_train)
accuracy = knn.score(X_test, y_test)
print(f"Model Accuracy: {accuracy:.2f}")

model_filename = 'knn_model.joblib'
model_filename = "knn_model.joblib"
dump(knn, model_filename)

return knn
return knn


def unsupervised_clustering(
full_dataframe,
Expand Down
4 changes: 2 additions & 2 deletions src/napatrackmater/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = version = "4.3.9"
__version_tuple__ = version_tuple = (4, 3, 9)
__version__ = version = "4.4.0"
__version_tuple__ = version_tuple = (4, 4, 0)

0 comments on commit 05f1d7c

Please sign in to comment.