Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make clustering algorithms to select the best model according to the loss on a given validation set #204

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,42 @@ def _train_model(
self._save_log_into_tb_file(
training_step, "training", loss_results
)

mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector)
mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector)
logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
f"train loss_discriminator {mean_epoch_train_D_loss:.4f}"
)
mean_loss = mean_epoch_train_G_loss

if val_loader is not None:
self.model.eval()
epoch_val_loss_G_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs, return_loss=True)
epoch_val_loss_G_collector.append(
results["generation_loss"].sum().item()
)
mean_val_G_loss = np.mean(epoch_val_loss_G_collector)
# save validating loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"generation_loss": mean_val_G_loss,
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
f"training loss_discriminator {mean_epoch_train_D_loss:.4f}, "
f"validating loss_generator {mean_val_G_loss:.4f}"
)
mean_loss = mean_val_G_loss
else:

logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
f"training loss_discriminator {mean_epoch_train_D_loss:.4f}"
)
mean_loss = mean_epoch_train_G_loss

if mean_loss < self.best_loss:
self.best_loss = mean_loss
Expand Down Expand Up @@ -314,8 +342,18 @@ def fit(
num_workers=self.num_workers,
)

val_loader = None
if val_set is not None:
val_set = DatasetForCRLI(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

# Step 2: train the model and freeze it
self._train_model(training_loader)
self._train_model(training_loader, val_loader)
self.model.load_state_dict(self.best_model_dict)
self.model.eval() # set the model as eval status to freeze it.

Expand All @@ -342,9 +380,9 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
inputs = self.model.forward(inputs, training=False)
inputs = self.model.forward(inputs, return_loss=False)
clustering_latent_collector.append(inputs["fcn_latent"])
imputation_collector.append(inputs["imputation"])
imputation_collector.append(inputs["imputation_latent"])

imputation = torch.cat(imputation_collector).cpu().detach().numpy()
clustering_latent = (
Expand All @@ -353,7 +391,7 @@ def predict(
clustering = self.model.kmeans.fit_predict(clustering_latent)
latent_collector = {
"clustering_latent": clustering_latent,
"imputation": imputation,
"imputation_latent": imputation,
}

result_dict = {
Expand Down
44 changes: 17 additions & 27 deletions pypots/clustering/crli/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,46 +46,36 @@ def __init__(
n_clusters=n_clusters,
n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the
# value of `n_init` explicitly to suppress the warning.
) # TODO: implement KMean with torch for gpu acceleration
)

self.n_clusters = n_clusters
self.lambda_kmeans = lambda_kmeans
self.device = device

def cluster(self, inputs: dict, training_object: str = "generator") -> dict:
# concat final states from generator and input it as the initial state of decoder
imputation, imputed_X, generator_fb_hidden_states = self.generator(inputs)
inputs["imputation"] = imputation
inputs["imputed_X"] = imputed_X
inputs["generator_fb_hidden_states"] = generator_fb_hidden_states
if training_object == "discriminator":
discrimination = self.discriminator(inputs)
inputs["discrimination"] = discrimination
return inputs # if only train discriminator, then no need to run decoder

reconstruction, fcn_latent = self.decoder(inputs)
inputs["reconstruction"] = reconstruction
inputs["fcn_latent"] = fcn_latent
return inputs

def forward(
self,
inputs: dict,
training_object: str = "generator",
training: bool = True,
return_loss: bool = True,
) -> dict:
assert training_object in [
"generator",
"discriminator",
], 'training_object should be "generator" or "discriminator"'

X = inputs["X"]
missing_mask = inputs["missing_mask"]
batch_size, n_steps, n_features = X.shape
losses = {}
inputs = self.cluster(inputs, training_object)
if not training:
# if only run clustering, then no need to calculate loss

# concat final states from generator and input it as the initial state of decoder
imputation_latent, generator_fb_hidden_states = self.generator(inputs)
inputs["imputation_latent"] = imputation_latent
inputs["generator_fb_hidden_states"] = generator_fb_hidden_states
discrimination = self.discriminator(inputs)
inputs["discrimination"] = discrimination

reconstruction, fcn_latent = self.decoder(inputs)
inputs["reconstruction"] = reconstruction
inputs["fcn_latent"] = fcn_latent

# return results directly, skip loss calculation to reduce inference time
if not return_loss:
return inputs

if training_object == "discriminator":
Expand All @@ -98,7 +88,7 @@ def forward(
l_G = F.binary_cross_entropy_with_logits(
inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask
)
l_pre = cal_mse(inputs["imputation"], X, missing_mask)
l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask)
l_rec = cal_mse(inputs["reconstruction"], X, missing_mask)
HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0))
term_F = torch.nn.init.orthogonal_(
Expand Down
14 changes: 7 additions & 7 deletions pypots/clustering/crli/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,15 @@ def __init__(
self.f_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device)
self.b_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device)

def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]:
f_outputs, f_final_hidden_state = self.f_rnn(inputs)
b_outputs, b_final_hidden_state = self.b_rnn(inputs)
b_outputs = reverse_tensor(b_outputs) # reverse the output of the backward rnn
imputation = (f_outputs + b_outputs) / 2
imputed_X = inputs["X"] * inputs["missing_mask"] + imputation * (
1 - inputs["missing_mask"]
)
imputation_latent = (f_outputs + b_outputs) / 2
fb_final_hidden_states = torch.concat(
[f_final_hidden_state, b_final_hidden_state], dim=-1
)
return imputation, imputed_X, fb_final_hidden_states
return imputation_latent, fb_final_hidden_states


class Discriminator(nn.Module):
Expand All @@ -161,7 +158,10 @@ def __init__(
self.output_layer = nn.Linear(32, d_input)

def forward(self, inputs: dict) -> torch.Tensor:
imputed_X = inputs["imputed_X"]
imputed_X = (inputs["X"] * inputs["missing_mask"]) + (
inputs["imputation_latent"] * (1 - inputs["missing_mask"])
)

bz, n_steps, _ = imputed_X.shape
hidden_states = [
torch.zeros((bz, 32), device=self.device),
Expand Down
13 changes: 12 additions & 1 deletion pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def _train_model(
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
# Step 1: wrap the input data with classes Dataset and DataLoader
Expand All @@ -353,8 +354,18 @@ def fit(
num_workers=self.num_workers,
)

val_loader = None
if val_set is not None:
val_set = DatasetForVaDER(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

# Step 2: train the model and freeze it
self._train_model(training_loader)
self._train_model(training_loader, val_loader)
self.model.load_state_dict(self.best_model_dict)
self.model.eval() # set the model as eval status to freeze it.

Expand Down
3 changes: 1 addition & 2 deletions pypots/clustering/vader/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def forward(
mu_tilde,
stddev_tilde,
) = self.get_results(X, missing_mask)
imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask

if not training and not pretrain:
results = {
Expand All @@ -182,7 +181,7 @@ def forward(
"var": var_c,
"phi": phi_c,
"z": z,
"imputed_X": imputed_X,
"imputation_latent": X_reconstructed,
}
# if only run clustering, then no need to calculate loss
return results
Expand Down
Loading