Skip to content

Commit

Permalink
refine format
Browse files Browse the repository at this point in the history
  • Loading branch information
weishi-deng committed Jan 12, 2024
1 parent 796f272 commit ce0bd2e
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
results_arg = f"--results_dir {results_dir}"
data_root = os.path.join(DATA_PATH, "pytorch_CycleGAN_and_pix2pix_inputs")
device_arg = ""
device_type = f"--device_type {self.device}"
device_type_arg = f"--device_type {self.device}"
if self.device == "cpu":
device_arg = "--gpu_ids -1"
else:
device_arg = "--gpu_ids 0"

if self.test == "train":
train_args = f"--tb_device {self.device} --dataroot {data_root}/datasets/horse2zebra --name horse2zebra --model cycle_gan --display_id 0 --n_epochs 3 " + \
f"--n_epochs_decay 3 {device_type} {device_arg} {checkpoints_arg}"
f"--n_epochs_decay 3 {device_type_arg} {device_arg} {checkpoints_arg}"
self.training_loop = prepare_training_loop(train_args.split(' '))
args = f"--dataroot {data_root}/datasets/horse2zebra/testA --name horse2zebra_pretrained --model test " + \
f"--no_dropout {device_type} {device_arg} {checkpoints_arg} {results_arg}"
f"--no_dropout {device_type_arg} {device_arg} {checkpoints_arg} {results_arg}"
self.model, self.input = get_model(args, self.device)

def get_module(self):
Expand Down

0 comments on commit ce0bd2e

Please sign in to comment.