Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
ShouWenWang committed Dec 27, 2022
1 parent 3f95e56 commit b7e833c
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 69 deletions.
6 changes: 3 additions & 3 deletions cospar/plotting/_gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def gene_expression_heatmap(
vmax=vmax,
color_bar_label=color_bar_label,
order_map_x=order_map_x,
x_tick_style='italic',
x_tick_style="italic",
order_map_y=order_map_y,
**kwargs,
)
Expand All @@ -427,7 +427,7 @@ def gene_expression_heatmap(
vmax=vmax,
color_bar_label=color_bar_label,
order_map_x=order_map_x,
y_tick_style='italic',
y_tick_style="italic",
order_map_y=order_map_y,
**kwargs,
)
Expand Down Expand Up @@ -506,7 +506,7 @@ def gene_expression_on_manifold(
point_size=point_size,
color_bar_label="Normalized expression",
)
plt.title(g,style='italic')
plt.title(g, style="italic")

plt.tight_layout()
if savefig:
Expand Down
56 changes: 36 additions & 20 deletions cospar/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,10 @@ def customized_embedding(
if color_bar:

norm = mpl_Normalize(vmin=vmin, vmax=vmax)
Clb = plt.colorbar(
plt.cm.ScalarMappable(norm=norm, cmap=color_map),
ax=ax)
Clb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=color_map), ax=ax)
Clb.set_label(
color_bar_label,
rotation=270,
rotation=270,
labelpad=20,
)
Clb.ax.set_title(color_bar_title)
Expand Down Expand Up @@ -396,9 +394,9 @@ def heatmap(

if ax is None:
fig, ax = plt.subplots()
ax_=None
ax_ = None
else:
ax_=ax
ax_ = ax
ax.imshow(
new_data,
aspect="auto",
Expand Down Expand Up @@ -438,10 +436,10 @@ def heatmap(
print(f"y_ticks: ['{y_ticks_print}']")

if x_label is not None:
ax.set_xlabel(x_label,style=x_label_style)
ax.set_xlabel(x_label, style=x_label_style)

if y_label is not None:
ax.set_ylabel(y_label,style=y_label_style)
ax.set_ylabel(y_label, style=y_label_style)

if color_bar:
norm = mpl_Normalize(vmin=vmin, vmax=vmax)
Expand Down Expand Up @@ -747,8 +745,10 @@ def layout(node):

display(Image(filename=os.path.join(figure_path, f"{data_des}.png")))


def plot_adata_with_prefered_order(adata,obs_key,basis='X_umap',plot_order=None,palette=None,**kwargs):

def plot_adata_with_prefered_order(
adata, obs_key, basis="X_umap", plot_order=None, palette=None, **kwargs
):
"""
An example code
```python
Expand All @@ -760,15 +760,31 @@ def plot_adata_with_prefered_order(adata,obs_key,basis='X_umap',plot_order=None,
```
"""
if plot_order is None:
plot_order=list(adata.obs[obs_key].unique())
plot_order = list(adata.obs[obs_key].unique())
if palette is None:
palette=dict(zip(plot_order,np.array(sns.color_palette().as_hex())[:len(plot_order)]))

df_fate_map=pd.DataFrame({obs_key:adata.obs[obs_key],'x':adata.obsm[basis][:,0],'y':adata.obsm[basis][:,1]})
df_list=[]
palette = dict(
zip(plot_order, np.array(sns.color_palette().as_hex())[: len(plot_order)])
)

df_fate_map = pd.DataFrame(
{
obs_key: adata.obs[obs_key],
"x": adata.obsm[basis][:, 0],
"y": adata.obsm[basis][:, 1],
}
)
df_list = []
for z in plot_order:
df_list.append(df_fate_map[df_fate_map[obs_key]==z])

df_map_v2=pd.concat(df_list,ignore_index=True)
g=sns.relplot(kind='scatter',data=df_map_v2,x='x',y='y',hue=obs_key,palette=palette,**kwargs)
g.ax.axis('off');
df_list.append(df_fate_map[df_fate_map[obs_key] == z])

df_map_v2 = pd.concat(df_list, ignore_index=True)
g = sns.relplot(
kind="scatter",
data=df_map_v2,
x="x",
y="y",
hue=obs_key,
palette=palette,
**kwargs,
)
g.ax.axis("off")
16 changes: 8 additions & 8 deletions cospar/preprocessing/_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def get_X_clone(

def refine_state_info_by_leiden_clustering(
adata,
selected_key='state_info',
selected_key="state_info",
selected_values=None,
resolution=0.5,
n_neighbors=20,
Expand All @@ -568,7 +568,7 @@ def refine_state_info_by_leiden_clustering(
Parameters
----------
adata: :class:`~anndata.AnnData` object
selected_key:
selected_key:
A key in adata.obs, including 'state_info', or 'time_info'
selected_values: `list`, optional (default: include all)
A list of clusters/time_points for further sub-clustering. Should be
Expand All @@ -595,8 +595,8 @@ def refine_state_info_by_leiden_clustering(

if selected_values == None:
selected_values = available_time_points
if type(selected_values)==str:
selected_values=[selected_values]
if type(selected_values) == str:
selected_values = [selected_values]

if np.sum(np.in1d(selected_values, available_time_points)) != len(selected_values):
logg.error(
Expand Down Expand Up @@ -635,7 +635,7 @@ def refine_state_info_by_marker_genes(
adata,
marker_genes,
express_threshold=0.1,
selected_key='state_info',
selected_key="state_info",
selected_values=None,
new_cluster_name="new_cluster",
confirm_change=False,
Expand Down Expand Up @@ -663,7 +663,7 @@ def refine_state_info_by_marker_genes(
Relative threshold of marker gene expression, in the range [0,1].
A state must have an expression above this threshold for all genes
to be included.
selected_key:
selected_key:
A key in adata.obs, including 'state_info', or 'time_info'
selected_values: `list`, optional (default: include all)
A list of clusters/time_points for further sub-clustering. Should be
Expand All @@ -688,8 +688,8 @@ def refine_state_info_by_marker_genes(

if selected_values == None:
selected_values = available_time_points
if type(selected_values)==str:
selected_values=[selected_values]
if type(selected_values) == str:
selected_values = [selected_values]

sp_idx = np.zeros(adata.shape[0], dtype=bool)
for xx in selected_values:
Expand Down
20 changes: 11 additions & 9 deletions cospar/tool/_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def filter_clones(adata, clone_size_threshold=2, filter_larger_clones=False):
else:
filter smaller clones
"""
clone_size=adata.obsm['X_clone'].sum(0).A.flatten()
clone_size = adata.obsm["X_clone"].sum(0).A.flatten()
if filter_larger_clones:
clone_idx = clone_size < clone_size_threshold
logg.info(
Expand All @@ -596,8 +596,8 @@ def filter_clones(adata, clone_size_threshold=2, filter_larger_clones=False):
X_clone_new = adata.obsm["X_clone"][:, clone_idx]

adata.obsm["X_clone_old"] = adata.obsm["X_clone"]
if 'clone_id' not in adata.uns:
adata.uns['clone_id']=np.arange(adata.obsm['X_clone'].shape[1])
if "clone_id" not in adata.uns:
adata.uns["clone_id"] = np.arange(adata.obsm["X_clone"].shape[1])
adata.uns["clone_id"] = np.array(adata.uns["clone_id"])[clone_idx]
adata.obsm["X_clone"] = ssp.csr_matrix(X_clone_new)
logg.info("Updated X_clone")
Expand All @@ -608,10 +608,10 @@ def clone_statistics(adata, joint_variable="time_info"):
"""
Extract the number of clones and clonal cells for each time point
"""
if 'clone_id' not in adata.uns:
adata.uns['clone_id']=np.arange(adata.obsm['X_clone'].shape[1])
adata.obs[joint_variable]=adata.obs[joint_variable].astype(str)

if "clone_id" not in adata.uns:
adata.uns["clone_id"] = np.arange(adata.obsm["X_clone"].shape[1])
adata.obs[joint_variable] = adata.obs[joint_variable].astype(str)

df = (
pd.DataFrame(adata.obsm["X_clone"].A)
Expand Down Expand Up @@ -748,14 +748,16 @@ def compute_sister_cell_distance(
logg.info(np.sum(X_clone.sum(0) >= 2), " clones with >=2 cells selected")

# observed distances
distance_list, selected_clone_idx = get_distance_within_each_clone(X_clone,norm_distance)
distance_list, selected_clone_idx = get_distance_within_each_clone(
X_clone, norm_distance
)

# randomized distances
random_dis = []
random_dis_stat = []
for _ in tqdm(range(max_N_simutation)):
np.random.shuffle(X_clone)
temp, __ = get_distance_within_each_clone(X_clone,norm_distance)
temp, __ = get_distance_within_each_clone(X_clone, norm_distance)
random_dis += temp
random_dis_stat.append(
[np.mean(temp), np.min(temp), np.median(temp), np.max(temp)]
Expand Down
88 changes: 59 additions & 29 deletions cospar/tool/_gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def differential_genes(
cell_group_B=None,
FDR_cutoff=0.05,
sort_by="ratio",
min_frac_expr=0.05,
pseudocount=1
min_frac_expr=0.05,
pseudocount=1,
):
"""
Perform differential gene expression analysis and plot top DGE genes.
Expand Down Expand Up @@ -96,7 +96,13 @@ def differential_genes(

else:

dge = hf.get_dge_SW(adata, selections[0], selections[1],min_frac_expr=min_frac_expr,pseudocount=pseudocount)
dge = hf.get_dge_SW(
adata,
selections[0],
selections[1],
min_frac_expr=min_frac_expr,
pseudocount=pseudocount,
)

dge = dge.sort_values(by=sort_by, ascending=False)
diff_gene_A_0 = dge
Expand All @@ -110,51 +116,75 @@ def differential_genes(

return diff_gene_A, diff_gene_B

def identify_TF_and_surface_marker(gene_list,species='mouse',go_term_keywards=['cell surface','cell cycle','regulation of transcription','DNA-binding transcription factor activity','regulation of transcription by RNA polymerase II']):

def identify_TF_and_surface_marker(
gene_list,
species="mouse",
go_term_keywards=[
"cell surface",
"cell cycle",
"regulation of transcription",
"DNA-binding transcription factor activity",
"regulation of transcription by RNA polymerase II",
],
):
"""
From an input gene list, return the go term and annotation for each gene,
and further select the genes identified as TF or cell surface protein
Returns
------
results:
Full annotation for each gene
df_anno
Only include genes identified as TF or cell surface protein
"""
if species not in ['mouse','human']:
raise ValueError('species must be either mouse or human')

if species not in ["mouse", "human"]:
raise ValueError("species must be either mouse or human")
else:
if species=='mouse':
dataset='mmusculus_gene_ensembl'
elif species=='human':
dataset='hsapiens_gene_ensembl'



if species == "mouse":
dataset = "mmusculus_gene_ensembl"
elif species == "human":
dataset = "hsapiens_gene_ensembl"

from gseapy.parser import Biomart

bm = Biomart()
## view validated marts
marts = bm.get_marts()
## view validated dataset
datasets = bm.get_datasets(mart='ENSEMBL_MART_ENSEMBL')
datasets = bm.get_datasets(mart="ENSEMBL_MART_ENSEMBL")
## view validated attributes
attrs = bm.get_attributes(dataset=dataset) #hsapiens_gene_ensembl: Human genes (GRCh38.p13); mmusculus_gene_ensembl for 'Mouse genes (GRCm39)'
attrs = bm.get_attributes(
dataset=dataset
) # hsapiens_gene_ensembl: Human genes (GRCh38.p13); mmusculus_gene_ensembl for 'Mouse genes (GRCm39)'
## view validated filters
filters = bm.get_filters(dataset=dataset) #Gene Name(s) [e.g. MT-TF]
filters = bm.get_filters(dataset=dataset) # Gene Name(s) [e.g. MT-TF]
## query results


results = bm.query(dataset=dataset,
attributes=['ensembl_gene_id', 'external_gene_name', 'namespace_1003','name_1006'],
filters={'external_gene_name': gene_list})
results=results.dropna()
df_list=[]
results = bm.query(
dataset=dataset,
attributes=[
"ensembl_gene_id",
"external_gene_name",
"namespace_1003",
"name_1006",
],
filters={"external_gene_name": gene_list},
)
results = results.dropna()
df_list = []
for term in go_term_keywards:
tmp_genes=list(set(results[results['name_1006'].apply(lambda x: term == x)]['external_gene_name']))
df_tmp=pd.DataFrame({'gene':tmp_genes})
df_tmp['annotation']=term
tmp_genes = list(
set(
results[results["name_1006"].apply(lambda x: term == x)][
"external_gene_name"
]
)
)
df_tmp = pd.DataFrame({"gene": tmp_genes})
df_tmp["annotation"] = term
df_list.append(df_tmp)
df_anno=pd.concat(df_list,ignore_index=True)
return results,df_anno
df_anno = pd.concat(df_list, ignore_index=True)
return results, df_anno

0 comments on commit b7e833c

Please sign in to comment.