Skip to content

Commit

Permalink
Update visualization.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vemuribv authored Oct 19, 2023
1 parent deb5642 commit 77c9aa0
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pypots/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Created by Bhargav Vemuri <[email protected]>
# License: GPL-v3

from typing import Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
Expand All @@ -14,8 +16,7 @@

def get_cluster_members(
test_data: np.ndarray, class_predictions: np.ndarray
) #-> dict[int, np.ndarray]
:
) -> Dict[int, np.ndarray]:
"""
Subset time series array using predicted cluster membership.
Expand All @@ -40,7 +41,7 @@ def get_cluster_members(

def clusters_for_plotting(
cluster_members: dict[int, np.ndarray],
) -> dict[int, dict]:
) -> Dict[int, dict]:
"""
Organize clustered arrays into format ready for plotting.
Expand Down Expand Up @@ -132,7 +133,7 @@ def plot_clusters(dict_to_plot: dict[int, dict]) -> None:
plt.show()


def get_cluster_means(dict_to_plot: dict[int, dict]) -> dict[int, dict]:
def get_cluster_means(dict_to_plot: dict[int, dict]) -> Dict[int, dict]:
"""
Get time series variables' mean values and 95% confidence intervals at each time point per cluster.
Expand Down

0 comments on commit 77c9aa0

Please sign in to comment.