Skip to content

Commit

Permalink
pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Dec 14, 2023
1 parent 980e3d9 commit 2efadd2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
7 changes: 4 additions & 3 deletions examples/notebooks/SF2M_2D_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@
" x0,\n",
" t_span=torch.linspace(0, 1, 100).to(device),\n",
" )\n",
" \n",
" \n",
"\n",
"\n",
"class SDE(torch.nn.Module):\n",
" noise_type = \"diagonal\"\n",
" sde_type = \"ito\"\n",
"\n",
" def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.):\n",
" def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.0):\n",
" super().__init__()\n",
" self.drift = ode_drift\n",
" self.score = score\n",
Expand All @@ -154,6 +154,7 @@
" def g(self, t, y):\n",
" return torch.ones_like(y) * self.sigma\n",
"\n",
"\n",
"sde = SDE(model, score_model, input_size=(2,), sigma=sigma)\n",
"with torch.no_grad():\n",
" sde_traj = torchsde.sdeint(\n",
Expand Down
3 changes: 2 additions & 1 deletion examples/notebooks/single-cell_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@
" noise_type = \"diagonal\"\n",
" sde_type = \"ito\"\n",
"\n",
" def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.):\n",
" def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.0):\n",
" super().__init__()\n",
" self.drift = ode_drift\n",
" self.score = score\n",
Expand All @@ -523,6 +523,7 @@
" def g(self, t, y):\n",
" return torch.ones_like(y) * self.sigma\n",
"\n",
"\n",
"sde = SDE(sf2m_model, sf2m_score_model, input_size=(2,), sigma=sigma)\n",
"with torch.no_grad():\n",
" sde_traj = torchsde.sdeint(\n",
Expand Down

0 comments on commit 2efadd2

Please sign in to comment.