Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: expose multimodal agent metrics #1080

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/tough-boats-appear.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-plugins-openai": patch
"livekit-agents": patch
---

Expose multimodal agent metrics
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .base import (
AgentMetrics,
LLMMetrics,
MultimodalLLMError,
MultimodalLLMMetrics,
PipelineEOUMetrics,
PipelineLLMMetrics,
PipelineSTTMetrics,
Expand All @@ -15,6 +17,8 @@

__all__ = [
"LLMMetrics",
"MultimodalLLMError",
"MultimodalLLMMetrics",
"AgentMetrics",
"PipelineEOUMetrics",
"PipelineSTTMetrics",
Expand Down
26 changes: 26 additions & 0 deletions livekit-agents/livekit/agents/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ class PipelineVADMetrics(VADMetrics):
pass


@dataclass
class MultimodalLLMError(Error):
type: str | None
reason: str | None = None
code: str | None = None
message: str | None = None


@dataclass
class MultimodalLLMMetrics(LLMMetrics):
@dataclass
class InputTokenDetails:
cached_tokens: int
text_tokens: int
audio_tokens: int

@dataclass
class OutputTokenDetails:
text_tokens: int
audio_tokens: int

input_token_details: InputTokenDetails
output_token_details: OutputTokenDetails


AgentMetrics = Union[
STTMetrics,
LLMMetrics,
Expand All @@ -108,4 +133,5 @@ class PipelineVADMetrics(VADMetrics):
PipelineLLMMetrics,
PipelineTTSMetrics,
PipelineVADMetrics,
MultimodalLLMMetrics,
]
6 changes: 6 additions & 0 deletions livekit-agents/livekit/agents/multimodal/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from livekit import rtc
from livekit.agents import llm, stt, tokenize, transcription, utils, vad
from livekit.agents.llm import ChatMessage
from livekit.agents.metrics import MultimodalLLMMetrics

from .._constants import ATTRIBUTE_AGENT_STATE
from .._types import AgentState
Expand All @@ -24,6 +25,7 @@
"agent_speech_interrupted",
"function_calls_collected",
"function_calls_finished",
"metrics_collected",
]


Expand Down Expand Up @@ -240,6 +242,10 @@ def _function_calls_collected(fnc_call_infos: list[llm.FunctionCallInfo]):
def _function_calls_finished(called_fncs: list[llm.CalledFunction]):
self.emit("function_calls_finished", called_fncs)

@self._session.on("metrics_collected")
def _metrics_collected(metrics: MultimodalLLMMetrics):
self.emit("metrics_collected", metrics)

def _update_state(self, state: AgentState, delay: float = 0.0):
"""Set the current state of the agent"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import base64
import os
import time
from copy import deepcopy
from dataclasses import dataclass
from typing import AsyncIterable, Callable, Literal, Union, cast, overload
Expand All @@ -12,6 +13,7 @@
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm import _oai_api
from livekit.agents.metrics import MultimodalLLMError, MultimodalLLMMetrics
from typing_extensions import TypedDict

from . import api_proto, remote_items
Expand All @@ -33,6 +35,7 @@
"response_done",
"function_calls_collected",
"function_calls_finished",
"metrics_collected",
]


Expand Down Expand Up @@ -66,6 +69,10 @@ class RealtimeResponse:
"""usage of the response"""
done_fut: asyncio.Future[None]
"""future that will be set when the response is completed"""
_created_timestamp: float
"""timestamp when the response was created"""
_first_token_timestamp: float | None = None
"""timestamp when the first token was received"""


@dataclass
Expand Down Expand Up @@ -703,6 +710,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__()
self._label = f"{type(self).__module__}.{type(self).__name__}"
self._main_atask = asyncio.create_task(
self._main_task(), name="openai-realtime-session"
)
Expand Down Expand Up @@ -1210,6 +1218,7 @@ def _handle_response_created(
output=[],
usage=response.get("usage"),
done_fut=done_fut,
_created_timestamp=time.time(),
)
self._pending_responses[new_response.id] = new_response
self.emit("response_created", new_response)
Expand Down Expand Up @@ -1264,6 +1273,7 @@ def _handle_response_content_part_added(
content_type=content_type,
)
output.content.append(new_content)
response._first_token_timestamp = time.time()
self.emit("response_content_added", new_content)

def _handle_response_audio_delta(
Expand Down Expand Up @@ -1368,15 +1378,19 @@ def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDon
response.status_details = response_data.get("status_details")
response.usage = response_data.get("usage")

metrics_error = None
cancelled = False
if response.status == "failed":
assert response.status_details is not None

error = response.status_details.get("error")
code: str | None = None
message: str | None = None
if error is not None:
code = error.get("code") # type: ignore
message = error.get("message") # type: ignore
error = response.status_details.get("error", {})
code: str | None = error.get("code") # type: ignore
message: str | None = error.get("message") # type: ignore
metrics_error = MultimodalLLMError(
type=response.status_details.get("type"),
code=code,
message=message,
)

logger.error(
"response generation failed",
Expand All @@ -1386,13 +1400,57 @@ def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDon
assert response.status_details is not None
reason = response.status_details.get("reason")

metrics_error = MultimodalLLMError(
type=response.status_details.get("type"),
reason=reason, # type: ignore
)

logger.warning(
"response generation incomplete",
extra={"reason": reason, **self.logging_extra()},
)
elif response.status == "cancelled":
cancelled = True

self.emit("response_done", response)

# calculate metrics
ttft = -1.0
if response._first_token_timestamp is not None:
ttft = response._first_token_timestamp - response._created_timestamp
duration = time.time() - response._created_timestamp

usage = response.usage or {} # type: ignore
metrics = MultimodalLLMMetrics(
timestamp=response._created_timestamp,
request_id=response.id,
ttft=ttft,
duration=duration,
cancelled=cancelled,
label=self._label,
completion_tokens=usage.get("output_tokens", 0),
prompt_tokens=usage.get("input_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
tokens_per_second=usage.get("output_tokens", 0) / duration,
error=metrics_error,
input_token_details=MultimodalLLMMetrics.InputTokenDetails(
cached_tokens=usage.get("input_token_details", {}).get(
"cached_tokens", 0
),
text_tokens=usage.get("input_token_details", {}).get("text_tokens", 0),
audio_tokens=usage.get("input_token_details", {}).get(
"audio_tokens", 0
),
),
output_token_details=MultimodalLLMMetrics.OutputTokenDetails(
text_tokens=usage.get("output_token_details", {}).get("text_tokens", 0),
audio_tokens=usage.get("output_token_details", {}).get(
"audio_tokens", 0
),
),
)
self.emit("metrics_collected", metrics)

def _get_content(self, ptr: _ContentPtr) -> RealtimeContent:
response = self._pending_responses[ptr["response_id"]]
output = response.output[ptr["output_index"]]
Expand Down
Loading