Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Streaming dosn't work | inference_stream is as fast as normal inference #97

Open
Fledermaus-20 opened this issue Oct 5, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@Fledermaus-20
Copy link

Describe the bug

When running tests against the TTS endpoint, I've observed that streaming the audio response takes nearly the same amount of time as receiving a fully generated audio file. This seems counterintuitive, as streaming should typically deliver the response faster, starting with the first available data chunk. Below are the code for the streaming endpoint

To Reproduce

model_manager.py

import asyncio
import os
import torch
import numpy as np
import wave
from io import BytesIO
from fastapi import HTTPException, status
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import logging

formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

class ModelManager:
    def __init__(self):
        self.tts_model = None
        self.model_loading_lock = asyncio.Lock()
        self.tts_model_path = '/code/models/TTS/'

    async def load_tts_model(self):
        if self.tts_model is None:
            async with self.model_loading_lock:
                if self.tts_model is None:
                    logger.info("Loading TTS model...")
                    device = "cuda" if torch.cuda.is_available() else "cpu"
                    logger.info(f"Using {device} as device.")
                    self.config = XttsConfig()
                    self.config.load_json(os.path.join(self.tts_model_path, "config.json"))
                    self.tts_model = Xtts.init_from_config(self.config)
                    self.tts_model.load_checkpoint(self.config, checkpoint_dir=self.tts_model_path, eval=True)
                    self.tts_model.to(device)
                    self.gpt, self.speaker = self.tts_model.get_conditioning_latents(self.tts_model_path + "speaker.wav")
                    logger.info("TTS model loaded.")

        return self.tts_model, self.speaker, self.gpt

tts_streaming.py

import torch
import numpy as np
import wave
from io import BytesIO
from fastapi import HTTPException, status, Request, Header
from model_manager import ModelManager
import logging

# Setup logging
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')

handler = logging.StreamHandler()
handler.setFormatter(formatter)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

model_manager = ModelManager()

def wav_data_generator(frame_input, sample_rate=24000, sample_width=2, channels=1):
    wav_buf = BytesIO()
    with wave.open(wav_buf, "wb") as vfout:
        vfout.setnchannels(channels)
        vfout.setsampwidth(sample_width)
        vfout.setframerate(sample_rate)
        vfout.writeframes(frame_input)
    wav_buf.seek(0)
    return wav_buf.read()

def postprocess(wav):
    if isinstance(wav, list):
        wav = torch.cat(wav, dim=0)
    wav = wav.clone().detach().cpu().numpy()
    wav = wav[None, : int(wav.shape[0])]
    wav = np.clip(wav, -1, 1)
    wav = (wav * 32767).astype(np.int16)
    return wav

async def text_to_speech_stream(text: str, language: str, chunk: int = 20):
    if not is_language_supported(language):
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported language: {language}")

    tts_model, speaker, gpt = await model_manager.load_tts_model()

    add_wav_header = True

    try:
        chunks = tts_model.inference_stream(
            text=text,
            language=language,
            gpt_cond_latent=gpt,
            speaker_embedding=speaker,
            stream_chunk_size=chunk,
            enable_text_splitting=True
        )

        for i, chunk in enumerate(chunks):
            chunk = postprocess(chunk)
            if i == 0 and add_wav_header:
                yield wav_data_generator(b"")
                yield chunk.tobytes()
            else:
                yield chunk.tobytes()

    except Exception as e:
        logger.error(f"Error in text-to-speech-stream: {e}")
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error in text-to-speech")

main.py

from fastapi import FastAPI, Request, Header
from fastapi.responses import StreamingResponse
from tts_streaming import text_to_speech_stream

app = FastAPI()

@app.post("/text-to-speech")
async def speech_route(
    stream_chunk: int = Header(20),
    text: str = Header(...),
    language: str = Header(...),
):
    return StreamingResponse(
        text_to_speech_stream(text=text, language=language, chunk=stream_chunk),
        media_type="audio/wav"
    )

Expected behavior

The behavior I'm expecting is that I get the tts stream back much sooner than if I request the finished file.

Logs

No response

Environment

It's Running in a Docker Contianer

{
    "CUDA": {
        "GPU": [],
        "available": false,
        "version": "12.1"
    },
    "Packages": {
        "PyTorch_debug": false,
        "PyTorch_version": "2.4.1+cu121",
        "TTS": "0.24.2",
        "numpy": "1.26.4"
    },
    "System": {
        "OS": "Linux",
        "architecture": [
            "64bit",
            ""
        ],
        "processor": "x86_64",
        "python": "3.11.0rc1",
        "version": "#45~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Sep 11 15:25:05 UTC 2"
    }
}

Additional context

Thanks for the help here in advance

@Fledermaus-20 Fledermaus-20 added the bug Something isn't working label Oct 5, 2024
@Fledermaus-20 Fledermaus-20 changed the title [Bug] Streaming doesn't work as intended [Bug] Streaming doesn't work Oct 7, 2024
@eginhard
Copy link
Member

eginhard commented Oct 8, 2024

You didn't share any timings, so it's hard to say what's going on. Note that streaming inference may take longer in total than normal inference, just the first chunks should arrive faster. Could you share (once models are loaded/warmed up):

  • average time to first chunk for streaming inference
  • average time for complete synthesis with normal inference

@Fledermaus-20
Copy link
Author

Sorry, I had forgotten, the times after the model is loaded/warmed up are:

  • Average time for complete synthesis with normal inference: 12.2383 seconds
  • Time to first chunk for streaming inference: 11.7466 seconds

@Fledermaus-20 Fledermaus-20 changed the title [Bug] Streaming doesn't work [Bug] Streaming dosn't work | inference_stream is as fast as normal inference Oct 8, 2024
@yalsaffar
Copy link

I have done this implementation, which might be useful for you (with additional threading to reduce latency):

https://github.com/yalsaffar/S3TVR/blob/main/models/TTS_utils.py#L232

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants