diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index 65110586..580f5859 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -1378,7 +1378,6 @@ def plot_effects_barplot( # pragma: no cover if len(covariate_names_zero) != 0: if plot_facets: if plot_zero_covariate and not plot_zero_cell_type: - plot_df = plot_df[plot_df["value"] != 0] for covariate_name_zero in covariate_names_zero: new_row = { "Covariate": covariate_name_zero, diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 12ff14d0..fddc54c2 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -521,6 +521,7 @@ def plot_barplot( # pragma: no cover legend_title_size: int = 8, legend_text_size: int = 8, legend_bbox_to_anchor: tuple[float, float] = None, + figsize: tuple[float, float] = (25, 25), show: bool = True, return_fig: bool = False, ) -> Figure | None: @@ -537,6 +538,7 @@ def plot_barplot( # pragma: no cover legend_title_size: Size of the legend title. legend_text_size: Size of the legend text. legend_bbox_to_anchor: The bbox that the legend will be anchored. + figsize: The size of the figure. {common_plot_args} Returns: @@ -574,7 +576,7 @@ def plot_barplot( # pragma: no cover color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"} unique_genes = NP_KO_cells["gene"].unique() - fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True) + fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=figsize, sharey=True) for i, gene in enumerate(unique_genes): ax = axs[int(i / 5), i % 5] grouped_df = ( @@ -594,11 +596,8 @@ def plot_barplot( # pragma: no cover ax.set_title(gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size) ax.set(xlabel="sgRNA", ylabel="% of cells") sns.despine(ax=ax, top=True, right=True, left=False, bottom=False) - ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size) - ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size) - - fig.subplots_adjust(right=0.8) - fig.subplots_adjust(hspace=0.5, wspace=0.5) + ax.set_xticks(ax.get_xticks(),ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size) + ax.set_yticks(ax.get_yticks(), ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size) ax.legend( title="Mixscape Class", loc="center right", @@ -608,10 +607,14 @@ def plot_barplot( # pragma: no cover title_fontsize=legend_title_size, ) + fig.subplots_adjust(right=0.8) + fig.subplots_adjust(hspace=0.5, wspace=0.5) + plt.tight_layout() + if show: plt.show() if return_fig: - return plt.gcf() + return fig return None @_doc_params(common_plot_args=doc_common_plot_args) @@ -792,12 +795,6 @@ def plot_perturbscore( # pragma: no cover plt.legend(title="gene_target", title_fontsize=14, fontsize=12) sns.despine() - if show: - plt.show() - if return_fig: - return plt.gcf() - return None - # If before_mixscape is False, split densities based on mixscape classifications else: if palette is None: @@ -854,11 +851,11 @@ def plot_perturbscore( # pragma: no cover plt.legend(title="mixscape class", title_fontsize=14, fontsize=12) sns.despine() - if show: - plt.show() - if return_fig: - return plt.gcf() - return None + if show: + plt.show() + if return_fig: + return plt.gcf() + return None @_doc_params(common_plot_args=doc_common_plot_args) def plot_violin( # pragma: no cover