Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update single-cell notebook with trajectories from a given cell #64

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions examples/notebooks/model-comparison-plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -205,7 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -222,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 4,
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -327,14 +327,14 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_8787/4293379450.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
"/tmp/ipykernel_10858/4293379450.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
" image = imageio.imread(filename)\n"
]
}
Expand All @@ -351,40 +351,25 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'models/gaussian-moons/cfm_v1.pt'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m models \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodels/gaussian-moons/cfm_v1.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOT-CFM (ours)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/otcfm_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSB-CFM (ours)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/sbcfm_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVP-CFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/stochastic_interpolant_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/flow_matching_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVP-SDE\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/vp_flow_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAction-Matching\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/action_matching_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 9\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAction-Matching (Swish)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/action_matching_swish_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 10\u001b[0m }\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:791\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, **pickle_load_args)\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 789\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m--> 791\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[1;32m 792\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[1;32m 793\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[1;32m 794\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[1;32m 796\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:271\u001b[0m, in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[1;32m 270\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[0;32m--> 271\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 273\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:252\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[0;32m--> 252\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/gaussian-moons/cfm_v1.pt'"
]
}
],
"outputs": [],
"source": [
"models = {\n",
" \"CFM\": torch.load(\"models/gaussian-moons/cfm_v1.pt\"),\n",
" \"OT-CFM (ours)\": torch.load(\"models/gaussian-moons/otcfm_v1.pt\"),\n",
" \"SB-CFM (ours)\": torch.load(\"models/gaussian-moons/sbcfm_v1.pt\"),\n",
" \"VP-CFM\": torch.load(\"models/gaussian-moons/stochastic_interpolant_v1.pt\"),\n",
" \"FM\": torch.load(\"models/gaussian-moons/flow_matching_v1.pt\"),\n",
" \"VP-SDE\": torch.load(\"models/gaussian-moons/vp_flow_v1.pt\"),\n",
" \"Action-Matching\": torch.load(\"models/gaussian-moons/action_matching_v1.pt\"),\n",
" \"Action-Matching (Swish)\": torch.load(\"models/gaussian-moons/action_matching_swish_v1.pt\"),\n",
" \"CFM\": torch.load(\"./models/8gaussian-moons/cfm_v1.pt\"),\n",
" \"OT-CFM (ours)\": torch.load(\"models/8gaussian-moons/otcfm_v1.pt\"),\n",
" \"SB-CFM (ours)\": torch.load(\"models/8gaussian-moons/sbcfm_v1.pt\"),\n",
" \"VP-CFM\": torch.load(\"models/8gaussian-moons/stochastic_interpolant_v1.pt\"),\n",
" # \"FM\": torch.load(\"models/8gaussian-moons/flow_matching_v1.pt\"),\n",
" # \"VP-SDE\": torch.load(\"models/8gaussian-moons/vp_flow_v1.pt\"),\n",
" \"Action-Matching\": torch.load(\"models/8gaussian-moons/action_matching_v1.pt\"),\n",
" \"Action-Matching (Swish)\": torch.load(\"models/8gaussian-moons/action_matching_swish_v1.pt\"),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -411,8 +396,8 @@
" traj = nde.trajectory(sample.to(device), t_span=ts.to(device)).detach().cpu().numpy()\n",
" trajs[name] = traj\n",
"names = [\n",
" \"VP-SDE\",\n",
" \"FM\",\n",
" # \"VP-SDE\",\n",
" # \"FM\",\n",
" \"CFM\",\n",
" \"Action-Matching\",\n",
" \"Action-Matching (Swish)\",\n",
Expand Down Expand Up @@ -490,9 +475,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_10858/1452409276.py:7: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
" image = imageio.imread(filename)\n"
]
}
],
"source": [
"gif_name = \"gaussians-to-moons\"\n",
"ts = torch.linspace(0, 1, 101)\n",
Expand All @@ -503,6 +497,13 @@
" image = imageio.imread(filename)\n",
" writer.append_data(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
412 changes: 340 additions & 72 deletions examples/notebooks/single-cell_example.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions runner/src/datamodules/distribution_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from pytorch_lightning import LightningDataModule
from pytorch_lightning.trainer.supporters import CombinedLoader
from sklearn.preprocessing import StandardScaler
from src import utils
from torch.utils.data import DataLoader, Sampler, TensorDataset, random_split
from torchdyn.datasets import ToyDataset

from src import utils

from .components.base import BaseLightningDataModule
from .components.time_dataset import load_dataset
from .components.tnet_dataset import SCData
Expand Down
1 change: 0 additions & 1 deletion runner/src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import LightningLoggerBase

from src import utils

log = utils.get_pylogger(__name__)
Expand Down
1 change: 0 additions & 1 deletion runner/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from omegaconf import DictConfig
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import LightningLoggerBase

from src import utils

log = utils.get_pylogger(__name__)
Expand Down
Loading