From 1609f57d60177438950482a9b28db45ad267999c Mon Sep 17 00:00:00 2001 From: technicca Date: Wed, 24 Apr 2024 19:07:43 +0300 Subject: [PATCH] fix passing base_ur; --- l2mac/llm_providers/general.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/l2mac/llm_providers/general.py b/l2mac/llm_providers/general.py index 5eba9c60..1b1db5e5 100644 --- a/l2mac/llm_providers/general.py +++ b/l2mac/llm_providers/general.py @@ -75,6 +75,7 @@ def get_llm_config(config, logger, name, rate_limiter): "stop": config.llm_settings.stop, "stream": config.llm_settings.api_stream, "api_key": config.llm.api_key, + "base_url": config.llm.base_url, "_open_ai_rate_limit_requests_per_minute": config.llm_settings.rate_limit_requests_per_minute, "_logger": logger, "_name": name, @@ -180,8 +181,9 @@ async def async_chat_completion_rl_inner(**kwargs): kwargs.get("_name", None) rate_limiter = kwargs.get("_rate_limiter", None) api_type = kwargs.get("api_type", ApiType.openai) + base_url = kwargs.get("base_url") if api_type == ApiType.openai: - aclient = AsyncOpenAI(api_key=kwargs["api_key"]) + aclient = AsyncOpenAI(api_key=kwargs["api_key"], base_url=kwargs["base_url"]) elif api_type == ApiType.azure: aclient = AsyncAzureOpenAI( api_key=kwargs["api_key"], api_version=kwargs["api_version"], azure_endpoint=kwargs["azure_endpoint"] @@ -196,6 +198,7 @@ async def async_chat_completion_rl_inner(**kwargs): "_rate_limiter", "stream", "api_type", + "base_url" } kwargs = {k: v for k, v in kwargs.items() if k not in keys_to_remove} perf_counter() @@ -242,7 +245,7 @@ def chat_completion_rl_inner(**kwargs): rate_limiter = kwargs.get("_rate_limiter", None) api_type = kwargs.get("api_type", ApiType.openai) if api_type == ApiType.openai: - client = OpenAI(api_key=kwargs["api_key"]) + client = OpenAI(api_key=kwargs["api_key"], base_url=kwargs["base_url"]) elif api_type == ApiType.azure: client = AzureOpenAI( api_key=kwargs["api_key"], api_version=kwargs["api_version"], azure_endpoint=kwargs["azure_endpoint"] @@ -257,6 +260,7 @@ def chat_completion_rl_inner(**kwargs): "_rate_limiter", "stream", "api_type", + "base_url" } kwargs = {k: v for k, v in kwargs.items() if k not in keys_to_remove} perf_counter()