From b7e833c57719b0ab6f0b4b491591ea6fdd184e6d Mon Sep 17 00:00:00 2001 From: ShouWenWang Date: Tue, 27 Dec 2022 16:41:04 -0500 Subject: [PATCH] black --- cospar/plotting/_gene.py | 6 +- cospar/plotting/_utils.py | 56 ++++++++++------ cospar/preprocessing/_preprocessing.py | 16 ++--- cospar/tool/_clone.py | 20 +++--- cospar/tool/_gene.py | 88 +++++++++++++++++--------- 5 files changed, 117 insertions(+), 69 deletions(-) diff --git a/cospar/plotting/_gene.py b/cospar/plotting/_gene.py index 564e080..f2a2d13 100644 --- a/cospar/plotting/_gene.py +++ b/cospar/plotting/_gene.py @@ -409,7 +409,7 @@ def gene_expression_heatmap( vmax=vmax, color_bar_label=color_bar_label, order_map_x=order_map_x, - x_tick_style='italic', + x_tick_style="italic", order_map_y=order_map_y, **kwargs, ) @@ -427,7 +427,7 @@ def gene_expression_heatmap( vmax=vmax, color_bar_label=color_bar_label, order_map_x=order_map_x, - y_tick_style='italic', + y_tick_style="italic", order_map_y=order_map_y, **kwargs, ) @@ -506,7 +506,7 @@ def gene_expression_on_manifold( point_size=point_size, color_bar_label="Normalized expression", ) - plt.title(g,style='italic') + plt.title(g, style="italic") plt.tight_layout() if savefig: diff --git a/cospar/plotting/_utils.py b/cospar/plotting/_utils.py index a48deac..a3acee0 100644 --- a/cospar/plotting/_utils.py +++ b/cospar/plotting/_utils.py @@ -214,12 +214,10 @@ def customized_embedding( if color_bar: norm = mpl_Normalize(vmin=vmin, vmax=vmax) - Clb = plt.colorbar( - plt.cm.ScalarMappable(norm=norm, cmap=color_map), - ax=ax) + Clb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=color_map), ax=ax) Clb.set_label( color_bar_label, - rotation=270, + rotation=270, labelpad=20, ) Clb.ax.set_title(color_bar_title) @@ -396,9 +394,9 @@ def heatmap( if ax is None: fig, ax = plt.subplots() - ax_=None + ax_ = None else: - ax_=ax + ax_ = ax ax.imshow( new_data, aspect="auto", @@ -438,10 +436,10 @@ def heatmap( print(f"y_ticks: ['{y_ticks_print}']") if x_label is not None: - ax.set_xlabel(x_label,style=x_label_style) + ax.set_xlabel(x_label, style=x_label_style) if y_label is not None: - ax.set_ylabel(y_label,style=y_label_style) + ax.set_ylabel(y_label, style=y_label_style) if color_bar: norm = mpl_Normalize(vmin=vmin, vmax=vmax) @@ -747,8 +745,10 @@ def layout(node): display(Image(filename=os.path.join(figure_path, f"{data_des}.png"))) - -def plot_adata_with_prefered_order(adata,obs_key,basis='X_umap',plot_order=None,palette=None,**kwargs): + +def plot_adata_with_prefered_order( + adata, obs_key, basis="X_umap", plot_order=None, palette=None, **kwargs +): """ An example code ```python @@ -760,15 +760,31 @@ def plot_adata_with_prefered_order(adata,obs_key,basis='X_umap',plot_order=None, ``` """ if plot_order is None: - plot_order=list(adata.obs[obs_key].unique()) + plot_order = list(adata.obs[obs_key].unique()) if palette is None: - palette=dict(zip(plot_order,np.array(sns.color_palette().as_hex())[:len(plot_order)])) - - df_fate_map=pd.DataFrame({obs_key:adata.obs[obs_key],'x':adata.obsm[basis][:,0],'y':adata.obsm[basis][:,1]}) - df_list=[] + palette = dict( + zip(plot_order, np.array(sns.color_palette().as_hex())[: len(plot_order)]) + ) + + df_fate_map = pd.DataFrame( + { + obs_key: adata.obs[obs_key], + "x": adata.obsm[basis][:, 0], + "y": adata.obsm[basis][:, 1], + } + ) + df_list = [] for z in plot_order: - df_list.append(df_fate_map[df_fate_map[obs_key]==z]) - - df_map_v2=pd.concat(df_list,ignore_index=True) - g=sns.relplot(kind='scatter',data=df_map_v2,x='x',y='y',hue=obs_key,palette=palette,**kwargs) - g.ax.axis('off'); \ No newline at end of file + df_list.append(df_fate_map[df_fate_map[obs_key] == z]) + + df_map_v2 = pd.concat(df_list, ignore_index=True) + g = sns.relplot( + kind="scatter", + data=df_map_v2, + x="x", + y="y", + hue=obs_key, + palette=palette, + **kwargs, + ) + g.ax.axis("off") diff --git a/cospar/preprocessing/_preprocessing.py b/cospar/preprocessing/_preprocessing.py index e36f199..f6f2cf3 100644 --- a/cospar/preprocessing/_preprocessing.py +++ b/cospar/preprocessing/_preprocessing.py @@ -550,7 +550,7 @@ def get_X_clone( def refine_state_info_by_leiden_clustering( adata, - selected_key='state_info', + selected_key="state_info", selected_values=None, resolution=0.5, n_neighbors=20, @@ -568,7 +568,7 @@ def refine_state_info_by_leiden_clustering( Parameters ---------- adata: :class:`~anndata.AnnData` object - selected_key: + selected_key: A key in adata.obs, including 'state_info', or 'time_info' selected_values: `list`, optional (default: include all) A list of clusters/time_points for further sub-clustering. Should be @@ -595,8 +595,8 @@ def refine_state_info_by_leiden_clustering( if selected_values == None: selected_values = available_time_points - if type(selected_values)==str: - selected_values=[selected_values] + if type(selected_values) == str: + selected_values = [selected_values] if np.sum(np.in1d(selected_values, available_time_points)) != len(selected_values): logg.error( @@ -635,7 +635,7 @@ def refine_state_info_by_marker_genes( adata, marker_genes, express_threshold=0.1, - selected_key='state_info', + selected_key="state_info", selected_values=None, new_cluster_name="new_cluster", confirm_change=False, @@ -663,7 +663,7 @@ def refine_state_info_by_marker_genes( Relative threshold of marker gene expression, in the range [0,1]. A state must have an expression above this threshold for all genes to be included. - selected_key: + selected_key: A key in adata.obs, including 'state_info', or 'time_info' selected_values: `list`, optional (default: include all) A list of clusters/time_points for further sub-clustering. Should be @@ -688,8 +688,8 @@ def refine_state_info_by_marker_genes( if selected_values == None: selected_values = available_time_points - if type(selected_values)==str: - selected_values=[selected_values] + if type(selected_values) == str: + selected_values = [selected_values] sp_idx = np.zeros(adata.shape[0], dtype=bool) for xx in selected_values: diff --git a/cospar/tool/_clone.py b/cospar/tool/_clone.py index eb0048d..2f66ffa 100644 --- a/cospar/tool/_clone.py +++ b/cospar/tool/_clone.py @@ -581,7 +581,7 @@ def filter_clones(adata, clone_size_threshold=2, filter_larger_clones=False): else: filter smaller clones """ - clone_size=adata.obsm['X_clone'].sum(0).A.flatten() + clone_size = adata.obsm["X_clone"].sum(0).A.flatten() if filter_larger_clones: clone_idx = clone_size < clone_size_threshold logg.info( @@ -596,8 +596,8 @@ def filter_clones(adata, clone_size_threshold=2, filter_larger_clones=False): X_clone_new = adata.obsm["X_clone"][:, clone_idx] adata.obsm["X_clone_old"] = adata.obsm["X_clone"] - if 'clone_id' not in adata.uns: - adata.uns['clone_id']=np.arange(adata.obsm['X_clone'].shape[1]) + if "clone_id" not in adata.uns: + adata.uns["clone_id"] = np.arange(adata.obsm["X_clone"].shape[1]) adata.uns["clone_id"] = np.array(adata.uns["clone_id"])[clone_idx] adata.obsm["X_clone"] = ssp.csr_matrix(X_clone_new) logg.info("Updated X_clone") @@ -608,10 +608,10 @@ def clone_statistics(adata, joint_variable="time_info"): """ Extract the number of clones and clonal cells for each time point """ - - if 'clone_id' not in adata.uns: - adata.uns['clone_id']=np.arange(adata.obsm['X_clone'].shape[1]) - adata.obs[joint_variable]=adata.obs[joint_variable].astype(str) + + if "clone_id" not in adata.uns: + adata.uns["clone_id"] = np.arange(adata.obsm["X_clone"].shape[1]) + adata.obs[joint_variable] = adata.obs[joint_variable].astype(str) df = ( pd.DataFrame(adata.obsm["X_clone"].A) @@ -748,14 +748,16 @@ def compute_sister_cell_distance( logg.info(np.sum(X_clone.sum(0) >= 2), " clones with >=2 cells selected") # observed distances - distance_list, selected_clone_idx = get_distance_within_each_clone(X_clone,norm_distance) + distance_list, selected_clone_idx = get_distance_within_each_clone( + X_clone, norm_distance + ) # randomized distances random_dis = [] random_dis_stat = [] for _ in tqdm(range(max_N_simutation)): np.random.shuffle(X_clone) - temp, __ = get_distance_within_each_clone(X_clone,norm_distance) + temp, __ = get_distance_within_each_clone(X_clone, norm_distance) random_dis += temp random_dis_stat.append( [np.mean(temp), np.min(temp), np.median(temp), np.max(temp)] diff --git a/cospar/tool/_gene.py b/cospar/tool/_gene.py index 60adc9e..cb72225 100644 --- a/cospar/tool/_gene.py +++ b/cospar/tool/_gene.py @@ -27,8 +27,8 @@ def differential_genes( cell_group_B=None, FDR_cutoff=0.05, sort_by="ratio", - min_frac_expr=0.05, - pseudocount=1 + min_frac_expr=0.05, + pseudocount=1, ): """ Perform differential gene expression analysis and plot top DGE genes. @@ -96,7 +96,13 @@ def differential_genes( else: - dge = hf.get_dge_SW(adata, selections[0], selections[1],min_frac_expr=min_frac_expr,pseudocount=pseudocount) + dge = hf.get_dge_SW( + adata, + selections[0], + selections[1], + min_frac_expr=min_frac_expr, + pseudocount=pseudocount, + ) dge = dge.sort_values(by=sort_by, ascending=False) diff_gene_A_0 = dge @@ -110,11 +116,22 @@ def differential_genes( return diff_gene_A, diff_gene_B -def identify_TF_and_surface_marker(gene_list,species='mouse',go_term_keywards=['cell surface','cell cycle','regulation of transcription','DNA-binding transcription factor activity','regulation of transcription by RNA polymerase II']): + +def identify_TF_and_surface_marker( + gene_list, + species="mouse", + go_term_keywards=[ + "cell surface", + "cell cycle", + "regulation of transcription", + "DNA-binding transcription factor activity", + "regulation of transcription by RNA polymerase II", + ], +): """ From an input gene list, return the go term and annotation for each gene, and further select the genes identified as TF or cell surface protein - + Returns ------ results: @@ -122,39 +139,52 @@ def identify_TF_and_surface_marker(gene_list,species='mouse',go_term_keywards=[' df_anno Only include genes identified as TF or cell surface protein """ - - if species not in ['mouse','human']: - raise ValueError('species must be either mouse or human') + + if species not in ["mouse", "human"]: + raise ValueError("species must be either mouse or human") else: - if species=='mouse': - dataset='mmusculus_gene_ensembl' - elif species=='human': - dataset='hsapiens_gene_ensembl' - - - + if species == "mouse": + dataset = "mmusculus_gene_ensembl" + elif species == "human": + dataset = "hsapiens_gene_ensembl" + from gseapy.parser import Biomart + bm = Biomart() ## view validated marts marts = bm.get_marts() ## view validated dataset - datasets = bm.get_datasets(mart='ENSEMBL_MART_ENSEMBL') + datasets = bm.get_datasets(mart="ENSEMBL_MART_ENSEMBL") ## view validated attributes - attrs = bm.get_attributes(dataset=dataset) #hsapiens_gene_ensembl: Human genes (GRCh38.p13); mmusculus_gene_ensembl for 'Mouse genes (GRCm39)' + attrs = bm.get_attributes( + dataset=dataset + ) # hsapiens_gene_ensembl: Human genes (GRCh38.p13); mmusculus_gene_ensembl for 'Mouse genes (GRCm39)' ## view validated filters - filters = bm.get_filters(dataset=dataset) #Gene Name(s) [e.g. MT-TF] + filters = bm.get_filters(dataset=dataset) # Gene Name(s) [e.g. MT-TF] ## query results - - results = bm.query(dataset=dataset, - attributes=['ensembl_gene_id', 'external_gene_name', 'namespace_1003','name_1006'], - filters={'external_gene_name': gene_list}) - results=results.dropna() - df_list=[] + results = bm.query( + dataset=dataset, + attributes=[ + "ensembl_gene_id", + "external_gene_name", + "namespace_1003", + "name_1006", + ], + filters={"external_gene_name": gene_list}, + ) + results = results.dropna() + df_list = [] for term in go_term_keywards: - tmp_genes=list(set(results[results['name_1006'].apply(lambda x: term == x)]['external_gene_name'])) - df_tmp=pd.DataFrame({'gene':tmp_genes}) - df_tmp['annotation']=term + tmp_genes = list( + set( + results[results["name_1006"].apply(lambda x: term == x)][ + "external_gene_name" + ] + ) + ) + df_tmp = pd.DataFrame({"gene": tmp_genes}) + df_tmp["annotation"] = term df_list.append(df_tmp) - df_anno=pd.concat(df_list,ignore_index=True) - return results,df_anno \ No newline at end of file + df_anno = pd.concat(df_list, ignore_index=True) + return results, df_anno