Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
prompt template as args
Browse files Browse the repository at this point in the history
  • Loading branch information
filopedraz committed Nov 8, 2023
1 parent 0543e13 commit 4b2c295
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion cht-llama-cpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,22 @@
MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'yarn-mistral-7b-128k.Q4_K_M')}.gguf"
# Mistral gguf follows ChatML syntax
# https://github.com/openai/openai-python/blob/main/chatml.md
PROMPT_TEMPLATE_STRING = '{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}' # noqa

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", help="Path to GGUF", default=MODEL_PATH)
parser.add_argument("--port", help="Port to run model server on", type=int, default=8000)
parser.add_argument("--ctx", help="Context dimension", type=int, default=4096)
parser.add_argument(
"--prompt_template",
help="Prompt Template",
type=str,
default='{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}', # noqa
) # noqa
args = parser.parse_args()
MODEL_PATH = args.model_path
MODEL_CTX = args.ctx
PROMPT_TEMPLATE_STRING = args.prompt_template

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
Expand Down
2 changes: 1 addition & 1 deletion cht-llama-cpp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def generate(
@classmethod
def get_model(cls, model_path, prompt_template_jsonstr, n_ctx):
chat_format = "llama-2"
if "mistral" in model_path:
if prompt_template_jsonstr != "" and "mistral" in model_path:
cls.PROMPT_TEMPLATE = json.loads(prompt_template_jsonstr)
chat_format = cls.PROMPT_TEMPLATE.get("template_format", "chatml")
if cls.model is None:
Expand Down

0 comments on commit 4b2c295

Please sign in to comment.