Skip to content

Commit

Permalink
refactor: apply small code improvements (#113)
Browse files Browse the repository at this point in the history
This commit makes some very small code improvements that make applying a
mock patch easier.
  • Loading branch information
rickstaa authored Jun 26, 2024
1 parent 8b1b455 commit e5f4bf7
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 5 deletions.
3 changes: 1 addition & 2 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def list(cls):

class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
Expand All @@ -69,8 +70,6 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.model_id = model_id

# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if ModelName.SDXL_LIGHTNING.value in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"
Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
Expand All @@ -39,7 +40,6 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.model_id = model_id
self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
self.ldm.to(get_torch_device())

Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

class TextToImagePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
Expand Down Expand Up @@ -51,7 +52,6 @@ def __init__(self, model_id: str):
logger.info("TextToImagePipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

self.model_id = model_id

# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if SDXL_LIGHTNING_MODEL_ID in model_id:
Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

class UpscalePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
Expand All @@ -42,7 +43,6 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.model_id = model_id
self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
Expand Down

0 comments on commit e5f4bf7

Please sign in to comment.