Skip to content

Commit

Permalink
pep 8
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Dec 6, 2023
1 parent dfe2ceb commit b35cb15
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 25 deletions.
2 changes: 1 addition & 1 deletion examples/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def train(argv):
net_model = torch.nn.DataParallel(net_model)
ema_model = torch.nn.DataParallel(ema_model)

print("Training is using {} GPUs!".format(torch.cuda.device_count()))
print(f"Training is using {torch.cuda.device_count()} GPUs!")
# show model size
model_size = 0
for param in net_model.parameters():
Expand Down
24 changes: 0 additions & 24 deletions examples/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,3 @@ def infiniteloop(dataloader):
while True:
for x, y in iter(dataloader):
yield x


class SDE(torch.nn.Module):
noise_type = "diagonal"
sde_type = "ito"

def __init__(self, ode_drift, score, input_size=(3, 32, 32), reverse=False):
super().__init__()
self.drift = ode_drift
self.score = score
self.reverse = reverse

# Drift
def f(self, t, y):
y = y.view(-1, 3, 32, 32)
if self.reverse:
t = 1 - t
return -self.drift(t, y) + self.score(t, y)
return self.drift(t, y).flatten(start_dim=1) - self.score(t, y).flatten(start_dim=1)

# Diffusion
def g(self, t, y):
y = y.view(-1, 3, 32, 32)
return (torch.ones_like(t) * torch.ones_like(y)).flatten(start_dim=1) * sigma

0 comments on commit b35cb15

Please sign in to comment.