From b522b29cd562a462d0c2061136d8b92cea5c28c0 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:44:16 -0700 Subject: [PATCH] chat mode improvements (#244) * chat mode improvements * disable int4 on macos/x86 because of old nightlies * typo * typo * typo * convert runtime error to arning * wording of option texts --- .github/workflows/eager-dtype.yml | 9 ++++++--- build/builder.py | 22 +++++++++++++++++++++- cli.py | 7 ++++++- generate.py | 18 ++++++++++++------ 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/.github/workflows/eager-dtype.yml b/.github/workflows/eager-dtype.yml index d73832dde..db430d9db 100644 --- a/.github/workflows/eager-dtype.yml +++ b/.github/workflows/eager-dtype.yml @@ -77,9 +77,12 @@ jobs: echo "******************************************" echo "******** INT4 group-wise quantized *******" echo "******************************************" - - python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - cat ./output_eager + + echo "INT4 should work on MacOS on x86, but cannot be tested" + echo "because nightlies are too old!" + + # python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # cat ./output_eager echo "tests complete for ${DTYPE}" done diff --git a/build/builder.py b/build/builder.py index 674a58afc..ac50c6c16 100644 --- a/build/builder.py +++ b/build/builder.py @@ -35,7 +35,8 @@ class BuilderArgs: precision: torch.dtype = torch.float32 setup_caches: bool = False use_tp: bool = False - + is_chat_model: bool = False + def __post_init__(self): if not ( (self.checkpoint_path and self.checkpoint_path.is_file()) @@ -66,6 +67,24 @@ def __post_init__(self): @classmethod def from_args(cls, args): # -> BuilderArgs: + is_chat_model = False + if args.is_chat_model: + is_chat_model = True + else: + for path in [ + args.checkpoint_path, + args.checkpoint_dir, + args.dso_path, + args.pte_path, + args.gguf_path + ]: + path = str(path) + if path.endswith('/'): + path = path[:-1] + path_basename = os.path.basename(path) + if "chat" in path_basename: + is_chat_model = True + return cls( checkpoint_path=args.checkpoint_path, checkpoint_dir=args.checkpoint_dir, @@ -78,6 +97,7 @@ def from_args(cls, args): # -> BuilderArgs: precision=name_to_dtype(args.dtype), setup_caches=(args.output_dso_path or args.output_pte_path), use_tp=False, + is_chat_model=is_chat_model, ) @classmethod diff --git a/cli.py b/cli.py index d34b94e38..ef43e7c77 100644 --- a/cli.py +++ b/cli.py @@ -76,7 +76,12 @@ def _add_arguments_common(parser): parser.add_argument( "--chat", action="store_true", - help="Use torchchat to for an interactive chat session.", + help="Use torchchat for an interactive chat session.", + ) + parser.add_argument( + "--is-chat-model", + action="store_true", + help="Indicate that the model was trained to support chat functionality.", ) parser.add_argument( "--gui", diff --git a/generate.py b/generate.py index 2dfdcac45..69f94fa86 100644 --- a/generate.py +++ b/generate.py @@ -27,6 +27,7 @@ from cli import add_arguments_for_generate, arg_init, check_args from quantize import set_precision +B_INST, E_INST = "[INST]", "[/INST]" @dataclass class GeneratorArgs: @@ -343,11 +344,16 @@ def _main( set_precision(builder_args.precision) is_speculative = speculative_builder_args.checkpoint_path is not None - is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path)) - if is_chat: - raise RuntimeError( - "need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!" - ) + if generator_args.chat_mode and not builder_args.is_chat_model: + print(""" +******************************************************* + This model is not known to support the chat function. + We will enable chat mode based on your instructions. + If the model is not trained to support chat, it will + produce nonsensical or false output. +******************************************************* + """) + # raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.") tokenizer = _initialize_tokenizer(tokenizer_args) @@ -410,7 +416,7 @@ def _main( device_sync(device=builder_args.device) if i >= 0 and generator_args.chat_mode: prompt = input("What is your prompt? ") - if is_chat: + if builder_args.is_chat_model: prompt = f"{B_INST} {prompt.strip()} {E_INST}" encoded = encode_tokens( tokenizer, prompt, bos=True, device=builder_args.device