Skip to content

Commit

Permalink
fix SDE class bugs in utils and mnist notebooks (#87)
Browse files Browse the repository at this point in the history
Fixes SDE class bugs in utils and mnist notebooks.
Refactors the example folder to split the examples into different folders.
Adds two new tutorials (Flow Matching and minibatch OT)
  • Loading branch information
kilianFatras authored Dec 15, 2023
1 parent b29d418 commit 6b3adb4
Show file tree
Hide file tree
Showing 22 changed files with 3,205 additions and 1,892 deletions.
566 changes: 566 additions & 0 deletions examples/2D_tutorials/Flow_matching_tutorial.ipynb

Large diffs are not rendered by default.

228 changes: 228 additions & 0 deletions examples/2D_tutorials/SF2M_tutorial.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset.

<p align="center">
<img src="../../assets/169_generated_samples_otcfm.png" width="600"/>
<img src="../../../assets/169_generated_samples_otcfm.png" width="600"/>
</p>

To reproduce the experiments and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
470 changes: 470 additions & 0 deletions examples/images/conditional_mnist.ipynb

Large diffs are not rendered by default.

362 changes: 362 additions & 0 deletions examples/images/mnist_example.ipynb

Large diffs are not rendered by default.

201 changes: 0 additions & 201 deletions examples/notebooks/SF2M_2D_example.ipynb

This file was deleted.

467 changes: 0 additions & 467 deletions examples/notebooks/conditional_mnist.ipynb

This file was deleted.

363 changes: 0 additions & 363 deletions examples/notebooks/mnist_example.ipynb

This file was deleted.

831 changes: 0 additions & 831 deletions examples/notebooks/single-cell_example.ipynb

This file was deleted.

856 changes: 856 additions & 0 deletions examples/single_cell/single-cell_example.ipynb

Large diffs are not rendered by default.

28 changes: 0 additions & 28 deletions torchcfm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,3 @@ def plot_trajectories(traj):
plt.xticks([])
plt.yticks([])
plt.show()


class SDE(torch.nn.Module):
noise_type = "diagonal"
sde_type = "ito"

def __init__(self, ode_drift, score, noise=1.0, reverse=False):
super().__init__()
self.drift = ode_drift
self.score = score
self.reverse = reverse
self.noise = noise

# Drift
def f(self, t, y):
if self.reverse:
t = 1 - t
if len(t.shape) == len(y.shape):
x = torch.cat([y, t], 1)
else:
x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1)
if self.reverse:
return -self.drift(x) + self.score(x)
return self.drift(x) + self.score(x)

# Diffusion
def g(self, t, y):
return torch.ones_like(t) * torch.ones_like(y) * self.noise
2 changes: 1 addition & 1 deletion torchcfm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.5"
__version__ = "1.0.6"

0 comments on commit 6b3adb4

Please sign in to comment.