From b35cb1524c6b53819da4ba1e8618b1f2e1ff90ff Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Wed, 6 Dec 2023 11:14:47 -0500 Subject: [PATCH] pep 8 --- examples/cifar10/train_cifar10.py | 2 +- examples/cifar10/utils_cifar.py | 24 ------------------------ 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/examples/cifar10/train_cifar10.py b/examples/cifar10/train_cifar10.py index 1ed30e1..97cccc0 100644 --- a/examples/cifar10/train_cifar10.py +++ b/examples/cifar10/train_cifar10.py @@ -108,7 +108,7 @@ def train(argv): net_model = torch.nn.DataParallel(net_model) ema_model = torch.nn.DataParallel(ema_model) - print("Training is using {} GPUs!".format(torch.cuda.device_count())) + print(f"Training is using {torch.cuda.device_count()} GPUs!") # show model size model_size = 0 for param in net_model.parameters(): diff --git a/examples/cifar10/utils_cifar.py b/examples/cifar10/utils_cifar.py index d37c295..4afbb90 100644 --- a/examples/cifar10/utils_cifar.py +++ b/examples/cifar10/utils_cifar.py @@ -38,27 +38,3 @@ def infiniteloop(dataloader): while True: for x, y in iter(dataloader): yield x - - -class SDE(torch.nn.Module): - noise_type = "diagonal" - sde_type = "ito" - - def __init__(self, ode_drift, score, input_size=(3, 32, 32), reverse=False): - super().__init__() - self.drift = ode_drift - self.score = score - self.reverse = reverse - - # Drift - def f(self, t, y): - y = y.view(-1, 3, 32, 32) - if self.reverse: - t = 1 - t - return -self.drift(t, y) + self.score(t, y) - return self.drift(t, y).flatten(start_dim=1) - self.score(t, y).flatten(start_dim=1) - - # Diffusion - def g(self, t, y): - y = y.view(-1, 3, 32, 32) - return (torch.ones_like(t) * torch.ones_like(y)).flatten(start_dim=1) * sigma \ No newline at end of file