Skip to content

Commit

Permalink
Refactor error handling and subprocess calls, extract common utilities
Browse files Browse the repository at this point in the history
- Replace `ValueError` with `UserError` for better user feedback and sentry sanity
- Copy over centralized HTTP & ffmpeg error handling in `exceptions.py` from gooey-server
- Simplify file download logic with `download_file_to_path()`
- Avoid too broad exception clause to handle bad face reco in sadtalker
- Handle bad face reco in eyeblink and ref_post too
- Remove stale wav2lip-src folder
  • Loading branch information
devxpy committed Jul 30, 2024
1 parent 8c86f97 commit 0e3463f
Show file tree
Hide file tree
Showing 53 changed files with 306 additions and 61,884 deletions.
2 changes: 1 addition & 1 deletion chart/model-values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ deployments:
thenlper/gte-base
- name: "retro-sadtalker"
image: *retroImg
image: "crgooeyprodwestus1.azurecr.io/gooey-gpu-retro:9"
autoscaling:
queueLength: 2
minReplicaCount: 3
Expand Down
2 changes: 1 addition & 1 deletion common/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _safety_checker(clip_input, images):
if not disabled:
images, has_nsfw_concepts = original(images=images, clip_input=clip_input)
if any(has_nsfw_concepts):
raise ValueError(
raise gooey_gpu.UserError(
"Potential NSFW content was detected in one or more images. "
"Try again with a different Prompt and/or Regenerate."
)
Expand Down
2 changes: 1 addition & 1 deletion deforum_sd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def deforum(pipeline: PipelineInfo, inputs: deforum_script.DeforumAnimArgs):
headers={"Content-Type": "video/mp4"},
data=vid_bytes,
)
r.raise_for_status()
gooey_gpu.raise_for_status(r)
return


Expand Down
8 changes: 2 additions & 6 deletions deforum_sd/deforum_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,7 @@ def create_video(args: DeforumArgs, anim_args: DeforumAnimArgs):
max_frames = str(anim_args.max_frames)

# make video
cmd = [
"ffmpeg",
"-y",
gooey_gpu.ffmpeg(
"-vcodec",
bitdepth_extension,
"-r",
Expand All @@ -405,9 +403,7 @@ def create_video(args: DeforumArgs, anim_args: DeforumAnimArgs):
"-pattern_type",
"sequence",
mp4_path,
]
print(f"---> {' '.join(cmd)}")
subprocess.check_call(cmd)
)
# process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# stdout, stderr = process.communicate()
# if process.returncode != 0:
Expand Down
68 changes: 68 additions & 0 deletions exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import typing

import requests


class UserError(Exception):
def __init__(
self,
message: str,
sentry_level: str = "info",
status_code: typing.Optional[int] = None,
):
self.message = message
self.sentry_level = sentry_level
self.status_code = status_code
super().__init__(
dict(message=message, sentry_level=sentry_level, status_code=status_code)
)

def __str__(self):
return self.message


def raise_for_status(resp: requests.Response, is_user_url: bool = False):
"""Raises :class:`HTTPError`, if one occurred."""

http_error_msg = ""
if isinstance(resp.reason, bytes):
# We attempt to decode utf-8 first because some servers
# choose to localize their reason strings. If the string
# isn't utf-8, we fall back to iso-8859-1 for all other
# encodings. (See PR #3538)
try:
reason = resp.reason.decode("utf-8")
except UnicodeDecodeError:
reason = resp.reason.decode("iso-8859-1")
else:
reason = resp.reason

if 400 <= resp.status_code < 500:
http_error_msg = f"{resp.status_code} Client Error: {reason} | URL: {resp.url} | Response: {_response_preview(resp)!r}"

elif 500 <= resp.status_code < 600:
http_error_msg = f"{resp.status_code} Server Error: {reason} | URL: {resp.url} | Response: {_response_preview(resp)!r}"

if http_error_msg:
exc = requests.HTTPError(http_error_msg, response=resp)
if is_user_url:
raise UserError(
f"[{resp.status_code}] You have provided an invalid URL: {resp.url} "
"Please make sure the URL is correct and accessible. ",
) from exc
else:
raise exc


def _response_preview(resp: requests.Response) -> bytes:
return truncate_filename(resp.content, 500, sep=b"...")


def truncate_filename(
text: typing.AnyStr, maxlen: int = 100, sep: typing.AnyStr = "..."
) -> typing.AnyStr:
if len(text) <= maxlen:
return text
assert len(sep) <= maxlen
mid = (maxlen - len(sep)) // 2
return text[:mid] + sep + text[-mid:]
69 changes: 53 additions & 16 deletions ffmpeg_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from pydantic import BaseModel

from exceptions import UserError


class VideoMetadata(BaseModel):
width: int = 0
Expand All @@ -23,40 +25,48 @@ class InputOutputVideoMetadata(BaseModel):

class AudioMetadata(BaseModel):
duration_sec: float = 0
codec_name: typing.Optional[str] = None


def ffprobe_audio(input_path: str) -> AudioMetadata:
cmd_args = [
text = call_cmd(
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
input_path,
]
print("\t$ " + " ".join(cmd_args))
"-v", "quiet",
"-print_format", "json",
"-show_streams", input_path,
"-select_streams", "a:0",
) # fmt:skip
data = json.loads(text)

try:
stream = data["streams"][0]
except IndexError:
raise UserError(
"Input has no audio streams. Make sure the you have uploaded an appropriate audio/video file."
)

return AudioMetadata(
duration_sec=float(subprocess.check_output(cmd_args, encoding="utf-8"))
duration_sec=float(stream.get("duration") or 0),
codec_name=stream.get("codec_name"),
)


def ffprobe_video(input_path: str) -> VideoMetadata:
cmd_args = [
text = call_cmd(
"ffprobe",
"-v", "quiet",
"-print_format", "json",
"-show_streams", input_path,
"-select_streams", "v:0",
] # fmt:skip
print("\t$ " + " ".join(cmd_args))
data = json.loads(subprocess.check_output(cmd_args, text=True))
) # fmt:skip
data = json.loads(text)

try:
stream = data["streams"][0]
except IndexError:
raise ValueError("input has no video streams")
raise UserError(
"Input has no video streams. Make sure the video you have uploaded is not corrupted."
)

try:
fps = float(Fraction(stream["avg_frame_rate"]))
Expand Down Expand Up @@ -120,3 +130,30 @@ def ffmpeg_get_writer_proc(
] # fmt:skip
print("\t$ " + " ".join(cmd_args))
return subprocess.Popen(cmd_args, stdin=subprocess.PIPE)


FFMPEG_ERR_MSG = (
"Unsupported File Format\n\n"
"We encountered an issue processing your file as it appears to be in a format not supported by our system or may be corrupted. "
"You can find a list of supported formats at [FFmpeg Formats](https://ffmpeg.org/general.html#File-Formats)."
)


def ffmpeg(*args) -> str:
return call_cmd("ffmpeg", "-hide_banner", "-y", *args, err_msg=FFMPEG_ERR_MSG)


def call_cmd(
*args, err_msg: str = "", ok_returncodes: typing.Iterable[int] = ()
) -> str:
print("\t$ " + " ".join(map(str, args)))
try:
return subprocess.check_output(args, stderr=subprocess.STDOUT, text=True)
except subprocess.CalledProcessError as e:
if e.returncode in ok_returncodes:
return e.output
err_msg = err_msg or f"{str(args[0]).capitalize()} Error"
try:
raise subprocess.SubprocessError(e.output) from e
except subprocess.SubprocessError as e:
raise UserError(err_msg) from e
17 changes: 9 additions & 8 deletions gooey_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import mimetypes
import os
import threading
import typing
from concurrent.futures import ThreadPoolExecutor
from functools import wraps

Expand All @@ -16,7 +15,9 @@
import sentry_sdk
import torch
import transformers
from pydantic import BaseModel

from exceptions import raise_for_status
from ffmpeg_util import *

# from accelerate import cpu_offload_with_hook

Expand Down Expand Up @@ -191,7 +192,7 @@ def upload_image(im_pil: PIL.Image.Image, url: str):
headers={"Content-Type": "image/png"},
data=im_bytes,
)
r.raise_for_status()
raise_for_status(r)


def apply_parallel(fn, *iterables):
Expand Down Expand Up @@ -220,22 +221,22 @@ def upload_audio(audio, url: str, rate: int = 16_000):

def upload_audio_from_bytes(audio: bytes, url: str):
r = requests.put(url, headers={"Content-Type": "audio/wav"}, data=audio)
r.raise_for_status()
raise_for_status(r)


def upload_video_from_bytes(video, url: str):
r = requests.put(url, headers={"Content-Type": "video/mp4"}, data=video)
r.raise_for_status()
raise_for_status(r)


# Add some missing mimetypes
mimetypes.add_type("audio/wav", ".wav")


def download_file_cached(*, url: str, path: str):
if os.path.exists(path):
def download_file_to_path(*, url: str, path: str, cached: bool = False):
if cached and os.path.exists(path):
return
r = requests.get(url)
r.raise_for_status()
raise_for_status(r, is_user_url=not cached)
with open(path, "wb") as f:
f.write(r.content)
40 changes: 16 additions & 24 deletions retro/gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing
from functools import lru_cache
from tempfile import TemporaryDirectory
from urllib.request import urlretrieve

import PIL.Image
from basicsr.archs.rrdbnet_arch import RRDBNet
Expand All @@ -14,13 +13,6 @@

import gooey_gpu
from celeryconfig import app, setup_queues
from ffmpeg_util import (
ffmpeg_get_writer_proc,
ffmpeg_read_input_frames,
ffprobe_video,
VideoMetadata,
InputOutputVideoMetadata,
)

MAX_RES = 1920 * 1080

Expand All @@ -40,7 +32,7 @@ class EsrganInputs(BaseModel):
@gooey_gpu.endpoint
def realesrgan(
pipeline: EsrganPipeline, inputs: EsrganInputs
) -> InputOutputVideoMetadata:
) -> gooey_gpu.InputOutputVideoMetadata:
esrganer = load_esrgan_model(pipeline.model_id)

def enhance(frame, outscale_factor):
Expand Down Expand Up @@ -71,7 +63,9 @@ class GfpganInputs(BaseModel):

@app.task(name="gfpgan")
@gooey_gpu.endpoint
def gfpgan(pipeline: GfpganPipeline, inputs: GfpganInputs) -> InputOutputVideoMetadata:
def gfpgan(
pipeline: GfpganPipeline, inputs: GfpganInputs
) -> gooey_gpu.InputOutputVideoMetadata:
gfpganer = load_gfpgan_model(pipeline.model_id)
if pipeline.bg_model_id:
gfpganer.bg_upsampler = load_esrgan_model(pipeline.bg_model_id)
Expand Down Expand Up @@ -102,24 +96,22 @@ def run_enhancer(
scale: float,
upload_url: str,
enhance: typing.Callable,
) -> InputOutputVideoMetadata:
input_file = image or video
assert input_file, "Please provide an image or video input"
) -> gooey_gpu.InputOutputVideoMetadata:
input_url = image or video
assert input_url, "Please provide an image or video input"

with TemporaryDirectory() as save_dir:
input_path, _ = urlretrieve(
input_file,
os.path.join(save_dir, "input" + os.path.splitext(input_file)[1]),
)
input_path = os.path.join(save_dir, "input" + os.path.splitext(input_url)[1])
gooey_gpu.download_file_to_path(url=input_url, path=input_path)
output_path = os.path.join(save_dir, "out.mp4")

response = InputOutputVideoMetadata(
input=ffprobe_video(input_path), output=VideoMetadata()
response = gooey_gpu.InputOutputVideoMetadata(
input=gooey_gpu.ffprobe_video(input_path), output=gooey_gpu.VideoMetadata()
)
# ensure max input/output is 1080p
input_pixels = response.input.width * response.input.height
if input_pixels > MAX_RES:
raise ValueError(
raise gooey_gpu.UserError(
"Input video resolution exceeds 1920x1080. Please downscale to 1080p."
)
max_scale = math.sqrt(MAX_RES / input_pixels)
Expand All @@ -128,7 +120,7 @@ def run_enhancer(

ffproc = None
for frame in tqdm(
ffmpeg_read_input_frames(
gooey_gpu.ffmpeg_read_input_frames(
width=response.input.width,
height=response.input.height,
input_path=input_path,
Expand All @@ -152,7 +144,7 @@ def run_enhancer(
response.output.width = restored_img.shape[1]
response.output.height = restored_img.shape[0]
response.output.fps = response.input.fps or 24
ffproc = ffmpeg_get_writer_proc(
ffproc = gooey_gpu.ffmpeg_get_writer_proc(
width=response.output.width,
height=response.output.height,
fps=response.output.fps,
Expand Down Expand Up @@ -214,7 +206,7 @@ def load_gfpgan_model(model_id: str) -> "GFPGANer":

print(f"loading {model_id} via {url}...")
model_path = os.path.join(gfpgan_checkpoint_dir, os.path.basename(url))
gooey_gpu.download_file_cached(url=url, path=model_path)
gooey_gpu.download_file_to_path(url=url, path=model_path, cached=True)

return GFPGANer(
model_path=model_path,
Expand Down Expand Up @@ -282,7 +274,7 @@ def load_esrgan_model(model_id: str) -> "RealESRGANer":
for url in file_url:
print(f"loading {model_id} via {url}...")
model_path = os.path.join(gooey_gpu.CHECKPOINTS_DIR, os.path.basename(url))
gooey_gpu.download_file_cached(url=url, path=model_path)
gooey_gpu.download_file_to_path(url=url, path=model_path, cached=True)
assert model_path, f"Model {model_id} not found"

return RealESRGANer(
Expand Down
2 changes: 1 addition & 1 deletion retro/nvidia_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_model(model_url: str):
# get cached model path
model_path = os.path.join(gooey_gpu.CHECKPOINTS_DIR, os.path.basename(model_url))
# if not cached, download again
gooey_gpu.download_file_cached(url=model_url, path=model_path)
gooey_gpu.download_file_to_path(url=model_url, path=model_path, cached=True)
# load model
return nemo_asr.models.ASRModel.restore_from(model_path)

Expand Down
Loading

0 comments on commit 0e3463f

Please sign in to comment.