diff --git a/.changeset/witty-months-train.md b/.changeset/witty-months-train.md new file mode 100644 index 000000000..8191b36c8 --- /dev/null +++ b/.changeset/witty-months-train.md @@ -0,0 +1,5 @@ +--- +"livekit-plugins-openai": patch +--- + +vertex ai support with openai library diff --git a/examples/voice-pipeline-agent/function_calling_weather.py b/examples/voice-pipeline-agent/function_calling_weather.py index f5bc3135b..0eb3da843 100644 --- a/examples/voice-pipeline-agent/function_calling_weather.py +++ b/examples/voice-pipeline-agent/function_calling_weather.py @@ -67,6 +67,8 @@ async def entrypoint(ctx: JobContext): vad=ctx.proc.userdata["vad"], stt=deepgram.STT(), llm=openai.LLM(), + # To use vertex AI LLM + # llm=openai.LLM.with_vertexai(project_id="your-project-id", location="us-central1"), tts=openai.TTS(), fnc_ctx=fnc_ctx, chat_ctx=initial_chat_ctx, diff --git a/examples/voice-pipeline-agent/minimal_assistant.py b/examples/voice-pipeline-agent/minimal_assistant.py index 4b94bd5b7..ebc525661 100644 --- a/examples/voice-pipeline-agent/minimal_assistant.py +++ b/examples/voice-pipeline-agent/minimal_assistant.py @@ -48,6 +48,8 @@ async def entrypoint(ctx: JobContext): vad=ctx.proc.userdata["vad"], stt=deepgram.STT(model=dg_model), llm=openai.LLM(), + # To use vertex AI LLM + # llm=openai.LLM.with_vertexai(project_id="your-project-id", location="us-central1"), tts=openai.TTS(), chat_ctx=initial_ctx, ) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 56f1598fd..25382490f 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -42,6 +42,7 @@ PerplexityChatModels, TelnyxChatModels, TogetherChatModels, + VertexModels, XAIChatModels, ) from .utils import AsyncAzureADTokenProvider, build_oai_message @@ -161,6 +162,85 @@ def with_cerebras( temperature=temperature, ) + @staticmethod + def with_vertexai( + *, + model: str | VertexModels = "google/gemini-1.5-pro", + project_id: str | None = None, + location: str | None = "us-central1", + client: Any | None = None, + user: str | None = None, + temperature: float | None = None, + ) -> LLM: + """ + Create a new instance of VertexAI LLM. + + `project_id` must be set to your VERTEXAI PROJECT ID, either using the argument or by setting + the `VERTEXAI_PROJECT_ID` environmental variable. + """ + + project_id = project_id or os.environ.get("VERTEXAI_PROJECT_ID") + if project_id is None: + raise ValueError( + "VERTEXAI_PROJECT_ID is required, either set project_id argument or set VERTEXAI_PROJECT_ID environmental variable" + ) + location = location or os.environ.get("VERTEXAI_LOCATION") + if location is None: + raise ValueError( + "VERTEXAI_LOCATION is required, either set location argument or set VERTEXAI_LOCATION environmental variable" + ) + _gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") + if _gac is None: + raise ValueError( + "`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file." + ) + + try: + import google.auth + import google.auth.transport.requests + except ImportError: + raise ImportError( + "Google Auth dependencies not found. Please install with: `pip install livekit-plugins-openai[vertex]`" + ) + + class OpenAICredentialsRefresher: + def __init__(self, **kwargs: Any) -> None: + self.client = openai.AsyncClient(**kwargs, api_key="DUMMY") + self.creds, self.project = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + + def __getattr__(self, name: str) -> Any: + if not self.creds.valid: + auth_req = google.auth.transport.requests.Request() + self.creds.refresh(auth_req) + + if not self.creds.valid: + raise RuntimeError("Unable to refresh auth") + + self.client.api_key = self.creds.token + return getattr(self.client, name) + + client = OpenAICredentialsRefresher( + base_url=f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi", + http_client=httpx.AsyncClient( + timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0), + follow_redirects=True, + limits=httpx.Limits( + max_connections=50, + max_keepalive_connections=50, + keepalive_expiry=120, + ), + ), + ) + + return LLM( + model=model, + client=client, + user=user, + temperature=temperature, + ) + @staticmethod def with_fireworks( *, @@ -492,6 +572,7 @@ def chat( temperature = self._opts.temperature messages = _build_oai_context(chat_ctx, id(self)) + logger.info(f"messages: {messages}") cmp = self._client.chat.completions.create( messages=messages, @@ -524,6 +605,7 @@ def __init__( self._tool_call_id: str | None = None self._fnc_name: str | None = None self._fnc_raw_arguments: str | None = None + self._tool_index: int | None = None async def _main_task(self) -> None: if not self._oai_stream: @@ -577,10 +659,11 @@ def _parse_choice(self, id: str, choice: Choice) -> llm.ChatChunk | None: continue # oai may add other tools in the future call_chunk = None - if self._tool_call_id and tool.id and tool.id != self._tool_call_id: + if self._tool_call_id and tool.id and tool.index != self._tool_index: call_chunk = self._try_build_function(id, choice) if tool.function.name: + self._tool_index = tool.index self._tool_call_id = tool.id self._fnc_name = tool.function.name self._fnc_raw_arguments = tool.function.arguments or "" diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py index 9e2e5dd18..c2667665d 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/models.py @@ -79,6 +79,16 @@ "deepseek-chat", ] +VertexModels = Literal[ + "google/gemini-1.5-flash", + "google/gemini-1.5-pro", + "google/gemini-1.0-pro-vision", + "google/gemini-1.0-pro-vision-001", + "google/gemini-1.0-pro-002", + "google/gemini-1.0-pro-001", + "google/gemini-1.0-pro", +] + TogetherChatModels = Literal[ "Austism/chronos-hermes-13b", "Gryphe/MythoMax-L2-13b", diff --git a/livekit-plugins/livekit-plugins-openai/setup.py b/livekit-plugins/livekit-plugins-openai/setup.py index d71123011..95449d4b1 100644 --- a/livekit-plugins/livekit-plugins-openai/setup.py +++ b/livekit-plugins/livekit-plugins-openai/setup.py @@ -47,7 +47,13 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents[codecs, images]>=0.11", "openai>=1.50"], + install_requires=[ + "livekit-agents[codecs, images]>=0.11", + "openai>=1.50", + ], + extras_require={ + "vertex": ["google-auth"], + }, package_data={"livekit.plugins.openai": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io",