diff --git a/chat_in_browser.py b/chat_in_browser.py new file mode 100644 index 000000000..94c660b9e --- /dev/null +++ b/chat_in_browser.py @@ -0,0 +1,52 @@ +# -*- coding: UTF-8 -*- +from flask import Flask, render_template, request +from cli import add_arguments_for_generate, arg_init, check_args +from generate import main as generate_main +import subprocess +import sys + + +convo = "" + +def create_app(*args): + app = Flask(__name__) + + import subprocess + # create a new process and set up pipes for communication + proc = subprocess.Popen(["python", "generate.py", *args], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE) + + @app.route('/') + def main(): + output = "" + while True: + line = proc.stdout.readline() + if line.decode('utf-8').startswith("What is your prompt?"): + break + output += line.decode('utf-8').strip() + "\n" + return render_template('chat.html', convo="Hello! What is your prompt?") + + @app.route('/chat', methods=['POST']) + def chat(): + # Retrieve the HTTP POST request parameter value from 'request.form' dictionary + _prompt = request.form.get('prompt') + proc.stdin.write((_prompt + "\n").encode('utf-8')) + proc.stdin.flush() + + output = "" + while True: + line = proc.stdout.readline() + if line.decode('utf-8').startswith("What is your prompt?"): + break + output += line.decode('utf-8').strip() + "\n" + + global convo + + if _prompt: + convo += "Your prompt:\n" + _prompt + "\n\n" + convo += "My response:\n" + output + "\n\n" + + return render_template('chat.html', convo=convo) + + return app diff --git a/cli.py b/cli.py index ef43e7c77..0b0c3b46a 100644 --- a/cli.py +++ b/cli.py @@ -54,6 +54,9 @@ def add_arguments_for_export(parser): # Only export specific options should be here _add_arguments_common(parser) +def add_arguments_for_browser(parser): + # Only export specific options should be here + _add_arguments_common(parser) def _add_arguments_common(parser): # TODO: Refactor this so that only common options are here diff --git a/generate.py b/generate.py index 65dd60910..170ec7fbd 100644 --- a/generate.py +++ b/generate.py @@ -452,7 +452,7 @@ def _main( for i in range(start, generator_args.num_samples): device_sync(device=builder_args.device) if i >= 0 and generator_args.chat_mode: - prompt = input("What is your prompt? ") + prompt = input("What is your prompt? \n") if builder_args.is_chat_model: prompt = f"{B_INST} {prompt.strip()} {E_INST}" encoded = encode_tokens( diff --git a/requirements.txt b/requirements.txt index 19e028bcd..e9eb5b7c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,6 @@ wheel cmake ninja zstd + +# Browser mode +flask diff --git a/templates/chat.html b/templates/chat.html new file mode 100644 index 000000000..55590919e --- /dev/null +++ b/templates/chat.html @@ -0,0 +1,15 @@ + + + + + torchchat + + +
{{ convo }}
+
+ +
+ +
+ + diff --git a/torchchat.py b/torchchat.py index 9008b9032..86c6d4f4f 100644 --- a/torchchat.py +++ b/torchchat.py @@ -6,11 +6,14 @@ import argparse import logging +import subprocess +import sys from cli import ( add_arguments_for_eval, add_arguments_for_export, add_arguments_for_generate, + add_arguments_for_browser, arg_init, check_args, ) @@ -34,6 +37,9 @@ parser_export = subparsers.add_parser("export") add_arguments_for_export(parser_export) + parser_browser = subparsers.add_parser("browser") + add_arguments_for_browser(parser_browser) + args = parser.parse_args() args = arg_init(args) logging.basicConfig( @@ -54,5 +60,13 @@ from export import main as export_main export_main(args) + elif args.subcommand == "browser": + # TODO: add check_args() + + # Assume the user wants "chat" when entering "browser". TODO: add support for "generate" as well + args_plus_chat = ['"{}"'.format(s) for s in sys.argv[2:]] + ["\"--chat\""] + formatted_args = ", ".join(args_plus_chat) + command = ["flask", "--app", "chat_in_browser:create_app(" + formatted_args + ")", "run"] + subprocess.run(command) else: raise RuntimeError("Must specify valid subcommands: generate, export, eval")