Skip to content

Commit

Permalink
feat: enable clustering algorithms to select model according to loss …
Browse files Browse the repository at this point in the history
…on the validation set;
  • Loading branch information
WenjieDu committed Oct 9, 2023
1 parent 2a36b2b commit 2decac5
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 50 deletions.
58 changes: 48 additions & 10 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,42 @@ def _train_model(
self._save_log_into_tb_file(
training_step, "training", loss_results
)

mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector)
mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector)
logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
f"train loss_discriminator {mean_epoch_train_D_loss:.4f}"
)
mean_loss = mean_epoch_train_G_loss

if val_loader is not None:
self.model.eval()
epoch_val_loss_G_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs, return_loss=True)
epoch_val_loss_G_collector.append(
results["generation_loss"].sum().item()
)
mean_val_G_loss = np.mean(epoch_val_loss_G_collector)
# save validating loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"generation_loss": mean_val_G_loss,
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
f"training loss_discriminator {mean_epoch_train_D_loss:.4f}, "
f"validating loss_generator {mean_val_G_loss:.4f}"
)
mean_loss = mean_val_G_loss
else:

logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
f"training loss_discriminator {mean_epoch_train_D_loss:.4f}"
)
mean_loss = mean_epoch_train_G_loss

if mean_loss < self.best_loss:
self.best_loss = mean_loss
Expand Down Expand Up @@ -314,8 +342,18 @@ def fit(
num_workers=self.num_workers,
)

val_loader = None
if val_set is not None:
val_set = DatasetForCRLI(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

# Step 2: train the model and freeze it
self._train_model(training_loader)
self._train_model(training_loader, val_loader)
self.model.load_state_dict(self.best_model_dict)
self.model.eval() # set the model as eval status to freeze it.

Expand All @@ -342,9 +380,9 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
inputs = self.model.forward(inputs, training=False)
inputs = self.model.forward(inputs, return_loss=False)
clustering_latent_collector.append(inputs["fcn_latent"])
imputation_collector.append(inputs["imputation"])
imputation_collector.append(inputs["imputation_latent"])

imputation = torch.cat(imputation_collector).cpu().detach().numpy()
clustering_latent = (
Expand All @@ -353,7 +391,7 @@ def predict(
clustering = self.model.kmeans.fit_predict(clustering_latent)
latent_collector = {
"clustering_latent": clustering_latent,
"imputation": imputation,
"imputation_latent": imputation,
}

result_dict = {
Expand Down
44 changes: 17 additions & 27 deletions pypots/clustering/crli/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,46 +46,36 @@ def __init__(
n_clusters=n_clusters,
n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the
# value of `n_init` explicitly to suppress the warning.
) # TODO: implement KMean with torch for gpu acceleration
)

self.n_clusters = n_clusters
self.lambda_kmeans = lambda_kmeans
self.device = device

def cluster(self, inputs: dict, training_object: str = "generator") -> dict:
# concat final states from generator and input it as the initial state of decoder
imputation, imputed_X, generator_fb_hidden_states = self.generator(inputs)
inputs["imputation"] = imputation
inputs["imputed_X"] = imputed_X
inputs["generator_fb_hidden_states"] = generator_fb_hidden_states
if training_object == "discriminator":
discrimination = self.discriminator(inputs)
inputs["discrimination"] = discrimination
return inputs # if only train discriminator, then no need to run decoder

reconstruction, fcn_latent = self.decoder(inputs)
inputs["reconstruction"] = reconstruction
inputs["fcn_latent"] = fcn_latent
return inputs

def forward(
self,
inputs: dict,
training_object: str = "generator",
training: bool = True,
return_loss: bool = True,
) -> dict:
assert training_object in [
"generator",
"discriminator",
], 'training_object should be "generator" or "discriminator"'

X = inputs["X"]
missing_mask = inputs["missing_mask"]
batch_size, n_steps, n_features = X.shape
losses = {}
inputs = self.cluster(inputs, training_object)
if not training:
# if only run clustering, then no need to calculate loss

# concat final states from generator and input it as the initial state of decoder
imputation_latent, generator_fb_hidden_states = self.generator(inputs)
inputs["imputation_latent"] = imputation_latent
inputs["generator_fb_hidden_states"] = generator_fb_hidden_states
discrimination = self.discriminator(inputs)
inputs["discrimination"] = discrimination

reconstruction, fcn_latent = self.decoder(inputs)
inputs["reconstruction"] = reconstruction
inputs["fcn_latent"] = fcn_latent

# return results directly, skip loss calculation to reduce inference time
if not return_loss:
return inputs

if training_object == "discriminator":
Expand All @@ -98,7 +88,7 @@ def forward(
l_G = F.binary_cross_entropy_with_logits(
inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask
)
l_pre = cal_mse(inputs["imputation"], X, missing_mask)
l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask)
l_rec = cal_mse(inputs["reconstruction"], X, missing_mask)
HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0))
term_F = torch.nn.init.orthogonal_(
Expand Down
14 changes: 7 additions & 7 deletions pypots/clustering/crli/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,15 @@ def __init__(
self.f_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device)
self.b_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device)

def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]:
f_outputs, f_final_hidden_state = self.f_rnn(inputs)
b_outputs, b_final_hidden_state = self.b_rnn(inputs)
b_outputs = reverse_tensor(b_outputs) # reverse the output of the backward rnn
imputation = (f_outputs + b_outputs) / 2
imputed_X = inputs["X"] * inputs["missing_mask"] + imputation * (
1 - inputs["missing_mask"]
)
imputation_latent = (f_outputs + b_outputs) / 2
fb_final_hidden_states = torch.concat(
[f_final_hidden_state, b_final_hidden_state], dim=-1
)
return imputation, imputed_X, fb_final_hidden_states
return imputation_latent, fb_final_hidden_states


class Discriminator(nn.Module):
Expand All @@ -161,7 +158,10 @@ def __init__(
self.output_layer = nn.Linear(32, d_input)

def forward(self, inputs: dict) -> torch.Tensor:
imputed_X = inputs["imputed_X"]
imputed_X = (inputs["X"] * inputs["missing_mask"]) + (
inputs["imputation_latent"] * (1 - inputs["missing_mask"])
)

bz, n_steps, _ = imputed_X.shape
hidden_states = [
torch.zeros((bz, 32), device=self.device),
Expand Down
13 changes: 12 additions & 1 deletion pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def _train_model(
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
# Step 1: wrap the input data with classes Dataset and DataLoader
Expand All @@ -353,8 +354,18 @@ def fit(
num_workers=self.num_workers,
)

val_loader = None
if val_set is not None:
val_set = DatasetForVaDER(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

# Step 2: train the model and freeze it
self._train_model(training_loader)
self._train_model(training_loader, val_loader)
self.model.load_state_dict(self.best_model_dict)
self.model.eval() # set the model as eval status to freeze it.

Expand Down
3 changes: 1 addition & 2 deletions pypots/clustering/vader/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def forward(
mu_tilde,
stddev_tilde,
) = self.get_results(X, missing_mask)
imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask

if not training and not pretrain:
results = {
Expand All @@ -182,7 +181,7 @@ def forward(
"var": var_c,
"phi": phi_c,
"z": z,
"imputed_X": imputed_X,
"imputation_latent": X_reconstructed,
}
# if only run clustering, then no need to calculate loss
return results
Expand Down
5 changes: 3 additions & 2 deletions tests/clustering/crli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tests.clustering.config import (
EPOCHS,
TRAIN_SET,
VAL_SET,
TEST_SET,
RESULT_SAVING_DIR_FOR_CLUSTERING,
)
Expand Down Expand Up @@ -74,9 +75,9 @@ class TestCRLI(unittest.TestCase):
@pytest.mark.xdist_group(name="clustering-crli")
def test_0_fit(self):
logger.info("Training CRLI-GRU...")
self.crli_gru.fit(TRAIN_SET)
self.crli_gru.fit(TRAIN_SET, VAL_SET)
logger.info("Training CRLI-LSTM...")
self.crli_lstm.fit(TRAIN_SET)
self.crli_lstm.fit(TRAIN_SET, VAL_SET)

@pytest.mark.xdist_group(name="clustering-crli")
def test_1_parameters(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/clustering/vader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tests.clustering.config import (
EPOCHS,
TRAIN_SET,
VAL_SET,
TEST_SET,
RESULT_SAVING_DIR_FOR_CLUSTERING,
)
Expand Down Expand Up @@ -58,7 +59,7 @@ class TestVaDER(unittest.TestCase):

@pytest.mark.xdist_group(name="clustering-vader")
def test_0_fit(self):
self.vader.fit(TRAIN_SET)
self.vader.fit(TRAIN_SET, VAL_SET)

@pytest.mark.xdist_group(name="clustering-vader")
def test_1_cluster(self):
Expand Down

0 comments on commit 2decac5

Please sign in to comment.