diff --git a/mriqc_learn/viz/metrics.py b/mriqc_learn/viz/metrics.py index f79debe..ef8366f 100644 --- a/mriqc_learn/viz/metrics.py +++ b/mriqc_learn/viz/metrics.py @@ -184,6 +184,7 @@ def plot_corrmat( cbarlabel="", symmetric=True, figsize=None, + sorted=False, **kwargs, ): """ @@ -204,14 +205,33 @@ def plot_corrmat( A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. cbarlabel The label for the colorbar. Optional. + sorted : :obj:`bool` + Flag to perform hierachical clustering on the correlation plot **kwargs All other arguments are forwarded to `imshow`. """ from mpl_toolkits.axes_grid1.inset_locator import inset_axes + # Cluster rows and columns (if arguments enabled) + if sorted: + from scipy.cluster.hierarchy import linkage, dendrogram, fcluster + + Z = linkage(data, "complete", optimal_ordering=True) + + dendrogram(Z, labels=data.columns, no_plot=True) + + # Clusterize the data + threshold = 0.1 + labels = fcluster(Z, threshold, criterion="distance") + # Keep the indices to sort labels + labels_order = np.argsort(labels) + + # Reorder data + data = data.take(labels_order, axis=0).take(labels_order, axis=1) + if hasattr(data, "columns"): - col_labels = data.columns.tolist() + col_labels = data.columns data = data.values if figsize is not None: @@ -220,6 +240,7 @@ def plot_corrmat( if not ax: ax = plt.gca() + # If matrix is symmetric, keep only lower triangle if symmetric: data[np.triu(np.ones(data.shape, dtype=bool))] = np.nan @@ -252,10 +273,16 @@ def plot_corrmat( ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True) # Rotate the tick labels and set their alignment. - plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor") + plt.setp( + ax.get_xticklabels(), + rotation=90, + ha="right", + va="center", + rotation_mode="anchor", + ) # Turn spines off and create white grid. - ax.spines[:].set_visible(False) + plt.setp(ax.spines.values(), visible=False) ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)