Skip to content

Commit

Permalink
SAFETY CHECK DRAFT DO NOT MERGE
Browse files Browse the repository at this point in the history
  • Loading branch information
rickstaa committed May 5, 2024
1 parent 792b620 commit e087d51
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
26 changes: 22 additions & 4 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
UNet2DConditionModel,
EulerDiscreteScheduler,
)
from transformers import CLIPImageProcessor
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
import torch
Expand Down Expand Up @@ -52,6 +54,11 @@ def __init__(self, model_id: str):

self.model_id = model_id

# Load SafetyChecker if requested
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)

# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if SDXL_LIGHTNING_MODEL_ID in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"
Expand Down Expand Up @@ -84,16 +91,27 @@ def __init__(self, model_id: str):
)

self.ldm = StableDiffusionXLPipeline.from_pretrained(
base, unet=unet, **kwargs
base,
unet=unet,
safety_checker=safety_checker,
feature_extractor=CLIPImageProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
),
**kwargs,
).to(torch_device)

self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
else:
self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to(
torch_device
)
self.ldm = AutoPipelineForText2Image.from_pretrained(
model_id,
safety_checker=safety_checker,
feature_extractor=CLIPImageProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
),
**kwargs,
).to(torch_device)

if os.environ.get("TORCH_COMPILE"):
torch._inductor.config.conv_1x1_as_mm = True
Expand Down
18 changes: 10 additions & 8 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pydantic import BaseModel
from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.pipelines.base import Pipeline
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
import logging
import random
import os
import random

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel

router = APIRouter()

Expand All @@ -23,6 +24,7 @@ class TextToImageParams(BaseModel):
width: int = None
guidance_scale: float = 7.5
negative_prompt: str = ""
safety_check: bool = False
seed: int = None
num_images_per_prompt: int = 1

Expand Down

0 comments on commit e087d51

Please sign in to comment.