From 2a36b2b1af5227424a87931d0f7e77cac8984f22 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 10 Oct 2023 00:11:19 +0800 Subject: [PATCH 1/2] refactor: separate functions in pypots.utils.metrics according to their tasks; --- pypots/utils/metrics.py | 765 ------------------------- pypots/utils/metrics/__init__.py | 50 ++ pypots/utils/metrics/classification.py | 256 +++++++++ pypots/utils/metrics/clustering.py | 297 ++++++++++ pypots/utils/metrics/error.py | 231 ++++++++ 5 files changed, 834 insertions(+), 765 deletions(-) delete mode 100644 pypots/utils/metrics.py create mode 100644 pypots/utils/metrics/__init__.py create mode 100644 pypots/utils/metrics/classification.py create mode 100644 pypots/utils/metrics/clustering.py create mode 100644 pypots/utils/metrics/error.py diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py deleted file mode 100644 index b6dbd5ae..00000000 --- a/pypots/utils/metrics.py +++ /dev/null @@ -1,765 +0,0 @@ -""" -Utilities for evaluation metrics -""" - -# Created by Wenjie Du -# License: GPL-v3 - -from typing import Union, Optional, Tuple - -import numpy as np -import torch -from sklearn import metrics - - -def cal_mae( - predictions: Union[np.ndarray, torch.Tensor, list], - targets: Union[np.ndarray, torch.Tensor, list], - masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, -) -> Union[float, torch.Tensor]: - """Calculate the Mean Absolute Error between ``predictions`` and ``targets``. - ``masks`` can be used for filtering. For values==0 in ``masks``, - values at their corresponding positions in ``predictions`` will be ignored. - - Parameters - ---------- - predictions : - The prediction data to be evaluated. - - targets : - The target data for helping evaluate the predictions. - - masks : - The masks for filtering the specific values in inputs and target from evaluation. - When given, only values at corresponding positions where values ==1 in ``masks`` will be used for evaluation. - - Examples - -------- - - >>> import numpy as np - >>> from pypots.utils.metrics import cal_mae - >>> targets = np.array([1, 2, 3, 4, 5]) - >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mae = cal_mae(predictions, targets) - - mae = 0.6 here, the error is from the 3rd and 5th elements and is :math:`|3-1|+|5-6|=3`, so the result is 3/5=0.6. - - If we want to prevent some values from MAE calculation, e.g. the first three elements here, - we can use ``masks`` to filter out them: - - >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mae = cal_mae(predictions, targets, masks) - - mae = 0.5 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|=1`, - so the result is 1/2=0.5. - - """ - assert isinstance(predictions, type(targets)), ( - f"types of inputs and target must match, but got" - f"type(inputs)={type(predictions)}, type(target)={type(targets)}" - ) - lib = np if isinstance(predictions, np.ndarray) else torch - if masks is not None: - return lib.sum(lib.abs(predictions - targets) * masks) / ( - lib.sum(masks) + 1e-12 - ) - else: - return lib.mean(lib.abs(predictions - targets)) - - -def cal_mse( - predictions: Union[np.ndarray, torch.Tensor, list], - targets: Union[np.ndarray, torch.Tensor, list], - masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, -) -> Union[float, torch.Tensor]: - """Calculate the Mean Square Error between ``predictions`` and ``targets``. - ``masks`` can be used for filtering. For values==0 in ``masks``, - values at their corresponding positions in ``predictions`` will be ignored. - - Parameters - ---------- - predictions : - The prediction data to be evaluated. - - targets : - The target data for helping evaluate the predictions. - - masks : - The masks for filtering the specific values in inputs and target from evaluation. - When given, only values at corresponding positions where values ==1 in ``masks`` will be used for evaluation. - - Examples - -------- - - >>> import numpy as np - >>> from pypots.utils.metrics import cal_mse - >>> targets = np.array([1, 2, 3, 4, 5]) - >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mse = cal_mse(predictions, targets) - - mse = 1 here, the error is from the 3rd and 5th elements and is :math:`|3-1|^2+|5-6|^2=5`, so the result is 5/5=1. - - If we want to prevent some values from MSE calculation, e.g. the first three elements here, - we can use ``masks`` to filter out them: - - >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mse = cal_mse(predictions, targets, masks) - - mse = 0.5 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, - so the result is 1/2=0.5. - - """ - - assert isinstance(predictions, type(targets)), ( - f"types of inputs and target must match, but got" - f"type(inputs)={type(predictions)}, type(target)={type(targets)}" - ) - lib = np if isinstance(predictions, np.ndarray) else torch - if masks is not None: - return lib.sum(lib.square(predictions - targets) * masks) / ( - lib.sum(masks) + 1e-12 - ) - else: - return lib.mean(lib.square(predictions - targets)) - - -def cal_rmse( - predictions: Union[np.ndarray, torch.Tensor, list], - targets: Union[np.ndarray, torch.Tensor, list], - masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, -) -> Union[float, torch.Tensor]: - """Calculate the Root Mean Square Error between ``predictions`` and ``targets``. - ``masks`` can be used for filtering. For values==0 in ``masks``, - values at their corresponding positions in ``predictions`` will be ignored. - - Parameters - ---------- - predictions : - The prediction data to be evaluated. - - targets : - The target data for helping evaluate the predictions. - - masks : - The masks for filtering the specific values in inputs and target from evaluation. - When given, only values at corresponding positions where values ==1 in ``masks`` will be used for evaluation. - - Examples - -------- - - >>> import numpy as np - >>> from pypots.utils.metrics import cal_rmse - >>> targets = np.array([1, 2, 3, 4, 5]) - >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> rmse = cal_rmse(predictions, targets) - - rmse = 1 here, the error is from the 3rd and 5th elements and is :math:`|3-1|^2+|5-6|^2=5`, - so the result is :math:`\\sqrt{5/5}=1`. - - If we want to prevent some values from RMSE calculation, e.g. the first three elements here, - we can use ``masks`` to filter out them: - - >>> masks = np.array([0, 0, 0, 1, 1]) - >>> rmse = cal_rmse(predictions, targets, masks) - - rmse = 0.707 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, - so the result is :math:`\\sqrt{1/2}=0.5`. - - """ - assert isinstance(predictions, type(targets)), ( - f"types of inputs and target must match, but got" - f"type(inputs)={type(predictions)}, type(target)={type(targets)}" - ) - lib = np if isinstance(predictions, np.ndarray) else torch - return lib.sqrt(cal_mse(predictions, targets, masks)) - - -def cal_mre( - predictions: Union[np.ndarray, torch.Tensor, list], - targets: Union[np.ndarray, torch.Tensor, list], - masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, -) -> Union[float, torch.Tensor]: - """Calculate the Mean Relative Error between ``predictions`` and ``targets``. - ``masks`` can be used for filtering. For values==0 in ``masks``, - values at their corresponding positions in ``predictions`` will be ignored. - - Parameters - ---------- - predictions : - The prediction data to be evaluated. - - targets : - The target data for helping evaluate the predictions. - - masks : - The masks for filtering the specific values in inputs and target from evaluation. - When given, only values at corresponding positions where values ==1 in ``masks`` will be used for evaluation. - - Examples - -------- - - >>> import numpy as np - >>> from pypots.utils.metrics import cal_mre - >>> targets = np.array([1, 2, 3, 4, 5]) - >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mre = cal_mre(predictions, targets) - - mre = 0.2 here, the error is from the 3rd and 5th elements and is :math:`|3-1|+|5-6|=3`, - so the result is :math:`\\sqrt{3/(1+2+3+4+5)}=1`. - - If we want to prevent some values from MRE calculation, e.g. the first three elements here, - we can use ``masks`` to filter out them: - - >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mre = cal_mre(predictions, targets, masks) - - mre = 0.111 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, - so the result is :math:`\\sqrt{1/2}=0.5`. - - """ - assert isinstance(predictions, type(targets)), ( - f"types of inputs and target must match, but got" - f"type(inputs)={type(predictions)}, type(target)={type(targets)}" - ) - lib = np if isinstance(predictions, np.ndarray) else torch - if masks is not None: - return lib.sum(lib.abs(predictions - targets) * masks) / ( - lib.sum(lib.abs(targets * masks)) + 1e-12 - ) - else: - return lib.sum(lib.abs(predictions - targets)) / ( - lib.sum(lib.abs(targets)) + 1e-12 - ) - - -def cal_binary_classification_metrics( - prob_predictions: np.ndarray, - targets: np.ndarray, - pos_label: int = 1, -) -> dict: - """Calculate the evaluation metrics for the binary classification task, - including accuracy, precision, recall, f1 score, area under ROC curve, and area under Precision-Recall curve. - If targets contains multiple categories, please set the positive category as `pos_label`. - - Parameters - ---------- - prob_predictions : - Estimated probability predictions returned by a decision function. - - targets : - Ground truth (correct) classification results. - - pos_label : - The label of the positive class. - Note that pos_label is also the index used to extract binary prediction probabilities from `predictions`. - - Returns - ------- - classification_metrics : - A dictionary contains classification metrics and useful results: - - predictions: binary categories of the prediction results; - - accuracy: prediction accuracy; - - precision: prediction precision; - - recall: prediction recall; - - f1: F1-score; - - precisions: precision values of Precision-Recall curve - - recalls: recall values of Precision-Recall curve - - pr_auc: area under Precision-Recall curve - - fprs: false positive rates of ROC curve - - tprs: true positive rates of ROC curve - - roc_auc: area under ROC curve - - """ - # check the dimensionality - if len(targets.shape) == 1: - pass - elif len(targets.shape) == 2 and targets.shape[1] == 1: - targets = np.asarray(targets).flatten() - else: - raise f"targets dimensions should be 1 or 2, but got targets.shape: {targets.shape}" - - if len(prob_predictions.shape) == 1 or ( - len(prob_predictions.shape) == 2 and prob_predictions.shape[1] == 1 - ): - prob_predictions = np.asarray( - prob_predictions - ).flatten() # turn the array shape into [n_samples] - binary_predictions = prob_predictions - prediction_categories = (prob_predictions >= 0.5).astype(int) - binary_prediction_categories = prediction_categories - elif len(prob_predictions.shape) == 2 and prob_predictions.shape[1] > 1: - prediction_categories = np.argmax(prob_predictions, axis=1) - binary_predictions = prob_predictions[:, pos_label] - binary_prediction_categories = (prediction_categories == pos_label).astype(int) - else: - raise f"predictions dimensions should be 1 or 2, but got predictions.shape: {prob_predictions.shape}" - - # accuracy score doesn't have to be of binary classification - acc_score = cal_acc(prediction_categories, targets) - - # turn targets into binary targets - mask_val = -1 if pos_label == 0 else 0 - mask = targets == pos_label - binary_targets = np.copy(targets) - binary_targets[~mask] = mask_val - - precision, recall, f1 = cal_precision_recall_f1( - binary_prediction_categories, binary_targets, pos_label - ) - pr_auc, precisions, recalls, _ = cal_pr_auc( - binary_predictions, binary_targets, pos_label - ) - ROC_AUC, fprs, tprs, _ = cal_roc_auc(binary_predictions, binary_targets, pos_label) - PR_AUC = metrics.auc(recalls, precisions) - classification_metrics = { - "predictions": prediction_categories, - "accuracy": acc_score, - "precision": precision, - "recall": recall, - "f1": f1, - "precisions": precisions, - "recalls": recalls, - "pr_auc": PR_AUC, - "fprs": fprs, - "tprs": tprs, - "roc_auc": ROC_AUC, - } - return classification_metrics - - -def cal_precision_recall_f1( - prob_predictions: np.ndarray, - targets: np.ndarray, - pos_label: int = 1, -) -> Tuple[float, float, float]: - """Calculate precision, recall, and F1-score of model predictions. - - Parameters - ---------- - prob_predictions : - Estimated probability predictions returned by a decision function. - - targets : - Ground truth (correct) classification results. - - pos_label: int, default=1 - The label of the positive class. - - Returns - ------- - precision : - The precision value of model predictions. - - recall : - The recall value of model predictions. - - f1 : - The F1 score of model predictions. - - """ - precision, recall, f1, _ = metrics.precision_recall_fscore_support( - targets, prob_predictions, pos_label=pos_label - ) - precision, recall, f1 = precision[pos_label], recall[pos_label], f1[pos_label] - return precision, recall, f1 - - -def cal_pr_auc( - prob_predictions: np.ndarray, - targets: np.ndarray, - pos_label: int = 1, -) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: - """Calculate precisions, recalls, and area under PR curve of model predictions. - - Parameters - ---------- - prob_predictions : - Estimated probability predictions returned by a decision function. - - targets : - Ground truth (correct) classification results. - - pos_label: int, default=1 - The label of the positive class. - - Returns - ------- - pr_auc : - Value of area under Precision-Recall curve. - - precisions : - Precision values of Precision-Recall curve. - - recalls : - Recall values of Precision-Recall curve. - - thresholds : - Increasing thresholds on the decision function used to compute precision and recall. - - """ - - precisions, recalls, thresholds = metrics.precision_recall_curve( - targets, prob_predictions, pos_label=pos_label - ) - pr_auc = metrics.auc(recalls, precisions) - return pr_auc, precisions, recalls, thresholds - - -def cal_roc_auc( - prob_predictions: np.ndarray, - targets: np.ndarray, - pos_label: int = 1, -) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: - """Calculate false positive rates, true positive rates, and area under AUC curve of model predictions. - - Parameters - ---------- - prob_predictions : - Estimated probabilities/predictions returned by a decision function. - - targets : - Ground truth (correct) classification results. - - pos_label: int, default=1 - The label of the positive class. - - Returns - ------- - roc_auc : - The area under ROC curve. - - fprs : - False positive rates of ROC curve. - - tprs : - True positive rates of ROC curve. - - thresholds : - Increasing thresholds on the decision function used to compute FPR and TPR. - - """ - fprs, tprs, thresholds = metrics.roc_curve( - y_true=targets, y_score=prob_predictions, pos_label=pos_label - ) - roc_auc = metrics.auc(fprs, tprs) - return roc_auc, fprs, tprs, thresholds - - -def cal_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: - """Calculate accuracy score of model predictions. - - Parameters - ---------- - class_predictions : - Estimated classification predictions returned by a classifier. - - targets : - Ground truth (correct) classification results. - - Returns - ------- - acc_score : - The accuracy of model predictions. - - """ - acc_score = metrics.accuracy_score(targets, class_predictions) - return acc_score - - -def cal_rand_index( - class_predictions: np.ndarray, - targets: np.ndarray, -) -> float: - """Calculate Rand Index, a measure of the similarity between two data clusterings. - Refer to :cite:`rand1971RandIndex`. - - Parameters - ---------- - class_predictions : - Clustering results returned by a clusterer. - - targets : - Ground truth (correct) clustering results. - - Returns - ------- - RI : - Rand index. - - References - ---------- - .. L. Hubert and P. Arabie, Comparing Partitions, Journal of - Classification 1985 - https://link.springer.com/article/10.1007%2FBF01908075 - - .. https://en.wikipedia.org/wiki/Simple_matching_coefficient - - .. https://en.wikipedia.org/wiki/Rand_index - """ - # # detailed implementation - # n = len(targets) - # TP = 0 - # TN = 0 - # for i in range(n - 1): - # for j in range(i + 1, n): - # if targets[i] != targets[j]: - # if class_predictions[i] != class_predictions[j]: - # TN += 1 - # else: - # if class_predictions[i] == class_predictions[j]: - # TP += 1 - # - # RI = n * (n - 1) / 2 - # RI = (TP + TN) / RI - - RI = metrics.rand_score(targets, class_predictions) - - return RI - - -def cal_adjusted_rand_index( - class_predictions: np.ndarray, - targets: np.ndarray, -) -> float: - """Calculate adjusted Rand Index. - - Parameters - ---------- - class_predictions : - Clustering results returned by a clusterer. - - targets : - Ground truth (correct) clustering results. - - Returns - ------- - aRI : - Adjusted Rand index. - - References - ---------- - .. [1] `L. Hubert and P. Arabie, Comparing Partitions, - Journal of Classification 1985 - `_ - - .. [2] `D. Steinley, Properties of the Hubert-Arabie - adjusted Rand index, Psychological Methods 2004 - `_ - - .. [3] https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index - - """ - aRI = metrics.adjusted_rand_score(targets, class_predictions) - return aRI - - -def cal_nmi( - class_predictions: np.ndarray, - targets: np.ndarray, -) -> float: - """Calculate Normalized Mutual Information between two clusterings. - - Parameters - ---------- - class_predictions : - Clustering results returned by a clusterer. - - targets : - Ground truth (correct) clustering results. - - Returns - ------- - NMI : float, - Normalized Mutual Information - - - """ - NMI = metrics.normalized_mutual_info_score(targets, class_predictions) - return NMI - - -def cal_cluster_purity( - class_predictions: np.ndarray, - targets: np.ndarray, -) -> float: - """Calculate cluster purity. - - Parameters - ---------- - class_predictions : - Clustering results returned by a clusterer. - - targets : - Ground truth (correct) clustering results. - - Returns - ------- - cluster_purity : - cluster purity. - - Notes - ----- - This function is from the answer https://stackoverflow.com/a/51672699 on StackOverflow. - - """ - contingency_matrix = metrics.cluster.contingency_matrix(targets, class_predictions) - cluster_purity = np.sum(np.amax(contingency_matrix, axis=0)) / np.sum( - contingency_matrix - ) - return cluster_purity - - -def cal_external_cluster_validation_metrics(class_predictions, targets): - """Computer all external cluster validation metrics available in PyPOTS and return as a dictionary. - - Parameters - ---------- - class_predictions : - Clustering results returned by a clusterer. - - targets : - Ground truth (correct) clustering results. - - Returns - ------- - external_cluster_validation_metrics : dict - A dictionary contains all external cluster validation metrics available in PyPOTS. - """ - - ri = cal_rand_index(class_predictions, targets) - ari = cal_adjusted_rand_index(class_predictions, targets) - nmi = cal_nmi(class_predictions, targets) - cp = cal_cluster_purity(class_predictions, targets) - - external_cluster_validation_metrics = { - "rand_index": ri, - "adjusted_rand_index": ari, - "nmi": nmi, - "cluster_purity": cp, - } - return external_cluster_validation_metrics - - -def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: - """Compute the mean Silhouette Coefficient of all samples. - - Parameters - ---------- - X : array-like of shape (n_samples_a, n_features) - A feature array, or learned latent representation, that can be used for clustering. - - predicted_labels : array-like of shape (n_samples) - Predicted labels for each sample. - - Returns - ------- - silhouette_score : float - Mean Silhouette Coefficient for all samples. In short, the higher, the better. - - References - ---------- - .. [1] `Peter J. Rousseeuw (1987). "Silhouettes: a Graphical Aid to the - Interpretation and Validation of Cluster Analysis". Computational - and Applied Mathematics 20: 53-65. - `_ - - .. [2] `Wikipedia entry on the Silhouette Coefficient - `_ - - """ - silhouette_score = metrics.silhouette_score(X, predicted_labels) - return silhouette_score - - -def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: - """Compute the Calinski and Harabasz score (also known as the Variance Ratio Criterion). - - X : array-like of shape (n_samples_a, n_features) - A feature array, or learned latent representation, that can be used for clustering. - - predicted_labels : array-like of shape (n_samples) - Predicted labels for each sample. - - Returns - ------- - calinski_harabasz_score : float - The resulting Calinski-Harabasz score. In short, the higher, the better. - - References - ---------- - .. [1] `T. Calinski and J. Harabasz, 1974. "A dendrite method for cluster - analysis". Communications in Statistics - `_ - - """ - calinski_harabasz_score = metrics.calinski_harabasz_score(X, predicted_labels) - return calinski_harabasz_score - - -def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: - """Compute the Davies-Bouldin score. - - Parameters - ---------- - X : array-like of shape (n_samples_a, n_features) - A feature array, or learned latent representation, that can be used for clustering. - - predicted_labels : array-like of shape (n_samples) - Predicted labels for each sample. - - Returns - ------- - davies_bouldin_score : float - The resulting Davies-Bouldin score. In short, the lower, the better. - - References - ---------- - .. [1] `Davies, David L.; Bouldin, Donald W. (1979). - "A Cluster Separation Measure" - IEEE Transactions on Pattern Analysis and Machine Intelligence. - PAMI-1 (2): 224-227 - `_ - - """ - davies_bouldin_score = metrics.davies_bouldin_score(X, predicted_labels) - return davies_bouldin_score - - -def cal_internal_cluster_validation_metrics(X, predicted_labels): - """Computer all internal cluster validation metrics available in PyPOTS and return as a dictionary. - - Parameters - ---------- - X : array-like of shape (n_samples_a, n_features) - A feature array, or learned latent representation, that can be used for clustering. - - predicted_labels : array-like of shape (n_samples) - Predicted labels for each sample. - - Returns - ------- - internal_cluster_validation_metrics : dict - A dictionary contains all internal cluster validation metrics available in PyPOTS. - """ - - silhouette_score = cal_silhouette(X, predicted_labels) - calinski_harabasz_score = cal_chs(X, predicted_labels) - davies_bouldin_score = cal_dbs(X, predicted_labels) - - internal_cluster_validation_metrics = { - "silhouette_score": silhouette_score, - "calinski_harabasz_score": calinski_harabasz_score, - "davies_bouldin_score": davies_bouldin_score, - } - return internal_cluster_validation_metrics diff --git a/pypots/utils/metrics/__init__.py b/pypots/utils/metrics/__init__.py new file mode 100644 index 00000000..dd7f2cc6 --- /dev/null +++ b/pypots/utils/metrics/__init__.py @@ -0,0 +1,50 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .classification import ( + cal_binary_classification_metrics, + cal_precision_recall_f1, + cal_pr_auc, + cal_roc_auc, + cal_acc, +) +from .clustering import ( + cal_rand_index, + cal_adjusted_rand_index, + cal_cluster_purity, + cal_nmi, + cal_chs, + cal_dbs, + cal_silhouette, + cal_internal_cluster_validation_metrics, + cal_external_cluster_validation_metrics, +) +from .error import cal_mae, cal_mse, cal_rmse, cal_mre + +__all__ = [ + # error + "cal_mae", + "cal_mse", + "cal_rmse", + "cal_mre", + # classification + "cal_binary_classification_metrics", + "cal_precision_recall_f1", + "cal_pr_auc", + "cal_roc_auc", + "cal_acc", + # clustering + "cal_rand_index", + "cal_adjusted_rand_index", + "cal_cluster_purity", + "cal_nmi", + "cal_chs", + "cal_dbs", + "cal_silhouette", + "cal_internal_cluster_validation_metrics", + "cal_external_cluster_validation_metrics", +] diff --git a/pypots/utils/metrics/classification.py b/pypots/utils/metrics/classification.py new file mode 100644 index 00000000..7531e607 --- /dev/null +++ b/pypots/utils/metrics/classification.py @@ -0,0 +1,256 @@ +""" +Evaluation metrics related to classification. +""" + +# Created by Wenjie Du +# License: GPL-v3 + +from typing import Tuple + +import numpy as np +from sklearn import metrics + + +def cal_binary_classification_metrics( + prob_predictions: np.ndarray, + targets: np.ndarray, + pos_label: int = 1, +) -> dict: + """Calculate the evaluation metrics for the binary classification task, + including accuracy, precision, recall, f1 score, area under ROC curve, and area under Precision-Recall curve. + If targets contains multiple categories, please set the positive category as `pos_label`. + + Parameters + ---------- + prob_predictions : + Estimated probability predictions returned by a decision function. + + targets : + Ground truth (correct) classification results. + + pos_label : + The label of the positive class. + Note that pos_label is also the index used to extract binary prediction probabilities from `predictions`. + + Returns + ------- + classification_metrics : + A dictionary contains classification metrics and useful results: + + predictions: binary categories of the prediction results; + + accuracy: prediction accuracy; + + precision: prediction precision; + + recall: prediction recall; + + f1: F1-score; + + precisions: precision values of Precision-Recall curve + + recalls: recall values of Precision-Recall curve + + pr_auc: area under Precision-Recall curve + + fprs: false positive rates of ROC curve + + tprs: true positive rates of ROC curve + + roc_auc: area under ROC curve + + """ + # check the dimensionality + if len(targets.shape) == 1: + pass + elif len(targets.shape) == 2 and targets.shape[1] == 1: + targets = np.asarray(targets).flatten() + else: + raise f"targets dimensions should be 1 or 2, but got targets.shape: {targets.shape}" + + if len(prob_predictions.shape) == 1 or ( + len(prob_predictions.shape) == 2 and prob_predictions.shape[1] == 1 + ): + prob_predictions = np.asarray( + prob_predictions + ).flatten() # turn the array shape into [n_samples] + binary_predictions = prob_predictions + prediction_categories = (prob_predictions >= 0.5).astype(int) + binary_prediction_categories = prediction_categories + elif len(prob_predictions.shape) == 2 and prob_predictions.shape[1] > 1: + prediction_categories = np.argmax(prob_predictions, axis=1) + binary_predictions = prob_predictions[:, pos_label] + binary_prediction_categories = (prediction_categories == pos_label).astype(int) + else: + raise f"predictions dimensions should be 1 or 2, but got predictions.shape: {prob_predictions.shape}" + + # accuracy score doesn't have to be of binary classification + acc_score = cal_acc(prediction_categories, targets) + + # turn targets into binary targets + mask_val = -1 if pos_label == 0 else 0 + mask = targets == pos_label + binary_targets = np.copy(targets) + binary_targets[~mask] = mask_val + + precision, recall, f1 = cal_precision_recall_f1( + binary_prediction_categories, binary_targets, pos_label + ) + pr_auc, precisions, recalls, _ = cal_pr_auc( + binary_predictions, binary_targets, pos_label + ) + ROC_AUC, fprs, tprs, _ = cal_roc_auc(binary_predictions, binary_targets, pos_label) + PR_AUC = metrics.auc(recalls, precisions) + classification_metrics = { + "predictions": prediction_categories, + "accuracy": acc_score, + "precision": precision, + "recall": recall, + "f1": f1, + "precisions": precisions, + "recalls": recalls, + "pr_auc": PR_AUC, + "fprs": fprs, + "tprs": tprs, + "roc_auc": ROC_AUC, + } + return classification_metrics + + +def cal_precision_recall_f1( + prob_predictions: np.ndarray, + targets: np.ndarray, + pos_label: int = 1, +) -> Tuple[float, float, float]: + """Calculate precision, recall, and F1-score of model predictions. + + Parameters + ---------- + prob_predictions : + Estimated probability predictions returned by a decision function. + + targets : + Ground truth (correct) classification results. + + pos_label: int, default=1 + The label of the positive class. + + Returns + ------- + precision : + The precision value of model predictions. + + recall : + The recall value of model predictions. + + f1 : + The F1 score of model predictions. + + """ + precision, recall, f1, _ = metrics.precision_recall_fscore_support( + targets, prob_predictions, pos_label=pos_label + ) + precision, recall, f1 = precision[pos_label], recall[pos_label], f1[pos_label] + return precision, recall, f1 + + +def cal_pr_auc( + prob_predictions: np.ndarray, + targets: np.ndarray, + pos_label: int = 1, +) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: + """Calculate precisions, recalls, and area under PR curve of model predictions. + + Parameters + ---------- + prob_predictions : + Estimated probability predictions returned by a decision function. + + targets : + Ground truth (correct) classification results. + + pos_label: int, default=1 + The label of the positive class. + + Returns + ------- + pr_auc : + Value of area under Precision-Recall curve. + + precisions : + Precision values of Precision-Recall curve. + + recalls : + Recall values of Precision-Recall curve. + + thresholds : + Increasing thresholds on the decision function used to compute precision and recall. + + """ + + precisions, recalls, thresholds = metrics.precision_recall_curve( + targets, prob_predictions, pos_label=pos_label + ) + pr_auc = metrics.auc(recalls, precisions) + return pr_auc, precisions, recalls, thresholds + + +def cal_roc_auc( + prob_predictions: np.ndarray, + targets: np.ndarray, + pos_label: int = 1, +) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: + """Calculate false positive rates, true positive rates, and area under AUC curve of model predictions. + + Parameters + ---------- + prob_predictions : + Estimated probabilities/predictions returned by a decision function. + + targets : + Ground truth (correct) classification results. + + pos_label: int, default=1 + The label of the positive class. + + Returns + ------- + roc_auc : + The area under ROC curve. + + fprs : + False positive rates of ROC curve. + + tprs : + True positive rates of ROC curve. + + thresholds : + Increasing thresholds on the decision function used to compute FPR and TPR. + + """ + fprs, tprs, thresholds = metrics.roc_curve( + y_true=targets, y_score=prob_predictions, pos_label=pos_label + ) + roc_auc = metrics.auc(fprs, tprs) + return roc_auc, fprs, tprs, thresholds + + +def cal_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: + """Calculate accuracy score of model predictions. + + Parameters + ---------- + class_predictions : + Estimated classification predictions returned by a classifier. + + targets : + Ground truth (correct) classification results. + + Returns + ------- + acc_score : + The accuracy of model predictions. + + """ + acc_score = metrics.accuracy_score(targets, class_predictions) + return acc_score diff --git a/pypots/utils/metrics/clustering.py b/pypots/utils/metrics/clustering.py new file mode 100644 index 00000000..e75e6ec5 --- /dev/null +++ b/pypots/utils/metrics/clustering.py @@ -0,0 +1,297 @@ +""" +Evaluation metrics related to clustering. +""" + +# Created by Wenjie Du +# License: GPL-v3 + +import numpy as np +from sklearn import metrics + + +def cal_rand_index( + class_predictions: np.ndarray, + targets: np.ndarray, +) -> float: + """Calculate Rand Index, a measure of the similarity between two data clusterings. + Refer to :cite:`rand1971RandIndex`. + + Parameters + ---------- + class_predictions : + Clustering results returned by a clusterer. + + targets : + Ground truth (correct) clustering results. + + Returns + ------- + RI : + Rand index. + + References + ---------- + .. L. Hubert and P. Arabie, Comparing Partitions, Journal of + Classification 1985 + https://link.springer.com/article/10.1007%2FBF01908075 + + .. https://en.wikipedia.org/wiki/Simple_matching_coefficient + + .. https://en.wikipedia.org/wiki/Rand_index + """ + # # detailed implementation + # n = len(targets) + # TP = 0 + # TN = 0 + # for i in range(n - 1): + # for j in range(i + 1, n): + # if targets[i] != targets[j]: + # if class_predictions[i] != class_predictions[j]: + # TN += 1 + # else: + # if class_predictions[i] == class_predictions[j]: + # TP += 1 + # + # RI = n * (n - 1) / 2 + # RI = (TP + TN) / RI + + RI = metrics.rand_score(targets, class_predictions) + + return RI + + +def cal_adjusted_rand_index( + class_predictions: np.ndarray, + targets: np.ndarray, +) -> float: + """Calculate adjusted Rand Index. + + Parameters + ---------- + class_predictions : + Clustering results returned by a clusterer. + + targets : + Ground truth (correct) clustering results. + + Returns + ------- + aRI : + Adjusted Rand index. + + References + ---------- + .. [1] `L. Hubert and P. Arabie, Comparing Partitions, + Journal of Classification 1985 + `_ + + .. [2] `D. Steinley, Properties of the Hubert-Arabie + adjusted Rand index, Psychological Methods 2004 + `_ + + .. [3] https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index + + """ + aRI = metrics.adjusted_rand_score(targets, class_predictions) + return aRI + + +def cal_nmi( + class_predictions: np.ndarray, + targets: np.ndarray, +) -> float: + """Calculate Normalized Mutual Information between two clusterings. + + Parameters + ---------- + class_predictions : + Clustering results returned by a clusterer. + + targets : + Ground truth (correct) clustering results. + + Returns + ------- + NMI : float, + Normalized Mutual Information + + + """ + NMI = metrics.normalized_mutual_info_score(targets, class_predictions) + return NMI + + +def cal_cluster_purity( + class_predictions: np.ndarray, + targets: np.ndarray, +) -> float: + """Calculate cluster purity. + + Parameters + ---------- + class_predictions : + Clustering results returned by a clusterer. + + targets : + Ground truth (correct) clustering results. + + Returns + ------- + cluster_purity : + cluster purity. + + Notes + ----- + This function is from the answer https://stackoverflow.com/a/51672699 on StackOverflow. + + """ + contingency_matrix = metrics.cluster.contingency_matrix(targets, class_predictions) + cluster_purity = np.sum(np.amax(contingency_matrix, axis=0)) / np.sum( + contingency_matrix + ) + return cluster_purity + + +def cal_external_cluster_validation_metrics(class_predictions, targets): + """Computer all external cluster validation metrics available in PyPOTS and return as a dictionary. + + Parameters + ---------- + class_predictions : + Clustering results returned by a clusterer. + + targets : + Ground truth (correct) clustering results. + + Returns + ------- + external_cluster_validation_metrics : dict + A dictionary contains all external cluster validation metrics available in PyPOTS. + """ + + ri = cal_rand_index(class_predictions, targets) + ari = cal_adjusted_rand_index(class_predictions, targets) + nmi = cal_nmi(class_predictions, targets) + cp = cal_cluster_purity(class_predictions, targets) + + external_cluster_validation_metrics = { + "rand_index": ri, + "adjusted_rand_index": ari, + "nmi": nmi, + "cluster_purity": cp, + } + return external_cluster_validation_metrics + + +def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: + """Compute the mean Silhouette Coefficient of all samples. + + Parameters + ---------- + X : array-like of shape (n_samples_a, n_features) + A feature array, or learned latent representation, that can be used for clustering. + + predicted_labels : array-like of shape (n_samples) + Predicted labels for each sample. + + Returns + ------- + silhouette_score : float + Mean Silhouette Coefficient for all samples. In short, the higher, the better. + + References + ---------- + .. [1] `Peter J. Rousseeuw (1987). "Silhouettes: a Graphical Aid to the + Interpretation and Validation of Cluster Analysis". Computational + and Applied Mathematics 20: 53-65. + `_ + + .. [2] `Wikipedia entry on the Silhouette Coefficient + `_ + + """ + silhouette_score = metrics.silhouette_score(X, predicted_labels) + return silhouette_score + + +def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: + """Compute the Calinski and Harabasz score (also known as the Variance Ratio Criterion). + + X : array-like of shape (n_samples_a, n_features) + A feature array, or learned latent representation, that can be used for clustering. + + predicted_labels : array-like of shape (n_samples) + Predicted labels for each sample. + + Returns + ------- + calinski_harabasz_score : float + The resulting Calinski-Harabasz score. In short, the higher, the better. + + References + ---------- + .. [1] `T. Calinski and J. Harabasz, 1974. "A dendrite method for cluster + analysis". Communications in Statistics + `_ + + """ + calinski_harabasz_score = metrics.calinski_harabasz_score(X, predicted_labels) + return calinski_harabasz_score + + +def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: + """Compute the Davies-Bouldin score. + + Parameters + ---------- + X : array-like of shape (n_samples_a, n_features) + A feature array, or learned latent representation, that can be used for clustering. + + predicted_labels : array-like of shape (n_samples) + Predicted labels for each sample. + + Returns + ------- + davies_bouldin_score : float + The resulting Davies-Bouldin score. In short, the lower, the better. + + References + ---------- + .. [1] `Davies, David L.; Bouldin, Donald W. (1979). + "A Cluster Separation Measure" + IEEE Transactions on Pattern Analysis and Machine Intelligence. + PAMI-1 (2): 224-227 + `_ + + """ + davies_bouldin_score = metrics.davies_bouldin_score(X, predicted_labels) + return davies_bouldin_score + + +def cal_internal_cluster_validation_metrics(X, predicted_labels): + """Computer all internal cluster validation metrics available in PyPOTS and return as a dictionary. + + Parameters + ---------- + X : array-like of shape (n_samples_a, n_features) + A feature array, or learned latent representation, that can be used for clustering. + + predicted_labels : array-like of shape (n_samples) + Predicted labels for each sample. + + Returns + ------- + internal_cluster_validation_metrics : dict + A dictionary contains all internal cluster validation metrics available in PyPOTS. + """ + + silhouette_score = cal_silhouette(X, predicted_labels) + calinski_harabasz_score = cal_chs(X, predicted_labels) + davies_bouldin_score = cal_dbs(X, predicted_labels) + + internal_cluster_validation_metrics = { + "silhouette_score": silhouette_score, + "calinski_harabasz_score": calinski_harabasz_score, + "davies_bouldin_score": davies_bouldin_score, + } + return internal_cluster_validation_metrics diff --git a/pypots/utils/metrics/error.py b/pypots/utils/metrics/error.py new file mode 100644 index 00000000..376b267c --- /dev/null +++ b/pypots/utils/metrics/error.py @@ -0,0 +1,231 @@ +""" +Evaluation metrics related to error calculation (like in tasks regression, imputation etc). +""" + +# Created by Wenjie Du +# License: GPL-v3 + +from typing import Union, Optional + +import numpy as np +import torch + + +def cal_mae( + predictions: Union[np.ndarray, torch.Tensor, list], + targets: Union[np.ndarray, torch.Tensor, list], + masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, +) -> Union[float, torch.Tensor]: + """Calculate the Mean Absolute Error between ``predictions`` and ``targets``. + ``masks`` can be used for filtering. For values==0 in ``masks``, + values at their corresponding positions in ``predictions`` will be ignored. + + Parameters + ---------- + predictions : + The prediction data to be evaluated. + + targets : + The target data for helping evaluate the predictions. + + masks : + The masks for filtering the specific values in inputs and target from evaluation. + When given, only values at corresponding positions where values ==1 in ``masks`` will be used for evaluation. + + Examples + -------- + + >>> import numpy as np + >>> from pypots.utils.metrics import cal_mae + >>> targets = np.array([1, 2, 3, 4, 5]) + >>> predictions = np.array([1, 2, 1, 4, 6]) + >>> mae = cal_mae(predictions, targets) + + mae = 0.6 here, the error is from the 3rd and 5th elements and is :math:`|3-1|+|5-6|=3`, so the result is 3/5=0.6. + + If we want to prevent some values from MAE calculation, e.g. the first three elements here, + we can use ``masks`` to filter out them: + + >>> masks = np.array([0, 0, 0, 1, 1]) + >>> mae = cal_mae(predictions, targets, masks) + + mae = 0.5 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|=1`, + so the result is 1/2=0.5. + + """ + assert isinstance(predictions, type(targets)), ( + f"types of inputs and target must match, but got" + f"type(inputs)={type(predictions)}, type(target)={type(targets)}" + ) + lib = np if isinstance(predictions, np.ndarray) else torch + if masks is not None: + return lib.sum(lib.abs(predictions - targets) * masks) / ( + lib.sum(masks) + 1e-12 + ) + else: + return lib.mean(lib.abs(predictions - targets)) + + +def cal_mse( + predictions: Union[np.ndarray, torch.Tensor, list], + targets: Union[np.ndarray, torch.Tensor, list], + masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, +) -> Union[float, torch.Tensor]: + """Calculate the Mean Square Error between ``predictions`` and ``targets``. + ``masks`` can be used for filtering. For values==0 in ``masks``, + values at their corresponding positions in ``predictions`` will be ignored. + + Parameters + ---------- + predictions : + The prediction data to be evaluated. + + targets : + The target data for helping evaluate the predictions. + + masks : + The masks for filtering the specific values in inputs and target from evaluation. + When given, only values at corresponding positions where values ==1 in ``masks`` will be used for evaluation. + + Examples + -------- + + >>> import numpy as np + >>> from pypots.utils.metrics import cal_mse + >>> targets = np.array([1, 2, 3, 4, 5]) + >>> predictions = np.array([1, 2, 1, 4, 6]) + >>> mse = cal_mse(predictions, targets) + + mse = 1 here, the error is from the 3rd and 5th elements and is :math:`|3-1|^2+|5-6|^2=5`, so the result is 5/5=1. + + If we want to prevent some values from MSE calculation, e.g. the first three elements here, + we can use ``masks`` to filter out them: + + >>> masks = np.array([0, 0, 0, 1, 1]) + >>> mse = cal_mse(predictions, targets, masks) + + mse = 0.5 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, + so the result is 1/2=0.5. + + """ + + assert isinstance(predictions, type(targets)), ( + f"types of inputs and target must match, but got" + f"type(inputs)={type(predictions)}, type(target)={type(targets)}" + ) + lib = np if isinstance(predictions, np.ndarray) else torch + if masks is not None: + return lib.sum(lib.square(predictions - targets) * masks) / ( + lib.sum(masks) + 1e-12 + ) + else: + return lib.mean(lib.square(predictions - targets)) + + +def cal_rmse( + predictions: Union[np.ndarray, torch.Tensor, list], + targets: Union[np.ndarray, torch.Tensor, list], + masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, +) -> Union[float, torch.Tensor]: + """Calculate the Root Mean Square Error between ``predictions`` and ``targets``. + ``masks`` can be used for filtering. For values==0 in ``masks``, + values at their corresponding positions in ``predictions`` will be ignored. + + Parameters + ---------- + predictions : + The prediction data to be evaluated. + + targets : + The target data for helping evaluate the predictions. + + masks : + The masks for filtering the specific values in inputs and target from evaluation. + When given, only values at corresponding positions where values ==1 in ``masks`` will be used for evaluation. + + Examples + -------- + + >>> import numpy as np + >>> from pypots.utils.metrics import cal_rmse + >>> targets = np.array([1, 2, 3, 4, 5]) + >>> predictions = np.array([1, 2, 1, 4, 6]) + >>> rmse = cal_rmse(predictions, targets) + + rmse = 1 here, the error is from the 3rd and 5th elements and is :math:`|3-1|^2+|5-6|^2=5`, + so the result is :math:`\\sqrt{5/5}=1`. + + If we want to prevent some values from RMSE calculation, e.g. the first three elements here, + we can use ``masks`` to filter out them: + + >>> masks = np.array([0, 0, 0, 1, 1]) + >>> rmse = cal_rmse(predictions, targets, masks) + + rmse = 0.707 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, + so the result is :math:`\\sqrt{1/2}=0.5`. + + """ + assert isinstance(predictions, type(targets)), ( + f"types of inputs and target must match, but got" + f"type(inputs)={type(predictions)}, type(target)={type(targets)}" + ) + lib = np if isinstance(predictions, np.ndarray) else torch + return lib.sqrt(cal_mse(predictions, targets, masks)) + + +def cal_mre( + predictions: Union[np.ndarray, torch.Tensor, list], + targets: Union[np.ndarray, torch.Tensor, list], + masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, +) -> Union[float, torch.Tensor]: + """Calculate the Mean Relative Error between ``predictions`` and ``targets``. + ``masks`` can be used for filtering. For values==0 in ``masks``, + values at their corresponding positions in ``predictions`` will be ignored. + + Parameters + ---------- + predictions : + The prediction data to be evaluated. + + targets : + The target data for helping evaluate the predictions. + + masks : + The masks for filtering the specific values in inputs and target from evaluation. + When given, only values at corresponding positions where values ==1 in ``masks`` will be used for evaluation. + + Examples + -------- + + >>> import numpy as np + >>> from pypots.utils.metrics import cal_mre + >>> targets = np.array([1, 2, 3, 4, 5]) + >>> predictions = np.array([1, 2, 1, 4, 6]) + >>> mre = cal_mre(predictions, targets) + + mre = 0.2 here, the error is from the 3rd and 5th elements and is :math:`|3-1|+|5-6|=3`, + so the result is :math:`\\sqrt{3/(1+2+3+4+5)}=1`. + + If we want to prevent some values from MRE calculation, e.g. the first three elements here, + we can use ``masks`` to filter out them: + + >>> masks = np.array([0, 0, 0, 1, 1]) + >>> mre = cal_mre(predictions, targets, masks) + + mre = 0.111 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, + so the result is :math:`\\sqrt{1/2}=0.5`. + + """ + assert isinstance(predictions, type(targets)), ( + f"types of inputs and target must match, but got" + f"type(inputs)={type(predictions)}, type(target)={type(targets)}" + ) + lib = np if isinstance(predictions, np.ndarray) else torch + if masks is not None: + return lib.sum(lib.abs(predictions - targets) * masks) / ( + lib.sum(lib.abs(targets * masks)) + 1e-12 + ) + else: + return lib.sum(lib.abs(predictions - targets)) / ( + lib.sum(lib.abs(targets)) + 1e-12 + ) From 2decac519efa186ba16ff9a7ebab0232e4113ee4 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 10 Oct 2023 00:57:29 +0800 Subject: [PATCH 2/2] feat: enable clustering algorithms to select model according to loss on the validation set; --- pypots/clustering/crli/model.py | 58 ++++++++++++++++---- pypots/clustering/crli/modules/core.py | 44 ++++++--------- pypots/clustering/crli/modules/submodules.py | 14 ++--- pypots/clustering/vader/model.py | 13 ++++- pypots/clustering/vader/modules/core.py | 3 +- tests/clustering/crli.py | 5 +- tests/clustering/vader.py | 3 +- 7 files changed, 90 insertions(+), 50 deletions(-) diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index ab13ce7a..c8f99455 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -254,14 +254,42 @@ def _train_model( self._save_log_into_tb_file( training_step, "training", loss_results ) + mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector) mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector) - logger.info( - f"epoch {epoch}: " - f"training loss_generator {mean_epoch_train_G_loss:.4f}, " - f"train loss_discriminator {mean_epoch_train_D_loss:.4f}" - ) - mean_loss = mean_epoch_train_G_loss + + if val_loader is not None: + self.model.eval() + epoch_val_loss_G_collector = [] + with torch.no_grad(): + for idx, data in enumerate(val_loader): + inputs = self._assemble_input_for_validating(data) + results = self.model.forward(inputs, return_loss=True) + epoch_val_loss_G_collector.append( + results["generation_loss"].sum().item() + ) + mean_val_G_loss = np.mean(epoch_val_loss_G_collector) + # save validating loss logs into the tensorboard file for every epoch if in need + if self.summary_writer is not None: + val_loss_dict = { + "generation_loss": mean_val_G_loss, + } + self._save_log_into_tb_file(epoch, "validating", val_loss_dict) + logger.info( + f"epoch {epoch}: " + f"training loss_generator {mean_epoch_train_G_loss:.4f}, " + f"training loss_discriminator {mean_epoch_train_D_loss:.4f}, " + f"validating loss_generator {mean_val_G_loss:.4f}" + ) + mean_loss = mean_val_G_loss + else: + + logger.info( + f"epoch {epoch}: " + f"training loss_generator {mean_epoch_train_G_loss:.4f}, " + f"training loss_discriminator {mean_epoch_train_D_loss:.4f}" + ) + mean_loss = mean_epoch_train_G_loss if mean_loss < self.best_loss: self.best_loss = mean_loss @@ -314,8 +342,18 @@ def fit( num_workers=self.num_workers, ) + val_loader = None + if val_set is not None: + val_set = DatasetForCRLI(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + # Step 2: train the model and freeze it - self._train_model(training_loader) + self._train_model(training_loader, val_loader) self.model.load_state_dict(self.best_model_dict) self.model.eval() # set the model as eval status to freeze it. @@ -342,9 +380,9 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - inputs = self.model.forward(inputs, training=False) + inputs = self.model.forward(inputs, return_loss=False) clustering_latent_collector.append(inputs["fcn_latent"]) - imputation_collector.append(inputs["imputation"]) + imputation_collector.append(inputs["imputation_latent"]) imputation = torch.cat(imputation_collector).cpu().detach().numpy() clustering_latent = ( @@ -353,7 +391,7 @@ def predict( clustering = self.model.kmeans.fit_predict(clustering_latent) latent_collector = { "clustering_latent": clustering_latent, - "imputation": imputation, + "imputation_latent": imputation, } result_dict = { diff --git a/pypots/clustering/crli/modules/core.py b/pypots/clustering/crli/modules/core.py index da653cde..cbca6356 100644 --- a/pypots/clustering/crli/modules/core.py +++ b/pypots/clustering/crli/modules/core.py @@ -46,46 +46,36 @@ def __init__( n_clusters=n_clusters, n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the # value of `n_init` explicitly to suppress the warning. - ) # TODO: implement KMean with torch for gpu acceleration + ) self.n_clusters = n_clusters self.lambda_kmeans = lambda_kmeans self.device = device - def cluster(self, inputs: dict, training_object: str = "generator") -> dict: - # concat final states from generator and input it as the initial state of decoder - imputation, imputed_X, generator_fb_hidden_states = self.generator(inputs) - inputs["imputation"] = imputation - inputs["imputed_X"] = imputed_X - inputs["generator_fb_hidden_states"] = generator_fb_hidden_states - if training_object == "discriminator": - discrimination = self.discriminator(inputs) - inputs["discrimination"] = discrimination - return inputs # if only train discriminator, then no need to run decoder - - reconstruction, fcn_latent = self.decoder(inputs) - inputs["reconstruction"] = reconstruction - inputs["fcn_latent"] = fcn_latent - return inputs - def forward( self, inputs: dict, training_object: str = "generator", - training: bool = True, + return_loss: bool = True, ) -> dict: - assert training_object in [ - "generator", - "discriminator", - ], 'training_object should be "generator" or "discriminator"' - X = inputs["X"] missing_mask = inputs["missing_mask"] batch_size, n_steps, n_features = X.shape losses = {} - inputs = self.cluster(inputs, training_object) - if not training: - # if only run clustering, then no need to calculate loss + + # concat final states from generator and input it as the initial state of decoder + imputation_latent, generator_fb_hidden_states = self.generator(inputs) + inputs["imputation_latent"] = imputation_latent + inputs["generator_fb_hidden_states"] = generator_fb_hidden_states + discrimination = self.discriminator(inputs) + inputs["discrimination"] = discrimination + + reconstruction, fcn_latent = self.decoder(inputs) + inputs["reconstruction"] = reconstruction + inputs["fcn_latent"] = fcn_latent + + # return results directly, skip loss calculation to reduce inference time + if not return_loss: return inputs if training_object == "discriminator": @@ -98,7 +88,7 @@ def forward( l_G = F.binary_cross_entropy_with_logits( inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask ) - l_pre = cal_mse(inputs["imputation"], X, missing_mask) + l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask) l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0)) term_F = torch.nn.init.orthogonal_( diff --git a/pypots/clustering/crli/modules/submodules.py b/pypots/clustering/crli/modules/submodules.py index f6837647..59b155b9 100644 --- a/pypots/clustering/crli/modules/submodules.py +++ b/pypots/clustering/crli/modules/submodules.py @@ -124,18 +124,15 @@ def __init__( self.f_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device) self.b_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device) - def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: f_outputs, f_final_hidden_state = self.f_rnn(inputs) b_outputs, b_final_hidden_state = self.b_rnn(inputs) b_outputs = reverse_tensor(b_outputs) # reverse the output of the backward rnn - imputation = (f_outputs + b_outputs) / 2 - imputed_X = inputs["X"] * inputs["missing_mask"] + imputation * ( - 1 - inputs["missing_mask"] - ) + imputation_latent = (f_outputs + b_outputs) / 2 fb_final_hidden_states = torch.concat( [f_final_hidden_state, b_final_hidden_state], dim=-1 ) - return imputation, imputed_X, fb_final_hidden_states + return imputation_latent, fb_final_hidden_states class Discriminator(nn.Module): @@ -161,7 +158,10 @@ def __init__( self.output_layer = nn.Linear(32, d_input) def forward(self, inputs: dict) -> torch.Tensor: - imputed_X = inputs["imputed_X"] + imputed_X = (inputs["X"] * inputs["missing_mask"]) + ( + inputs["imputation_latent"] * (1 - inputs["missing_mask"]) + ) + bz, n_steps, _ = imputed_X.shape hidden_states = [ torch.zeros((bz, 32), device=self.device), diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 4e31b412..fff13643 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -340,6 +340,7 @@ def _train_model( def fit( self, train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, file_type: str = "h5py", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader @@ -353,8 +354,18 @@ def fit( num_workers=self.num_workers, ) + val_loader = None + if val_set is not None: + val_set = DatasetForVaDER(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + # Step 2: train the model and freeze it - self._train_model(training_loader) + self._train_model(training_loader, val_loader) self.model.load_state_dict(self.best_model_dict) self.model.eval() # set the model as eval status to freeze it. diff --git a/pypots/clustering/vader/modules/core.py b/pypots/clustering/vader/modules/core.py index 1e33ba54..9f461d80 100644 --- a/pypots/clustering/vader/modules/core.py +++ b/pypots/clustering/vader/modules/core.py @@ -172,7 +172,6 @@ def forward( mu_tilde, stddev_tilde, ) = self.get_results(X, missing_mask) - imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask if not training and not pretrain: results = { @@ -182,7 +181,7 @@ def forward( "var": var_c, "phi": phi_c, "z": z, - "imputed_X": imputed_X, + "imputation_latent": X_reconstructed, } # if only run clustering, then no need to calculate loss return results diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 2385d1e5..99524753 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -21,6 +21,7 @@ from tests.clustering.config import ( EPOCHS, TRAIN_SET, + VAL_SET, TEST_SET, RESULT_SAVING_DIR_FOR_CLUSTERING, ) @@ -74,9 +75,9 @@ class TestCRLI(unittest.TestCase): @pytest.mark.xdist_group(name="clustering-crli") def test_0_fit(self): logger.info("Training CRLI-GRU...") - self.crli_gru.fit(TRAIN_SET) + self.crli_gru.fit(TRAIN_SET, VAL_SET) logger.info("Training CRLI-LSTM...") - self.crli_lstm.fit(TRAIN_SET) + self.crli_lstm.fit(TRAIN_SET, VAL_SET) @pytest.mark.xdist_group(name="clustering-crli") def test_1_parameters(self): diff --git a/tests/clustering/vader.py b/tests/clustering/vader.py index a76b61e8..42bcda00 100644 --- a/tests/clustering/vader.py +++ b/tests/clustering/vader.py @@ -22,6 +22,7 @@ from tests.clustering.config import ( EPOCHS, TRAIN_SET, + VAL_SET, TEST_SET, RESULT_SAVING_DIR_FOR_CLUSTERING, ) @@ -58,7 +59,7 @@ class TestVaDER(unittest.TestCase): @pytest.mark.xdist_group(name="clustering-vader") def test_0_fit(self): - self.vader.fit(TRAIN_SET) + self.vader.fit(TRAIN_SET, VAL_SET) @pytest.mark.xdist_group(name="clustering-vader") def test_1_cluster(self):