diff --git a/tests/modules/test_layers.py b/tests/modules/test_layers.py index d2d830c..0b9f414 100644 --- a/tests/modules/test_layers.py +++ b/tests/modules/test_layers.py @@ -56,7 +56,9 @@ def test_SNEmbedding(self): num_classes = 10 X = torch.ones(self.N, dtype=torch.int64) for default in [True, False]: - layer = layers.SNEmbedding(num_classes, self.n_out, default=default) + layer = layers.SNEmbedding(num_classes, + self.n_out, + default=default) assert layer(X).shape == (self.N, self.n_out) diff --git a/tests/training/test_scheduler.py b/tests/training/test_scheduler.py index c395887..a84a6e5 100644 --- a/tests/training/test_scheduler.py +++ b/tests/training/test_scheduler.py @@ -36,12 +36,13 @@ def test_linear_decay(self): assert abs(2e-4 - self.get_lr(optG)) < 1e-5 else: - curr_lr = ((1 - (max(0, step - lr_scheduler.start_step) / (self.num_steps-lr_scheduler.start_step))) * self.lr_D) + curr_lr = ((1 - (max(0, step - lr_scheduler.start_step) / + (self.num_steps - lr_scheduler.start_step))) * + self.lr_D) assert abs(curr_lr - self.get_lr(optD)) < 1e-5 assert abs(curr_lr - self.get_lr(optG)) < 1e-5 - def test_no_decay(self): optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) @@ -60,8 +61,12 @@ def test_no_decay(self): def test_arguments(self): with pytest.raises(NotImplementedError): - optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) - optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) + optD = optim.Adam(self.netD.parameters(), + self.lr_D, + betas=(0.0, 0.9)) + optG = optim.Adam(self.netG.parameters(), + self.lr_G, + betas=(0.0, 0.9)) scheduler.LRScheduler(lr_decay='does_not_exist', optD=optD, optG=optG, diff --git a/tests/training/test_trainer.py b/tests/training/test_trainer.py index 012296e..406c6a7 100644 --- a/tests/training/test_trainer.py +++ b/tests/training/test_trainer.py @@ -143,19 +143,19 @@ def test_attributes(self): with pytest.raises(ValueError): bad_trainer = Trainer(netD=self.netD, - netG=self.netG, - optD=self.optD, - optG=self.optG, - netG_ckpt_file=netG_ckpt_file, - netD_ckpt_file=netD_ckpt_file, - log_dir=os.path.join(self.log_dir, 'extra'), - dataloader=self.dataloader, - num_steps=-1000, - device=device, - save_steps=float('inf'), - log_steps=float('inf'), - vis_steps=float('inf'), - lr_decay='linear') + netG=self.netG, + optD=self.optD, + optG=self.optG, + netG_ckpt_file=netG_ckpt_file, + netD_ckpt_file=netD_ckpt_file, + log_dir=os.path.join(self.log_dir, 'extra'), + dataloader=self.dataloader, + num_steps=-1000, + device=device, + save_steps=float('inf'), + log_steps=float('inf'), + vis_steps=float('inf'), + lr_decay='linear') def test_get_latest_checkpoint(self): ckpt_files = [ diff --git a/torch_mimicry/nets/wgan_gp/wgan_gp_resblocks.py b/torch_mimicry/nets/wgan_gp/wgan_gp_resblocks.py index 6f244d2..24b9702 100644 --- a/torch_mimicry/nets/wgan_gp/wgan_gp_resblocks.py +++ b/torch_mimicry/nets/wgan_gp/wgan_gp_resblocks.py @@ -88,7 +88,7 @@ def __init__(self, self.norm2 = None # TODO: Verify again. Interestingly, LN has no effect on FID. Not using LN - # has almost no difference in FID score. + # has almost no difference in FID score. # def residual(self, x): # r""" # Helper function for feedforwarding through main layers. diff --git a/torch_mimicry/training/scheduler.py b/torch_mimicry/training/scheduler.py index 8d71618..8df1767 100644 --- a/torch_mimicry/training/scheduler.py +++ b/torch_mimicry/training/scheduler.py @@ -19,7 +19,13 @@ class LRScheduler: lr_D (float): The initial learning rate of optD. lr_G (float): The initial learning rate of optG. """ - def __init__(self, lr_decay, optD, optG, num_steps, start_step=0, **kwargs): + def __init__(self, + lr_decay, + optD, + optG, + num_steps, + start_step=0, + **kwargs): if lr_decay not in [None, 'None', 'linear']: raise NotImplementedError( "lr_decay {} is not currently supported.") @@ -90,12 +96,14 @@ def step(self, log_data, global_step): lr_D = self.linear_decay(optimizer=self.optD, global_step=global_step, lr_value_range=(self.lr_D, 0.0), - lr_step_range=(self.start_step, self.num_steps)) + lr_step_range=(self.start_step, + self.num_steps)) lr_G = self.linear_decay(optimizer=self.optG, global_step=global_step, lr_value_range=(self.lr_G, 0.0), - lr_step_range=(self.start_step, self.num_steps)) + lr_step_range=(self.start_step, + self.num_steps)) elif self.lr_decay in [None, "None"]: lr_D = self.lr_D