From de7a6279787221116e6be8ac94bc768b4f860dee Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 05:11:34 -0700 Subject: [PATCH] more fixes and cleanup --- scripts/demo/streamlit_helpers.py | 9 +++------ sgm/inference/api.py | 14 +++++++------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 5b0214ae..a9ff5e8e 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A pipeline = SamplingPipeline( model_spec=spec, use_fp16=True, - model_loader=CudaModelManager(device="cuda", swap_device="cpu"), + device_manager=CudaModelManager(device="cuda", swap_device="cpu"), ) else: pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) @@ -207,7 +207,7 @@ def get_discretization(params: SamplingParams, key=1) -> SamplingParams: def get_sampler(params: SamplingParams, key=1) -> SamplingParams: - if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM: + if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM): params.s_churn = st.sidebar.number_input( f"s_churn #{key}", value=params.s_churn, min_value=0.0 ) @@ -221,10 +221,7 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: f"s_noise #{key}", value=params.s_noise, min_value=0.0 ) - elif ( - params.sampler == Sampler.EULER_ANCESTRAL - or params.sampler == Sampler.DPMPP2S_ANCESTRAL - ): + elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): params.s_noise = st.sidebar.number_input( "s_noise", value=params.s_noise, min_value=0.0 ) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 8516e733..96aead65 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -205,9 +205,9 @@ def __init__( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - self.model_manager = device_manager + self.device_manager = device_manager self.model = self._load_model( - device_manager=self.model_manager, use_fp16=use_fp16 + device_manager=self.device_manager, use_fp16=use_fp16 ) def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): @@ -229,7 +229,7 @@ def text_to_image( samples: int = 1, return_latents: bool = False, noise_strength: Optional[float] = None, - filter: Any = None, + filter=None, ): sampler = get_sampler_config(params) @@ -257,7 +257,7 @@ def text_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - model_manager=self.model_manager, + device_manager=self.device_manager, ) def image_to_image( @@ -269,7 +269,7 @@ def image_to_image( samples: int = 1, return_latents: bool = False, noise_strength: Optional[float] = None, - filter: Any = None, + filter=None, ): sampler = get_sampler_config(params) @@ -295,7 +295,7 @@ def image_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - device=self.device, + device_manager=self.device_manager, ) def wrap_discretization( @@ -364,7 +364,7 @@ def refiner( return_latents=return_latents, filter=filter, add_noise=add_noise, - device=self.device, + device_manager=self.device_manager, )