Skip to content

Commit

Permalink
Merge pull request #192 from WenjieDu/extract_latent_from_clustering_…
Browse files Browse the repository at this point in the history
…models

Extract latent from clustering models
  • Loading branch information
WenjieDu authored Sep 26, 2023
2 parents 1984734 + 8dfd189 commit 28b9fdc
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 18 deletions.
27 changes: 20 additions & 7 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from typing import Union, Optional
from typing import Union, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -419,7 +419,8 @@ def cluster(
self,
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
return_latent: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForCRLI(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -428,15 +429,27 @@ def cluster(
shuffle=False,
num_workers=self.num_workers,
)
latent_collector = []
clustering_latent_collector = []
imputation_collector = []

with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
inputs = self.model.forward(inputs, training=False)
latent_collector.append(inputs["fcn_latent"])
clustering_latent_collector.append(inputs["fcn_latent"])
imputation_collector.append(inputs["imputation"])

latent_collector = torch.cat(latent_collector).cpu().detach().numpy()
clustering = self.model.kmeans.fit_predict(latent_collector)
imputation = torch.cat(imputation_collector).cpu().detach().numpy()
clustering_latent = (
torch.cat(clustering_latent_collector).cpu().detach().numpy()
)
clustering_results = self.model.kmeans.fit_predict(clustering_latent)
latent_collector = {
"clustering_latent": clustering_latent,
"imputation": imputation,
}

if return_latent:
return clustering_results, latent_collector

return clustering
return clustering_results
42 changes: 41 additions & 1 deletion pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,17 @@ def forward(
mu_tilde,
stddev_tilde,
) = self.get_results(X, missing_mask)
imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask

if not training and not pretrain:
results = {
"mu_tilde": mu_tilde,
"stddev_tilde": stddev_tilde,
"mu": mu_c,
"var": var_c,
"phi": phi_c,
"z": z,
"imputed_X": imputed_X,
}
# if only run clustering, then no need to calculate loss
return results
Expand Down Expand Up @@ -260,6 +264,7 @@ def forward(
results = {
"loss": reconstruction_loss + self.alpha * latent_loss,
"z": z,
"imputed_X": imputed_X,
}

return results
Expand Down Expand Up @@ -597,7 +602,12 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
def cluster(
self,
X: Union[dict, str],
file_type: str = "h5py",
return_latent: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForVaDER(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -606,6 +616,13 @@ def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
shuffle=False,
num_workers=self.num_workers,
)
mu_tilde_collector = []
stddev_tilde_collector = []
mu_collector = []
var_collector = []
phi_collector = []
z_collector = []
imputed_X_collector = []
clustering_results_collector = []

with torch.no_grad():
Expand All @@ -614,9 +631,19 @@ def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
results = self.model.forward(inputs, training=False)

mu_tilde = results["mu_tilde"].cpu().numpy()
mu_tilde_collector.append(mu_tilde)
stddev_tilde = results["stddev_tilde"].cpu().numpy()
stddev_tilde_collector.append(stddev_tilde)
mu = results["mu"].cpu().numpy()
mu_collector.append(mu)
var = results["var"].cpu().numpy()
var_collector.append(var)
phi = results["phi"].cpu().numpy()
phi_collector.append(phi)
z = results["z"].cpu().numpy()
z_collector.append(z)
imputed_X = results["imputed_X"].cpu().numpy()
imputed_X_collector.append(imputed_X)

def func_to_apply(
mu_t_: np.ndarray,
Expand All @@ -640,4 +667,17 @@ def func_to_apply(
clustering_results_collector.append(clustering_results)

clustering_results = np.concatenate(clustering_results_collector)
latent_collector = {
"mu_tilde": np.concatenate(mu_tilde_collector),
"stddev_tilde": np.concatenate(stddev_tilde_collector),
"mu": np.concatenate(mu_collector),
"var": np.concatenate(var_collector),
"phi": np.concatenate(phi_collector),
"z": np.concatenate(z_collector),
"imputation": np.concatenate(imputed_X_collector),
}

if return_latent:
return clustering_results, latent_collector

return clustering_results
18 changes: 13 additions & 5 deletions tests/clustering/crli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from pypots.clustering import CRLI
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import cal_rand_index, cal_cluster_purity
from pypots.utils.metrics import (
cal_external_cluster_validation_metrics,
cal_internal_cluster_validation_metrics,
)
from tests.clustering.config import (
EPOCHS,
TRAIN_SET,
Expand Down Expand Up @@ -74,10 +77,15 @@ def test_1_parameters(self):

@pytest.mark.xdist_group(name="clustering-crli")
def test_2_cluster(self):
clustering = self.crli.cluster(TEST_SET)
RI = cal_rand_index(clustering, DATA["test_y"])
CP = cal_cluster_purity(clustering, DATA["test_y"])
logger.info(f"RI: {RI}\nCP: {CP}")
clustering, latent_collector = self.crli.cluster(TEST_SET, return_latent=True)
external_metrics = cal_external_cluster_validation_metrics(
clustering, DATA["test_y"]
)
internal_metrics = cal_internal_cluster_validation_metrics(
latent_collector["clustering_latent"], DATA["test_y"]
)
logger.info(f"{external_metrics}")
logger.info(f"{internal_metrics}")

@pytest.mark.xdist_group(name="clustering-crli")
def test_3_saving_path(self):
Expand Down
20 changes: 15 additions & 5 deletions tests/clustering/vader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from pypots.clustering import VaDER
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import cal_rand_index, cal_cluster_purity
from pypots.utils.metrics import (
cal_external_cluster_validation_metrics,
cal_internal_cluster_validation_metrics,
)
from tests.clustering.config import (
EPOCHS,
TRAIN_SET,
Expand Down Expand Up @@ -60,10 +63,17 @@ def test_0_fit(self):
@pytest.mark.xdist_group(name="clustering-vader")
def test_1_cluster(self):
try:
clustering = self.vader.cluster(TEST_SET)
RI = cal_rand_index(clustering, DATA["test_y"])
CP = cal_cluster_purity(clustering, DATA["test_y"])
logger.info(f"RI: {RI}\nCP: {CP}")
clustering, latent_collector = self.vader.cluster(
TEST_SET, return_latent=True
)
external_metrics = cal_external_cluster_validation_metrics(
clustering, DATA["test_y"]
)
internal_metrics = cal_internal_cluster_validation_metrics(
latent_collector["z"], DATA["test_y"]
)
logger.info(f"{external_metrics}")
logger.info(f"{internal_metrics}")
except np.linalg.LinAlgError as e:
logger.error(
f"{e}\n"
Expand Down

0 comments on commit 28b9fdc

Please sign in to comment.