Skip to content

Commit

Permalink
keep possible deprecated function commented
Browse files Browse the repository at this point in the history
  • Loading branch information
vinicvaz committed Jul 22, 2024
1 parent 30cea9e commit daa81ee
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions src/vame/analysis/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,44 +61,44 @@ def umap_embedding(cfg: dict, file: str, model_name: str, n_cluster: int, parame
return embed


def umap_vis_community_labels(cfg: dict, embed: np.ndarray, community_labels_all: np.ndarray, save_path: str | None) -> None:
"""Create plotly visualizaton of UMAP embedding with community labels.
Args:
cfg (dict): Configuration parameters.
embed (np.ndarray): UMAP embedding.
community_labels_all (np.ndarray): Community labels.
save_path: Path to save the plot. If None it will not save the plot.
Returns:
None
"""
num_points = cfg['num_points']
community_labels_all = np.asarray(community_labels_all)
if num_points > community_labels_all.shape[0]:
num_points = community_labels_all.shape[0]
logger.info("Embedding %d data points.." %num_points)

num = np.unique(community_labels_all)

fig = plt.figure(1)
plt.scatter(
embed[:,0],
embed[:,1],
c=community_labels_all[:num_points],
cmap='Spectral',
s=2,
alpha=1
)
plt.colorbar(boundaries=np.arange(np.max(num)+2)-0.5).set_ticks(np.arange(np.max(num)+1))
plt.gca().set_aspect('equal', 'datalim')
plt.grid(False)

if save_path is not None:
plt.savefig(save_path)
return fig
plt.show()
return fig
# def umap_vis_community_labels(cfg: dict, embed: np.ndarray, community_labels_all: np.ndarray, save_path: str | None) -> None:
# """Create plotly visualizaton of UMAP embedding with community labels.

# Args:
# cfg (dict): Configuration parameters.
# embed (np.ndarray): UMAP embedding.
# community_labels_all (np.ndarray): Community labels.
# save_path: Path to save the plot. If None it will not save the plot.

# Returns:
# None
# """
# num_points = cfg['num_points']
# community_labels_all = np.asarray(community_labels_all)
# if num_points > community_labels_all.shape[0]:
# num_points = community_labels_all.shape[0]
# logger.info("Embedding %d data points.." %num_points)

# num = np.unique(community_labels_all)

# fig = plt.figure(1)
# plt.scatter(
# embed[:,0],
# embed[:,1],
# c=community_labels_all[:num_points],
# cmap='Spectral',
# s=2,
# alpha=1
# )
# plt.colorbar(boundaries=np.arange(np.max(num)+2)-0.5).set_ticks(np.arange(np.max(num)+1))
# plt.gca().set_aspect('equal', 'datalim')
# plt.grid(False)

# if save_path is not None:
# plt.savefig(save_path)
# return fig
# plt.show()
# return fig


def umap_vis(embed: np.ndarray, num_points: int) -> None:
Expand Down

0 comments on commit daa81ee

Please sign in to comment.