Skip to content

Commit

Permalink
update SDE class within notebooks and remove from utils
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Dec 14, 2023
1 parent 91fd396 commit 980e3d9
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 103 deletions.
38 changes: 32 additions & 6 deletions examples/notebooks/SF2M_2D_example.ipynb

Large diffs are not rendered by default.

51 changes: 25 additions & 26 deletions examples/notebooks/conditional_mnist.ipynb

Large diffs are not rendered by default.

41 changes: 20 additions & 21 deletions examples/notebooks/mnist_example.ipynb

Large diffs are not rendered by default.

82 changes: 53 additions & 29 deletions examples/notebooks/single-cell_example.ipynb

Large diffs are not rendered by default.

21 changes: 0 additions & 21 deletions torchcfm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,3 @@ def plot_trajectories(traj):
plt.xticks([])
plt.yticks([])
plt.show()


class SDE(torch.nn.Module):
noise_type = "diagonal"
sde_type = "ito"

def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=0.1):
super().__init__()
self.drift = ode_drift
self.score = score
self.input_size = input_size
self.sigma = sigma

# Drift
def f(self, t, y):
y = y.view(-1, *self.input_size)
return self.drift(t, y).flatten(start_dim=1) + self.score(t, y).flatten(start_dim=1)

# Diffusion
def g(self, t, y):
return torch.ones_like(y) * self.sigma

0 comments on commit 980e3d9

Please sign in to comment.