diff --git a/retro/sadtalker.py b/retro/sadtalker.py index 3667ddb..5399422 100644 --- a/retro/sadtalker.py +++ b/retro/sadtalker.py @@ -6,7 +6,7 @@ from functools import lru_cache from tempfile import TemporaryDirectory import requests -import random +from urllib.parse import urlparse import cv2 import numpy as np @@ -151,11 +151,13 @@ def sadtalker( ) -> InputOutputVideoMetadata: assert len(pipeline.upload_urls) == 1, "Expected exactly 1 upload url" - face_mime_type = mimetypes.guess_type(inputs.source_image.split("?")[0])[0] or "" + face_url_without_query = urlparse(inputs.source_image)._replace(query={}).geturl() + face_mime_type = mimetypes.guess_type(face_url_without_query)[0] or "" if not ("video/" in face_mime_type or "image/" in face_mime_type): raise ValueError(f"Unsupported face format {face_mime_type!r}") - audio_mime_type = mimetypes.guess_type(inputs.driven_audio.split("?")[0])[0] or "" + audio_url_without_query = urlparse(inputs.driven_audio)._replace(query={}).geturl() + audio_mime_type = mimetypes.guess_type(audio_url_without_query)[0] or "" if not ("audio/" in audio_mime_type or "video/" in audio_mime_type): raise ValueError(f"Unsupported audio format {audio_mime_type!r}")