From deac40947253f276b79e3529ea82cb568efb5fee Mon Sep 17 00:00:00 2001 From: Jayesh Date: Thu, 14 Nov 2024 17:03:40 +0530 Subject: [PATCH 1/7] vertex ai support with openai library --- .changeset/witty-months-train.md | 5 ++ .../function_calling_weather.py | 2 + .../livekit/plugins/openai/llm.py | 53 ++++++++++++++++++- .../livekit/plugins/openai/models.py | 14 +++++ 4 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 .changeset/witty-months-train.md diff --git a/.changeset/witty-months-train.md b/.changeset/witty-months-train.md new file mode 100644 index 000000000..8191b36c8 --- /dev/null +++ b/.changeset/witty-months-train.md @@ -0,0 +1,5 @@ +--- +"livekit-plugins-openai": patch +--- + +vertex ai support with openai library diff --git a/examples/voice-pipeline-agent/function_calling_weather.py b/examples/voice-pipeline-agent/function_calling_weather.py index f5bc3135b..8ea32d3a2 100644 --- a/examples/voice-pipeline-agent/function_calling_weather.py +++ b/examples/voice-pipeline-agent/function_calling_weather.py @@ -67,6 +67,8 @@ async def entrypoint(ctx: JobContext): vad=ctx.proc.userdata["vad"], stt=deepgram.STT(), llm=openai.LLM(), + # To use vertex AI LLM + # llm=openai.LLM# .with_vertexai(project_id="your-project-id", location="us-central1"), tts=openai.TTS(), fnc_ctx=fnc_ctx, chat_ctx=initial_chat_ctx, diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 56f1598fd..4cbc10d0f 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -37,6 +37,7 @@ CerebrasChatModels, ChatModels, DeepSeekChatModels, + GeminiModels, GroqChatModels, OctoChatModels, PerplexityChatModels, @@ -161,6 +162,54 @@ def with_cerebras( temperature=temperature, ) + @staticmethod + def with_vertexai( + *, + model: str | GeminiModels = "gemini-1.5-flash-002", + project_id: str, + location: str = "us-central1", + client: openai.AsyncClient | None = None, + user: str | None = None, + temperature: float | None = None, + ) -> LLM: + """ + Create a new instance of VertexAI LLM. + + ``project_id`` must be set to your VERTEXAI PROJECT ID, either using the argument or by setting + the ``VERTEXAI_PROJECT_ID`` environmental variable. + """ + + project_id = project_id or os.environ.get("VERTEXAI_PROJECT_ID") + if project_id is None: + raise ValueError( + "VERTEXAI_PROJECT_ID is required, either set project_id argument or set VERTEXAI_PROJECT_ID environmental variable" + ) + location = location or os.environ.get("VERTEXAI_LOCATION") + if location is None: + raise ValueError( + "VERTEXAI_LOCATION is required, either set location argument or set VERTEXAI_LOCATION environmental variable" + ) + + from google.auth import default + from google.auth.transport import requests + + credentials, _ = default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + auth_request = requests.Request() + credentials.refresh(auth_request) + base_url = f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi" + api_key = credentials.token + + return LLM( + model=model, + api_key=api_key, + base_url=base_url, + client=client, + user=user, + temperature=temperature, + ) + @staticmethod def with_fireworks( *, @@ -524,6 +573,7 @@ def __init__( self._tool_call_id: str | None = None self._fnc_name: str | None = None self._fnc_raw_arguments: str | None = None + self._tool_index: int | None = None async def _main_task(self) -> None: if not self._oai_stream: @@ -577,10 +627,11 @@ def _parse_choice(self, id: str, choice: Choice) -> llm.ChatChunk | None: continue # oai may add other tools in the future call_chunk = None - if self._tool_call_id and tool.id and tool.id != self._tool_call_id: + if self._tool_call_id and tool.id and tool.index != self._tool_index: call_chunk = self._try_build_function(id, choice) if tool.function.name: + self._tool_index = tool.index self._tool_call_id = tool.id self._fnc_name = tool.function.name self._fnc_raw_arguments = tool.function.arguments or "" diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py index 9e2e5dd18..bc4d7d1c1 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py @@ -79,6 +79,20 @@ "deepseek-chat", ] +GeminiModels = Literal[ + "gemini-1.0-pro", + "gemini-1.0-pro-vision", + "gemini-1.0-pro-vision-001", + "gemini-1.5-flash", + "gemini-1.5-flash-002", + "gemini-1.5-flash-8b", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro", + "gemini-1.5-pro-002", + "gemini-1.5-pro-preview-0409", + "gemini-1.5-pro-preview-0514", +] + TogetherChatModels = Literal[ "Austism/chronos-hermes-13b", "Gryphe/MythoMax-L2-13b", From 358c385f1e10c3f5f90d9e1c0dde33916fa781eb Mon Sep 17 00:00:00 2001 From: Jayesh Date: Thu, 14 Nov 2024 17:05:34 +0530 Subject: [PATCH 2/7] vertex ai support with openai library --- examples/voice-pipeline-agent/function_calling_weather.py | 2 +- examples/voice-pipeline-agent/minimal_assistant.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/voice-pipeline-agent/function_calling_weather.py b/examples/voice-pipeline-agent/function_calling_weather.py index 8ea32d3a2..0eb3da843 100644 --- a/examples/voice-pipeline-agent/function_calling_weather.py +++ b/examples/voice-pipeline-agent/function_calling_weather.py @@ -68,7 +68,7 @@ async def entrypoint(ctx: JobContext): stt=deepgram.STT(), llm=openai.LLM(), # To use vertex AI LLM - # llm=openai.LLM# .with_vertexai(project_id="your-project-id", location="us-central1"), + # llm=openai.LLM.with_vertexai(project_id="your-project-id", location="us-central1"), tts=openai.TTS(), fnc_ctx=fnc_ctx, chat_ctx=initial_chat_ctx, diff --git a/examples/voice-pipeline-agent/minimal_assistant.py b/examples/voice-pipeline-agent/minimal_assistant.py index 4b94bd5b7..ebc525661 100644 --- a/examples/voice-pipeline-agent/minimal_assistant.py +++ b/examples/voice-pipeline-agent/minimal_assistant.py @@ -48,6 +48,8 @@ async def entrypoint(ctx: JobContext): vad=ctx.proc.userdata["vad"], stt=deepgram.STT(model=dg_model), llm=openai.LLM(), + # To use vertex AI LLM + # llm=openai.LLM.with_vertexai(project_id="your-project-id", location="us-central1"), tts=openai.TTS(), chat_ctx=initial_ctx, ) From 6994969949cbd81ddf0cf6b1966c17643fcbacda Mon Sep 17 00:00:00 2001 From: Jayesh Date: Thu, 14 Nov 2024 17:09:47 +0530 Subject: [PATCH 3/7] google auth dependency --- livekit-plugins/livekit-plugins-openai/setup.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-openai/setup.py b/livekit-plugins/livekit-plugins-openai/setup.py index d71123011..ee69baae9 100644 --- a/livekit-plugins/livekit-plugins-openai/setup.py +++ b/livekit-plugins/livekit-plugins-openai/setup.py @@ -47,7 +47,11 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents[codecs, images]>=0.11", "openai>=1.50"], + install_requires=[ + "livekit-agents[codecs, images]>=0.11", + "openai>=1.50", + "google-auth", + ], package_data={"livekit.plugins.openai": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", From e9a25433913e1d7866ed1bbfb9fb51dbb27a2c22 Mon Sep 17 00:00:00 2001 From: Jayesh Date: Thu, 14 Nov 2024 17:15:14 +0530 Subject: [PATCH 4/7] typecheck --- .../livekit-plugins-openai/livekit/plugins/openai/llm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 4cbc10d0f..1e024aa19 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -166,8 +166,8 @@ def with_cerebras( def with_vertexai( *, model: str | GeminiModels = "gemini-1.5-flash-002", - project_id: str, - location: str = "us-central1", + project_id: str | None = None, + location: str | None = "us-central1", client: openai.AsyncClient | None = None, user: str | None = None, temperature: float | None = None, @@ -541,6 +541,7 @@ def chat( temperature = self._opts.temperature messages = _build_oai_context(chat_ctx, id(self)) + logger.info(f"messages: {messages}") cmp = self._client.chat.completions.create( messages=messages, From 8b11f42060f71502c41d0283109a94de852dd902 Mon Sep 17 00:00:00 2001 From: Jayesh Date: Thu, 14 Nov 2024 19:52:33 +0530 Subject: [PATCH 5/7] catch import error and extra require --- .../livekit/plugins/openai/llm.py | 13 +++++++++---- livekit-plugins/livekit-plugins-openai/setup.py | 4 +++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 1e024aa19..3029eef57 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -175,8 +175,8 @@ def with_vertexai( """ Create a new instance of VertexAI LLM. - ``project_id`` must be set to your VERTEXAI PROJECT ID, either using the argument or by setting - the ``VERTEXAI_PROJECT_ID`` environmental variable. + `project_id` must be set to your VERTEXAI PROJECT ID, either using the argument or by setting + the `VERTEXAI_PROJECT_ID` environmental variable. """ project_id = project_id or os.environ.get("VERTEXAI_PROJECT_ID") @@ -190,8 +190,13 @@ def with_vertexai( "VERTEXAI_LOCATION is required, either set location argument or set VERTEXAI_LOCATION environmental variable" ) - from google.auth import default - from google.auth.transport import requests + try: + from google.auth import default + from google.auth.transport import requests + except ImportError: + raise ImportError( + "Google Auth dependencies not found. Please install with: `pip install livekit-plugins-openai[vertex]`" + ) credentials, _ = default( scopes=["https://www.googleapis.com/auth/cloud-platform"] diff --git a/livekit-plugins/livekit-plugins-openai/setup.py b/livekit-plugins/livekit-plugins-openai/setup.py index ee69baae9..95449d4b1 100644 --- a/livekit-plugins/livekit-plugins-openai/setup.py +++ b/livekit-plugins/livekit-plugins-openai/setup.py @@ -50,8 +50,10 @@ install_requires=[ "livekit-agents[codecs, images]>=0.11", "openai>=1.50", - "google-auth", ], + extras_require={ + "vertex": ["google-auth"], + }, package_data={"livekit.plugins.openai": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", From bddc372859418e0e9d10b7f14aad7857c7ef5642 Mon Sep 17 00:00:00 2001 From: jayesh Date: Fri, 15 Nov 2024 05:24:28 +0530 Subject: [PATCH 6/7] added auto refresher class --- .../livekit/plugins/openai/llm.py | 50 ++++++++++++++----- .../livekit/plugins/openai/models.py | 20 +++----- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 3029eef57..17b1d2b56 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -37,12 +37,12 @@ CerebrasChatModels, ChatModels, DeepSeekChatModels, - GeminiModels, GroqChatModels, OctoChatModels, PerplexityChatModels, TelnyxChatModels, TogetherChatModels, + VertexModels, XAIChatModels, ) from .utils import AsyncAzureADTokenProvider, build_oai_message @@ -165,7 +165,7 @@ def with_cerebras( @staticmethod def with_vertexai( *, - model: str | GeminiModels = "gemini-1.5-flash-002", + model: str | VertexModels = "google/gemini-1.5-pro", project_id: str | None = None, location: str | None = "us-central1", client: openai.AsyncClient | None = None, @@ -189,27 +189,53 @@ def with_vertexai( raise ValueError( "VERTEXAI_LOCATION is required, either set location argument or set VERTEXAI_LOCATION environmental variable" ) + _gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") + if _gac is None: + raise ValueError( + "`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file." + ) try: - from google.auth import default - from google.auth.transport import requests + import google.auth + import google.auth.transport.requests except ImportError: raise ImportError( "Google Auth dependencies not found. Please install with: `pip install livekit-plugins-openai[vertex]`" ) - credentials, _ = default( - scopes=["https://www.googleapis.com/auth/cloud-platform"] + class OpenAICredentialsRefresher: + def __init__(self, **kwargs: Any) -> None: + self.client = openai.AsyncClient(**kwargs, api_key="DUMMY") + self.creds, self.project = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + + def __getattr__(self, name: str) -> Any: + if not self.creds.valid: + auth_req = google.auth.transport.requests.Request() + self.creds.refresh(auth_req) + + if not self.creds.valid: + raise RuntimeError("Unable to refresh auth") + + self.client.api_key = self.creds.token + return getattr(self.client, name) + + client = OpenAICredentialsRefresher( + base_url=f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi", + http_client=httpx.AsyncClient( + timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0), + follow_redirects=True, + limits=httpx.Limits( + max_connections=50, + max_keepalive_connections=50, + keepalive_expiry=120, + ), + ), ) - auth_request = requests.Request() - credentials.refresh(auth_request) - base_url = f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi" - api_key = credentials.token return LLM( model=model, - api_key=api_key, - base_url=base_url, client=client, user=user, temperature=temperature, diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py index bc4d7d1c1..c2667665d 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py @@ -79,18 +79,14 @@ "deepseek-chat", ] -GeminiModels = Literal[ - "gemini-1.0-pro", - "gemini-1.0-pro-vision", - "gemini-1.0-pro-vision-001", - "gemini-1.5-flash", - "gemini-1.5-flash-002", - "gemini-1.5-flash-8b", - "gemini-1.5-flash-preview-0514", - "gemini-1.5-pro", - "gemini-1.5-pro-002", - "gemini-1.5-pro-preview-0409", - "gemini-1.5-pro-preview-0514", +VertexModels = Literal[ + "google/gemini-1.5-flash", + "google/gemini-1.5-pro", + "google/gemini-1.0-pro-vision", + "google/gemini-1.0-pro-vision-001", + "google/gemini-1.0-pro-002", + "google/gemini-1.0-pro-001", + "google/gemini-1.0-pro", ] TogetherChatModels = Literal[ From 05beea17f484418b9227f7908b06f649bd16c0ec Mon Sep 17 00:00:00 2001 From: jayesh Date: Fri, 15 Nov 2024 05:32:59 +0530 Subject: [PATCH 7/7] type fix --- .../livekit-plugins-openai/livekit/plugins/openai/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 17b1d2b56..25382490f 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -168,7 +168,7 @@ def with_vertexai( model: str | VertexModels = "google/gemini-1.5-pro", project_id: str | None = None, location: str | None = "us-central1", - client: openai.AsyncClient | None = None, + client: Any | None = None, user: str | None = None, temperature: float | None = None, ) -> LLM: