-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
95 lines (70 loc) · 2.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import fire
from dotenv import load_dotenv
from utils.launch import Launcher, LaunchConfig
from utils.pipeline import Pipeline, PairwisePipeline
load_dotenv()
launcher = Launcher()
def run_chatgpt(openai_api_key: str = "", **kwargs):
from models.chatgpt.core import ChatGPT
chatgpt = ChatGPT(api_key=openai_api_key)
config = LaunchConfig(**kwargs, title="MAIA (ChatGPT Only)")
launcher.launch_gradio(chatgpt, config)
def run_whisperx(**kwargs):
from models.whisperx.core import WhisperX
whisper = WhisperX(
device=launcher.get_device(),
device_index=0,
compute_type="float16",
batch_size=16,
)
config = LaunchConfig(**kwargs, title="MAIA (WhisperX Only)")
launcher.launch_gradio(whisper, config)
def run_alpaca(**kwargs):
from models.alpaca.core import Alpaca
alpaca = Alpaca(
device=launcher.get_device(),
load_8bit=True,
base_model="decapoda-research/llama-7b-hf",
lora_weights="tloen/alpaca-lora-7b",
)
config = LaunchConfig(**kwargs, title="MAIA (Alpaca Only)")
launcher.launch_gradio(alpaca, config)
def run_bard(bard_api_key: str = "", **kwargs):
from models.bard.core import Bard
bard = Bard(api_key=bard_api_key)
config = LaunchConfig(**kwargs, title="MAIA (Bard Only)")
launcher.launch_gradio(bard, config)
def run_palm(google_api_key: str = "", **kwargs):
from models.palm.core import PaLM
palm = PaLM(api_key=google_api_key)
config = LaunchConfig(**kwargs, title="MAIA (PaLM Only)")
launcher.launch_gradio(palm, config)
def run_googletts(google_tts_api_key: str = "", **kwargs):
from models.googletts.core import GoogleTTS
papago = GoogleTTS(api_key=google_tts_api_key)
config = LaunchConfig(**kwargs, title="MAIA (GoogleTTS Only)")
launcher.launch_gradio(papago, config)
def main(**kwargs):
from models.whisperx.core import WhisperX
from models.chatgpt.core import ChatGPT
from models.palm.core import PaLM
from conversation.prompter import BasePrompter, AugmentedPrompter
from conversation.form import ConversationForm
whisper = WhisperX(
device=launcher.get_device(),
device_index=0,
compute_type="float32",
batch_size=16,
)
model_class = ChatGPT
base_model = BasePrompter(model_class)
augmented_model = AugmentedPrompter(model_class)
pipeline = PairwisePipeline(
transcribe_model=whisper,
generate_model_1=base_model,
generate_model_2=augmented_model,
)
config = LaunchConfig(**kwargs)
launcher.launch_gradio(pipeline, config, ConversationForm)
if __name__ == "__main__":
fire.Fire(main)