Skip to content

Commit

Permalink
Refactor config_util.py to handle gitlab host input
Browse files Browse the repository at this point in the history
  • Loading branch information
yangbobo2021 authored and kagami-l committed May 24, 2024
1 parent 619ac65 commit f0e33f9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
29 changes: 14 additions & 15 deletions merico/pr/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
import sys

# add the current directory to the path
from os.path import abspath, dirname

from lib.ide_service import IDEService

sys.path.append(dirname(dirname(abspath(__file__))))

# add new model configs to algo.MAX_TOKENS
import pr_agent.algo as algo

from lib.ide_service import IDEService
from merico.pr.config_util import get_model_max_input

algo.MAX_TOKENS["gpt-4-turbo-preview"] = 128000
algo.MAX_TOKENS["claude-3-opus"] = 100000
algo.MAX_TOKENS["claude-3-sonnet"] = 100000
Expand All @@ -42,12 +39,9 @@
algo.MAX_TOKENS["BAAI/bge-base-en-v1.5"] = 512
algo.MAX_TOKENS["sentence-transformers/msmarco-bert-base-dot-v5"] = 512
algo.MAX_TOKENS["bert-base-uncased"] = 512
if os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106") not in algo.MAX_TOKENS:
current_model = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")
IDEService().ide_logging(
"info", f"{current_model}'s max tokens is not config, we use it as default 16000"
)
algo.MAX_TOKENS[os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")] = 16000

current_model = os.environ.get("LLM_MODEL", "gpt-3.5-turbo-1106")
algo.MAX_TOKENS[current_model] = get_model_max_input(current_model)


# add new git provider
Expand All @@ -62,7 +56,8 @@ def get_git_provider():


import pr_agent.git_providers as git_providers
from providers.devchat_provider import DevChatProvider

from merico.pr.providers.devchat_provider import DevChatProvider

git_providers._GIT_PROVIDERS["devchat"] = DevChatProvider
_get_git_provider_old = git_providers.get_git_provider
Expand Down Expand Up @@ -103,8 +98,12 @@ def close(self):
)


from config_util import get_gitlab_host, get_repo_type, read_server_access_token_with_input
from custom_suggestions_config import get_custom_suggestions_system_prompt
from merico.pr.config_util import (
get_gitlab_host,
get_repo_type,
read_server_access_token_with_input,
)
from merico.pr.custom_suggestions_config import get_custom_suggestions_system_prompt

# set openai key and api base
get_settings().set("OPENAI.KEY", os.environ.get("OPENAI_API_KEY", ""))
Expand Down
16 changes: 16 additions & 0 deletions merico/pr/config_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os

import yaml

from lib.chatmark import Radio, TextEditor


Expand Down Expand Up @@ -224,3 +226,17 @@ def get_gitlab_host(pr_url):
gitlab_host_map[pr_host] = host
_save_config_value("gitlab_host_map", gitlab_host_map)
return host


def get_model_max_input(model):
config_file = os.path.expanduser("~/.chat/config.yml")
try:
with open(config_file, "r", encoding="utf-8") as file:
yaml_contents = file.read()
parsed_yaml = yaml.safe_load(yaml_contents)
for model_t in parsed_yaml.get("models", {}):
if model_t == model:
return parsed_yaml["models"][model_t].get("max_input_tokens", 6000)
return 6000
except Exception:
return 6000

0 comments on commit f0e33f9

Please sign in to comment.