Skip to content

Commit

Permalink
refactor: uniformize options API
Browse files Browse the repository at this point in the history
Introduces InferenceOptions for inference scripts.

The options can also be used by the server and be parsed either as json
or command line.

Options are also better organized.

New tests for inference GAN + diffusion
  • Loading branch information
pnsuau authored and Bycob committed Sep 25, 2023
1 parent 8f8ccf0 commit 5e746e5
Show file tree
Hide file tree
Showing 18 changed files with 1,230 additions and 970 deletions.
4 changes: 4 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def modify_commandline_options(parser, is_train):
"""
return parser

@staticmethod
def modify_commandline_options_train(parser):
return parser

@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
Expand Down
83 changes: 0 additions & 83 deletions data/template_dataset.py

This file was deleted.

6 changes: 6 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def get_option_setter(model_name):
return model_class.modify_commandline_options


def get_after_parse(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.after_parse


def create_model(opt, rank):
"""Create a model given the option.
Expand Down
1 change: 1 addition & 0 deletions models/base_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from util.diff_aug import DiffAugment
from util.util import save_image, tensor2im
from util.util import pairs_of_floats, pairs_of_ints, MAX_INT

from .base_model import BaseModel
from .modules.sam.sam_inference import (
Expand Down
4 changes: 4 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ def modify_commandline_options(parser, is_train):
"""
return parser

@staticmethod
def modify_commandline_options_train(parser):
return parser

def set_input(self, data):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
Expand Down
4 changes: 4 additions & 0 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def modify_commandline_options(parser, is_train=True):

return parser

@staticmethod
def after_parse(opt):
return opt

def __init__(self, opt, rank):
super().__init__(opt, rank)

Expand Down
4 changes: 4 additions & 0 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def modify_commandline_options(parser, is_train=True):
)
return parser

@staticmethod
def after_parse(opt):
return opt

def __init__(self, opt, rank):

super().__init__(opt, rank)
Expand Down
1 change: 1 addition & 0 deletions models/modules/sam/sam_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import scipy
import torch

from mobile_sam.modeling import (
ImageEncoderViT,
MaskDecoder,
Expand Down
86 changes: 49 additions & 37 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,44 @@ def modify_commandline_options(parser, is_train=True):
parser = BaseDiffusionModel.modify_commandline_options(
parser, is_train=is_train
)

parser.add_argument(
"--alg_palette_cond_image_creation",
type=str,
default="y_t",
choices=[
"y_t",
"previous_frame",
"computed_sketch",
"low_res",
"ref",
],
help="how cond_image is created",
)

parser.add_argument(
"--alg_palette_ddim_num_steps",
type=int,
default=10,
help="number of steps for ddim sampling",
)

parser.add_argument(
"--alg_palette_ddim_eta",
type=float,
default=0.5,
help="eta for ddim sampling variance",
)

if is_train:
parser = PaletteModel.modify_commandline_options_train(parser)

return parser

@staticmethod
def modify_commandline_options_train(parser):
parser = BaseDiffusionModel.modify_commandline_options_train(parser)

parser.add_argument(
"--alg_palette_task",
default="inpainting",
Expand Down Expand Up @@ -63,20 +101,6 @@ def modify_commandline_options(parser, is_train=True):
help="dropout probability for classifier-free guidance",
)

parser.add_argument(
"--alg_palette_cond_image_creation",
type=str,
default="y_t",
choices=[
"y_t",
"previous_frame",
"computed_sketch",
"low_res",
"ref",
],
help="how cond_image is created",
)

parser.add_argument(
"--alg_palette_computed_sketch_list",
nargs="+",
Expand Down Expand Up @@ -184,7 +208,6 @@ def modify_commandline_options(parser, is_train=True):
default=0.5,
help="prob to use previous frame as y cond",
)

parser.add_argument(
"--alg_palette_sampling_method",
type=str,
Expand All @@ -193,20 +216,6 @@ def modify_commandline_options(parser, is_train=True):
help="choose the sampling method between ddpm and ddim",
)

parser.add_argument(
"--alg_palette_ddim_num_steps",
type=int,
default=10,
help="number of steps for ddim sampling",
)

parser.add_argument(
"--alg_palette_ddim_eta",
type=float,
default=0.5,
help="eta for ddim sampling variance",
)

parser.add_argument(
"--alg_palette_conditioning",
type=str,
Expand Down Expand Up @@ -238,6 +247,14 @@ def modify_commandline_options(parser, is_train=True):

return parser

@staticmethod
def after_parse(opt):
if opt.isTrain and opt.alg_palette_dropout_prob > 0:
# we add a class to be the unconditionned one.
opt.f_s_semantic_nclasses += 1
opt.cls_semantic_nclasses += 1
return opt

def __init__(self, opt, rank):
super().__init__(opt, rank)

Expand All @@ -263,14 +280,6 @@ def __init__(self, opt, rank):
else:
self.inference_num = min(self.opt.alg_palette_inference_num, batch_size)

self.ddim_num_steps = self.opt.alg_palette_ddim_num_steps
self.ddim_eta = self.opt.alg_palette_ddim_eta

if self.opt.alg_palette_dropout_prob > 0:
# we add a class to be the unconditionned one.
self.opt.f_s_semantic_nclasses += 1
self.opt.cls_semantic_nclasses += 1

self.num_classes = max(
self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses
)
Expand Down Expand Up @@ -391,6 +400,9 @@ def __init__(self, opt, rank):

self.sample_num = 2

self.ddim_num_steps = self.opt.alg_palette_ddim_num_steps
self.ddim_eta = self.opt.alg_palette_ddim_eta

def set_input(self, data):
"""must use set_device in tensor"""

Expand Down
Loading

0 comments on commit 5e746e5

Please sign in to comment.