Skip to content

Commit

Permalink
feat: add cluster name stop start and reset functions to skypilot
Browse files Browse the repository at this point in the history
  • Loading branch information
ZackBradshaw committed Dec 10, 2023
1 parent 3f011ec commit 7442144
Showing 1 changed file with 151 additions and 135 deletions.
286 changes: 151 additions & 135 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sky
import boto3
from transformers import AutoTokenizer
from botocore.exceptions import NoCredentialsError
Expand Down Expand Up @@ -178,7 +179,15 @@ def load_tools():
print(f"all_tools_list: {all_tools_list}") # Debugging line
return gr.update(choices=all_tools_list)

def set_environ(OPENAI_API_KEY: str = "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",
def start_sky_pilot(cluster_name: str):
sky.start(cluster_name)

def stop_sky_pilot(cluster_name: str):
sky.stop(cluster_name)

def status_sky_pilot(cluster_name: str):
return sky.status(cluster_name)
def set_environ(OPENAI_API_KEY: str = "",
WOLFRAMALPH_APP_ID: str = "",
WEATHER_API_KEYS: str = "",
BING_SUBSCRIPT_KEY: str = "",
Expand Down Expand Up @@ -331,26 +340,25 @@ def fetch_tokenizer(model_name):
return f"Error loading tokenizer: {str(e)}"

# Add this function to handle the button click
import sky

def deploy_on_sky_pilot(model_name: str, tokenizer: str, accelerators: str):
# Create serving.yaml for SkyPilot deployment
serving_yaml = {
"resources": {
"accelerators": accelerators
},
"envs": {
# Create a SkyPilot Task
#TODO have ai generate a yaml file for the configuration the user desires add this as a tool
task = sky.Task(
setup="conda create -n vllm python=3.9 -y\nconda activate vllm\ngit clone https://github.com/vllm-project/vllm.git\ncd vllm\npip install .\npip install gradio",
run="conda activate vllm\necho 'Starting vllm api server...'\npython -u -m vllm.entrypoints.api_server --model $MODEL_NAME --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE --tokenizer $TOKENIZER 2>&1 | tee api_server.log &\necho 'Waiting for vllm api server to start...'\nwhile ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done\necho 'Starting gradio server...'\npython vllm/examples/gradio_webserver.py",
envs={
"MODEL_NAME": model_name,
"TOKENIZER": AutoTokenizer.from_pretrained(model_name)
},
"setup": "conda create -n vllm python=3.9 -y\nconda activate vllm\ngit clone https://github.com/vllm-project/vllm.git\ncd vllm\npip install .\npip install gradio",
"run": "conda activate vllm\necho 'Starting vllm api server...'\npython -u -m vllm.entrypoints.api_server --model $MODEL_NAME --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE --tokenizer $TOKENIZER 2>&1 | tee api_server.log &\necho 'Waiting for vllm api server to start...'\nwhile ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done\necho 'Starting gradio server...'\npython vllm/examples/gradio_webserver.py"
}

# Write serving.yaml to file
with open('serving.yaml', 'w') as f:
yaml.dump(serving_yaml, f)
resources={
"accelerators": accelerators
}
)

# Deploy on SkyPilot
os.system("sky launch serving.yaml")
# Launch the task on SkyPilot
sky.launch(task,cluster_name=cluster_name)

# Add this line where you define your Gradio interface

Expand All @@ -364,128 +372,136 @@ def deploy_on_sky_pilot(model_name: str, tokenizer: str, accelerators: str):

# with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as demo:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=14):
gr.Markdown("")
with gr.Column(scale=1):
gr.Image(show_label=False, show_download_button=False, value="images/swarmslogobanner.png")

with gr.Tab("Key setting"):
OPENAI_API_KEY = gr.Textbox(label="OpenAI API KEY:", placeholder="sk-...", type="text")
WOLFRAMALPH_APP_ID = gr.Textbox(label="Wolframalpha app id:", placeholder="Key to use wlframalpha", type="text")
WEATHER_API_KEYS = gr.Textbox(label="Weather api key:", placeholder="Key to use weather api", type="text")
BING_SUBSCRIPT_KEY = gr.Textbox(label="Bing subscript key:", placeholder="Key to use bing search", type="text")
ALPHA_VANTAGE_KEY = gr.Textbox(label="Stock api key:", placeholder="Key to use stock api", type="text")
BING_MAP_KEY = gr.Textbox(label="Bing map key:", placeholder="Key to use bing map", type="text")
BAIDU_TRANSLATE_KEY = gr.Textbox(label="Baidu translation key:", placeholder="Key to use baidu translation", type="text")
RAPIDAPI_KEY = gr.Textbox(label="Rapidapi key:", placeholder="Key to use zillow, airbnb and job search", type="text")
SERPER_API_KEY = gr.Textbox(label="Serper key:", placeholder="Key to use google serper and google scholar", type="text")
GPLACES_API_KEY = gr.Textbox(label="Google places key:", placeholder="Key to use google places", type="text")
SCENEX_API_KEY = gr.Textbox(label="Scenex api key:", placeholder="Key to use sceneXplain", type="text")
STEAMSHIP_API_KEY = gr.Textbox(label="Steamship api key:", placeholder="Key to use image generation", type="text")
HUGGINGFACE_API_KEY = gr.Textbox(label="Huggingface api key:", placeholder="Key to use models in huggingface hub", type="text")
HUGGINGFACE_TOKEN = gr.Textbox(label="HuggingFace Token:", placeholder="Token for huggingface", type="text"),
AMADEUS_ID = gr.Textbox(label="Amadeus id:", placeholder="Id to use Amadeus", type="text")
AMADEUS_KEY = gr.Textbox(label="Amadeus key:", placeholder="Key to use Amadeus", type="text")
AWS_ACCESS_KEY_ID = gr.Textbox(label="AWS Access Key ID:", placeholder="AWS Access Key ID", type="text")
AWS_SECRET_ACCESS_KEY = gr.Textbox(label="AWS Secret Access Key:", placeholder="AWS Secret Access Key", type="text")
AWS_DEFAULT_REGION = gr.Textbox(label="AWS Default Region:", placeholder="AWS Default Region", type="text")
key_set_btn = gr.Button(value="Set keys!")


with gr.Tab("Chat with Tool"):
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(show_label=False, placeholder="Question here. Use Shift+Enter to add new line.",
lines=1).style(container=False)
with gr.Column(scale=0.15, min_width=0):
buttonChat = gr.Button("Chat")

memory_utilization = gr.Slider(label="Memory Utilization:", min=0, max=1, step=0.1, default=0.5)

chatbot = gr.Chatbot(show_label=False, visible=True).style(height=600)
buttonClear = gr.Button("Clear History")
buttonStop = gr.Button("Stop", visible=False)

with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=1):
model_url = gr.Textbox(label="VLLM Model URL:", placeholder="URL to download VLLM model from Hugging Face", type="text");
buttonDownload = gr.Button("Download Model");
buttonDownload.click(fn=download_model, inputs=[model_url, memory_utilization]);
model_chosen = gr.Dropdown(
list(available_models),
value=DEFAULTMODEL,
multiselect=False,
label="Model provided",
info="Choose the model to solve your question, Default means ChatGPT."
with gr.Column(scale=14):
gr.Markdown("")
with gr.Column(scale=1):
gr.Image(show_label=False, show_download_button=False, value="images/swarmslogobanner.png")

with gr.Tab("Key setting"):
OPENAI_API_KEY = gr.Textbox(label="OpenAI API KEY:", placeholder="sk-...", type="text")
WOLFRAMALPH_APP_ID = gr.Textbox(label="Wolframalpha app id:", placeholder="Key to use wlframalpha", type="text")
WEATHER_API_KEYS = gr.Textbox(label="Weather api key:", placeholder="Key to use weather api", type="text")
BING_SUBSCRIPT_KEY = gr.Textbox(label="Bing subscript key:", placeholder="Key to use bing search", type="text")
ALPHA_VANTAGE_KEY = gr.Textbox(label="Stock api key:", placeholder="Key to use stock api", type="text")
BING_MAP_KEY = gr.Textbox(label="Bing map key:", placeholder="Key to use bing map", type="text")
BAIDU_TRANSLATE_KEY = gr.Textbox(label="Baidu translation key:", placeholder="Key to use baidu translation", type="text")
RAPIDAPI_KEY = gr.Textbox(label="Rapidapi key:", placeholder="Key to use zillow, airbnb and job search", type="text")
SERPER_API_KEY = gr.Textbox(label="Serper key:", placeholder="Key to use google serper and google scholar", type="text")
GPLACES_API_KEY = gr.Textbox(label="Google places key:", placeholder="Key to use google places", type="text")
SCENEX_API_KEY = gr.Textbox(label="Scenex api key:", placeholder="Key to use sceneXplain", type="text")
STEAMSHIP_API_KEY = gr.Textbox(label="Steamship api key:", placeholder="Key to use image generation", type="text")
HUGGINGFACE_API_KEY = gr.Textbox(label="Huggingface api key:", placeholder="Key to use models in huggingface hub", type="text")
AMADEUS_KEY = gr.Textbox(label="Amadeus key:", placeholder="Key to use Amadeus", type="text")
AMADEUS_ID = gr.Textbox(label="Amadeus ID:", placeholder="Amadeus ID",
type="text")
AWS_ACCESS_KEY_ID = gr.Textbox(label="AWS Access Key ID:", placeholder="AWS Access Key ID", type="text")
AWS_SECRET_ACCESS_KEY = gr.Textbox(label="AWS Secret Access Key:", placeholder="AWS Secret Access Key", type="text")
AWS_DEFAULT_REGION = gr.Textbox(label="AWS Default Region:", placeholder="AWS Default Region", type="text")
key_set_btn = gr.Button(value="Set keys!")


with gr.Tab("Chat with Tool"):
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(show_label=False, placeholder="Question here. Use Shift+Enter to add new line.",
lines=1).style(container=False)
with gr.Column(scale=0.15, min_width=0):
buttonChat = gr.Button("Chat")

memory_utilization = gr.Slider(label="Memory Utilization:", min=0, max=1, step=0.1, default=0.5)

chatbot = gr.Chatbot(show_label=False, visible=True).style(height=600)
buttonClear = gr.Button("Clear History")
buttonStop = gr.Button("Stop", visible=False)

with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=1):
model_url = gr.Textbox(label="VLLM Model URL:", placeholder="URL to download VLLM model from Hugging Face", type="text");
buttonDownload = gr.Button("Download Model");
buttonDownload.click(fn=download_model, inputs=[model_url, memory_utilization]);
model_chosen = gr.Dropdown(
list(available_models),
value=DEFAULTMODEL,
multiselect=False,
label="Model provided",
info="Choose the model to solve your question, Default means ChatGPT."
)
tokenizer_output = gr.outputs.Textbox(label="Tokenizer")
cluster_name = gr.outputs.Textbox(label="Cluster")
model_chosen.change(fetch_tokenizer, outputs=tokenizer_output)
available_accelerators = ["A100", "V100", "P100", "K80", "T4", "P4"]
accelerators = gr.Dropdown(available_accelerators, label="Accelerators:")
buttonDeploy = gr.Button("Deploy on SkyPilot")

buttonDeploy.click(deploy_on_sky_pilot, [model_chosen, tokenizer_output, accelerators, HUGGINGFACE_API_KEY])
buttonStart = gr.Button("Start SkyPilot")
buttonStart.click(start_sky_pilot, [cluster_name])

buttonStop = gr.Button("Stop SkyPilot")
buttonStop.click(stop_sky_pilot, [cluster_name])

buttonStatus = gr.Button("Check SkyPilot Status")
buttonStatus.click(status_sky_pilot, [cluster_name])
with gr.Row():
tools_search = gr.Textbox(
lines=1,
label="Tools Search",
placeholder="Please input some text to search tools.",
)
tokenizer_output = gr.outputs.Textbox(label="Tokenizer")
model_chosen.change(fetch_tokenizer, outputs=tokenizer_output)
available_accelerators = ["A100", "V100", "P100", "K80", "T4", "P4"]
accelerators = gr.Dropdown(available_accelerators, label="Accelerators:")
buttonDeploy = gr.Button("Deploy on SkyPilot")

buttonDeploy.click(deploy_on_sky_pilot, [model_chosen, tokenizer_output, accelerators, HUGGINGFACE_TOKEN])
with gr.Row():
tools_search = gr.Textbox(
lines=1,
label="Tools Search",
placeholder="Please input some text to search tools.",
buttonSearch = gr.Button("Reset search condition")
tools_chosen = gr.CheckboxGroup(
choices=all_tools_list,
# value=["chemical-prop"],
label="Tools provided",
info="Choose the tools to solve your question.",
)
buttonSearch = gr.Button("Reset search condition")
tools_chosen = gr.CheckboxGroup(
choices=all_tools_list,
# value=["chemical-prop"],
label="Tools provided",
info="Choose the tools to solve your question.",
)


# TODO finish integrating model flow
# with gr.Tab("model"):
# create_inferance();
# def serve_iframe():
# return f'hi'

# TODO fix webgl galaxy backgroun
# def serve_iframe():
# return "<iframe src='http://localhost:8000/shader.html' width='100%' height='400'></iframe>"

# iface = gr.Interface(fn=serve_iframe, inputs=[], outputs=gr.outputs.HTML())

key_set_btn.click(fn=set_environ, inputs=[
OPENAI_API_KEY,
WOLFRAMALPH_APP_ID,
WEATHER_API_KEYS,
BING_SUBSCRIPT_KEY,
ALPHA_VANTAGE_KEY,
BING_MAP_KEY,
BAIDU_TRANSLATE_KEY,
RAPIDAPI_KEY,
SERPER_API_KEY,
GPLACES_API_KEY,
SCENEX_API_KEY,
STEAMSHIP_API_KEY,
HUGGINGFACE_API_KEY,
HUGGINGFACE_TOKEN,
AMADEUS_ID,
AMADEUS_KEY,
], outputs=key_set_btn)
key_set_btn.click(fn=load_tools, outputs=tools_chosen)

tools_search.change(retrieve, tools_search, tools_chosen)
buttonSearch.click(clear_retrieve, [], [tools_search, tools_chosen])

txt.submit(lambda: [gr.update(value=''), gr.update(visible=False), gr.update(visible=True)], [],
[txt, buttonClear, buttonStop])
inference_event = txt.submit(answer_by_tools, [txt, tools_chosen, model_chosen], [chatbot, buttonClear, buttonStop])
buttonChat.click(answer_by_tools, [txt, tools_chosen, model_chosen], [chatbot, buttonClear, buttonStop])
buttonStop.click(lambda: [gr.update(visible=True), gr.update(visible=False)], [], [buttonClear, buttonStop],
cancels=[inference_event])
buttonClear.click(clear_history, [], chatbot)


# TODO finish integrating model flow
# with gr.Tab("model"):
# create_inferance();
# def serve_iframe():
# return f'hi'

# TODO fix webgl galaxy backgroun
# def serve_iframe():
# return "<iframe src='http://localhost:8000/shader.html' width='100%' height='400'></iframe>"

# iface = gr.Interface(fn=serve_iframe, inputs=[], outputs=gr.outputs.HTML())

key_set_btn.click(fn=set_environ, inputs=[
OPENAI_API_KEY,
WOLFRAMALPH_APP_ID,
WEATHER_API_KEYS,
BING_SUBSCRIPT_KEY,
ALPHA_VANTAGE_KEY,
BING_MAP_KEY,
BAIDU_TRANSLATE_KEY,
RAPIDAPI_KEY,
SERPER_API_KEY,
GPLACES_API_KEY,
SCENEX_API_KEY,
STEAMSHIP_API_KEY,
HUGGINGFACE_API_KEY,
AMADEUS_ID,
AMADEUS_KEY,
], outputs=key_set_btn)
key_set_btn.click(fn=load_tools, outputs=tools_chosen)

tools_search.change(retrieve, tools_search, tools_chosen)
buttonSearch.click(clear_retrieve, [], [tools_search, tools_chosen])

txt.submit(lambda: [gr.update(value=''), gr.update(visible=False), gr.update(visible=True)], [],
[txt, buttonClear, buttonStop])
inference_event = txt.submit(answer_by_tools, [txt, tools_chosen, model_chosen], [chatbot, buttonClear, buttonStop])
buttonChat.click(answer_by_tools, [txt, tools_chosen, model_chosen], [chatbot, buttonClear, buttonStop])
buttonStop.click(lambda: [gr.update(visible=True), gr.update(visible=False)], [], [buttonClear, buttonStop],
cancels=[inference_event])
buttonClear.click(clear_history, [], chatbot)

# demo.queue().launch(share=False, inbrowser=True, server_name="127.0.0.1", server_port=7001)
demo.queue().launch()
Expand Down

0 comments on commit 7442144

Please sign in to comment.