Skip to content

Commit

Permalink
update defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
palp committed Aug 6, 2023
1 parent 76ca428 commit ced97f0
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions sgm/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class SamplingParams:
discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE
scale: float = 6.0
aesthetic_score: float = 5.0
negative_aesthetic_score: float = 5.0
scale: float = 5.0
aesthetic_score: float = 6.0
negative_aesthetic_score: float = 2.5
img2img_strength: float = 1.0
orig_width: int = 1024
orig_height: int = 1024
Expand Down Expand Up @@ -181,20 +181,30 @@ def __init__(
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
if not os.path.exists(model_path):
# This supports development installs where checkpoints is root level of the repo
model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints"
model_path = (
pathlib.Path(__file__).parent.parent.parent.resolve()
/ "checkpoints"
)
if config_path is None:
config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
config_path = (
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
)
if not os.path.exists(config_path):
# This supports development installs where configs is root level of the repo
config_path = (
pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference"
pathlib.Path(__file__).parent.parent.parent.resolve()
/ "configs/inference"
)
self.config = str(pathlib.Path(config_path) / self.specs.config)
self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
if not os.path.exists(self.ckpt):
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)

Expand Down Expand Up @@ -290,7 +300,9 @@ def wrap_discretization(
):
return discretization # Already wrapped
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength)
discretization = Img2ImgDiscretizationWrapper(
discretization, strength=image_strength
)

if (
noise_strength is not None
Expand Down Expand Up @@ -349,7 +361,9 @@ def refiner(

def get_guider_config(params: SamplingParams):
if params.guider == Guider.IDENTITY:
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif params.guider == Guider.VANILLA:
scale = params.scale

Expand Down

0 comments on commit ced97f0

Please sign in to comment.