From b432297a198cbd13b9801993b230ffa237165c2b Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Thu, 12 Dec 2024 21:43:56 +0800 Subject: [PATCH] Add: Add use_gpu params. --- tools/generate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/generate.py b/tools/generate.py index c051aab..75dc871 100644 --- a/tools/generate.py +++ b/tools/generate.py @@ -45,7 +45,7 @@ def __init__(self, gen_args, deploy=False): # Weight path self.weight_path = self.args.weight_path # Run device initializer - self.device = device_initializer() + self.device = device_initializer(device_id=self.args.use_gpu) # Enable conditional generation, sample type, network, image size, # number of classes and select activation function gen_results = generate_initializer(ckpt_path=self.weight_path, conditional=self.args.conditional, @@ -165,6 +165,8 @@ def init_generate_args(): parser.add_argument("--class_name", type=int, default=0) # classifier-free guidance interpolation weight, users can better generate model effect (recommend) parser.add_argument("--cfg_scale", type=int, default=3) + # Set the use GPU in generate (required) + parser.add_argument("--use_gpu", type=int, default=0) # =====================Older versions(version <= 1.1.1)===================== # Enable conditional generation (required)