Skip to content

Commit

Permalink
chat in browser (#242)
Browse files Browse the repository at this point in the history
* chat in browser

* remove jinja2 comment seems irrelavant

* remove jinja2 comment seems irrelavant

* remove debug prints

* use torchchat as entry point
  • Loading branch information
Olivia-liu authored and malfet committed Jul 17, 2024
1 parent d174616 commit 391a846
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 1 deletion.
52 changes: 52 additions & 0 deletions chat_in_browser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: UTF-8 -*-
from flask import Flask, render_template, request
from cli import add_arguments_for_generate, arg_init, check_args
from generate import main as generate_main
import subprocess
import sys


convo = ""

def create_app(*args):
app = Flask(__name__)

import subprocess
# create a new process and set up pipes for communication
proc = subprocess.Popen(["python", "generate.py", *args],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)

@app.route('/')
def main():
output = ""
while True:
line = proc.stdout.readline()
if line.decode('utf-8').startswith("What is your prompt?"):
break
output += line.decode('utf-8').strip() + "\n"
return render_template('chat.html', convo="Hello! What is your prompt?")

@app.route('/chat', methods=['POST'])
def chat():
# Retrieve the HTTP POST request parameter value from 'request.form' dictionary
_prompt = request.form.get('prompt')
proc.stdin.write((_prompt + "\n").encode('utf-8'))
proc.stdin.flush()

output = ""
while True:
line = proc.stdout.readline()
if line.decode('utf-8').startswith("What is your prompt?"):
break
output += line.decode('utf-8').strip() + "\n"

global convo

if _prompt:
convo += "Your prompt:\n" + _prompt + "\n\n"
convo += "My response:\n" + output + "\n\n"

return render_template('chat.html', convo=convo)

return app
3 changes: 3 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def add_arguments_for_export(parser):
# Only export specific options should be here
_add_arguments_common(parser)

def add_arguments_for_browser(parser):
# Only export specific options should be here
_add_arguments_common(parser)

def _add_arguments_common(parser):
# TODO: Refactor this so that only common options are here
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def _main(
for i in range(start, generator_args.num_samples):
device_sync(device=builder_args.device)
if i >= 0 and generator_args.chat_mode:
prompt = input("What is your prompt? ")
prompt = input("What is your prompt? \n")
if builder_args.is_chat_model:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ wheel
cmake
ninja
zstd

# Browser mode
flask
15 changes: 15 additions & 0 deletions templates/chat.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>torchchat</title>
</head>
<body>
<pre>{{ convo }}</pre>
<form action="chat" method="post">
<label for="username">Prompt: </label>
<input type="text" id="prompt" name="prompt"><br>
<input type="submit" value="SEND">
</form>
</body>
</html>
14 changes: 14 additions & 0 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

import argparse
import logging
import subprocess
import sys

from cli import (
add_arguments_for_eval,
add_arguments_for_export,
add_arguments_for_generate,
add_arguments_for_browser,
arg_init,
check_args,
)
Expand All @@ -34,6 +37,9 @@
parser_export = subparsers.add_parser("export")
add_arguments_for_export(parser_export)

parser_browser = subparsers.add_parser("browser")
add_arguments_for_browser(parser_browser)

args = parser.parse_args()
args = arg_init(args)
logging.basicConfig(
Expand All @@ -54,5 +60,13 @@
from export import main as export_main

export_main(args)
elif args.subcommand == "browser":
# TODO: add check_args()

# Assume the user wants "chat" when entering "browser". TODO: add support for "generate" as well
args_plus_chat = ['"{}"'.format(s) for s in sys.argv[2:]] + ["\"--chat\""]
formatted_args = ", ".join(args_plus_chat)
command = ["flask", "--app", "chat_in_browser:create_app(" + formatted_args + ")", "run"]
subprocess.run(command)
else:
raise RuntimeError("Must specify valid subcommands: generate, export, eval")

0 comments on commit 391a846

Please sign in to comment.