From 20352afcae2bc0e50f473f9873c00f6659455698 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Wed, 17 Apr 2024 13:30:43 -0700 Subject: [PATCH 1/7] chat mode improvements --- build/builder.py | 22 +++++++++++++++++++++- cli.py | 5 +++++ generate.py | 10 ++++------ 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/build/builder.py b/build/builder.py index 674a58afc..8397aede5 100644 --- a/build/builder.py +++ b/build/builder.py @@ -35,7 +35,8 @@ class BuilderArgs: precision: torch.dtype = torch.float32 setup_caches: bool = False use_tp: bool = False - + is_chat_model: bool = False + def __post_init__(self): if not ( (self.checkpoint_path and self.checkpoint_path.is_file()) @@ -66,6 +67,24 @@ def __post_init__(self): @classmethod def from_args(cls, args): # -> BuilderArgs: + is_chat_mode = False + if args.is_chat_mode: + is_chat_mode = True + else: + for path in [ + args.checkpoint_path, + args.checkpoint_dir, + args.dso_path, + args.pte_path, + args.gguf_path + ]: + path = str(path) + if path.endswith('/'): + path = path[:-1] + path_basename = os.path.basename(path) + if "chat" in path_basename: + args.is_chat_mode = True + return cls( checkpoint_path=args.checkpoint_path, checkpoint_dir=args.checkpoint_dir, @@ -78,6 +97,7 @@ def from_args(cls, args): # -> BuilderArgs: precision=name_to_dtype(args.dtype), setup_caches=(args.output_dso_path or args.output_pte_path), use_tp=False, + is_chat_mode=is_chat_mode, ) @classmethod diff --git a/cli.py b/cli.py index d34b94e38..e8e37a170 100644 --- a/cli.py +++ b/cli.py @@ -78,6 +78,11 @@ def _add_arguments_common(parser): action="store_true", help="Use torchchat to for an interactive chat session.", ) + parser.add_argument( + "--is-chat-model", + action="store_true", + help="Use torchchat to for an interactive chat session.", + ) parser.add_argument( "--gui", action="store_true", diff --git a/generate.py b/generate.py index a199c7e4b..16861a4ef 100644 --- a/generate.py +++ b/generate.py @@ -27,6 +27,7 @@ from cli import add_arguments_for_generate, arg_init, check_args from quantize import set_precision +B_INST, E_INST = "[INST]", "[/INST]" @dataclass class GeneratorArgs: @@ -339,11 +340,8 @@ def _main( set_precision(builder_args.precision) is_speculative = speculative_builder_args.checkpoint_path is not None - is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path)) - if is_chat: - raise RuntimeError( - "need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!" - ) + if generator_args.chat_mode and not builder_args.is_chat_model: + raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.") tokenizer = _initialize_tokenizer(tokenizer_args) @@ -406,7 +404,7 @@ def _main( device_sync(device=builder_args.device) if i >= 0 and generator_args.chat_mode: prompt = input("What is your prompt? ") - if is_chat: + if builder_args.is_chat_model: prompt = f"{B_INST} {prompt.strip()} {E_INST}" encoded = encode_tokens( tokenizer, prompt, bos=True, device=builder_args.device From 1845215e83ffcacb8ea00d1e5e865578ed9d1481 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Wed, 17 Apr 2024 13:38:39 -0700 Subject: [PATCH 2/7] disable int4 on macos/x86 because of old nightlies --- .github/workflows/eager-dtype.yml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/eager-dtype.yml b/.github/workflows/eager-dtype.yml index d73832dde..db430d9db 100644 --- a/.github/workflows/eager-dtype.yml +++ b/.github/workflows/eager-dtype.yml @@ -77,9 +77,12 @@ jobs: echo "******************************************" echo "******** INT4 group-wise quantized *******" echo "******************************************" - - python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - cat ./output_eager + + echo "INT4 should work on MacOS on x86, but cannot be tested" + echo "because nightlies are too old!" + + # python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # cat ./output_eager echo "tests complete for ${DTYPE}" done From 434adf7a2467c456e2edc61b63dcd4c2ccb4958c Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Wed, 17 Apr 2024 13:44:56 -0700 Subject: [PATCH 3/7] typo --- build/builder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/build/builder.py b/build/builder.py index 8397aede5..3e2a8d604 100644 --- a/build/builder.py +++ b/build/builder.py @@ -67,9 +67,9 @@ def __post_init__(self): @classmethod def from_args(cls, args): # -> BuilderArgs: - is_chat_mode = False + is_chat_model = False if args.is_chat_mode: - is_chat_mode = True + is_chat_model = True else: for path in [ args.checkpoint_path, @@ -83,7 +83,7 @@ def from_args(cls, args): # -> BuilderArgs: path = path[:-1] path_basename = os.path.basename(path) if "chat" in path_basename: - args.is_chat_mode = True + is_chat_model = True return cls( checkpoint_path=args.checkpoint_path, @@ -97,7 +97,7 @@ def from_args(cls, args): # -> BuilderArgs: precision=name_to_dtype(args.dtype), setup_caches=(args.output_dso_path or args.output_pte_path), use_tp=False, - is_chat_mode=is_chat_mode, + is_chat_model=is_chat_model, ) @classmethod From 7e0bee2bcc8938a69c28b9a5150b7cb4a970294d Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Wed, 17 Apr 2024 13:56:52 -0700 Subject: [PATCH 4/7] typo --- build/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/builder.py b/build/builder.py index 3e2a8d604..cb5970726 100644 --- a/build/builder.py +++ b/build/builder.py @@ -68,7 +68,7 @@ def __post_init__(self): @classmethod def from_args(cls, args): # -> BuilderArgs: is_chat_model = False - if args.is_chat_mode: + if args.is_chat_model: is_chat_model = True else: for path in [ From 70b06b65f8937272e0c42662294b38f7ee707fbe Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Wed, 17 Apr 2024 14:05:53 -0700 Subject: [PATCH 5/7] typo --- build/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/builder.py b/build/builder.py index cb5970726..ac50c6c16 100644 --- a/build/builder.py +++ b/build/builder.py @@ -81,7 +81,7 @@ def from_args(cls, args): # -> BuilderArgs: path = str(path) if path.endswith('/'): path = path[:-1] - path_basename = os.path.basename(path) + path_basename = os.path.basename(path) if "chat" in path_basename: is_chat_model = True From c261cbd7844c037d33daf1620fea48be53f8165d Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Wed, 17 Apr 2024 14:38:14 -0700 Subject: [PATCH 6/7] convert runtime error to arning --- generate.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/generate.py b/generate.py index 16861a4ef..e64ce8f33 100644 --- a/generate.py +++ b/generate.py @@ -341,7 +341,15 @@ def _main( is_speculative = speculative_builder_args.checkpoint_path is not None if generator_args.chat_mode and not builder_args.is_chat_model: - raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.") + print(""" +******************************************************* + This model is not known to support the chat function. + We will enable chat mode based on your instructions. + If the model is not trained to support chat, it will + produce nonsensical or false output. +******************************************************* + """) + # raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.") tokenizer = _initialize_tokenizer(tokenizer_args) From 45495ff0af7bbcb2c08b177456539bba7048cb68 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Wed, 17 Apr 2024 14:39:33 -0700 Subject: [PATCH 7/7] wording of option texts --- cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli.py b/cli.py index e8e37a170..ef43e7c77 100644 --- a/cli.py +++ b/cli.py @@ -76,12 +76,12 @@ def _add_arguments_common(parser): parser.add_argument( "--chat", action="store_true", - help="Use torchchat to for an interactive chat session.", + help="Use torchchat for an interactive chat session.", ) parser.add_argument( "--is-chat-model", action="store_true", - help="Use torchchat to for an interactive chat session.", + help="Indicate that the model was trained to support chat functionality.", ) parser.add_argument( "--gui",