From d2a5f1f6f8e9bc7b08d2e82099908b7bdf9863c5 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Apr 2024 22:58:06 -0700 Subject: [PATCH 1/4] upgraded to v2 and added text2speech + translation support --- chart/values-azure.yaml | 3 +- common/requirements.txt | 2 +- common/seamless_asr.py | 103 ++++++++++++++++++++++++++++------------ sadtalker/SadTalker | 1 + scripts/run-dev.sh | 5 +- 5 files changed, 79 insertions(+), 35 deletions(-) create mode 160000 sadtalker/SadTalker diff --git a/chart/values-azure.yaml b/chart/values-azure.yaml index c14f8c0..6235380 100644 --- a/chart/values-azure.yaml +++ b/chart/values-azure.yaml @@ -325,8 +325,7 @@ deployments: IMPORTS: |- common.seamless_asr SEAMLESS_MODEL_IDS: |- - facebook/hf-seamless-m4t-large - facebook/hf-seamless-m4t-medium + facebook/seamless-m4t-v2-large - name: "common-diffusion-instruct-pix2pix" image: *commonImg diff --git a/common/requirements.txt b/common/requirements.txt index 4ea5a80..63027cf 100644 --- a/common/requirements.txt +++ b/common/requirements.txt @@ -7,7 +7,7 @@ accelerate ~= 0.20.3 celery ~= 5.3.0 pydantic ~= 1.10.9 redis ~= 4.5.5 -transformers ~= 4.35.0 +transformers ~= 4.40.1 ## pytorch --extra-index-url https://download.pytorch.org/whl/cu118 diff --git a/common/seamless_asr.py b/common/seamless_asr.py index 5325032..6078645 100644 --- a/common/seamless_asr.py +++ b/common/seamless_asr.py @@ -7,6 +7,9 @@ import transformers from pydantic import BaseModel +import io +from scipy.io.wavfile import write + import gooey_gpu from api import AsrOutput from celeryconfig import app, setup_queues @@ -14,15 +17,15 @@ class SeamlessM4TPipeline(BaseModel): upload_urls: typing.List[str] = [] - model_id: typing.Literal[ - "facebook/hf-seamless-m4t-large", "facebook/hf-seamless-m4t-medium" - ] = "facebook/hf-seamless-m4t-large" + model_id: typing.Literal["facebook/seamless-m4t-v2-large"] = ( + "facebook/seamless-m4t-v2-large" + ) class SeamlessM4TInputs(BaseModel): - audio: str | None # required for ASR, S2ST, and S2TT - text: str | None # required for T2ST and T2TT - task: typing.Literal["S2ST", "T2ST", "S2TT", "T2TT", "ASR"] = "ASR" + audio: str | None = None # required for ASR, S2ST, and S2TT + text: str | None = None # required for T2ST and T2TT + task: typing.Literal["T2ST", "T2TT", "ASR"] = "ASR" # we do not need S2ST and S2TT src_lang: str | None = None # required for T2ST and T2TT tgt_lang: str | None = None # ignored for ASR (only src_lang is used) # seamless uses ISO 639-3 codes for languages @@ -31,10 +34,7 @@ class SeamlessM4TInputs(BaseModel): stride_length_s: typing.Tuple[float, float] = (6, 0) batch_size: int = 16 - -class SeamlessM4TOutput(typing.TypedDict): - text: str | None - audio: str | None + speaker_id: int = 0 # only used for T2ST, value in [0, 200) @app.task(name="seamless") @@ -42,42 +42,83 @@ class SeamlessM4TOutput(typing.TypedDict): def seamless_asr( pipeline: SeamlessM4TPipeline, inputs: SeamlessM4TInputs, -) -> AsrOutput: - audio = requests.get(inputs.audio).content - pipe = load_pipe(pipeline.model_id) +) -> AsrOutput | None: + pipe, processor, model = load_pipe(pipeline.model_id) + tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" - previous_src_lang = None - if inputs.src_lang: - previous_src_lang = pipe.tokenizer.src_lang - pipe.tokenizer.src_lang = inputs.src_lang + if inputs.task == "ASR": + assert inputs.audio is not None - tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" + audio = requests.get(inputs.audio).content - prediction = pipe( - audio, - # see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1 - chunk_length_s=inputs.chunk_length_s, - stride_length_s=inputs.stride_length_s, - batch_size=inputs.batch_size, - generate_kwargs=dict(tgt_lang=tgt_lang), - ) + previous_src_lang = pipe.tokenizer.src_lang + if inputs.src_lang: + pipe.tokenizer.src_lang = inputs.src_lang + + prediction = pipe( + audio, + # see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1 + chunk_length_s=inputs.chunk_length_s, + stride_length_s=inputs.stride_length_s, + batch_size=inputs.batch_size, + generate_kwargs=dict(tgt_lang=tgt_lang), + ) - if previous_src_lang: pipe.tokenizer.src_lang = previous_src_lang - return prediction + return prediction + + assert inputs.text is not None + assert inputs.src_lang is not None + text_inputs = processor( + text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" + ) + + if inputs.task == "T2ST": + audio_array_from_text = ( + model.generate( + **text_inputs, tgt_lang=tgt_lang, speaker_id=inputs.speaker_id + )[0] + .cpu() + .numpy() + .squeeze() + ) + + bytes_wav = bytes() + byte_io = io.BytesIO(bytes_wav) + write(byte_io, 16000, audio_array_from_text) + audio_bytes = byte_io.read() + gooey_gpu.upload_audio_from_bytes(audio_bytes, pipeline.upload_urls[0]) + return + if inputs.task == "T2TT": + output_tokens = model.generate( + **text_inputs, tgt_lang=tgt_lang, generate_speech=False + ) + translated_text_from_text = processor.decode( + output_tokens[0].tolist()[0], skip_special_tokens=True + ) + + return AsrOutput(text=translated_text_from_text) @lru_cache -def load_pipe(model_id: str): +def load_pipe( + model_id: str, +) -> typing.Tuple[ + transformers.AutomaticSpeechRecognitionPipeline, + transformers.SeamlessM4TProcessor, + transformers.SeamlessM4Tv2Model, +]: print(f"Loading asr model {model_id!r}...") pipe = transformers.pipeline( - "automatic-speech-recognition", + task="automatic-speech-recognition", model=model_id, device=gooey_gpu.DEVICE_ID, torch_dtype=torch.float16, ) - return pipe + processor = transformers.AutoProcessor.from_pretrained(model_id) + model = transformers.SeamlessM4Tv2Model.from_pretrained(model_id) + return pipe, processor, model setup_queues( diff --git a/sadtalker/SadTalker b/sadtalker/SadTalker new file mode 160000 index 0000000..cd4c046 --- /dev/null +++ b/sadtalker/SadTalker @@ -0,0 +1 @@ +Subproject commit cd4c0465ae0b54a6f85af57f5c65fec9fe23e7f8 diff --git a/scripts/run-dev.sh b/scripts/run-dev.sh index 18df3bf..9fa8d05 100755 --- a/scripts/run-dev.sh +++ b/scripts/run-dev.sh @@ -62,6 +62,9 @@ docker run \ -e U2NET_MODEL_IDS=" u2net "\ + -e SEAMLESS_MODEL_IDS=" + facebook/seamless-m4t-v2-large + "\ -e C_FORCE_ROOT=1 \ -e BROKER_URL=${BROKER_URL:-"amqp://"} \ -e RESULT_BACKEND=${RESULT_BACKEND:-"redis://"} \ @@ -70,7 +73,7 @@ docker run \ -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \ -v $HOME/.cache/huggingface:/root/.cache/huggingface \ -v $HOME/.cache/torch:/root/.cache/torch \ - --net host --runtime=nvidia --gpus all \ + --net host \ --memory 14g \ -it --rm --name $IMG \ $IMG:latest From 32964e50635f52dca31c8ade728480a2e7c6ad5e Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Sun, 9 Jun 2024 20:25:31 -0700 Subject: [PATCH 2/4] refactor into separate gpu tasks --- common/seamless_asr.py | 117 ++++++++++++++++++++++++----------------- 1 file changed, 70 insertions(+), 47 deletions(-) diff --git a/common/seamless_asr.py b/common/seamless_asr.py index 6078645..78b7300 100644 --- a/common/seamless_asr.py +++ b/common/seamless_asr.py @@ -25,7 +25,6 @@ class SeamlessM4TPipeline(BaseModel): class SeamlessM4TInputs(BaseModel): audio: str | None = None # required for ASR, S2ST, and S2TT text: str | None = None # required for T2ST and T2TT - task: typing.Literal["T2ST", "T2TT", "ASR"] = "ASR" # we do not need S2ST and S2TT src_lang: str | None = None # required for T2ST and T2TT tgt_lang: str | None = None # ignored for ASR (only src_lang is used) # seamless uses ISO 639-3 codes for languages @@ -37,36 +36,46 @@ class SeamlessM4TInputs(BaseModel): speaker_id: int = 0 # only used for T2ST, value in [0, 200) -@app.task(name="seamless") +@app.task(name="seamless.t2st") @gooey_gpu.endpoint -def seamless_asr( +def seamless_text2speech_translation( pipeline: SeamlessM4TPipeline, inputs: SeamlessM4TInputs, -) -> AsrOutput | None: - pipe, processor, model = load_pipe(pipeline.model_id) +) -> None: + _, processor, model = load_pipe(pipeline.model_id) tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" - if inputs.task == "ASR": - assert inputs.audio is not None - - audio = requests.get(inputs.audio).content + assert inputs.text is not None + assert inputs.src_lang is not None + text_inputs = processor( + text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" + ) - previous_src_lang = pipe.tokenizer.src_lang - if inputs.src_lang: - pipe.tokenizer.src_lang = inputs.src_lang + audio_array_from_text = ( + model.generate(**text_inputs, tgt_lang=tgt_lang, speaker_id=inputs.speaker_id)[ + 0 + ] + .cpu() + .numpy() + .squeeze() + ) - prediction = pipe( - audio, - # see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1 - chunk_length_s=inputs.chunk_length_s, - stride_length_s=inputs.stride_length_s, - batch_size=inputs.batch_size, - generate_kwargs=dict(tgt_lang=tgt_lang), - ) + bytes_wav = bytes() + byte_io = io.BytesIO(bytes_wav) + write(byte_io, 16000, audio_array_from_text) + audio_bytes = byte_io.read() + gooey_gpu.upload_audio_from_bytes(audio_bytes, pipeline.upload_urls[0]) + return - pipe.tokenizer.src_lang = previous_src_lang - return prediction +@app.task(name="seamless.t2tt") +@gooey_gpu.endpoint +def seamless_text2text_translation( + pipeline: SeamlessM4TPipeline, + inputs: SeamlessM4TInputs, +) -> AsrOutput | None: + _, processor, model = load_pipe(pipeline.model_id) + tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" assert inputs.text is not None assert inputs.src_lang is not None @@ -74,31 +83,45 @@ def seamless_asr( text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" ) - if inputs.task == "T2ST": - audio_array_from_text = ( - model.generate( - **text_inputs, tgt_lang=tgt_lang, speaker_id=inputs.speaker_id - )[0] - .cpu() - .numpy() - .squeeze() - ) - - bytes_wav = bytes() - byte_io = io.BytesIO(bytes_wav) - write(byte_io, 16000, audio_array_from_text) - audio_bytes = byte_io.read() - gooey_gpu.upload_audio_from_bytes(audio_bytes, pipeline.upload_urls[0]) - return - if inputs.task == "T2TT": - output_tokens = model.generate( - **text_inputs, tgt_lang=tgt_lang, generate_speech=False - ) - translated_text_from_text = processor.decode( - output_tokens[0].tolist()[0], skip_special_tokens=True - ) - - return AsrOutput(text=translated_text_from_text) + output_tokens = model.generate( + **text_inputs, tgt_lang=tgt_lang, generate_speech=False + ) + translated_text_from_text = processor.decode( + output_tokens[0].tolist()[0], skip_special_tokens=True + ) + + return AsrOutput(text=translated_text_from_text) + + +@app.task(name="seamless") +@gooey_gpu.endpoint +def seamless_asr( + pipeline: SeamlessM4TPipeline, + inputs: SeamlessM4TInputs, +) -> AsrOutput | None: + pipe, _, _ = load_pipe(pipeline.model_id) + tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" + + assert inputs.audio is not None + + audio = requests.get(inputs.audio).content + + previous_src_lang = pipe.tokenizer.src_lang + if inputs.src_lang: + pipe.tokenizer.src_lang = inputs.src_lang + + prediction = pipe( + audio, + # see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1 + chunk_length_s=inputs.chunk_length_s, + stride_length_s=inputs.stride_length_s, + batch_size=inputs.batch_size, + generate_kwargs=dict(tgt_lang=tgt_lang), + ) + + pipe.tokenizer.src_lang = previous_src_lang + + return prediction @lru_cache From a5e5fb1436a51a3064024e8c77ebb75b388856f2 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 13 Aug 2024 00:16:11 +0530 Subject: [PATCH 3/4] Handle exceptions in Celery worker initialization, set GPU memory limit based on environment variable and add gpu limits to deployment configuration --- celeryconfig.py | 13 +++++++++++-- chart/templates/deployment.yaml | 4 ++++ gooey_gpu.py | 11 +++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/celeryconfig.py b/celeryconfig.py index f52884e..9f0bd18 100644 --- a/celeryconfig.py +++ b/celeryconfig.py @@ -1,7 +1,9 @@ import os +import traceback import typing from celery import Celery +from celery.exceptions import WorkerShutdown from celery.signals import worker_init from kombu import Queue @@ -33,8 +35,15 @@ def setup_queues( queue_prefix: str = os.environ.get("QUEUE_PREFIX", "gooey-gpu"), ): def init(**kwargs): - for model_id in model_ids: - load_fn(model_id) + model_id = None + try: + for model_id in model_ids: + load_fn(model_id) + except: + # for some reason, celery seems to swallow exceptions in init + print(f"Error loading {model_id}:") + traceback.print_exc() + raise WorkerShutdown() init_fns.append(init) diff --git a/chart/templates/deployment.yaml b/chart/templates/deployment.yaml index 8dccbc3..5b2e416 100644 --- a/chart/templates/deployment.yaml +++ b/chart/templates/deployment.yaml @@ -46,6 +46,10 @@ spec: value: "{{ $value }}" {{- end }} {{- end }} + {{- range $name, $value := .limits }} + - name: "RESOURCE_LIMITS_{{ $name | upper }}" + value: "{{ $value }}" + {{- end }} livenessProbe: exec: command: [ "bash", "-c", "celery inspect ping -d celery@$HOSTNAME" ] diff --git a/gooey_gpu.py b/gooey_gpu.py index 3adaa7a..a73f77a 100644 --- a/gooey_gpu.py +++ b/gooey_gpu.py @@ -36,6 +36,17 @@ or "/root/.cache/gooey-gpu/checkpoints" ) +try: + gpu_limit_gib = float(os.environ["RESOURCE_LIMITS_GPU"].removesuffix("Gi")) +except (KeyError, ValueError): + print("RESOURCE_LIMITS_GPU environment variable not set to a valid value.") +else: + total_mem_bytes = torch.cuda.mem_get_info()[1] + fraction = gpu_limit_gib * 1024**3 / total_mem_bytes + torch.cuda.set_per_process_memory_fraction(fraction) + print(f"GPU limit set to {gpu_limit_gib}Gi ({fraction:.2%})") + + if SENTRY_DSN: sentry_sdk.init( dsn=SENTRY_DSN, From 73ff9c5bd076bf5fd1025338126132edb67bff8d Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 13 Aug 2024 00:17:57 +0530 Subject: [PATCH 4/4] Add seamless text-to-speech and text-to-text translation with updated dependencies - Create `seamless_v2.py` for new seamless text-to-speech and text-to-text translation tasks. - Update Dockerfile to use CUDA 12.4.1 and Ubuntu 22.04. - Upgrade dependencies in `requirements.txt`. - Modify `model-values.yaml` to update image reference, GPU, and memory limits. - Simplify `seamless_asr.py` for seamless ASR only (as huggingface expects a different model class for asr vs other tasks) - Fix `run-dev.sh` to enable NVIDIA runtime and GPU usage. --- chart/model-values.yaml | 5 +- common/Dockerfile | 2 +- common/deepfloyd.py | 137 --------------- common/pipeline_if_sr_patch.py | 308 --------------------------------- common/requirements.txt | 18 +- common/seamless_asr.py | 122 +++---------- common/seamless_v2.py | 83 +++++++++ scripts/run-dev.sh | 2 +- 8 files changed, 121 insertions(+), 556 deletions(-) delete mode 100644 common/deepfloyd.py delete mode 100644 common/pipeline_if_sr_patch.py create mode 100644 common/seamless_v2.py diff --git a/chart/model-values.yaml b/chart/model-values.yaml index 34be57c..38c6501 100644 --- a/chart/model-values.yaml +++ b/chart/model-values.yaml @@ -195,9 +195,10 @@ deployments: bark - name: "common-seamless" - image: "us-docker.pkg.dev/dara-c1b52/cloudbuild/gooey-gpu/common@sha256:5fb0ffa128cbdda86747fedf5ef68e9df8256735d8535149c6fffa41a3749883" + image: "crgooeyprodwestus1.azurecr.io/gooey-gpu-common:3" limits: - memory: "19Gi" + gpu: "10Gi" + memory: "28Gi" # (220 / 80) * 10 cpu: "1" env: IMPORTS: |- diff --git a/common/Dockerfile b/common/Dockerfile index 0f31852..6db0dec 100644 --- a/common/Dockerfile +++ b/common/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.2.0-devel-ubuntu20.04 +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 ARG DEBIAN_FRONTEND=noninteractive diff --git a/common/deepfloyd.py b/common/deepfloyd.py deleted file mode 100644 index ed2b2b4..0000000 --- a/common/deepfloyd.py +++ /dev/null @@ -1,137 +0,0 @@ -from functools import lru_cache - -import torch -from diffusers import ( - IFPipeline, - StableDiffusionUpscalePipeline, -) -from fastapi import APIRouter -from pydantic import BaseModel -from transformers import T5EncoderModel - -import gooey_gpu -from common.diffusion import safety_checker_wrapper -from common.pipeline_if_sr_patch import IFSuperResolutionPipelinePatch - -app = APIRouter(prefix="/deepfloyd_if") - - -class PipelineInfo(BaseModel): - upload_urls: list[str] = [] - model_id: tuple[str, str, str] - seed: int = 42 - disable_safety_checker: bool = False - - -class DeepfloydInputs(BaseModel): - prompt: list[str] - negative_prompt: list[str] = None - num_inference_steps: tuple[int, int, int] = (100, 50, 75) - num_images_per_prompt: int = 1 - guidance_scale: tuple[float, float, float] = (7, 4, 9) - - -class Text2ImgInputs(DeepfloydInputs): - width: int - height: int - - -@app.post("/text2img/") -@gooey_gpu.endpoint -def text2img(pipeline: PipelineInfo, inputs: Text2ImgInputs): - output_images = _run_model(pipeline, inputs) - gooey_gpu.upload_images(output_images, pipeline.upload_urls) - - -@gooey_gpu.gpu_task -def _run_model(pipeline: PipelineInfo, inputs: Text2ImgInputs): - pipe1 = load_pipe1(pipeline.model_id[0]) - pipe2 = load_pipe2(pipeline.model_id[1]) - pipe3 = load_pipe3(pipeline.model_id[2]) - - inputs.prompt *= inputs.num_images_per_prompt - if inputs.negative_prompt: - inputs.negative_prompt *= inputs.num_images_per_prompt - - with gooey_gpu.use_models(pipe1), torch.inference_mode(): - generator = torch.Generator().manual_seed(pipeline.seed) - # custom safety checker impl - safety_checker_wrapper(pipe1, disabled=pipeline.disable_safety_checker) - # Create text embeddings - prompt_embeds, negative_embeds = pipe1.encode_prompt( - inputs.prompt, negative_prompt=inputs.negative_prompt - ) - # The main diffusion process - images = pipe1( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_embeds, - guidance_scale=inputs.guidance_scale[0], - num_inference_steps=inputs.num_inference_steps[0], - output_type="pt", - generator=generator, - width=inputs.width // 16, - height=inputs.height // 16, - ).images - - with gooey_gpu.use_models(pipe2), torch.inference_mode(): - # custom safety checker impl - safety_checker_wrapper(pipe2, disabled=pipeline.disable_safety_checker) - # Super Resolution 64x64 to 256x256 - images = pipe2( - image=images, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_embeds, - guidance_scale=inputs.guidance_scale[1], - num_inference_steps=inputs.num_inference_steps[1], - output_type="pt", - generator=generator, - ).images - - with gooey_gpu.use_models(pipe3), torch.inference_mode(): - # custom safety checker impl - safety_checker_wrapper(pipe3, disabled=pipeline.disable_safety_checker) - # Super Resolution 256x256 to 1024x1024 - output_images = pipe3( - image=images, - prompt=inputs.prompt, - negative_prompt=inputs.negative_prompt, - guidance_scale=inputs.guidance_scale[2], - num_inference_steps=inputs.num_inference_steps[2], - generator=generator, - ).images - - return output_images - - -@lru_cache -def load_pipe1(model_id: str): - # text_encoder = T5EncoderModel.from_pretrained( - # model_id, - # subfolder="text_encoder", - # load_in_8bit=True, - # variant="8bit", - # ) - return IFPipeline.from_pretrained( - model_id, - # text_encoder=text_encoder, - variant="fp16", - torch_dtype=torch.float16, - ) - - -@lru_cache -def load_pipe2(model_id: str): - return IFSuperResolutionPipelinePatch.from_pretrained( - model_id, - text_encoder=None, # no use of text encoder => memory savings! - variant="fp16", - torch_dtype=torch.float16, - ) - - -@lru_cache -def load_pipe3(model_id: str): - return StableDiffusionUpscalePipeline.from_pretrained( - model_id, - torch_dtype=torch.float16, - ) diff --git a/common/pipeline_if_sr_patch.py b/common/pipeline_if_sr_patch.py deleted file mode 100644 index 4cbc7b2..0000000 --- a/common/pipeline_if_sr_patch.py +++ /dev/null @@ -1,308 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional, Union - -import PIL -import numpy as np -import torch -import torch.nn.functional as F -from diffusers import IFSuperResolutionPipeline -from diffusers.pipelines.deepfloyd_if import IFPipelineOutput -from diffusers.utils import ( - randn_tensor, -) - - -class IFSuperResolutionPipelinePatch(IFSuperResolutionPipeline): - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, - num_inference_steps: int = 50, - timesteps: List[int] = None, - guidance_scale: float = 4.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - noise_level: int = 250, - clean_caption: bool = True, - ): - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): - The image to be upscaled. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - noise_level (`int`, *optional*, defaults to 250): - The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` - clean_caption (`bool`, *optional*, defaults to `True`): - Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to - be installed. If the dependencies are not installed, the embeddings will be created from the raw - prompt. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When - returning a tuple, the first element is a list with the generated images, and the second element is a list - of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) - or watermarked content, according to the `safety_checker`. - """ - # 1. Check inputs. Raise error if not correct - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - self.check_inputs( - prompt, - image, - batch_size, - noise_level, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - ) - - # 2. Define call parameters - - # height = self.unet.config.sample_size - # width = self.unet.config.sample_size - height, width = image.shape[2:] - width *= 4 - height *= 4 - - device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - do_classifier_free_guidance, - num_images_per_prompt=num_images_per_prompt, - device=device, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - clean_caption=clean_caption, - ) - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - # 4. Prepare timesteps - if timesteps is not None: - self.scheduler.set_timesteps(timesteps=timesteps, device=device) - timesteps = self.scheduler.timesteps - num_inference_steps = len(timesteps) - else: - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare intermediate images - num_channels = self.unet.config.in_channels // 2 - intermediate_images = self.prepare_intermediate_images( - batch_size * num_images_per_prompt, - num_channels, - height, - width, - prompt_embeds.dtype, - device, - generator, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Prepare upscaled image and noise level - image = self.preprocess_image(image, num_images_per_prompt, device) - upscaled = F.interpolate( - image, (height, width), mode="bilinear", align_corners=True - ) - - noise_level = torch.tensor( - [noise_level] * upscaled.shape[0], device=upscaled.device - ) - noise = randn_tensor( - upscaled.shape, - generator=generator, - device=upscaled.device, - dtype=upscaled.dtype, - ) - upscaled = self.image_noising_scheduler.add_noise( - upscaled, noise, timesteps=noise_level - ) - - if do_classifier_free_guidance: - noise_level = torch.cat([noise_level] * 2) - - # HACK: see comment in `enable_model_cpu_offload` - if ( - hasattr(self, "text_encoder_offload_hook") - and self.text_encoder_offload_hook is not None - ): - self.text_encoder_offload_hook.offload() - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - model_input = torch.cat([intermediate_images, upscaled], dim=1) - - model_input = ( - torch.cat([model_input] * 2) - if do_classifier_free_guidance - else model_input - ) - model_input = self.scheduler.scale_model_input(model_input, t) - - # predict the noise residual - noise_pred = self.unet( - model_input, - t, - encoder_hidden_states=prompt_embeds, - class_labels=noise_level, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred_uncond, _ = noise_pred_uncond.split( - model_input.shape[1] // 2, dim=1 - ) - noise_pred_text, predicted_variance = noise_pred_text.split( - model_input.shape[1] // 2, dim=1 - ) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) - - # compute the previous noisy sample x_t -> x_t-1 - intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, intermediate_images) - - image = intermediate_images - - if output_type == "pil": - # 9. Post-processing - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - # 10. Run safety checker - image, nsfw_detected, watermark_detected = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) - - # 11. Convert to PIL - image = self.numpy_to_pil(image) - - # 12. Apply watermark - if self.watermarker is not None: - self.watermarker.apply_watermark(image, self.unet.config.sample_size) - elif output_type == "pt": - nsfw_detected = None - watermark_detected = None - - if ( - hasattr(self, "unet_offload_hook") - and self.unet_offload_hook is not None - ): - self.unet_offload_hook.offload() - else: - # 9. Post-processing - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - # 10. Run safety checker - image, nsfw_detected, watermark_detected = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, nsfw_detected, watermark_detected) - - return IFPipelineOutput( - images=image, - nsfw_detected=nsfw_detected, - watermark_detected=watermark_detected, - ) diff --git a/common/requirements.txt b/common/requirements.txt index 63027cf..a932a1b 100644 --- a/common/requirements.txt +++ b/common/requirements.txt @@ -7,19 +7,19 @@ accelerate ~= 0.20.3 celery ~= 5.3.0 pydantic ~= 1.10.9 redis ~= 4.5.5 -transformers ~= 4.40.1 +transformers ~= 4.44.0 ## pytorch ---extra-index-url https://download.pytorch.org/whl/cu118 -torch ~= 2.0.0 -torchvision ~= 0.15.1 -torchaudio ~= 2.0.1 +--extra-index-url https://download.pytorch.org/whl/cu124 +torch ~= 2.4.0 +torchvision ~= 0.19.0 +torchaudio ~= 2.4.0 ## huggingface diffusers -diffusers ~= 0.21.1 -sentencepiece ~= 0.1.99 -torchsde ~= 0.2.5 -xformers ~= 0.0.20 +diffusers ~= 0.30.0 +sentencepiece ~= 0.2.0 +torchsde ~= 0.2.6 +xformers ~= 0.0.27 ## controlnet controlnet_aux ~= 0.0.1 diff --git a/common/seamless_asr.py b/common/seamless_asr.py index 78b7300..f7934ef 100644 --- a/common/seamless_asr.py +++ b/common/seamless_asr.py @@ -7,109 +7,40 @@ import transformers from pydantic import BaseModel -import io -from scipy.io.wavfile import write - import gooey_gpu from api import AsrOutput from celeryconfig import app, setup_queues -class SeamlessM4TPipeline(BaseModel): - upload_urls: typing.List[str] = [] - model_id: typing.Literal["facebook/seamless-m4t-v2-large"] = ( - "facebook/seamless-m4t-v2-large" - ) +class SeamlessASRPipeline(BaseModel): + model_id: str -class SeamlessM4TInputs(BaseModel): - audio: str | None = None # required for ASR, S2ST, and S2TT - text: str | None = None # required for T2ST and T2TT - src_lang: str | None = None # required for T2ST and T2TT - tgt_lang: str | None = None # ignored for ASR (only src_lang is used) - # seamless uses ISO 639-3 codes for languages +class SeamlessASRInputs(BaseModel): + audio: str + src_lang: str | None = None + tgt_lang: str | None = None chunk_length_s: float = 30 stride_length_s: typing.Tuple[float, float] = (6, 0) batch_size: int = 16 - speaker_id: int = 0 # only used for T2ST, value in [0, 200) - -@app.task(name="seamless.t2st") -@gooey_gpu.endpoint -def seamless_text2speech_translation( - pipeline: SeamlessM4TPipeline, - inputs: SeamlessM4TInputs, -) -> None: - _, processor, model = load_pipe(pipeline.model_id) - tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" - - assert inputs.text is not None - assert inputs.src_lang is not None - text_inputs = processor( - text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" - ) - - audio_array_from_text = ( - model.generate(**text_inputs, tgt_lang=tgt_lang, speaker_id=inputs.speaker_id)[ - 0 - ] - .cpu() - .numpy() - .squeeze() - ) - - bytes_wav = bytes() - byte_io = io.BytesIO(bytes_wav) - write(byte_io, 16000, audio_array_from_text) - audio_bytes = byte_io.read() - gooey_gpu.upload_audio_from_bytes(audio_bytes, pipeline.upload_urls[0]) - return - - -@app.task(name="seamless.t2tt") -@gooey_gpu.endpoint -def seamless_text2text_translation( - pipeline: SeamlessM4TPipeline, - inputs: SeamlessM4TInputs, -) -> AsrOutput | None: - _, processor, model = load_pipe(pipeline.model_id) - tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" - - assert inputs.text is not None - assert inputs.src_lang is not None - text_inputs = processor( - text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" - ) - - output_tokens = model.generate( - **text_inputs, tgt_lang=tgt_lang, generate_speech=False - ) - translated_text_from_text = processor.decode( - output_tokens[0].tolist()[0], skip_special_tokens=True - ) - - return AsrOutput(text=translated_text_from_text) - - -@app.task(name="seamless") +@app.task(name="seamless.asr") @gooey_gpu.endpoint def seamless_asr( - pipeline: SeamlessM4TPipeline, - inputs: SeamlessM4TInputs, -) -> AsrOutput | None: - pipe, _, _ = load_pipe(pipeline.model_id) - tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" - - assert inputs.audio is not None - + pipeline: SeamlessASRPipeline, + inputs: SeamlessASRInputs, +) -> AsrOutput: audio = requests.get(inputs.audio).content + pipe = load_pipe(pipeline.model_id) previous_src_lang = pipe.tokenizer.src_lang if inputs.src_lang: pipe.tokenizer.src_lang = inputs.src_lang + tgt_lang = inputs.tgt_lang or inputs.src_lang + prediction = pipe( audio, # see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1 @@ -125,23 +56,18 @@ def seamless_asr( @lru_cache -def load_pipe( - model_id: str, -) -> typing.Tuple[ - transformers.AutomaticSpeechRecognitionPipeline, - transformers.SeamlessM4TProcessor, - transformers.SeamlessM4Tv2Model, -]: - print(f"Loading asr model {model_id!r}...") - pipe = transformers.pipeline( - task="automatic-speech-recognition", - model=model_id, - device=gooey_gpu.DEVICE_ID, - torch_dtype=torch.float16, +def load_pipe(model_id: str) -> transformers.AutomaticSpeechRecognitionPipeline: + print(f"Loading seamless m4t pipeline {model_id!r}...") + pipe = typing.cast( + transformers.AutomaticSpeechRecognitionPipeline, + transformers.pipeline( + task="automatic-speech-recognition", + model=model_id, + device=gooey_gpu.DEVICE_ID, + torch_dtype=torch.float16, + ), ) - processor = transformers.AutoProcessor.from_pretrained(model_id) - model = transformers.SeamlessM4Tv2Model.from_pretrained(model_id) - return pipe, processor, model + return pipe setup_queues( diff --git a/common/seamless_v2.py b/common/seamless_v2.py new file mode 100644 index 0000000..e7a8c7e --- /dev/null +++ b/common/seamless_v2.py @@ -0,0 +1,83 @@ +import typing +from functools import lru_cache + +import transformers +from pydantic import BaseModel + +import gooey_gpu +from celeryconfig import app + + +class SeamlessPipeline(BaseModel): + upload_urls: typing.List[str] = [] + model_id: str + + +class SeamlessT2STInputs(BaseModel): + text: str + src_lang: str + tgt_lang: str + speaker_id: int = 0 # [0, 200) + + +@app.task(name="seamless.t2st") +@gooey_gpu.endpoint +def seamless_text_to_speech_translation( + pipeline: SeamlessPipeline, + inputs: SeamlessT2STInputs, +) -> None: + model, processor = load_model(pipeline.model_id) + text_inputs = processor( + text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" + ).to(gooey_gpu.DEVICE_ID) + audio_array_from_text = ( + model.generate( + **text_inputs, tgt_lang=inputs.tgt_lang, speaker_id=inputs.speaker_id + )[0] + .cpu() + .numpy() + .squeeze() + ) + gooey_gpu.upload_audio(audio_array_from_text, pipeline.upload_urls[0]) + + +class SeamlessT2TTInputs(BaseModel): + text: str + src_lang: str + tgt_lang: str + + +@app.task(name="seamless.t2tt") +@gooey_gpu.endpoint +def seamless_text2text_translation( + pipeline: SeamlessPipeline, + inputs: SeamlessT2TTInputs, +) -> str: + model, processor = load_model(pipeline.model_id) + text_inputs = processor( + text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" + ).to(gooey_gpu.DEVICE_ID) + output_tokens = model.generate( + **text_inputs, tgt_lang=inputs.tgt_lang, generate_speech=False + ) + translated_text_from_text = processor.decode( + output_tokens[0].tolist()[0], skip_special_tokens=True + ) + return translated_text_from_text + + +@lru_cache +def load_model(model_id: str) -> typing.Tuple[ + transformers.SeamlessM4Tv2Model, + transformers.SeamlessM4TProcessor, +]: + print(f"Loading seamless m4t model {model_id!r}...") + model = typing.cast( + transformers.SeamlessM4Tv2Model, + transformers.AutoModel.from_pretrained(model_id).to(gooey_gpu.DEVICE_ID), + ) + processor = typing.cast( + transformers.SeamlessM4TProcessor, + transformers.AutoProcessor.from_pretrained(model_id), + ) + return model, processor diff --git a/scripts/run-dev.sh b/scripts/run-dev.sh index 873a91c..79c53bb 100755 --- a/scripts/run-dev.sh +++ b/scripts/run-dev.sh @@ -75,7 +75,7 @@ docker run \ -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \ -v $HOME/.cache/huggingface:/root/.cache/huggingface \ -v $HOME/.cache/torch:/root/.cache/torch \ - --net host \ + --net host --runtime=nvidia --gpus all \ --memory 14g \ -it --rm --name $IMG \ $IMG:latest