forked from keithito/tacotron
-
Notifications
You must be signed in to change notification settings - Fork 103
/
demo_server.py
124 lines (110 loc) · 4.02 KB
/
demo_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from flask import Flask, request, send_file
from flask.views import MethodView
from hparams import hparams, hparams_debug_string
import argparse
import os
from util import audio
from synthesizer import Synthesizer
from flask_cors import CORS
import io
import numpy as np
import math
from synthesize_helper import synthesize_helper, replace_acronym, custom_splitter
app = Flask(__name__)
CORS(app)
use_synthesize_helper = False
html_body = '''<html><title>Demo</title>
<style>
body {padding: 16px; font-family: sans-serif; font-size: 14px; color: #444}
input {font-size: 14px; padding: 8px 12px; outline: none; border: 1px solid #ddd}
input:focus {box-shadow: 0 1px 2px rgba(0,0,0,.15)}
p {padding: 12px}
button {background: #28d; padding: 9px 14px; margin-left: 8px; border: none; outline: none;
color: #fff; font-size: 14px; border-radius: 4px; cursor: pointer;}
button:hover {box-shadow: 0 1px 2px rgba(0,0,0,.15); opacity: 0.9;}
button:active {background: #29f;}
button[disabled] {opacity: 0.4; cursor: default}
</style>
<body>
<form>
<input id="text" type="text" size="40" placeholder="Enter Text">
<button id="button" name="synthesize">Speak</button>
</form>
<p id="message"></p>
<audio id="audio" controls autoplay hidden></audio>
<script>
function q(selector) {return document.querySelector(selector)}
q('#text').focus()
q('#button').addEventListener('click', function(e) {
text = q('#text').value.trim()
if (text) {
q('#message').textContent = 'Synthesizing...'
q('#button').disabled = true
q('#audio').hidden = true
synthesize(text)
}
e.preventDefault()
return false
})
function synthesize(text) {
fetch('/synthesize?text=' + encodeURIComponent(text), {cache: 'no-cache'})
.then(function(res) {
if (!res.ok) throw Error(response.statusText)
return res.blob()
}).then(function(blob) {
q('#message').textContent = ''
q('#button').disabled = false
q('#audio').src = URL.createObjectURL(blob)
q('#audio').hidden = false
}).catch(function(err) {
q('#message').textContent = 'Error: ' + err.message
q('#button').disabled = false
})
}
</script></body></html>
'''
synthesizer = Synthesizer()
class Mimic2(MethodView):
def get(self):
text = request.args.get('text')
text = " ".join(replace_acronym(custom_splitter(text)))
if text:
if use_synthesize_helper:
wav = synthesize_helper(text, synthesizer)
# wav, _ = synthesizer.synthesize(text)
audio = io.BytesIO(wav)
return send_file(audio, mimetype="audio/wav")
else:
wav, _ = synthesizer.synthesize(text)
audio = io.BytesIO(wav)
return send_file(audio, mimetype="audio/wav")
class UI(MethodView):
def get(self):
return html_body
ui_view = UI.as_view('ui_view')
app.add_url_rule('/', view_func=ui_view, methods=['GET'])
mimic2_api = Mimic2.as_view('mimic2_api')
app.add_url_rule('/synthesize', view_func=mimic2_api, methods=['GET'])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True,
help='Full path to model checkpoint')
parser.add_argument('--port', type=int, default=3000)
parser.add_argument('--ip', type=str, default='0.0.0.0')
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument(
'--gpu_assignment', default='0',
help='Set the gpu the model should run on')
parser.add_argument(
'--synthezier_helper', default=False, action="store_false",
help='uses the synthesize helper during sythesis'
)
args = parser.parse_args()
use_synthesize_helper = args.synthezier_helper
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_assignment
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
hparams.parse(args.hparams)
print(hparams_debug_string())
synthesizer.load(args.checkpoint)
app.run(host=args.ip, port=args.port)