From f0e33f9e74f844deffe0111d847814257a518161 Mon Sep 17 00:00:00 2001 From: bobo Date: Fri, 24 May 2024 18:24:59 +0800 Subject: [PATCH] Refactor config_util.py to handle gitlab host input --- merico/pr/command.py | 29 ++++++++++++++--------------- merico/pr/config_util.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/merico/pr/command.py b/merico/pr/command.py index 6dc26a6..d4909d1 100644 --- a/merico/pr/command.py +++ b/merico/pr/command.py @@ -8,15 +8,12 @@ import sys # add the current directory to the path -from os.path import abspath, dirname - -from lib.ide_service import IDEService - -sys.path.append(dirname(dirname(abspath(__file__)))) - # add new model configs to algo.MAX_TOKENS import pr_agent.algo as algo +from lib.ide_service import IDEService +from merico.pr.config_util import get_model_max_input + algo.MAX_TOKENS["gpt-4-turbo-preview"] = 128000 algo.MAX_TOKENS["claude-3-opus"] = 100000 algo.MAX_TOKENS["claude-3-sonnet"] = 100000 @@ -42,12 +39,9 @@ algo.MAX_TOKENS["BAAI/bge-base-en-v1.5"] = 512 algo.MAX_TOKENS["sentence-transformers/msmarco-bert-base-dot-v5"] = 512 algo.MAX_TOKENS["bert-base-uncased"] = 512 -if os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106") not in algo.MAX_TOKENS: - current_model = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106") - IDEService().ide_logging( - "info", f"{current_model}'s max tokens is not config, we use it as default 16000" - ) - algo.MAX_TOKENS[os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")] = 16000 + +current_model = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106") +algo.MAX_TOKENS[current_model] = get_model_max_input(current_model) # add new git provider @@ -62,7 +56,8 @@ def get_git_provider(): import pr_agent.git_providers as git_providers -from providers.devchat_provider import DevChatProvider + +from merico.pr.providers.devchat_provider import DevChatProvider git_providers._GIT_PROVIDERS["devchat"] = DevChatProvider _get_git_provider_old = git_providers.get_git_provider @@ -103,8 +98,12 @@ def close(self): ) -from config_util import get_gitlab_host, get_repo_type, read_server_access_token_with_input -from custom_suggestions_config import get_custom_suggestions_system_prompt +from merico.pr.config_util import ( + get_gitlab_host, + get_repo_type, + read_server_access_token_with_input, +) +from merico.pr.custom_suggestions_config import get_custom_suggestions_system_prompt # set openai key and api base get_settings().set("OPENAI.KEY", os.environ.get("OPENAI_API_KEY", "")) diff --git a/merico/pr/config_util.py b/merico/pr/config_util.py index 0d2bd9f..8b60158 100644 --- a/merico/pr/config_util.py +++ b/merico/pr/config_util.py @@ -1,6 +1,8 @@ import json import os +import yaml + from lib.chatmark import Radio, TextEditor @@ -224,3 +226,17 @@ def get_gitlab_host(pr_url): gitlab_host_map[pr_host] = host _save_config_value("gitlab_host_map", gitlab_host_map) return host + + +def get_model_max_input(model): + config_file = os.path.expanduser("~/.chat/config.yml") + try: + with open(config_file, "r", encoding="utf-8") as file: + yaml_contents = file.read() + parsed_yaml = yaml.safe_load(yaml_contents) + for model_t in parsed_yaml.get("models", {}): + if model_t == model: + return parsed_yaml["models"][model_t].get("max_input_tokens", 6000) + return 6000 + except Exception: + return 6000