Skip to content

Commit

Permalink
feat: add sampling options to test
Browse files Browse the repository at this point in the history
  • Loading branch information
royale authored and beniz committed Aug 28, 2023
1 parent 911fc20 commit a2958dc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
2 changes: 2 additions & 0 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ def inference(self):
mask=mask,
sample_num=self.sample_num,
cls=cls,
ddim_num_steps=self.ddim_num_steps,
ddim_eta=self.ddim_eta,
)
self.fake_B = self.output

Expand Down
39 changes: 37 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from models import create_model
from util.parser import get_opt
from util.util import MAX_INT
from models.modules.diffusion_utils import set_new_noise_schedule


def launch_testing(opt):
def launch_testing(opt, main_opt):
rank = 0

opt.jg_dir = os.path.join("/".join(__file__.split("/")[:-1]))
Expand Down Expand Up @@ -47,6 +48,20 @@ def launch_testing(opt):

model = create_model(opt, rank) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers

# sampling options
if main_opt.sampling_steps is not None:
model.netG_A.denoise_fn.model.beta_schedule["test"][
"n_timestep"
] = main_opt.sampling_steps
set_new_noise_schedule(model.netG_A.denoise_fn.model, "test")
if main_opt.sampling_method is not None:
model.netG_A.set_new_sampling_method(main_opt.sampling_method)
if main_opt.ddim_num_steps is not None:
model.ddim_num_steps = main_opt.ddim_num_steps
if main_opt.ddim_eta is not None:
model.ddim_eta = main_opt.ddim_eta

model.use_temporal = use_temporal
model.eval()
if opt.use_cuda:
Expand Down Expand Up @@ -106,6 +121,25 @@ def launch_testing(opt):
main_parser.add_argument(
"--test_seed", type=int, default=42, help="seed to use for tests"
)
main_parser.add_argument(
"--sampling_steps", type=int, help="number of sampling steps"
)
main_parser.add_argument(
"--sampling_method",
type=str,
choices=["ddpm", "ddim"],
help="choose the sampling method between ddpm and ddim",
)
main_parser.add_argument(
"--ddim_num_steps",
type=int,
help="number of steps for ddim sampling method",
)
main_parser.add_argument(
"--ddim_eta",
type=float,
help="eta parameter for ddim variance",
)

main_opt, remaining_args = main_parser.parse_known_args()
main_opt.config_json = os.path.join(main_opt.test_model_dir, "train_config.json")
Expand All @@ -119,9 +153,10 @@ def launch_testing(opt):
opt.train_metrics_list = main_opt.test_metrics_list
opt.train_nb_img_max_fid = main_opt.test_nb_img
opt.test_batch_size = main_opt.test_batch_size
opt.alg_palette_generate_per_class = False

random.seed(main_opt.test_seed)
torch.manual_seed(main_opt.test_seed)
np.random.seed(main_opt.test_seed)

launch_testing(opt)
launch_testing(opt, main_opt)

0 comments on commit a2958dc

Please sign in to comment.