Skip to content

Commit

Permalink
refactor: uniformize options API
Browse files Browse the repository at this point in the history
Introduces InferenceOptions for inference scripts.

The options can also be used by the server and be parsed either as json
or command line.

Options are also better organized.

New tests for inference GAN + diffusion
  • Loading branch information
pnsuau authored and Bycob committed Sep 22, 2023
1 parent 8f8ccf0 commit c27ddd1
Show file tree
Hide file tree
Showing 18 changed files with 1,203 additions and 827 deletions.
4 changes: 4 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def modify_commandline_options(parser, is_train):
"""
return parser

@staticmethod
def modify_commandline_options_train(parser):
return parser

@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
Expand Down
83 changes: 0 additions & 83 deletions data/template_dataset.py

This file was deleted.

6 changes: 6 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def get_option_setter(model_name):
return model_class.modify_commandline_options


def get_after_parse(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.after_parse


def create_model(opt, rank):
"""Create a model given the option.
Expand Down
136 changes: 136 additions & 0 deletions models/base_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from util.diff_aug import DiffAugment
from util.util import save_image, tensor2im
from util.util import pairs_of_floats, pairs_of_ints, MAX_INT

from .base_model import BaseModel
from .modules.sam.sam_inference import (
Expand All @@ -35,6 +36,141 @@ class BaseDiffusionModel(BaseModel):
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""

@staticmethod
def modify_commandline_options(parser, is_train: bool = True):
print(parser)
if is_train:
BaseDiffusionModel.modify_commandline_options_train(parser)
else:
BaseDiffusionModel.modify_commandline_options_inference(parser)

return parser

@staticmethod
def modify_commandline_options_train(parser):
return parser

@staticmethod
def modify_commandline_options_inference(parser):
parser.add_argument(
"--crop_width",
default=-1,
type=int,
help="crop width added on each side of the bbox (optional)",
)
parser.add_argument(
"--crop_height",
default=-1,
type=int,
help="crop height added on each side of the bbox (optional)",
)

parser.add_argument(
"--bbox_width_factor",
type=float,
default=0.0,
help="bbox width added factor of original width",
)
parser.add_argument(
"--bbox_height_factor",
type=float,
default=0.0,
help="bbox height added factor of original height",
)

parser.add_argument(
"--sampling_steps", default=-1, type=int, help="number of sampling steps"
)
parser.add_argument(
"--seed", type=int, default=-1, help="random seed for reproducibility"
)

parser.add_argument(
"--mask_delta",
default=[[0]],
nargs="+",
type=pairs_of_ints,
help="mask offset to allow generation of a bigger object, format : width (x) height (y) for each class or only one size if square",
)
parser.add_argument(
"--mask_delta_ratio",
default=[[0]],
nargs="+",
type=pairs_of_floats,
help="ratio mask offset to allow generation of a bigger object, format : width (x),height (y) for each class or only one size if square",
)

parser.add_argument(
"--mask_square", action="store_true", help="whether to use square mask"
)

parser.add_argument("--name", help="generated img name", default="img")

parser.add_argument(
"--sampling_method",
type=str,
default="ddpm",
choices=["ddpm", "ddim"],
help="choose the sampling method between ddpm and ddim",
)

parser.add_argument(
"--cls_value",
type=int,
default=-1,
help="override input bbox classe for generation",
)

# last options
parser.add_argument("--previous_frame", help="image to transform", default=None)
parser.add_argument(
"--mask_in", help="mask used for image transformation", required=False
)
# XXX: put directly in the script?
parser.add_argument(
"--dir_out",
help="The directory where to output result images",
required=True,
)
parser.add_argument("--bbox_in", help="bbox file used for masking")

parser.add_argument(
"--nb_samples", help="nb of samples generated", type=int, default=1
)
parser.add_argument(
"--bbox_ref_id", help="bbox id to use", type=int, default=-1
)
parser.add_argument("--cond_in", help="conditionning image to use")
parser.add_argument("--cond_keep_ratio", action="store_true")
parser.add_argument("--cond_rotation", type=float, default=0)
parser.add_argument("--cond_persp_horizontal", type=float, default=0)
parser.add_argument("--cond_persp_vertical", type=float, default=0)

parser.add_argument(
"--alg_palette_guidance_scale",
type=float,
default=0.0, # literature value: 0.2
help="scale for classifier-free guidance, default is conditional DDPM only",
)
parser.add_argument(
"--alg_palette_sketch_canny_thresholds",
type=int,
nargs="+",
default=[0, 255 * 3],
help="Canny thresholds",
)
parser.add_argument(
"--alg_palette_super_resolution_downsample",
action="store_true",
help="whether to downsample the image for super resolution",
)
parser.add_argument(
"--min_crop_bbox_ratio",
type=float,
help="minimum crop/bbox ratio, allows to add context when bbox is larger than crop",
)
return parser

def __init__(self, opt, rank):
"""Initialize the BaseModel class.
Expand Down
4 changes: 4 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ def modify_commandline_options(parser, is_train):
"""
return parser

@staticmethod
def modify_commandline_options_train(parser):
return parser

def set_input(self, data):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
Expand Down
4 changes: 4 additions & 0 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def modify_commandline_options(parser, is_train=True):

return parser

@staticmethod
def after_parse(opt):
return opt

def __init__(self, opt, rank):
super().__init__(opt, rank)

Expand Down
4 changes: 4 additions & 0 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def modify_commandline_options(parser, is_train=True):
)
return parser

@staticmethod
def after_parse(opt):
return opt

def __init__(self, opt, rank):

super().__init__(opt, rank)
Expand Down
1 change: 1 addition & 0 deletions models/modules/sam/sam_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import scipy
import torch

from mobile_sam.modeling import (
ImageEncoderViT,
MaskDecoder,
Expand Down
Loading

0 comments on commit c27ddd1

Please sign in to comment.