Skip to content

Commit

Permalink
[FEAT] TaskWeaver Integration (#541)
Browse files Browse the repository at this point in the history
* taskweaver basic tracking

* linting

* some logging occurs

* more debug info to understand llm info flow

* saving model info now in `LLMEvent`

* get service mappings from taskweaver

* remove taskweaver code from agentops init

* fix incorrect use of agent_id in events

* clean and refactor code for taskweaver LLM tracking

* convert `LLMEvent` to `ActionEvent`

* improved event handling

* add `ActionEvent` for recording `json_schema`

* linting

* add microsoft and taskweaver logos

* add default tags `taskweaver`

* cast message as string

* add session image

* add documentation for taskweaver

* linting

* overhauled handler code

* linting

* use correct logger import

* stutter fix

* add warning for stutter

---------

Co-authored-by: teocns <[email protected]>
  • Loading branch information
the-praxs and teocns authored Dec 23, 2024
1 parent f844e0a commit 5d4ff2f
Show file tree
Hide file tree
Showing 9 changed files with 472 additions and 0 deletions.
146 changes: 146 additions & 0 deletions agentops/llms/providers/taskweaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import pprint
from typing import Optional
import json

from agentops.event import ErrorEvent, LLMEvent, ActionEvent
from agentops.session import Session
from agentops.log_config import logger
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id
from agentops.llms.providers.instrumented_provider import InstrumentedProvider
from agentops.singleton import singleton


@singleton
class TaskWeaverProvider(InstrumentedProvider):
original_chat_completion = None

def __init__(self, client):
super().__init__(client)
self._provider_name = "TaskWeaver"
self.client.add_default_tags(["taskweaver"])

def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
"""Handle responses for TaskWeaver"""
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
action_event = ActionEvent(init_timestamp=init_timestamp)

try:
response_dict = response.get("response", {})

action_event.params = kwargs.get("json_schema", None)
action_event.returns = response_dict
action_event.end_timestamp = get_ISO_time()
self._safe_record(session, action_event)
except Exception as e:
error_event = ErrorEvent(
trigger_event=action_event, exception=e, details={"response": str(response), "kwargs": str(kwargs)}
)
self._safe_record(session, error_event)
kwargs_str = pprint.pformat(kwargs)
response_str = pprint.pformat(response)
logger.error(
f"Unable to parse response for Action call. Skipping upload to AgentOps\n"
f"response:\n {response_str}\n"
f"kwargs:\n {kwargs_str}\n"
)

try:
llm_event.init_timestamp = init_timestamp
llm_event.params = kwargs
llm_event.returns = response_dict
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs.get("model", "unknown")
llm_event.prompt = kwargs.get("messages")
llm_event.completion = response_dict.get("message", "")
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

except Exception as e:
error_event = ErrorEvent(
trigger_event=llm_event, exception=e, details={"response": str(response), "kwargs": str(kwargs)}
)
self._safe_record(session, error_event)
kwargs_str = pprint.pformat(kwargs)
response_str = pprint.pformat(response)
logger.error(
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n"
f"response:\n {response_str}\n"
f"kwargs:\n {kwargs_str}\n"
)

return response

def override(self):
"""Override TaskWeaver's chat completion methods"""
try:
from taskweaver.llm import llm_completion_config_map

def create_patched_chat_completion(original_method):
"""Create a new patched chat_completion function with bound original method"""

def patched_chat_completion(service, *args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

result = original_method(service, *args, **kwargs)
kwargs.update(
{
"model": self._get_model_name(service),
"messages": args[0],
"stream": args[1],
"temperature": args[2],
"max_tokens": args[3],
"top_p": args[4],
"stop": args[5],
}
)

if kwargs["stream"]:
accumulated_content = ""
for chunk in result:
if isinstance(chunk, dict) and "content" in chunk:
accumulated_content += chunk["content"]
else:
accumulated_content += chunk
yield chunk
accumulated_content = json.loads(accumulated_content)
return self.handle_response(accumulated_content, kwargs, init_timestamp, session=session)
else:
return self.handle_response(result, kwargs, init_timestamp, session=session)

return patched_chat_completion

for service_name, service_class in llm_completion_config_map.items():
if not hasattr(service_class, "_original_chat_completion"):
service_class._original_chat_completion = service_class.chat_completion
service_class.chat_completion = create_patched_chat_completion(
service_class._original_chat_completion
)

except Exception as e:
logger.error(f"Failed to patch method: {str(e)}", exc_info=True)

def undo_override(self):
"""Restore original TaskWeaver chat completion methods"""
try:
from taskweaver.llm import llm_completion_config_map

for service_name, service_class in llm_completion_config_map.items():
service_class.chat_completion = service_class._original_chat_completion
delattr(service_class, "_original_chat_completion")

except Exception as e:
logger.error(f"Failed to restore original method: {str(e)}", exc_info=True)

def _get_model_name(self, service) -> str:
"""Extract model name from service instance"""
model_name = "unknown"
if hasattr(service, "config"):
config = service.config
if hasattr(config, "model"):
model_name = config.model or "unknown"
elif hasattr(config, "llm_module_config") and hasattr(config.llm_module_config, "model"):
model_name = config.llm_module_config.model or "unknown"
return model_name
14 changes: 14 additions & 0 deletions agentops/llms/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .providers.anthropic import AnthropicProvider
from .providers.mistral import MistralProvider
from .providers.ai21 import AI21Provider
from .providers.taskweaver import TaskWeaverProvider

original_func = {}
original_create = None
Expand Down Expand Up @@ -54,6 +55,9 @@ class LlmTracker:
"client.answer.create",
),
},
"taskweaver": {
"0.0.1": ("chat_completion", "chat_completion_stream"),
},
}

def __init__(self, client):
Expand Down Expand Up @@ -164,6 +168,15 @@ def override_api(self):
else:
logger.warning(f"Only LlamaStackClient>=0.0.53 supported. v{module_version} found.")

if api == "taskweaver":
module_version = version(api)

if Version(module_version) >= parse("0.0.1"):
provider = TaskWeaverProvider(self.client)
provider.override()
else:
logger.warning(f"Only TaskWeaver>=0.0.1 supported. v{module_version} found.")

def stop_instrumenting(self):
OpenAiProvider(self.client).undo_override()
GroqProvider(self.client).undo_override()
Expand All @@ -174,3 +187,4 @@ def stop_instrumenting(self):
MistralProvider(self.client).undo_override()
AI21Provider(self.client).undo_override()
LlamaStackClientProvider(self.client).undo_override()
TaskWeaverProvider(self.client).undo_override()
191 changes: 191 additions & 0 deletions agentops/partners/taskweaver_event_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from taskweaver.module.event_emitter import (
SessionEventHandlerBase,
SessionEventType,
RoundEventType,
PostEventType,
)
import agentops
from agentops.event import ActionEvent, ErrorEvent, ToolEvent
from datetime import datetime, timezone
from typing import Dict, Any
from agentops.log_config import logger

ATTACHMENT_TOOLS = [
"thought",
"reply_type",
"reply_content",
"verification",
"code_error",
"execution_status",
"execution_result",
"artifact_paths",
"revise_message",
"function",
"web_exploring_plan",
"web_exploring_screenshot",
"web_exploring_link",
]


class TaskWeaverEventHandler(SessionEventHandlerBase):
def __init__(self):
super().__init__()
self._message_buffer: Dict[str, Dict[str, Any]] = {}
self._attachment_buffer: Dict[str, Dict[str, Any]] = {}
self._active_agents: Dict[str, str] = {}

def _get_or_create_agent(self, role: str):
"""Get existing agent ID or create new agent for role+round combination"""
if role not in self._active_agents:
agent_id = agentops.create_agent(name=role)
if agent_id:
self._active_agents[role] = agent_id
return self._active_agents.get(role)

def handle_session(self, type: SessionEventType, msg: str, extra: Any, **kwargs: Any):
agentops.record(ActionEvent(action_type=type.value, params={"extra": extra, "message": msg}))

def handle_round(self, type: RoundEventType, msg: str, extra: Any, round_id: str, **kwargs: Any):
if type == RoundEventType.round_error:
agentops.record(
ErrorEvent(error_type=type.value, details={"round_id": round_id, "message": msg, "extra": extra})
)
logger.error(f"Could not record the Round event: {msg}")
self.cleanup_round()
else:
agentops.record(
ActionEvent(
action_type=type.value,
params={"round_id": round_id, "extra": extra},
returns=msg,
)
)
if type == RoundEventType.round_end:
self.cleanup_round()

def handle_post(self, type: PostEventType, msg: str, extra: Any, post_id: str, round_id: str, **kwargs: Any):
role = extra.get("role", "Planner")
agent_id = self._get_or_create_agent(role=role)

if type == PostEventType.post_error:
agentops.record(
ErrorEvent(
error_type=type.value,
details={"post_id": post_id, "round_id": round_id, "message": msg, "extra": extra},
)
)
logger.error(f"Could not record the Post event: {msg}")

elif type == PostEventType.post_start or type == PostEventType.post_end:
agentops.record(
ActionEvent(
action_type=type.value,
params={"post_id": post_id, "round_id": round_id, "extra": extra},
returns=msg,
agent_id=agent_id,
)
)

elif type == PostEventType.post_status_update:
agentops.record(
ActionEvent(
action_type=type.value,
params={"post_id": post_id, "round_id": round_id, "extra": extra},
returns=msg,
agent_id=agent_id,
)
)

elif type == PostEventType.post_attachment_update:
attachment_id = extra["id"]
attachment_type = extra["type"].value
is_end = extra["is_end"]

if attachment_id not in self._attachment_buffer:
self._attachment_buffer[attachment_id] = {
"role": attachment_type,
"content": [],
"init_timestamp": datetime.now(timezone.utc).isoformat(),
"end_timestamp": None,
}

self._attachment_buffer[attachment_id]["content"].append(str(msg))

if is_end:
self._attachment_buffer[attachment_id]["end_timestamp"] = datetime.now(timezone.utc).isoformat()
complete_message = "".join(self._attachment_buffer[attachment_id]["content"])

if attachment_type in ATTACHMENT_TOOLS:
agentops.record(
ToolEvent(
name=type.value,
init_timestamp=self._attachment_buffer[attachment_id]["init_timestamp"],
end_timestamp=self._attachment_buffer[attachment_id]["end_timestamp"],
params={
"post_id": post_id,
"round_id": round_id,
"attachment_id": attachment_id,
"attachment_type": self._attachment_buffer[attachment_id]["role"],
"extra": extra,
},
returns=complete_message,
agent_id=agent_id,
)
)
else:
agentops.record(
ActionEvent(
action_type=type.value,
init_timestamp=self._attachment_buffer[attachment_id]["init_timestamp"],
end_timestamp=self._attachment_buffer[attachment_id]["end_timestamp"],
params={
"post_id": post_id,
"round_id": round_id,
"attachment_id": attachment_id,
"attachment_type": self._attachment_buffer[attachment_id]["role"],
"extra": extra,
},
returns=complete_message,
agent_id=agent_id,
)
)

self._attachment_buffer.pop(attachment_id, None)

elif type == PostEventType.post_message_update:
is_end = extra["is_end"]

if post_id not in self._message_buffer:
self._message_buffer[post_id] = {
"content": [],
"init_timestamp": datetime.now(timezone.utc).isoformat(),
"end_timestamp": None,
}

self._message_buffer[post_id]["content"].append(str(msg))

if is_end:
self._message_buffer[post_id]["end_timestamp"] = datetime.now(timezone.utc).isoformat()
complete_message = "".join(self._message_buffer[post_id]["content"])
agentops.record(
ActionEvent(
action_type=type.value,
init_timestamp=self._message_buffer[post_id]["init_timestamp"],
end_timestamp=self._message_buffer[post_id]["end_timestamp"],
params={
"post_id": post_id,
"round_id": round_id,
"extra": extra,
},
returns=complete_message,
agent_id=agent_id,
)
)

self._message_buffer.pop(post_id, None)

def cleanup_round(self):
"""Cleanup agents and buffers for a completed round"""
self._active_agents.clear()
self._message_buffer.clear()
self._attachment_buffer.clear()
1 change: 1 addition & 0 deletions docs/images/external/microsoft/microsoft_logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 5d4ff2f

Please sign in to comment.