Skip to content

Commit

Permalink
fix(vertex): async / streaming was missing output fields (#2253)
Browse files Browse the repository at this point in the history
  • Loading branch information
nirga authored Nov 5, 2024
1 parent 2e1598c commit 97584b0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,60 +32,70 @@
"object": "GenerativeModel",
"method": "generate_content",
"span_name": "vertexai.generate_content",
"is_async": False,
},
{
"package": "vertexai.generative_models",
"object": "GenerativeModel",
"method": "generate_content_async",
"span_name": "vertexai.generate_content_async",
"is_async": True,
},
{
"package": "vertexai.preview.generative_models",
"object": "GenerativeModel",
"method": "generate_content",
"span_name": "vertexai.generate_content",
"is_async": False,
},
{
"package": "vertexai.preview.generative_models",
"object": "GenerativeModel",
"method": "generate_content_async",
"span_name": "vertexai.generate_content_async",
"is_async": True,
},
{
"package": "vertexai.language_models",
"object": "TextGenerationModel",
"method": "predict",
"span_name": "vertexai.predict",
"is_async": False,
},
{
"package": "vertexai.language_models",
"object": "TextGenerationModel",
"method": "predict_async",
"span_name": "vertexai.predict_async",
"is_async": True,
},
{
"package": "vertexai.language_models",
"object": "TextGenerationModel",
"method": "predict_streaming",
"span_name": "vertexai.predict_streaming",
"is_async": False,
},
{
"package": "vertexai.language_models",
"object": "TextGenerationModel",
"method": "predict_streaming_async",
"span_name": "vertexai.predict_streaming_async",
"is_async": True,
},
{
"package": "vertexai.language_models",
"object": "ChatSession",
"method": "send_message",
"span_name": "vertexai.send_message",
"is_async": False,
},
{
"package": "vertexai.language_models",
"object": "ChatSession",
"method": "send_message_streaming",
"span_name": "vertexai.send_message_streaming",
"is_async": False,
},
]

Expand Down Expand Up @@ -150,73 +160,63 @@ def _set_input_attributes(span, args, kwargs, llm_model):


@dont_throw
def _set_response_attributes(span, response, llm_model):
def _set_response_attributes(span, llm_model, generation_text, token_usage):
_set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, llm_model)

if hasattr(response, "text"):
if hasattr(response, "_raw_response") and hasattr(
response._raw_response, "usage_metadata"
):
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
response._raw_response.usage_metadata.total_token_count,
)
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
response._raw_response.usage_metadata.candidates_token_count,
)
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
response._raw_response.usage_metadata.prompt_token_count,
)

if isinstance(response.text, list):
for index, item in enumerate(response):
prefix = f"{SpanAttributes.LLM_COMPLETIONS}.{index}"
_set_span_attribute(span, f"{prefix}.content", item.text)
elif isinstance(response.text, str):
_set_span_attribute(
span, f"{SpanAttributes.LLM_COMPLETIONS}.0.content", response.text
)
else:
if isinstance(response, list):
for index, item in enumerate(response):
prefix = f"{SpanAttributes.LLM_COMPLETIONS}.{index}"
_set_span_attribute(span, f"{prefix}.content", item)
elif isinstance(response, str):
_set_span_attribute(
span, f"{SpanAttributes.LLM_COMPLETIONS}.0.content", response
)
if token_usage:
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
token_usage.total_token_count,
)
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
token_usage.candidates_token_count,
)
_set_span_attribute(
span,
SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
token_usage.prompt_token_count,
)

return
_set_span_attribute(span, f"{SpanAttributes.LLM_COMPLETIONS}.0.role", "assistant")
_set_span_attribute(
span,
f"{SpanAttributes.LLM_COMPLETIONS}.0.content",
generation_text,
)


def _build_from_streaming_response(span, response, llm_model):
complete_response = ""
token_usage = None
for item in response:
item_to_yield = item
complete_response += str(item.text)
if item.usage_metadata:
token_usage = item.usage_metadata

yield item_to_yield

_set_response_attributes(span, complete_response, llm_model)
_set_response_attributes(span, llm_model, complete_response, token_usage)

span.set_status(Status(StatusCode.OK))
span.end()


async def _abuild_from_streaming_response(span, response, llm_model):
complete_response = ""
token_usage = None
async for item in response:
item_to_yield = item
complete_response += str(item.text)
if item.usage_metadata:
token_usage = item.usage_metadata

yield item_to_yield

_set_response_attributes(span, complete_response, llm_model)
_set_response_attributes(span, llm_model, complete_response, token_usage)

span.set_status(Status(StatusCode.OK))
span.end()
Expand All @@ -231,7 +231,9 @@ def _handle_request(span, args, kwargs, llm_model):
@dont_throw
def _handle_response(span, response, llm_model):
if span.is_recording():
_set_response_attributes(span, response, llm_model)
_set_response_attributes(
span, llm_model, response.candidates[0].text, response.usage_metadata
)

span.set_status(Status(StatusCode.OK))

Expand Down Expand Up @@ -351,7 +353,7 @@ def _instrument(self, **kwargs):
f"{wrap_object}.{wrap_method}",
(
_awrap(tracer, wrapped_method)
if wrap_method == "predict_async"
if wrapped_method.get("is_async")
else _wrap(tracer, wrapped_method)
),
)
Expand Down
15 changes: 8 additions & 7 deletions packages/sample-app/sample_app/vertexai_streaming.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import vertexai
from traceloop.sdk import Traceloop
from traceloop.sdk.decorators import workflow
from traceloop.sdk.decorators import aworkflow
from vertexai.generative_models import GenerativeModel

Traceloop.init(app_name="stream_prediction_service")

vertexai.init()


@workflow("stream_prediction")
def streaming_prediction() -> str:
@aworkflow("stream_prediction")
async def streaming_prediction() -> str:
"""Streaming Text Example with a Large Language Model"""

model = GenerativeModel(
Expand All @@ -27,10 +28,10 @@ def streaming_prediction() -> str:

contents = [prompt]

response = model.generate_content(contents)

return response.text
response = await model.generate_content_async(contents, stream=True)
async for chunk in response:
print(chunk.text)


if __name__ == "__main__":
print(streaming_prediction())
asyncio.run(streaming_prediction())

0 comments on commit 97584b0

Please sign in to comment.