Skip to content

Commit

Permalink
streaming example
Browse files Browse the repository at this point in the history
  • Loading branch information
aldrinjenson committed Feb 20, 2024
1 parent f92e44a commit 38e4072
Show file tree
Hide file tree
Showing 3 changed files with 347 additions and 28 deletions.
233 changes: 233 additions & 0 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import base64
import time
import tempfile
import logging
from typing import Dict
from modal import Image, Stub, web_endpoint, asgi_app
from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

# Define the GPU type to be used for processing
GPU_TYPE = "T4"


def download_models():
"""
Downloads and initializes models required for speech processing, including a translator model and a VAD (Voice Activity Detection) model.
"""
from seamless_communication.inference import Translator
import torch

# Define model names for the translator and vocoder
model_name = "seamlessM4T_v2_large"
vocoder_name = (
"vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs"
)

# Initialize the translator model with specified parameters
Translator(
model_name,
vocoder_name,
device=torch.device("cuda:0"),
dtype=torch.float16,
)

# Load the silero-VAD model from the specified repository
USE_ONNX = False
torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=USE_ONNX)


def base64_to_audio_file(b64_contents: str):
"""
Converts a base64 encoded string to an audio file and returns the path to the temporary audio file.
Parameters:
- b64_contents (str): Base64 encoded string of the audio file.
Returns:
- str: Path to the temporary audio file.
"""
# Decode the base64 string to audio data
audio_data = base64.b64decode(b64_contents)

# Create a temporary file to store the audio data
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_file.write(audio_data)
return tmp_file.name


def convert_to_mono_16k(input_file: str, output_file: str) -> None:
"""
Converts an audio file to mono channel with a sample rate of 16000 Hz.
Parameters:
- input_file (str): Path to the input audio file.
- output_file (str): Path where the converted audio file will be saved.
"""
from pydub import AudioSegment

sound = AudioSegment.from_file(input_file)
sound = sound.set_channels(1).set_frame_rate(16000)
sound.export(output_file, format="wav")


# Define the Docker image configuration for the processing environment
image = (
Image.from_registry("nvidia/cuda:12.2.0-devel-ubuntu20.04", add_python="3.10")
.apt_install("git", "ffmpeg")
.pip_install(
"fairseq2==0.2.*",
"sentencepiece",
"pydub",
"ffmpeg-python",
"torch==2.1.1",
"seamless_communication @ git+https://github.com/facebookresearch/seamless_communication.git", # torchaudio already included in seamless_communication
"faster-whisper",
)
.run_function(download_models, gpu=GPU_TYPE)
)

# Initialize the processing stub with the defined Docker image
stub = Stub(name="seamless_m4t_speech", image=image)


@stub.function(gpu=GPU_TYPE, timeout=600)
@web_endpoint(method="POST")
async def generate_seamlessm4t_speech(item: Dict):
"""
Processes the input speech audio, performs voice activity detection, and translates the speech from the source language to the target language.
Parameters:
- item (Dict): A dictionary containing the base64 encoded audio data, source language, and target language.
Returns:
- Dict: A dictionary containing the status code, message, detected speech chunks, and the translated text.
"""
# import wave
import os

import torch
import torchaudio
from pydub import AudioSegment
from seamless_communication.inference import Translator
import json

try:
USE_ONNX = False
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=USE_ONNX
)

(
get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks,
) = utils

# Decode the base64 audio and convert it for processing
b64 = item["wav_base64"]
# source_lang = item["source"]
print(f"Target_lang: {item.get('target')}")
target_lang = item["target"]

fname = base64_to_audio_file(b64_contents=b64)
print(fname)
convert_to_mono_16k(fname, "output.wav")

# Perform voice activity detection on the processed audio
SAMPLING_RATE = 16000
wav = read_audio("output.wav", sampling_rate=SAMPLING_RATE)

# get speech timestamps from full audio file
speech_timestamps_seconds = get_speech_timestamps(
wav, model, sampling_rate=SAMPLING_RATE, return_seconds=True
)
print(speech_timestamps_seconds)
# translator = download_models()
model_name = "seamlessM4T_v2_large"
vocoder_name = (
"vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs"
)

translator = Translator(
model_name,
vocoder_name,
device=torch.device("cuda:0"),
dtype=torch.float16,
)

# duration = get_duration_wave(fname)
# print(f"Duration: {duration:.2f} seconds")

resample_rate = 16000

# Replace t1, t2 with VAD time
timestamps_start = []
timestamps_end = []
text = []

# Logic for VAD based filtering
async def generate():
for item in speech_timestamps_seconds:
s = item["start"]
e = item["end"]

timestamps_start.append(s)
timestamps_end.append(e)
newAudio = AudioSegment.from_wav("output.wav")

# time in seconds should be multiplied by 1000.0 for AudioSegment array. So 20s = 20000
newAudio = newAudio[s * 1000 : e * 1000]
new_audio_name = "new_" + str(s) + ".wav"
newAudio.export(new_audio_name, format="wav")
waveform, sample_rate = torchaudio.load(new_audio_name)
resampler = torchaudio.transforms.Resample(
sample_rate, resample_rate, dtype=waveform.dtype
)
resampled_waveform = resampler(waveform)
torchaudio.save("resampled.wav", resampled_waveform, resample_rate)
translated_text, _ = translator.predict(
"resampled.wav", "s2tt", target_lang
)
text.append(str(translated_text[0]))
# os.remove(new_audio_name)
# os.remove("resampled.wav")
obj = {
"start": s,
"end": e,
"text": str(translated_text[0]),
}
print(obj)
time.sleep(0.5)

# yield json.dumps({"data": json.dumps(obj)})
yield json.dumps(obj)

response = StreamingResponse(generate(), media_type="application/json")
return response

# chunks = []
# for i in range(len(text)):
# chunks.append(
# {
# "start": timestamps_start[i],
# "end": timestamps_end[i],
# "text": text[i],
# }
# )

# full_text = " ".join([x["text"] for x in chunks])
# return {
# "code": 200,
# "message": "Speech generated successfully.",
# "chunks": chunks,
# "text": full_text,
# }

except Exception as e:
print(e)
logging.critical(e, exc_info=True)
return {"message": "Internal server error", "code": 500}
38 changes: 19 additions & 19 deletions ui/src/pages/generate.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,24 @@ export default function dashboard() {
async function handleSubmit() {
reset(true);
const response = await handleTranscribe(uploadedFile, targetLanguage);
if (response) {
reset(false);
if (response.data.code !== 200) {
console.log(response);
toast.error(response.data.message);
} else {
setTranscribed(response.data.chunks);
const file = {
filename: uploadedFile.path,
size: uploadedFile.size,
transcribedData: response.data.chunks,
uploadDate: new Date(),
sourceLanguage: sourceLanguage,
targetLanguage: targetLanguage,
};
storeFileToLocalStorage(file);
}
}
// if (response) {
// reset(false);
// if (response.data.code !== 200) {
// console.log(response);
// toast.error(response.data.message);
// } else {
// setTranscribed(response.data.chunks);
// const file = {
// filename: uploadedFile.path,
// size: uploadedFile.size,
// transcribedData: response.data.chunks,
// uploadDate: new Date(),
// sourceLanguage: sourceLanguage,
// targetLanguage: targetLanguage,
// };
// storeFileToLocalStorage(file);
// }
// }
}

return (
Expand Down Expand Up @@ -116,7 +116,7 @@ export default function dashboard() {
/>
</div>
<button
disabled={disabled}
// disabled={disabled}
onClick={handleSubmit}
className={` ${
disabled
Expand Down
104 changes: 95 additions & 9 deletions ui/src/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,104 @@ export const handleTranscribe = async (file, targetLang) => {
wav_base64: base64Data,
target: targetLang,
};
let finalData = [];

try {
const response = await axios.post(
"https://kurianbenoy--seamless-m4t-speech-generate-seamlessm4t-speech.modal.run/",
requestData
);
return response;
} catch (error) {
return error;
}
fetch("https://aldrinjenson--vllm-mixtral.modal.run", {
method: "POST",
body: JSON.stringify(requestData),
headers: {
"Content-Type": "application/json",
},
})
.then(async (res) => {
console.log(res);
console.log(res.body);
const decoder = new TextDecoder();
const reader = res.body.getReader();

while (true) {
const { done, value } = await reader.read();
console.log(done, value);
if (done) break;
const chunk = decoder.decode(value, { stream: true });
console.log(chunk);

// const decodedValue = new TextDecoder().decode(value);
// console.log(decodedValue);
// const jsonData = JSON.parse(decodedValue);
// console.log(jsonData);
// finalData.push(jsonData);
}
})
.catch((err) => console.log("error: ", err));
// return finalData;
};

// export const handleTranscribe = async (file, targetLang) => {
// const base64Data = await fileToBase64(file);

// const requestData = {
// wav_base64: base64Data,
// target: targetLang,
// };
// console.log(requestData);

// try {
// const response = await axios.post(
// "https://aldrinjenson--vllm-mixtral.modal.run",
// requestData,
// { responseType: "stream" } // Set responseType to 'stream' to enable streaming response
// );

// // Create an array to hold the streaming JSON data
// const streamingData = [];

// // Create a new TextDecoder to decode the response stream
// const decoder = new TextDecoder("utf-8");

// // Read the response stream in chunks
// for await (const chunk of response.data) {
// // Decode the chunk and append it to the streamingData array
// console.log(chunk);
// const decodedChunk = decoder.decode(chunk, { stream: true });
// console.log(decodedChunk);
// streamingData.push(decodedChunk);
// }

// // Join the decoded chunks to form a single string
// const jsonDataString = streamingData.join("");

// // Parse the JSON array
// const jsonData = JSON.parse(jsonDataString);

// console.log(jsonData);
// return jsonData;
// } catch (error) {
// console.error("Error:", error);
// throw error; // Rethrow the error to handle it outside of this function
// }
// };

// export const handleTranscribe = async (file, targetLang) => {
// const base64Data = await fileToBase64(file);

// const requestData = {
// wav_base64: base64Data,
// target: targetLang,
// };

// try {
// const response = await axios.post(
// "https://aldrinjenson--vllm-mixtral.modal.run",
// requestData
// );
// console.log(response);
// return response;
// } catch (error) {
// return error;
// }
// };

export function formatTime(time) {
const hours = Math.floor(time / 3600);
const minutes = Math.floor((time % 3600) / 60);
Expand Down

0 comments on commit 38e4072

Please sign in to comment.