diff --git a/convoys/plotting.py b/convoys/plotting.py index 613562c..c488709 100644 --- a/convoys/plotting.py +++ b/convoys/plotting.py @@ -1,4 +1,5 @@ import numpy +from matplotlib import colors as mcolors from matplotlib import pyplot import convoys.multi @@ -17,7 +18,8 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', ci=None, ax=None, plot_kwargs={}, plot_ci_kwargs={}, groups=None, specific_groups=None, - label_fmt='%(group)s (n=%(n).0f, k=%(k).0f)'): + label_fmt='%(group)s (n=%(n).0f, k=%(k).0f)', + colormap=None): ''' Helper function to fit data using a model and then plot the cohorts. :param G: list with group assignment @@ -37,6 +39,7 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', :param groups: list of group labels :param specific_groups: subset of groups to plot :param label_fmt: custom format for the labels to use in the legend + :param colormap: a colormap to use for the lines that will be plotted See :meth:`convoys.utils.get_arrays` which is handy for converting a Pandas dataframe into arrays `G`, `B`, `T`. @@ -73,6 +76,11 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', t = numpy.linspace(0, t_max, 1000) _, y_max = ax.get_ylim() ax.set_prop_cycle(None) # Reset to first color + if isinstance(colormap, mcolors.ListedColormap): + num_groups = len(groups) + colors = colormap(numpy.linspace(0, 1, num_groups)) + else: + colors = None for i, group in enumerate(specific_groups): j = groups.index(group) # matching index of group @@ -84,6 +92,8 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', p_y, p_y_lo, p_y_hi = m.predict_ci(j, t, ci=ci).T merged_plot_ci_kwargs = {'alpha': 0.2} merged_plot_ci_kwargs.update(plot_ci_kwargs) + if colors is not None: + merged_plot_ci_kwargs['color'] = colors[i] p = ax.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, **merged_plot_ci_kwargs) color = p.get_facecolor()[0] # reuse color for the line @@ -94,6 +104,8 @@ def plot_cohorts(G, B, T, t_max=None, model='kaplan-meier', merged_plot_kwargs = {'color': color, 'linewidth': 1.5, 'alpha': 0.7} merged_plot_kwargs.update(plot_kwargs) + if colors is not None: + merged_plot_kwargs['color'] = colors[i] ax.plot(t, 100. * p_y, label=label, **merged_plot_kwargs) y_max = max(y_max, 110. * max(p_y))