Skip to content

Commit

Permalink
new models
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jun 6, 2024
1 parent 7b5e3e1 commit 240ffd6
Show file tree
Hide file tree
Showing 10 changed files with 1,025 additions and 68 deletions.
842 changes: 775 additions & 67 deletions docs/notebooks/fid/additional_models.ipynb

Large diffs are not rendered by default.

245 changes: 245 additions & 0 deletions docs/notebooks/fid/model_table.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np "
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"results_fid = np.load('results_fid.npy', allow_pickle=True).item()"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"def print_mean(means, scientific=True):\n",
" out = ''\n",
" i = np.argmin(means[:3])\n",
" j = np.argmin(means[3:]) + 3\n",
" for k,mean in enumerate(means):\n",
" \n",
" if k == i:\n",
" out += '$\\\\mathbf{'\n",
" elif k == j:\n",
" out += '$\\\\mathbf{'\n",
" else:\n",
" out += '$'\n",
" \n",
" if scientific:\n",
" out += f'{mean:.1e} ' # Scientific notation for numbers smaller than 1e-3\n",
" else:\n",
" out += f'{mean:.2f} '\n",
" \n",
" if k == i or k == j:\n",
" out += '}$ &'\n",
" else:\n",
" out += '$ &'\n",
" \n",
" out = out[:-2] # Remove the last ' & '\n",
" out += '\\\\\\\\' # Add '\\\\'\n",
" print(out)\n"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"means = []\n",
"stds = []\n",
"for key in results_fid.keys():\n",
" means.append(np.mean(results_fid[key]))\n",
" stds.append(np.std(results_fid[key]))"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"$\\mathbf{6.2e+00 }$ &$6.4e+00 $ &$7.0e+00 $ &$1.3e+01 $ &$1.7e+01 $ &$1.7e+01 $ &$\\mathbf{1.1e+01 }$ &$1.9e+01 $ &$1.3e+01 $ &$1.7e+01 $\\\\\n"
]
}
],
"source": [
"print_mean(means)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"$2.3e-02 $ &$\\mathbf{2.2e-02 }$ &$2.5e-02 $ &$5.1e-02 $ &$5.6e-02 $ &$5.6e-02 $ &$\\mathbf{4.1e-02 }$ &$5.6e-02 $ &$5.0e-02 $ &$4.8e-02 $\\\\\n"
]
}
],
"source": [
"results_fid = np.load('results_sw.npy', allow_pickle=True).item()\n",
"means = []\n",
"stds = []\n",
"for key in results_fid.keys():\n",
" means.append(np.mean(results_fid[key]))\n",
" stds.append(np.std(results_fid[key]))\n",
"print_mean(means)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"$6.6e-05 $ &$\\mathbf{6.2e-05 }$ &$8.5e-05 $ &$1.9e-04 $ &$2.1e-04 $ &$2.1e-04 $ &$\\mathbf{1.5e-04 }$ &$2.0e-04 $ &$1.9e-04 $ &$1.8e-04 $\\\\\n"
]
}
],
"source": [
"results_fid = np.load('results_mmd_rbf64.npy', allow_pickle=True).item()\n",
"means = []\n",
"stds = []\n",
"for key in results_fid.keys():\n",
" means.append(np.mean(results_fid[key]))\n",
" stds.append(np.std(results_fid[key]))\n",
"print_mean(means)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"$2.5e-01 $ &$\\mathbf{2.4e-01 }$ &$3.2e-01 $ &$6.3e-01 $ &$6.9e-01 $ &$6.9e-01 $ &$\\mathbf{5.1e-01 }$ &$6.4e-01 $ &$6.5e-01 $ &$6.0e-01 $\\\\\n"
]
}
],
"source": [
"results_fid = np.load('results_mmd_lin.npy', allow_pickle=True).item()\n",
"means = []\n",
"stds = []\n",
"for key in results_fid.keys():\n",
" means.append(np.mean(results_fid[key]))\n",
" stds.append(np.std(results_fid[key]))\n",
"print_mean(means)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"$\\mathbf{1.1e+04 }$ &$1.7e+04 $ &$1.6e+04 $ &$3.0e+04 $ &$3.6e+04 $ &$3.6e+04 $ &$\\mathbf{2.3e+04 }$ &$3.8e+04 $ &$3.4e+04 $ &$3.6e+04 $\\\\\n"
]
}
],
"source": [
"results_fid = np.load('results_mmd_poly_kid.npy', allow_pickle=True).item()\n",
"means = []\n",
"stds = []\n",
"for key in results_fid.keys():\n",
" means.append(np.mean(results_fid[key]))\n",
" stds.append(np.std(results_fid[key]))\n",
"print_mean(means)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"$0.65 $ &$\\mathbf{0.63 }$ &$0.65 $ &$\\mathbf{0.75 }$ &$0.79 $ &$0.79 $ &$0.77 $ &$0.79 $ &$0.76 $ &$0.79 $\\\\\n"
]
}
],
"source": [
"results_fid = np.load('results_c2st_knn.npy', allow_pickle=True).item()\n",
"means = []\n",
"stds = []\n",
"for key in results_fid.keys():\n",
" means.append(np.mean(results_fid[key]))\n",
" stds.append(np.std(results_fid[key]))\n",
"print_mean(means, False)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"$\\mathbf{0.72 }$ &$0.77 $ &$0.76 $ &$0.86 $ &$0.92 $ &$0.92 $ &$\\mathbf{0.85 }$ &$0.92 $ &$0.85 $ &$0.94 $\\\\\n"
]
}
],
"source": [
"results_fid = np.load('results_c2st_nn.npy', allow_pickle=True).item()\n",
"means = []\n",
"stds = []\n",
"for key in results_fid.keys():\n",
" means.append(np.mean(results_fid[key]))\n",
" stds.append(np.std(results_fid[key]))\n",
"print_mean(means, False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "labproject",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file added docs/notebooks/fid/results_c2st_knn.npy
Binary file not shown.
Binary file added docs/notebooks/fid/results_c2st_nn.npy
Binary file not shown.
Binary file added docs/notebooks/fid/results_fid.npy
Binary file not shown.
Binary file added docs/notebooks/fid/results_mmd_lin.npy
Binary file not shown.
Binary file added docs/notebooks/fid/results_mmd_poly_kid.npy
Binary file not shown.
Binary file added docs/notebooks/fid/results_mmd_rbf64.npy
Binary file not shown.
Binary file added docs/notebooks/fid/results_sw.npy
Binary file not shown.
6 changes: 5 additions & 1 deletion labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def imagenet_uncond_embeddings(n=1000, d=2048):
# raise FileNotFoundError(f"No file `data/samples_50k_unconditional_moresteps_embeddings.pt` found")
# data = torch.load(os.path.join(data_dir, "samples_50k_unconditional_moresteps_embeddings.pt"))
@register_dataset("imagenet_unconditional_model_embedding")
def imagenet_unconditional_model_embedding(n, d=2048, device="cpu", save_path="data"):
def imagenet_unconditional_model_embedding(n, d=2048, device="cpu", save_path="data", permute=False):
r"""Get the unconditional model embeddings for ImageNet
Args:
Expand All @@ -484,6 +484,10 @@ def imagenet_unconditional_model_embedding(n, d=2048, device="cpu", save_path="d
quiet=False,
)
unconditional_embeddigns = torch.load("imagenet_unconditional_model_embedding.pt")

if permute:
idx = torch.randperm(unconditional_embeddigns.shape[0])
unconditional_embeddigns = unconditional_embeddigns[idx]

max_n = unconditional_embeddigns.shape[0]

Expand Down

0 comments on commit 240ffd6

Please sign in to comment.