diff --git a/torchbenchmark/models/pytorch_CycleGAN_and_pix2pix/__init__.py b/torchbenchmark/models/pytorch_CycleGAN_and_pix2pix/__init__.py index 524af4ec25..5b1bc3e5b1 100644 --- a/torchbenchmark/models/pytorch_CycleGAN_and_pix2pix/__init__.py +++ b/torchbenchmark/models/pytorch_CycleGAN_and_pix2pix/__init__.py @@ -35,7 +35,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): results_arg = f"--results_dir {results_dir}" data_root = os.path.join(DATA_PATH, "pytorch_CycleGAN_and_pix2pix_inputs") device_arg = "" - device_type = f"--device_type {self.device}" + device_type_arg = f"--device_type {self.device}" if self.device == "cpu": device_arg = "--gpu_ids -1" else: @@ -43,10 +43,10 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): if self.test == "train": train_args = f"--tb_device {self.device} --dataroot {data_root}/datasets/horse2zebra --name horse2zebra --model cycle_gan --display_id 0 --n_epochs 3 " + \ - f"--n_epochs_decay 3 {device_type} {device_arg} {checkpoints_arg}" + f"--n_epochs_decay 3 {device_type_arg} {device_arg} {checkpoints_arg}" self.training_loop = prepare_training_loop(train_args.split(' ')) args = f"--dataroot {data_root}/datasets/horse2zebra/testA --name horse2zebra_pretrained --model test " + \ - f"--no_dropout {device_type} {device_arg} {checkpoints_arg} {results_arg}" + f"--no_dropout {device_type_arg} {device_arg} {checkpoints_arg} {results_arg}" self.model, self.input = get_model(args, self.device) def get_module(self):