Skip to content

Commit

Permalink
correct bugs from Torchdyn
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Oct 2, 2023
1 parent 5cef086 commit dbeeb75
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 239 deletions.
33 changes: 12 additions & 21 deletions examples/notebooks/SF2M_2D_example.ipynb

Large diffs are not rendered by default.

67 changes: 44 additions & 23 deletions examples/notebooks/conditional_mnist.ipynb

Large diffs are not rendered by default.

80 changes: 39 additions & 41 deletions examples/notebooks/mnist_example.ipynb

Large diffs are not rendered by default.

65 changes: 31 additions & 34 deletions examples/notebooks/model-comparison-plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"cells": [
{
"cell_type": "markdown",
"id": "681547de-0a62-4688-9447-352310e1100b",
"metadata": {},
"source": [
"# Model Plotting\n",
Expand All @@ -12,8 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "21ee69f6-1bef-4323-b150-58e07145c698",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,8 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4bcbdbcc-a7ce-46e1-9221-39d953c4dab9",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -170,7 +167,7 @@
" super().__init__()\n",
" self.model = model\n",
"\n",
" def forward(self, t, x):\n",
" def forward(self, t, x, *args, **kwargs):\n",
" return model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n",
"\n",
"\n",
Expand All @@ -191,7 +188,7 @@
" self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace\n",
" self.noise_dist, self.noise = noise_dist, None\n",
"\n",
" def forward(self, t, x):\n",
" def forward(self, t, x, *args, **kwargs):\n",
" with torch.set_grad_enabled(True):\n",
" x_in = x[:, 1:].requires_grad_(\n",
" True\n",
Expand All @@ -208,8 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"id": "42f597ce-25d6-4127-acda-2c6d1dee9683",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -226,8 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"id": "95ca0ef0-1043-4554-94b1-5cb71abf6c97",
"execution_count": 15,
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -332,15 +327,14 @@
},
{
"cell_type": "code",
"execution_count": 74,
"id": "cc1261bf-1e05-43ae-b261-6492918d58f5",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_20114/591848956.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
"/tmp/ipykernel_8787/4293379450.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
" image = imageio.imread(filename)\n"
]
}
Expand All @@ -357,10 +351,24 @@
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f3252186-11a0-48e0-87dd-0c04aca16a3e",
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'models/gaussian-moons/cfm_v1.pt'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m models \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodels/gaussian-moons/cfm_v1.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOT-CFM (ours)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/otcfm_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSB-CFM (ours)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/sbcfm_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVP-CFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/stochastic_interpolant_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/flow_matching_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVP-SDE\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/vp_flow_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAction-Matching\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/action_matching_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 9\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAction-Matching (Swish)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/action_matching_swish_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 10\u001b[0m }\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:791\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, **pickle_load_args)\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 789\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m--> 791\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[1;32m 792\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[1;32m 793\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[1;32m 794\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[1;32m 796\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:271\u001b[0m, in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[1;32m 270\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[0;32m--> 271\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 273\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:252\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[0;32m--> 252\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/gaussian-moons/cfm_v1.pt'"
]
}
],
"source": [
"models = {\n",
" \"CFM\": torch.load(\"models/gaussian-moons/cfm_v1.pt\"),\n",
Expand All @@ -376,8 +384,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"id": "b8d00559-bcaf-4791-a11d-41ba948ef85b",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -483,19 +490,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"id": "00edb4ff-cb31-474f-88b3-a60952979a4e",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_21097/3442484490.py:7: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
" image = imageio.imread(filename)\n"
]
}
],
"outputs": [],
"source": [
"gif_name = \"gaussians-to-moons\"\n",
"ts = torch.linspace(0, 1, 101)\n",
Expand All @@ -510,9 +507,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "torchcfm2",
"language": "python",
"name": "python3"
"name": "torchcfm2"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -524,7 +521,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
94 changes: 47 additions & 47 deletions examples/notebooks/single-cell_example.ipynb

Large diffs are not rendered by default.

142 changes: 71 additions & 71 deletions examples/notebooks/training-8gaussians-to-moons.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions torchcfm/models/unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,3 +920,6 @@ def __init__(
resblock_updown=resblock_updown,
use_new_attention_order=use_new_attention_order,
)

def forward(self, t, x, y=None, *args, **kwargs):
return super().forward(t, x, y=y)
2 changes: 1 addition & 1 deletion torchcfm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, model):
super().__init__()
self.model = model

def forward(self, t, x):
def forward(self, t, x, *args, **kwargs):
return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))


Expand Down
2 changes: 1 addition & 1 deletion torchcfm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.2"
__version__ = "1.0.3"

0 comments on commit dbeeb75

Please sign in to comment.