Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch 2.0 compatible with Torchdyn #61

Merged
merged 1 commit into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions examples/notebooks/SF2M_2D_example.ipynb

Large diffs are not rendered by default.

67 changes: 44 additions & 23 deletions examples/notebooks/conditional_mnist.ipynb

Large diffs are not rendered by default.

80 changes: 39 additions & 41 deletions examples/notebooks/mnist_example.ipynb

Large diffs are not rendered by default.

65 changes: 31 additions & 34 deletions examples/notebooks/model-comparison-plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"cells": [
{
"cell_type": "markdown",
"id": "681547de-0a62-4688-9447-352310e1100b",
"metadata": {},
"source": [
"# Model Plotting\n",
Expand All @@ -12,8 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "21ee69f6-1bef-4323-b150-58e07145c698",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,8 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4bcbdbcc-a7ce-46e1-9221-39d953c4dab9",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -170,7 +167,7 @@
" super().__init__()\n",
" self.model = model\n",
"\n",
" def forward(self, t, x):\n",
" def forward(self, t, x, *args, **kwargs):\n",
" return model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n",
"\n",
"\n",
Expand All @@ -191,7 +188,7 @@
" self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace\n",
" self.noise_dist, self.noise = noise_dist, None\n",
"\n",
" def forward(self, t, x):\n",
" def forward(self, t, x, *args, **kwargs):\n",
" with torch.set_grad_enabled(True):\n",
" x_in = x[:, 1:].requires_grad_(\n",
" True\n",
Expand All @@ -208,8 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"id": "42f597ce-25d6-4127-acda-2c6d1dee9683",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -226,8 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"id": "95ca0ef0-1043-4554-94b1-5cb71abf6c97",
"execution_count": 15,
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -332,15 +327,14 @@
},
{
"cell_type": "code",
"execution_count": 74,
"id": "cc1261bf-1e05-43ae-b261-6492918d58f5",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_20114/591848956.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
"/tmp/ipykernel_8787/4293379450.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
" image = imageio.imread(filename)\n"
]
}
Expand All @@ -357,10 +351,24 @@
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f3252186-11a0-48e0-87dd-0c04aca16a3e",
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'models/gaussian-moons/cfm_v1.pt'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m models \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodels/gaussian-moons/cfm_v1.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOT-CFM (ours)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/otcfm_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSB-CFM (ours)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/sbcfm_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVP-CFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/stochastic_interpolant_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/flow_matching_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVP-SDE\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/vp_flow_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAction-Matching\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/action_matching_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 9\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAction-Matching (Swish)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/action_matching_swish_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 10\u001b[0m }\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:791\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, **pickle_load_args)\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 789\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m--> 791\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[1;32m 792\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[1;32m 793\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[1;32m 794\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[1;32m 796\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:271\u001b[0m, in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[1;32m 270\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[0;32m--> 271\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 273\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:252\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[0;32m--> 252\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/gaussian-moons/cfm_v1.pt'"
]
}
],
"source": [
"models = {\n",
" \"CFM\": torch.load(\"models/gaussian-moons/cfm_v1.pt\"),\n",
Expand All @@ -376,8 +384,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"id": "b8d00559-bcaf-4791-a11d-41ba948ef85b",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -483,19 +490,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"id": "00edb4ff-cb31-474f-88b3-a60952979a4e",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_21097/3442484490.py:7: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
" image = imageio.imread(filename)\n"
]
}
],
"outputs": [],
"source": [
"gif_name = \"gaussians-to-moons\"\n",
"ts = torch.linspace(0, 1, 101)\n",
Expand All @@ -510,9 +507,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "torchcfm2",
"language": "python",
"name": "python3"
"name": "torchcfm2"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -524,7 +521,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
94 changes: 47 additions & 47 deletions examples/notebooks/single-cell_example.ipynb

Large diffs are not rendered by default.

142 changes: 71 additions & 71 deletions examples/notebooks/training-8gaussians-to-moons.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions torchcfm/models/unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,3 +920,6 @@ def __init__(
resblock_updown=resblock_updown,
use_new_attention_order=use_new_attention_order,
)

def forward(self, t, x, y=None, *args, **kwargs):
return super().forward(t, x, y=y)
2 changes: 1 addition & 1 deletion torchcfm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, model):
super().__init__()
self.model = model

def forward(self, t, x):
def forward(self, t, x, *args, **kwargs):
return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))


Expand Down
2 changes: 1 addition & 1 deletion torchcfm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.2"
__version__ = "1.0.3"