diff --git a/pertpy/preprocessing/_guide_rna.py b/pertpy/preprocessing/_guide_rna.py index d8ddc764..50fad2bf 100644 --- a/pertpy/preprocessing/_guide_rna.py +++ b/pertpy/preprocessing/_guide_rna.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from anndata import AnnData - from matplotlib.axes import Axes + from matplotlib.pyplot import Figure class GuideAssignment: @@ -113,13 +113,14 @@ def assign_to_max_guide( def plot_heatmap( self, adata: AnnData, + *, layer: str | None = None, order_by: np.ndarray | str | None = None, key_to_save_order: str = None, show: bool = True, return_fig: bool = False, **kwargs, - ) -> list[Axes]: + ) -> Figure | None: """Heatmap plotting of guide RNA expression matrix. Assuming guides have sparse expression, this function reorders cells @@ -141,8 +142,8 @@ def plot_heatmap( kwargs: Are passed to sc.pl.heatmap. Returns: - If return_fig is True, returns a list of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap. - Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided. + If `return_fig` is `True`, returns the figure, otherwise `None`. + Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided. Examples: Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then @@ -179,7 +180,7 @@ def plot_heatmap( adata.obs[key_to_save_order] = pd.Categorical(order) try: - axis_group = sc.pl.heatmap( + fig = sc.pl.heatmap( adata[order, :], var_names=adata.var.index.tolist(), groupby=temp_col_name, @@ -196,5 +197,5 @@ def plot_heatmap( if show: plt.show() if return_fig: - return axis_group + return fig return None diff --git a/pertpy/tools/_augur.py b/pertpy/tools/_augur.py index 5c682128..bd226f6f 100644 --- a/pertpy/tools/_augur.py +++ b/pertpy/tools/_augur.py @@ -979,6 +979,7 @@ def predict_differential_prioritization( def plot_dp_scatter( self, results: pd.DataFrame, + *, top_n: int = None, ax: Axes = None, show: bool = True, @@ -1050,6 +1051,7 @@ def plot_dp_scatter( def plot_important_features( self, data: dict[str, Any], + *, key: str = "augurpy_results", top_n: int = 10, ax: Axes = None, @@ -1117,11 +1119,12 @@ def plot_important_features( def plot_lollipop( self, data: dict[str, Any] | AnnData, + *, key: str = "augurpy_results", ax: Axes = None, show: bool = True, return_fig: bool = False, - ) -> Axes | Figure | None: + ) -> Figure | None: """Plot a lollipop plot of the mean augur values. Args: @@ -1180,6 +1183,7 @@ def plot_scatterplot( self, results1: dict[str, Any], results2: dict[str, Any], + *, top_n: int = None, show: bool = True, return_fig: bool = False, diff --git a/pertpy/tools/_cinemaot.py b/pertpy/tools/_cinemaot.py index a2f1cd38..f79ea951 100644 --- a/pertpy/tools/_cinemaot.py +++ b/pertpy/tools/_cinemaot.py @@ -651,6 +651,7 @@ def plot_vis_matching( control: str, de_label: str, source_label: str, + *, matching_rep: str = "ot", resolution: float = 0.5, normalize: str = "col", @@ -677,6 +678,9 @@ def plot_vis_matching( {common_plot_args} **kwargs: Other parameters to input for seaborn.heatmap. + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + Examples: >>> import pertpy as pt >>> adata = pt.dt.cinemaot_example() @@ -716,10 +720,7 @@ def plot_vis_matching( if show: plt.show() if return_fig: - if ax is not None: - return ax - else: - return g + return g return None diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index c320b1b3..65110586 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -1192,6 +1192,7 @@ def plot_stacked_barplot( # pragma: no cover self, data: AnnData | MuData, feature_name: str, + *, modality_key: str = "coda", palette: ListedColormap | None = cm.tab20, show_legend: bool | None = True, @@ -1200,7 +1201,7 @@ def plot_stacked_barplot( # pragma: no cover dpi: int | None = 100, show: bool = True, return_fig: bool = False, - ) -> plt.Axes | Figure | None: + ) -> Figure | None: """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples"). Args: @@ -1287,6 +1288,7 @@ def plot_stacked_barplot( # pragma: no cover def plot_effects_barplot( # pragma: no cover self, data: AnnData | MuData, + *, modality_key: str = "coda", covariates: str | list | None = None, parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change", @@ -1300,7 +1302,7 @@ def plot_effects_barplot( # pragma: no cover dpi: int | None = 100, show: bool = True, return_fig: bool = False, - ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None: + ) -> Figure | None: """Barplot visualization for effects. The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types. @@ -1322,8 +1324,7 @@ def plot_effects_barplot( # pragma: no cover {common_plot_args} Returns: - Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`) - or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1476,6 +1477,7 @@ def plot_boxplots( # pragma: no cover self, data: AnnData | MuData, feature_name: str, + *, modality_key: str = "coda", y_scale: Literal["relative", "log", "log10", "count"] = "relative", plot_facets: bool = False, @@ -1490,7 +1492,7 @@ def plot_boxplots( # pragma: no cover dpi: int | None = 100, show: bool = True, return_fig: bool = False, - ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None: + ) -> Figure | None: """Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots @@ -1515,8 +1517,7 @@ def plot_boxplots( # pragma: no cover {common_plot_args} Returns: - Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`) - or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1707,6 +1708,7 @@ def plot_boxplots( # pragma: no cover def plot_rel_abundance_dispersion_plot( # pragma: no cover self, data: AnnData | MuData, + *, modality_key: str = "coda", abundant_threshold: float | None = 0.9, default_color: str | None = "Grey", @@ -1717,7 +1719,7 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover ax: plt.Axes | None = None, show: bool = True, return_fig: bool = False, - ) -> plt.Axes | plt.Figure | None: + ) -> Figure | None: """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type. If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color. @@ -1735,7 +1737,7 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover {common_plot_args} Returns: - A :class:`~matplotlib.axes.Axes` object + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -1829,6 +1831,7 @@ def label_point(x, y, val, ax): def plot_draw_tree( # pragma: no cover self, data: AnnData | MuData, + *, modality_key: str = "coda", tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors tight_text: bool | None = False, @@ -1912,6 +1915,7 @@ def plot_draw_effects( # pragma: no cover self, data: AnnData | MuData, covariate: str, + *, modality_key: str = "coda", tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors show_legend: bool | None = None, @@ -2106,6 +2110,7 @@ def plot_effects_umap( # pragma: no cover mdata: MuData, effect_name: str | list | None, cluster_key: str, + *, modality_key_1: str = "rna", modality_key_2: str = "coda", color_map: Colormap | str | None = None, @@ -2114,7 +2119,7 @@ def plot_effects_umap( # pragma: no cover show: bool = True, return_fig: bool = False, **kwargs, - ) -> plt.Axes | plt.Figure | None: + ) -> Figure | None: """Plot a UMAP visualization colored by effect strength. Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData @@ -2134,7 +2139,7 @@ def plot_effects_umap( # pragma: no cover **kwargs: All other keyword arguments are passed to `scanpy.plot.umap()` Returns: - If `return_fig==True` a :class:`~matplotlib.axes.Axes` or a list of it. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt diff --git a/pertpy/tools/_dialogue.py b/pertpy/tools/_dialogue.py index 61a842ca..9eaf349e 100644 --- a/pertpy/tools/_dialogue.py +++ b/pertpy/tools/_dialogue.py @@ -1067,11 +1067,12 @@ def plot_split_violins( adata: AnnData, split_key: str, celltype_key: str, + *, split_which: tuple[str, str] = None, mcp: str = "mcp_0", show: bool = True, return_fig: bool = False, - ) -> Axes | Figure | None: + ) -> Figure | None: """Plots split violin plots for a given MCP and split variable. Any cells with a value for split_key not in split_which are removed from the plot. @@ -1122,10 +1123,11 @@ def plot_pairplot( celltype_key: str, color: str, sample_id: str, + *, mcp: str = "mcp_0", show: bool = True, return_fig: bool = False, - ) -> PairGrid | Figure | None: + ) -> Figure | None: """Generate a pairplot visualization for multi-cell perturbation (MCP) data. Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type, diff --git a/pertpy/tools/_differential_gene_expression/_base.py b/pertpy/tools/_differential_gene_expression/_base.py index 54ce26ea..c6c30293 100644 --- a/pertpy/tools/_differential_gene_expression/_base.py +++ b/pertpy/tools/_differential_gene_expression/_base.py @@ -529,7 +529,7 @@ def plot_paired( show_legend: bool = True, size: int = 10, y_label: str = "expression", - pvalue_template=lambda x: f"unadj. p={x:.2e}, t-test", + pvalue_template=lambda x: f"p={x:.2e}", boxplot_properties=None, palette=None, show: bool = True, @@ -594,7 +594,7 @@ def plot_paired( raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing") if var_names is None: - var_names = results_df.sort_values(pvalue_col, ascending=True).head(n_top_vars)[symbol_col].tolist() + var_names = results_df.head(n_top_vars)[symbol_col].tolist() adata = adata[:, var_names] diff --git a/pertpy/tools/_enrichment.py b/pertpy/tools/_enrichment.py index 63f9b5fa..8c6b6bc5 100644 --- a/pertpy/tools/_enrichment.py +++ b/pertpy/tools/_enrichment.py @@ -296,6 +296,7 @@ def gsea( def plot_dotplot( self, adata: AnnData, + *, targets: dict[str, dict[str, list[str]]] = None, source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl", category_name: str = "interaction_type", @@ -426,6 +427,7 @@ def plot_gsea( self, adata: AnnData, enrichment: dict[str, pd.DataFrame], + *, n: int = 10, key: str = "pertpy_enrichment_gsea", interactive_plot: bool = False, diff --git a/pertpy/tools/_milo.py b/pertpy/tools/_milo.py index 2244788c..53bda6b9 100644 --- a/pertpy/tools/_milo.py +++ b/pertpy/tools/_milo.py @@ -718,6 +718,7 @@ def _graph_spatial_fdr( def plot_nhood_graph( self, mdata: MuData, + *, alpha: float = 0.1, min_logFC: float = 0, min_size: int = 10, @@ -813,6 +814,7 @@ def plot_nhood( self, mdata: MuData, ix: int, + *, feature_key: str | None = "rna", basis: str = "X_umap", color_map: Colormap | str | None = None, @@ -874,6 +876,7 @@ def plot_nhood( def plot_da_beeswarm( self, mdata: MuData, + *, feature_key: str | None = "rna", anno_col: str = "nhood_annotation", alpha: float = 0.1, @@ -881,7 +884,7 @@ def plot_da_beeswarm( palette: str | Sequence[str] | dict[str, str] | None = None, show: bool = True, return_fig: bool = False, - ) -> Figure | Axes | None: + ) -> Figure | None: """Plot beeswarm plot of logFC against nhood labels Args: @@ -894,6 +897,9 @@ def plot_da_beeswarm( Defaults to pre-defined category colors for violinplots. {common_plot_args} + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + Examples: >>> import pertpy as pt >>> import scanpy as sc @@ -999,11 +1005,12 @@ def plot_nhood_counts_by_cond( self, mdata: MuData, test_var: str, + *, subset_nhoods: list[str] = None, log_counts: bool = False, show: bool = True, return_fig: bool = False, - ) -> Figure | Axes | None: + ) -> Figure | None: """Plot boxplot of cell numbers vs condition of interest. Args: @@ -1012,6 +1019,9 @@ def plot_nhood_counts_by_cond( subset_nhoods: List of obs_names for neighbourhoods to include in plot. If None, plot all nhoods. log_counts: Whether to plot log1p of cell counts. {common_plot_args} + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. """ try: nhood_adata = mdata["milo"].T.copy() diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 53577a2b..ade7ea49 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -513,16 +513,16 @@ def plot_barplot( # pragma: no cover self, adata: AnnData, guide_rna_column: str, + *, mixscape_class_global: str = "mixscape_class_global", axis_text_x_size: int = 8, axis_text_y_size: int = 6, axis_title_size: int = 8, legend_title_size: int = 8, legend_text_size: int = 8, - ax: Axes | None = None, show: bool = True, return_fig: bool = False, - ): + ) -> Figure | None: """Barplot to visualize perturbation scores calculated by the `mixscape` function. Args: @@ -614,6 +614,7 @@ def plot_heatmap( # pragma: no cover labels: str, target_gene: str, control: str, + *, layer: str | None = None, method: str | None = "wilcoxon", subsample_number: int | None = 900, @@ -622,7 +623,7 @@ def plot_heatmap( # pragma: no cover show: bool = True, return_fig: bool = False, **kwds, - ) -> Axes | None: + ) -> Figure | None: """Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first. Args: @@ -639,7 +640,7 @@ def plot_heatmap( # pragma: no cover **kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`. Returns: - If `return_fig` is `True`, return a :class:`~matplotlib.axes.Axes`. + If `return_fig` is `True`, returns the figure, otherwise `None`. Examples: >>> import pertpy as pt @@ -684,6 +685,7 @@ def plot_perturbscore( # pragma: no cover adata: AnnData, labels: str, target_gene: str, + *, mixscape_class: str = "mixscape_class", color: str = "orange", palette: dict[str, str] = None, @@ -856,6 +858,7 @@ def plot_violin( # pragma: no cover self, adata: AnnData, target_gene_idents: str | list[str], + *, keys: str | Sequence[str] = "mixscape_class_p_ko", groupby: str | None = "mixscape_class", log: bool = False, @@ -875,7 +878,7 @@ def plot_violin( # pragma: no cover show: bool = True, return_fig: bool = False, **kwargs, - ) -> Axes | None: + ) -> Axes | Figure | None: """Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first. @@ -897,7 +900,7 @@ def plot_violin( # pragma: no cover **kwargs: Additional arguments to `seaborn.violinplot`. Returns: - A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`. + If `return_fig` is `True`, returns the figure (as Axes list if it's a multi-panel plot), otherwise `None`. Examples: >>> import pertpy as pt @@ -1060,6 +1063,7 @@ def plot_lda( # pragma: no cover self, adata: AnnData, control: str, + *, mixscape_class: str = "mixscape_class", mixscape_class_global: str = "mixscape_class_global", perturbation_type: str | None = "KO", diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index 6a6db40c..a3cd954d 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -180,6 +180,7 @@ def plot_psbulk_samples( self, adata: AnnData, groupby: str, + *, show: bool = True, return_fig: bool = False, **kwargs, diff --git a/pertpy/tools/_scgen/_scgen.py b/pertpy/tools/_scgen/_scgen.py index 3168de01..eca887ee 100644 --- a/pertpy/tools/_scgen/_scgen.py +++ b/pertpy/tools/_scgen/_scgen.py @@ -381,6 +381,7 @@ def plot_reg_mean_plot( condition_key: str, axis_keys: dict[str, str], labels: dict[str, str], + *, gene_list: list[str] = None, top_100_genes: list[str] = None, verbose: bool = False, @@ -521,6 +522,7 @@ def plot_reg_var_plot( condition_key: str, axis_keys: dict[str, str], labels: dict[str, str], + *, gene_list: list[str] = None, top_100_genes: list[str] = None, legend: bool = True, @@ -650,10 +652,11 @@ def plot_binary_classifier( delta: np.ndarray, ctrl_key: str, stim_key: str, + *, fontsize: float = 14, show: bool = True, return_fig: bool = False, - ) -> plt.Axes | None: + ) -> Figure | None: """Plots the dot product between delta and latent representation of a linear classifier. Builds a linear classifier based on the dot product between @@ -670,6 +673,9 @@ def plot_binary_classifier( stim_key: Key for `stimulated` part of the `data` found in `condition_key`. fontsize: Set the font size of the plot. {common_plot_args} + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. """ plt.close("all") adata = scgen._validate_anndata(adata)