diff --git a/pysaliency/plotting.py b/pysaliency/plotting.py index d4e3dc5..0e129a5 100644 --- a/pysaliency/plotting.py +++ b/pysaliency/plotting.py @@ -150,11 +150,11 @@ def normalize_log_density(log_density): unsorted_cummulative = cummulative[np.argsort(inds)] return unsorted_cummulative.reshape(log_density.shape) -def visualize_distribution(log_densities, ax=None, levels=None, level_colors='black'): +def visualize_distribution(log_densities, ax=None, levels=None, level_colors='black', cmap=plt.cm.viridis): if ax is None: ax = plt.gca() t = normalize_log_density(log_densities) - img = ax.imshow(t, cmap=plt.cm.viridis) + img = ax.imshow(t, cmap=cmap) if levels is None: levels = [0, 0.25, 0.5, 0.75, 1.0] cs = ax.contour(t, levels=levels, colors=level_colors)