diff --git a/examples/cifar10/train_cifar10.py b/examples/cifar10/train_cifar10.py index e24db2f..1ed30e1 100644 --- a/examples/cifar10/train_cifar10.py +++ b/examples/cifar10/train_cifar10.py @@ -108,8 +108,7 @@ def train(argv): net_model = torch.nn.DataParallel(net_model) ema_model = torch.nn.DataParallel(ema_model) - net_node = NeuralODE(net_model, solver="euler", sensitivity="adjoint") - ema_node = NeuralODE(ema_model, solver="euler", sensitivity="adjoint") + print("Training is using {} GPUs!".format(torch.cuda.device_count())) # show model size model_size = 0 for param in net_model.parameters(): @@ -151,8 +150,8 @@ def train(argv): # sample and Saving the weights if FLAGS.save_step > 0 and step % FLAGS.save_step == 0: - generate_samples(net_node, net_model, savedir, step, net_="normal") - generate_samples(ema_node, ema_model, savedir, step, net_="ema") + generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal") + generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema") torch.save( { "net_model": net_model.state_dict(), diff --git a/examples/cifar10/utils_cifar.py b/examples/cifar10/utils_cifar.py index bc47cbb..d37c295 100644 --- a/examples/cifar10/utils_cifar.py +++ b/examples/cifar10/utils_cifar.py @@ -8,8 +8,11 @@ device = torch.device("cuda" if use_cuda else "cpu") -def generate_samples(node_, model, savedir, step, net_="normal"): +def generate_samples(model, parallel, savedir, step, net_="normal"): model.eval() + if parallel: + model = model.module.to(device) + node_ = NeuralODE(model, solver="euler", sensitivity="adjoint") with torch.no_grad(): traj = node_.trajectory( torch.randn(64, 3, 32, 32).to(device), @@ -35,3 +38,27 @@ def infiniteloop(dataloader): while True: for x, y in iter(dataloader): yield x + + +class SDE(torch.nn.Module): + noise_type = "diagonal" + sde_type = "ito" + + def __init__(self, ode_drift, score, input_size=(3, 32, 32), reverse=False): + super().__init__() + self.drift = ode_drift + self.score = score + self.reverse = reverse + + # Drift + def f(self, t, y): + y = y.view(-1, 3, 32, 32) + if self.reverse: + t = 1 - t + return -self.drift(t, y) + self.score(t, y) + return self.drift(t, y).flatten(start_dim=1) - self.score(t, y).flatten(start_dim=1) + + # Diffusion + def g(self, t, y): + y = y.view(-1, 3, 32, 32) + return (torch.ones_like(t) * torch.ones_like(y)).flatten(start_dim=1) * sigma \ No newline at end of file