diff --git a/pypots/utils/visualization.py b/pypots/utils/visualization.py index 282b9e63..69d0ffde 100644 --- a/pypots/utils/visualization.py +++ b/pypots/utils/visualization.py @@ -5,6 +5,8 @@ # Created by Bhargav Vemuri # License: GPL-v3 +from typing import Dict + import numpy as np import pandas as pd import matplotlib.pyplot as plt @@ -14,8 +16,7 @@ def get_cluster_members( test_data: np.ndarray, class_predictions: np.ndarray -) #-> dict[int, np.ndarray] -: +) -> Dict[int, np.ndarray]: """ Subset time series array using predicted cluster membership. @@ -40,7 +41,7 @@ def get_cluster_members( def clusters_for_plotting( cluster_members: dict[int, np.ndarray], -) -> dict[int, dict]: +) -> Dict[int, dict]: """ Organize clustered arrays into format ready for plotting. @@ -132,7 +133,7 @@ def plot_clusters(dict_to_plot: dict[int, dict]) -> None: plt.show() -def get_cluster_means(dict_to_plot: dict[int, dict]) -> dict[int, dict]: +def get_cluster_means(dict_to_plot: dict[int, dict]) -> Dict[int, dict]: """ Get time series variables' mean values and 95% confidence intervals at each time point per cluster.