Skip to content

Commit

Permalink
make generate_image function work with parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Dec 6, 2023
1 parent 12a7f52 commit dfe2ceb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
7 changes: 3 additions & 4 deletions examples/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ def train(argv):
net_model = torch.nn.DataParallel(net_model)
ema_model = torch.nn.DataParallel(ema_model)

net_node = NeuralODE(net_model, solver="euler", sensitivity="adjoint")
ema_node = NeuralODE(ema_model, solver="euler", sensitivity="adjoint")
print("Training is using {} GPUs!".format(torch.cuda.device_count()))
# show model size
model_size = 0
for param in net_model.parameters():
Expand Down Expand Up @@ -151,8 +150,8 @@ def train(argv):

# sample and Saving the weights
if FLAGS.save_step > 0 and step % FLAGS.save_step == 0:
generate_samples(net_node, net_model, savedir, step, net_="normal")
generate_samples(ema_node, ema_model, savedir, step, net_="ema")
generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal")
generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema")
torch.save(
{
"net_model": net_model.state_dict(),
Expand Down
29 changes: 28 additions & 1 deletion examples/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
device = torch.device("cuda" if use_cuda else "cpu")


def generate_samples(node_, model, savedir, step, net_="normal"):
def generate_samples(model, parallel, savedir, step, net_="normal"):
model.eval()
if parallel:
model = model.module.to(device)
node_ = NeuralODE(model, solver="euler", sensitivity="adjoint")
with torch.no_grad():
traj = node_.trajectory(
torch.randn(64, 3, 32, 32).to(device),
Expand All @@ -35,3 +38,27 @@ def infiniteloop(dataloader):
while True:
for x, y in iter(dataloader):
yield x


class SDE(torch.nn.Module):
noise_type = "diagonal"
sde_type = "ito"

def __init__(self, ode_drift, score, input_size=(3, 32, 32), reverse=False):
super().__init__()
self.drift = ode_drift
self.score = score
self.reverse = reverse

# Drift
def f(self, t, y):
y = y.view(-1, 3, 32, 32)
if self.reverse:
t = 1 - t
return -self.drift(t, y) + self.score(t, y)
return self.drift(t, y).flatten(start_dim=1) - self.score(t, y).flatten(start_dim=1)

# Diffusion
def g(self, t, y):
y = y.view(-1, 3, 32, 32)
return (torch.ones_like(t) * torch.ones_like(y)).flatten(start_dim=1) * sigma

0 comments on commit dfe2ceb

Please sign in to comment.