Skip to content

Commit

Permalink
all checked
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Dec 4, 2023
1 parent 1387287 commit b7291e8
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 69 deletions.
4 changes: 2 additions & 2 deletions src/napatrackmater/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = version = "4.5.1"
__version_tuple__ = version_tuple = (4, 5, 1)
__version__ = version = "4.5.2"
__version_tuple__ = version_tuple = (4, 5, 2)
174 changes: 107 additions & 67 deletions src/napatrackmater/clustering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from kapoorlabs_lightning.lightning_trainer import AutoLightningModel
from kapoorlabs_lightning.pytorch_models import CloudAutoEncoder
import numpy as np
import concurrent
import os
Expand All @@ -17,6 +16,7 @@
from typing import List
from tqdm import tqdm


class PointCloudDataset(Dataset):
def __init__(self, clouds: List[PyntCloud], center=True, scale_z=1.0, scale_xy=1.0):
self.clouds = clouds
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
center=True,
compute_with_autoencoder=True,
):

self.pretrainer = pretrainer
self.accelerator = accelerator
self.devices = devices
Expand All @@ -79,12 +79,11 @@ def __init__(
self.timed_cluster_label = {}
self.timed_latent_features = {}
self.count = 0

def _compute_latent_features(self):

ndim = len(self.label_image.shape)

# YX image
if ndim == 2:

labels, centroids, clouds, marching_cube_points = _label_cluster(
Expand All @@ -94,7 +93,11 @@ def _compute_latent_features(self):
ndim,
self.compute_with_autoencoder,
)
latent_features, cluster_centroids, output_largest_eigenvalues = _extract_latent_features(
(
latent_features,
cluster_centroids,
output_largest_eigenvalues,
) = _extract_latent_features(
self.model,
self.accelerator,
clouds,
Expand All @@ -103,64 +106,31 @@ def _compute_latent_features(self):
self.batch_size,
self.scale_z,
self.scale_xy,

)

self.timed_latent_features[str(self.key)] = latent_features, cluster_centroids, output_largest_eigenvalues
self.timed_latent_features[str(self.key)] = (
latent_features,
cluster_centroids,
output_largest_eigenvalues,
)

# ZYX image
if ndim == 3 and "T" not in self.axes:

labels, centroids, clouds, marching_cube_points = _label_cluster(
self.label_image,
self.num_points,
self.min_size,
ndim,
self.compute_with_autoencoder,
)
if len(labels) > 1:

latent_features, cluster_centroids, output_largest_eigenvalues = _extract_latent_features(
self.model,
self.accelerator,
clouds,
marching_cube_points,
centroids,
self.batch_size,
self.scale_z,
self.scale_xy,
)

self.timed_latent_features[str(self.key)] = latent_features, cluster_centroids, output_largest_eigenvalues

# TYX
if ndim == 3 and "T" in self.axes:

for i in range(self.label_image.shape[0]):
latent_features, cluster_centroids, output_largest_eigenvalues = self._latent_computer(i, ndim - 1)
self.timed_latent_features[str(i)] = latent_features , cluster_centroids, output_largest_eigenvalues


# TZYX image
if ndim == 4:

for i in range(self.label_image.shape[0]):
latent_features, cluster_centroids, output_largest_eigenvalues = self._latent_computer(i, ndim)
self.timed_latent_features[str(i)] = latent_features, cluster_centroids, output_largest_eigenvalues

def _latent_computer(self, i, dim):

xyz_label_image = self.label_image[i, :]
labels, centroids, clouds, marching_cube_points = _label_cluster(
xyz_label_image,
self.label_image,
self.num_points,
self.min_size,
dim,
ndim,
self.compute_with_autoencoder,
)
if len(labels) > 1:

latent_features, cluster_centroids, output_largest_eigenvalues = _extract_latent_features(

(
latent_features,
cluster_centroids,
output_largest_eigenvalues,
) = _extract_latent_features(
self.model,
self.accelerator,
clouds,
Expand All @@ -169,9 +139,71 @@ def _latent_computer(self, i, dim):
self.batch_size,
self.scale_z,
self.scale_xy,
)
return latent_features, cluster_centroids, output_largest_eigenvalues

)

self.timed_latent_features[str(self.key)] = (
latent_features,
cluster_centroids,
output_largest_eigenvalues,
)

# TYX
if ndim == 3 and "T" in self.axes:

for i in range(self.label_image.shape[0]):
(
latent_features,
cluster_centroids,
output_largest_eigenvalues,
) = self._latent_computer(i, ndim - 1)
self.timed_latent_features[str(i)] = (
latent_features,
cluster_centroids,
output_largest_eigenvalues,
)

# TZYX image
if ndim == 4:

for i in range(self.label_image.shape[0]):
(
latent_features,
cluster_centroids,
output_largest_eigenvalues,
) = self._latent_computer(i, ndim)
self.timed_latent_features[str(i)] = (
latent_features,
cluster_centroids,
output_largest_eigenvalues,
)

def _latent_computer(self, i, dim):

xyz_label_image = self.label_image[i, :]
labels, centroids, clouds, marching_cube_points = _label_cluster(
xyz_label_image,
self.num_points,
self.min_size,
dim,
self.compute_with_autoencoder,
)
if len(labels) > 1:

(
latent_features,
cluster_centroids,
output_largest_eigenvalues,
) = _extract_latent_features(
self.model,
self.accelerator,
clouds,
marching_cube_points,
centroids,
self.batch_size,
self.scale_z,
self.scale_xy,
)
return latent_features, cluster_centroids, output_largest_eigenvalues

def _create_cluster_labels(self):

Expand Down Expand Up @@ -349,6 +381,7 @@ def _label_computer(self, i, dim):
output_cloud_surface_area,
)


def _extract_latent_features(
model: AutoLightningModel,
accelerator: str,
Expand All @@ -371,20 +404,23 @@ def _extract_latent_features(
device = accelerator
torch_model.to(device)
latent_features = []
output_largest_eigenvalue = [get_eccentricity(cloud_input)[2] if get_eccentricity(cloud_input) is not None else -1 for cloud_input in marching_cube_points]

output_largest_eigenvalue = [
get_eccentricity(cloud_input)[2]
if get_eccentricity(cloud_input) is not None
else -1
for cloud_input in marching_cube_points
]

for batch in dataloader:

with torch.no_grad():
batch = batch.to(device).float()
latent_representation_list = torch_model.encoder(batch)
latent_representation_list = torch_model.encoder(batch)
for latent_representation in latent_representation_list:
latent_features.append(latent_representation.cpu().numpy())


latent_features.append(latent_representation.cpu().numpy())

return latent_features, output_cluster_centroids, output_largest_eigenvalue


def _model_output(
model: AutoLightningModel,
Expand Down Expand Up @@ -417,7 +453,7 @@ def _model_output(
if compute_with_autoencoder:

model.eval()

outputs_list = pretrainer.predict(model=model, dataloaders=dataloader)

for outputs in tqdm(outputs_list, desc="Autoencoder model", unit="batch"):
Expand All @@ -444,11 +480,15 @@ def _model_output(

else:

for cloud_input in tqdm(marching_cube_points, desc="Marching cubes", unit="cloud_input"):
for cloud_input in tqdm(
marching_cube_points, desc="Marching cubes", unit="cloud_input"
):
try:
ConvexHull(cloud_input)

output_cloud_eccentricity.append(tuple(get_eccentricity(cloud_input))[0])

output_cloud_eccentricity.append(
tuple(get_eccentricity(cloud_input))[0]
)

output_largest_eigenvector.append(get_eccentricity(cloud_input)[1])
output_largest_eigenvalue.append(get_eccentricity(cloud_input)[2])
Expand Down

0 comments on commit b7291e8

Please sign in to comment.