Skip to content

Commit

Permalink
feat: Added support for custom models
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed May 4, 2024
1 parent 921af92 commit e888600
Show file tree
Hide file tree
Showing 15 changed files with 186 additions and 392 deletions.
17 changes: 15 additions & 2 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import sys
import commentjson as json
import colorama
from collections import defaultdict

from . import shared
from . import presets
from .presets import i18n


__all__ = [
Expand Down Expand Up @@ -100,14 +102,25 @@ def load_config_to_environ(key_list):
sensitive_id = config.get("sensitive_id", "")
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)

if "extra_model_metadata" in config:
presets.MODEL_METADATA.update(config["extra_model_metadata"])
logging.info(i18n("已添加 {extra_model_quantity} 个额外的模型元数据").format(extra_model_quantity=len(config["extra_model_metadata"])))

_model_metadata = {}
for k, v in presets.MODEL_METADATA.items():
temp_dict = presets.DEFAULT_METADATA.copy()
temp_dict.update(v)
_model_metadata[k] = temp_dict
presets.MODEL_METADATA = _model_metadata

if "available_models" in config:
presets.MODELS = config["available_models"]
logging.info(f"已设置可用模型:{config['available_models']}")
logging.info(i18n("已设置可用模型:{available_models}").format(available_models=config["available_models"]))

# 模型配置
if "extra_models" in config:
presets.MODELS.extend(config["extra_models"])
logging.info(f"已添加额外的模型:{config['extra_models']}")
logging.info(i18n("已添加额外的模型:{extra_models}").format(extra_models=config["extra_models"]))

HIDE_MY_KEY = config.get("hide_my_key", False)

Expand Down
2 changes: 1 addition & 1 deletion modules/models/Claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, model_name, api_secret) -> None:
self.api_secret = api_secret
if None in [self.api_secret]:
raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
self.claude_client = Anthropic(api_key=self.api_secret)
self.claude_client = Anthropic(api_key=self.api_secret, base_url=self.api_host)

def _get_claude_style_history(self):
history = []
Expand Down
15 changes: 9 additions & 6 deletions modules/models/DALLE3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

class OpenAI_DALLE3_Client(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.api_key = api_key
super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
if self.api_host is not None:
self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.format_openai_host(self.api_host)
else:
self.api_host, self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.state.api_host, shared.state.chat_completion_url, shared.state.images_completion_url, shared.state.openai_api_base, shared.state.balance_api_url, shared.state.usage_api_url
self._refresh_header()

def _get_dalle3_prompt(self):
Expand All @@ -24,7 +27,7 @@ def get_answer_at_once(self, stream=False):
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": "dall-e-3",
"model": self.model_name,
"prompt": prompt,
"n": 1,
"size": "1024x1024",
Expand All @@ -35,13 +38,13 @@ def get_answer_at_once(self, stream=False):
else:
timeout = TIMEOUT_ALL

if shared.state.images_completion_url != IMAGES_COMPLETION_URL:
logging.debug(f"使用自定义API URL: {shared.state.images_completion_url}")
if self.images_completion_url != IMAGES_COMPLETION_URL:
logging.debug(f"使用自定义API URL: {self.images_completion_url}")

with retrieve_proxy():
try:
response = requests.post(
shared.state.images_completion_url,
self.images_completion_url,
headers=headers,
json=payload,
stream=stream,
Expand Down
3 changes: 1 addition & 2 deletions modules/models/GoogleGemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

class GoogleGeminiClient(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.api_key = api_key
super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
if "vision" in model_name.lower():
self.multimodal = True
else:
Expand Down
5 changes: 2 additions & 3 deletions modules/models/GooglePaLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

class Google_PaLM_Client(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.api_key = api_key
super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})

def _get_palm_style_input(self):
new_history = []
Expand All @@ -20,7 +19,7 @@ def get_answer_at_once(self):
palm.configure(api_key=self.api_key)
messages = self._get_palm_style_input()
response = palm.chat(context=self.system_prompt, messages=messages,
temperature=self.temperature, top_p=self.top_p)
temperature=self.temperature, top_p=self.top_p, model=self.model_name)
if response.last is not None:
return response.last, len(response.last)
else:
Expand Down
4 changes: 2 additions & 2 deletions modules/models/Groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

class Groq_Client(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.api_key = api_key
super().__init__(model_name=model_name, user=user_name, api_key=api_key)
self.client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
base_url=self.api_host,
)

def _get_groq_style_input(self):
Expand Down
Loading

0 comments on commit e888600

Please sign in to comment.