Skip to content

Commit

Permalink
more fixes and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
palp committed Aug 10, 2023
1 parent 9b18e6f commit de7a627
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
9 changes: 3 additions & 6 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=True,
model_loader=CudaModelManager(device="cuda", swap_device="cpu"),
device_manager=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
Expand Down Expand Up @@ -207,7 +207,7 @@ def get_discretization(params: SamplingParams, key=1) -> SamplingParams:


def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM:
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
params.s_churn = st.sidebar.number_input(
f"s_churn #{key}", value=params.s_churn, min_value=0.0
)
Expand All @@ -221,10 +221,7 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
f"s_noise #{key}", value=params.s_noise, min_value=0.0
)

elif (
params.sampler == Sampler.EULER_ANCESTRAL
or params.sampler == Sampler.DPMPP2S_ANCESTRAL
):
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
Expand Down
14 changes: 7 additions & 7 deletions sgm/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def __init__(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)

self.model_manager = device_manager
self.device_manager = device_manager
self.model = self._load_model(
device_manager=self.model_manager, use_fp16=use_fp16
device_manager=self.device_manager, use_fp16=use_fp16
)

def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
Expand All @@ -229,7 +229,7 @@ def text_to_image(
samples: int = 1,
return_latents: bool = False,
noise_strength: Optional[float] = None,
filter: Any = None,
filter=None,
):
sampler = get_sampler_config(params)

Expand Down Expand Up @@ -257,7 +257,7 @@ def text_to_image(
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
model_manager=self.model_manager,
device_manager=self.device_manager,
)

def image_to_image(
Expand All @@ -269,7 +269,7 @@ def image_to_image(
samples: int = 1,
return_latents: bool = False,
noise_strength: Optional[float] = None,
filter: Any = None,
filter=None,
):
sampler = get_sampler_config(params)

Expand All @@ -295,7 +295,7 @@ def image_to_image(
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
device=self.device,
device_manager=self.device_manager,
)

def wrap_discretization(
Expand Down Expand Up @@ -364,7 +364,7 @@ def refiner(
return_latents=return_latents,
filter=filter,
add_noise=add_noise,
device=self.device,
device_manager=self.device_manager,
)


Expand Down

0 comments on commit de7a627

Please sign in to comment.