Skip to content

Commit

Permalink
Merge pull request #2 from u66u/fix-base-url
Browse files Browse the repository at this point in the history
fix passing base_url
  • Loading branch information
samholt authored Apr 24, 2024
2 parents 741dba4 + 1609f57 commit aed84b8
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions l2mac/llm_providers/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_llm_config(config, logger, name, rate_limiter):
"stop": config.llm_settings.stop,
"stream": config.llm_settings.api_stream,
"api_key": config.llm.api_key,
"base_url": config.llm.base_url,
"_open_ai_rate_limit_requests_per_minute": config.llm_settings.rate_limit_requests_per_minute,
"_logger": logger,
"_name": name,
Expand Down Expand Up @@ -180,8 +181,9 @@ async def async_chat_completion_rl_inner(**kwargs):
kwargs.get("_name", None)
rate_limiter = kwargs.get("_rate_limiter", None)
api_type = kwargs.get("api_type", ApiType.openai)
base_url = kwargs.get("base_url")
if api_type == ApiType.openai:
aclient = AsyncOpenAI(api_key=kwargs["api_key"])
aclient = AsyncOpenAI(api_key=kwargs["api_key"], base_url=kwargs["base_url"])
elif api_type == ApiType.azure:
aclient = AsyncAzureOpenAI(
api_key=kwargs["api_key"], api_version=kwargs["api_version"], azure_endpoint=kwargs["azure_endpoint"]
Expand All @@ -196,6 +198,7 @@ async def async_chat_completion_rl_inner(**kwargs):
"_rate_limiter",
"stream",
"api_type",
"base_url"
}
kwargs = {k: v for k, v in kwargs.items() if k not in keys_to_remove}
perf_counter()
Expand Down Expand Up @@ -242,7 +245,7 @@ def chat_completion_rl_inner(**kwargs):
rate_limiter = kwargs.get("_rate_limiter", None)
api_type = kwargs.get("api_type", ApiType.openai)
if api_type == ApiType.openai:
client = OpenAI(api_key=kwargs["api_key"])
client = OpenAI(api_key=kwargs["api_key"], base_url=kwargs["base_url"])
elif api_type == ApiType.azure:
client = AzureOpenAI(
api_key=kwargs["api_key"], api_version=kwargs["api_version"], azure_endpoint=kwargs["azure_endpoint"]
Expand All @@ -257,6 +260,7 @@ def chat_completion_rl_inner(**kwargs):
"_rate_limiter",
"stream",
"api_type",
"base_url"
}
kwargs = {k: v for k, v in kwargs.items() if k not in keys_to_remove}
perf_counter()
Expand Down

0 comments on commit aed84b8

Please sign in to comment.