Skip to content

Commit

Permalink
sorting function from other repo, config fields for prob and emb plots
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jun 26, 2024
1 parent 6eff549 commit a71457e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
)


def sort_by_transport(cost):
m,n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost)
indices = np.argsort((transport * np.arange(m)[...,None]).sum(0))
return cost[:,indices], indices, transport


def compute_wasserstein_between_distributions_from_weights_and_cost(
weights_a, weights_b, cost, numItermax=1000000
):
Expand Down
107 changes: 59 additions & 48 deletions tutorials/5_tutorial_plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from cryomethods_comparison_pipeline.distribution_to_distribution import sort_by_transport\n",
"from cryo_challenge._distribution_to_distribution.distribution_to_distribution import sort_by_transport\n",
"from cryo_challenge._ploting.plotting_utils import res_at_fsc_threshold\n",
"\n",
"from dataclasses import dataclass\n",
Expand Down Expand Up @@ -57,50 +57,60 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path_to_config = FileChooser(os.path.expanduser(\"~\"))\n",
"display(path_to_config)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9242b69072f64782b69ff2b63aa23d53",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"FileChooser(path='/mnt/home/gwoollard', filename='', title='', show_hidden=False, select_desc='Select', change…"
"{'prob_submitted_plot': {'pkl_fnames': ['/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_0.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_1.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_2.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_3.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_4.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_5.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_6.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_7.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_8.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_9.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_10.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_11.pkl']},\n",
" 'emd_plot': {'pkl_globs': ['/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate[0-9]_norank_submission_*.pkl',\n",
" '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate[0-9][0-9]_norank_submission_*.pkl']}}"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "display_data"
"output_type": "execute_result"
}
],
"source": [
"path_to_config = FileChooser(os.path.expanduser(\"~\"))\n",
"display(path_to_config)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"@dataclass_json\n",
"@dataclass\n",
"class PlottingConfig:\n",
" gt_metadata: str\n",
" map2map_results: List[str]\n",
" dist2dist_results: Dict[str, List[str]]\n",
" dist2dist_results: Dict[str, Dict[str, List[str]]]\n",
"\n",
"with open(path_to_config.value, \"r\") as file:\n",
"with open(path_to_config, \"r\") as file:\n",
" config = yaml.safe_load(file)\n",
"config = PlottingConfig.from_dict(config)"
"config = PlottingConfig.from_dict(config)\n",
"config.dist2dist_results"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -118,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -136,7 +146,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 23,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -196,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -215,26 +225,27 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs_sorted_d['Cookie Dough'], threshold=0.5)\n",
"units_Angstroms = 2 * 2.146 / (np.arange(1,112+1) / 112)"
"n_fourier_bins = fscs_sorted_d['Cookie Dough'].shape[-1]\n",
"units_Angstroms = 2 * 2.146 / (np.arange(1,n_fourier_bins+1) / n_fourier_bins)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x1552b08f1880>"
"<matplotlib.colorbar.Colorbar at 0x15533d547970>"
]
},
"execution_count": 10,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -256,7 +267,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 30,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -286,10 +297,10 @@
"\n",
" for idx, (anonymous_label, fscs) in enumerate(fscs_sorted_d.items()):\n",
" # map2map_dist_matrix = data.iloc[gt_ordering].values\n",
" res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs, threshold=0.5)\n",
" res_fsc_half, _ = res_at_fsc_threshold(fscs, threshold=0.5)\n",
" map2map_dist_matrix = units_Angstroms[res_fsc_half][gt_ordering]\n",
"\n",
" sorted_map2map_dist_matrix, indices, transport = sort_by_transport(map2map_dist_matrix)\n",
" sorted_map2map_dist_matrix, _, _ = sort_by_transport(map2map_dist_matrix)\n",
"\n",
"\n",
" ncols = 4\n",
Expand Down Expand Up @@ -323,19 +334,19 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"fname = config.dist2dist_results['pkl_fnames'][0]\n",
"fname = config.dist2dist_results['prob_submitted_plot']['pkl_fnames'][0]\n",
"\n",
"with open(fname, 'rb') as f:\n",
" data = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -349,12 +360,12 @@
" data_d[anonymous_label] = data\n",
" return data_d\n",
"\n",
"dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['pkl_fnames'])"
"dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['prob_submitted_plot']['pkl_fnames'])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 34,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -383,7 +394,7 @@
" y=0.95)\n",
" alpha = 0.05\n",
"\n",
" for idx_fname, (key,data) in enumerate(dist2dist_results_d.items()):\n",
" for idx_fname, (_,data) in enumerate(dist2dist_results_d.items()):\n",
" \n",
"\n",
" axes[idx_fname//ncols, idx_fname%ncols].plot(data['user_submitted_populations'], color='black', label='submited')\n",
Expand Down Expand Up @@ -421,7 +432,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -461,16 +472,16 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"df = wragle_pkl_to_dataframe(config.dist2dist_results['pkl_globs'])"
"df = wragle_pkl_to_dataframe(config.dist2dist_results['emd_plot']['pkl_globs'])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 37,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -505,9 +516,9 @@
" plt.figure(figsize=(plot_width, plot_height), dpi=300)\n",
"\n",
" # Create a scatter plot for each id\n",
" for i, id in enumerate(ids):\n",
" df_average_id = df_average[df_average['id'] == id]\n",
" sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[i % len(markers)], label=id, s=marker_size)\n",
" for idx, id_label in enumerate(ids):\n",
" df_average_id = df_average[df_average['id'] == id_label]\n",
" sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[idx % len(markers)], label=id_label, s=marker_size)\n",
"\n",
" plt.errorbar(x=df_average_and_error['EMD_submitted_norm'], \n",
" y=df_average_and_error['EMD_opt_norm'], \n",
Expand Down

0 comments on commit a71457e

Please sign in to comment.