-
Notifications
You must be signed in to change notification settings - Fork 237
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] TaskWeaver Integration (#541)
* taskweaver basic tracking * linting * some logging occurs * more debug info to understand llm info flow * saving model info now in `LLMEvent` * get service mappings from taskweaver * remove taskweaver code from agentops init * fix incorrect use of agent_id in events * clean and refactor code for taskweaver LLM tracking * convert `LLMEvent` to `ActionEvent` * improved event handling * add `ActionEvent` for recording `json_schema` * linting * add microsoft and taskweaver logos * add default tags `taskweaver` * cast message as string * add session image * add documentation for taskweaver * linting * overhauled handler code * linting * use correct logger import * stutter fix * add warning for stutter --------- Co-authored-by: teocns <[email protected]>
- Loading branch information
Showing
9 changed files
with
472 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import pprint | ||
from typing import Optional | ||
import json | ||
|
||
from agentops.event import ErrorEvent, LLMEvent, ActionEvent | ||
from agentops.session import Session | ||
from agentops.log_config import logger | ||
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id | ||
from agentops.llms.providers.instrumented_provider import InstrumentedProvider | ||
from agentops.singleton import singleton | ||
|
||
|
||
@singleton | ||
class TaskWeaverProvider(InstrumentedProvider): | ||
original_chat_completion = None | ||
|
||
def __init__(self, client): | ||
super().__init__(client) | ||
self._provider_name = "TaskWeaver" | ||
self.client.add_default_tags(["taskweaver"]) | ||
|
||
def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: | ||
"""Handle responses for TaskWeaver""" | ||
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) | ||
action_event = ActionEvent(init_timestamp=init_timestamp) | ||
|
||
try: | ||
response_dict = response.get("response", {}) | ||
|
||
action_event.params = kwargs.get("json_schema", None) | ||
action_event.returns = response_dict | ||
action_event.end_timestamp = get_ISO_time() | ||
self._safe_record(session, action_event) | ||
except Exception as e: | ||
error_event = ErrorEvent( | ||
trigger_event=action_event, exception=e, details={"response": str(response), "kwargs": str(kwargs)} | ||
) | ||
self._safe_record(session, error_event) | ||
kwargs_str = pprint.pformat(kwargs) | ||
response_str = pprint.pformat(response) | ||
logger.error( | ||
f"Unable to parse response for Action call. Skipping upload to AgentOps\n" | ||
f"response:\n {response_str}\n" | ||
f"kwargs:\n {kwargs_str}\n" | ||
) | ||
|
||
try: | ||
llm_event.init_timestamp = init_timestamp | ||
llm_event.params = kwargs | ||
llm_event.returns = response_dict | ||
llm_event.agent_id = check_call_stack_for_agent_id() | ||
llm_event.model = kwargs.get("model", "unknown") | ||
llm_event.prompt = kwargs.get("messages") | ||
llm_event.completion = response_dict.get("message", "") | ||
llm_event.end_timestamp = get_ISO_time() | ||
self._safe_record(session, llm_event) | ||
|
||
except Exception as e: | ||
error_event = ErrorEvent( | ||
trigger_event=llm_event, exception=e, details={"response": str(response), "kwargs": str(kwargs)} | ||
) | ||
self._safe_record(session, error_event) | ||
kwargs_str = pprint.pformat(kwargs) | ||
response_str = pprint.pformat(response) | ||
logger.error( | ||
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n" | ||
f"response:\n {response_str}\n" | ||
f"kwargs:\n {kwargs_str}\n" | ||
) | ||
|
||
return response | ||
|
||
def override(self): | ||
"""Override TaskWeaver's chat completion methods""" | ||
try: | ||
from taskweaver.llm import llm_completion_config_map | ||
|
||
def create_patched_chat_completion(original_method): | ||
"""Create a new patched chat_completion function with bound original method""" | ||
|
||
def patched_chat_completion(service, *args, **kwargs): | ||
init_timestamp = get_ISO_time() | ||
session = kwargs.get("session", None) | ||
if "session" in kwargs.keys(): | ||
del kwargs["session"] | ||
|
||
result = original_method(service, *args, **kwargs) | ||
kwargs.update( | ||
{ | ||
"model": self._get_model_name(service), | ||
"messages": args[0], | ||
"stream": args[1], | ||
"temperature": args[2], | ||
"max_tokens": args[3], | ||
"top_p": args[4], | ||
"stop": args[5], | ||
} | ||
) | ||
|
||
if kwargs["stream"]: | ||
accumulated_content = "" | ||
for chunk in result: | ||
if isinstance(chunk, dict) and "content" in chunk: | ||
accumulated_content += chunk["content"] | ||
else: | ||
accumulated_content += chunk | ||
yield chunk | ||
accumulated_content = json.loads(accumulated_content) | ||
return self.handle_response(accumulated_content, kwargs, init_timestamp, session=session) | ||
else: | ||
return self.handle_response(result, kwargs, init_timestamp, session=session) | ||
|
||
return patched_chat_completion | ||
|
||
for service_name, service_class in llm_completion_config_map.items(): | ||
if not hasattr(service_class, "_original_chat_completion"): | ||
service_class._original_chat_completion = service_class.chat_completion | ||
service_class.chat_completion = create_patched_chat_completion( | ||
service_class._original_chat_completion | ||
) | ||
|
||
except Exception as e: | ||
logger.error(f"Failed to patch method: {str(e)}", exc_info=True) | ||
|
||
def undo_override(self): | ||
"""Restore original TaskWeaver chat completion methods""" | ||
try: | ||
from taskweaver.llm import llm_completion_config_map | ||
|
||
for service_name, service_class in llm_completion_config_map.items(): | ||
service_class.chat_completion = service_class._original_chat_completion | ||
delattr(service_class, "_original_chat_completion") | ||
|
||
except Exception as e: | ||
logger.error(f"Failed to restore original method: {str(e)}", exc_info=True) | ||
|
||
def _get_model_name(self, service) -> str: | ||
"""Extract model name from service instance""" | ||
model_name = "unknown" | ||
if hasattr(service, "config"): | ||
config = service.config | ||
if hasattr(config, "model"): | ||
model_name = config.model or "unknown" | ||
elif hasattr(config, "llm_module_config") and hasattr(config.llm_module_config, "model"): | ||
model_name = config.llm_module_config.model or "unknown" | ||
return model_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from taskweaver.module.event_emitter import ( | ||
SessionEventHandlerBase, | ||
SessionEventType, | ||
RoundEventType, | ||
PostEventType, | ||
) | ||
import agentops | ||
from agentops.event import ActionEvent, ErrorEvent, ToolEvent | ||
from datetime import datetime, timezone | ||
from typing import Dict, Any | ||
from agentops.log_config import logger | ||
|
||
ATTACHMENT_TOOLS = [ | ||
"thought", | ||
"reply_type", | ||
"reply_content", | ||
"verification", | ||
"code_error", | ||
"execution_status", | ||
"execution_result", | ||
"artifact_paths", | ||
"revise_message", | ||
"function", | ||
"web_exploring_plan", | ||
"web_exploring_screenshot", | ||
"web_exploring_link", | ||
] | ||
|
||
|
||
class TaskWeaverEventHandler(SessionEventHandlerBase): | ||
def __init__(self): | ||
super().__init__() | ||
self._message_buffer: Dict[str, Dict[str, Any]] = {} | ||
self._attachment_buffer: Dict[str, Dict[str, Any]] = {} | ||
self._active_agents: Dict[str, str] = {} | ||
|
||
def _get_or_create_agent(self, role: str): | ||
"""Get existing agent ID or create new agent for role+round combination""" | ||
if role not in self._active_agents: | ||
agent_id = agentops.create_agent(name=role) | ||
if agent_id: | ||
self._active_agents[role] = agent_id | ||
return self._active_agents.get(role) | ||
|
||
def handle_session(self, type: SessionEventType, msg: str, extra: Any, **kwargs: Any): | ||
agentops.record(ActionEvent(action_type=type.value, params={"extra": extra, "message": msg})) | ||
|
||
def handle_round(self, type: RoundEventType, msg: str, extra: Any, round_id: str, **kwargs: Any): | ||
if type == RoundEventType.round_error: | ||
agentops.record( | ||
ErrorEvent(error_type=type.value, details={"round_id": round_id, "message": msg, "extra": extra}) | ||
) | ||
logger.error(f"Could not record the Round event: {msg}") | ||
self.cleanup_round() | ||
else: | ||
agentops.record( | ||
ActionEvent( | ||
action_type=type.value, | ||
params={"round_id": round_id, "extra": extra}, | ||
returns=msg, | ||
) | ||
) | ||
if type == RoundEventType.round_end: | ||
self.cleanup_round() | ||
|
||
def handle_post(self, type: PostEventType, msg: str, extra: Any, post_id: str, round_id: str, **kwargs: Any): | ||
role = extra.get("role", "Planner") | ||
agent_id = self._get_or_create_agent(role=role) | ||
|
||
if type == PostEventType.post_error: | ||
agentops.record( | ||
ErrorEvent( | ||
error_type=type.value, | ||
details={"post_id": post_id, "round_id": round_id, "message": msg, "extra": extra}, | ||
) | ||
) | ||
logger.error(f"Could not record the Post event: {msg}") | ||
|
||
elif type == PostEventType.post_start or type == PostEventType.post_end: | ||
agentops.record( | ||
ActionEvent( | ||
action_type=type.value, | ||
params={"post_id": post_id, "round_id": round_id, "extra": extra}, | ||
returns=msg, | ||
agent_id=agent_id, | ||
) | ||
) | ||
|
||
elif type == PostEventType.post_status_update: | ||
agentops.record( | ||
ActionEvent( | ||
action_type=type.value, | ||
params={"post_id": post_id, "round_id": round_id, "extra": extra}, | ||
returns=msg, | ||
agent_id=agent_id, | ||
) | ||
) | ||
|
||
elif type == PostEventType.post_attachment_update: | ||
attachment_id = extra["id"] | ||
attachment_type = extra["type"].value | ||
is_end = extra["is_end"] | ||
|
||
if attachment_id not in self._attachment_buffer: | ||
self._attachment_buffer[attachment_id] = { | ||
"role": attachment_type, | ||
"content": [], | ||
"init_timestamp": datetime.now(timezone.utc).isoformat(), | ||
"end_timestamp": None, | ||
} | ||
|
||
self._attachment_buffer[attachment_id]["content"].append(str(msg)) | ||
|
||
if is_end: | ||
self._attachment_buffer[attachment_id]["end_timestamp"] = datetime.now(timezone.utc).isoformat() | ||
complete_message = "".join(self._attachment_buffer[attachment_id]["content"]) | ||
|
||
if attachment_type in ATTACHMENT_TOOLS: | ||
agentops.record( | ||
ToolEvent( | ||
name=type.value, | ||
init_timestamp=self._attachment_buffer[attachment_id]["init_timestamp"], | ||
end_timestamp=self._attachment_buffer[attachment_id]["end_timestamp"], | ||
params={ | ||
"post_id": post_id, | ||
"round_id": round_id, | ||
"attachment_id": attachment_id, | ||
"attachment_type": self._attachment_buffer[attachment_id]["role"], | ||
"extra": extra, | ||
}, | ||
returns=complete_message, | ||
agent_id=agent_id, | ||
) | ||
) | ||
else: | ||
agentops.record( | ||
ActionEvent( | ||
action_type=type.value, | ||
init_timestamp=self._attachment_buffer[attachment_id]["init_timestamp"], | ||
end_timestamp=self._attachment_buffer[attachment_id]["end_timestamp"], | ||
params={ | ||
"post_id": post_id, | ||
"round_id": round_id, | ||
"attachment_id": attachment_id, | ||
"attachment_type": self._attachment_buffer[attachment_id]["role"], | ||
"extra": extra, | ||
}, | ||
returns=complete_message, | ||
agent_id=agent_id, | ||
) | ||
) | ||
|
||
self._attachment_buffer.pop(attachment_id, None) | ||
|
||
elif type == PostEventType.post_message_update: | ||
is_end = extra["is_end"] | ||
|
||
if post_id not in self._message_buffer: | ||
self._message_buffer[post_id] = { | ||
"content": [], | ||
"init_timestamp": datetime.now(timezone.utc).isoformat(), | ||
"end_timestamp": None, | ||
} | ||
|
||
self._message_buffer[post_id]["content"].append(str(msg)) | ||
|
||
if is_end: | ||
self._message_buffer[post_id]["end_timestamp"] = datetime.now(timezone.utc).isoformat() | ||
complete_message = "".join(self._message_buffer[post_id]["content"]) | ||
agentops.record( | ||
ActionEvent( | ||
action_type=type.value, | ||
init_timestamp=self._message_buffer[post_id]["init_timestamp"], | ||
end_timestamp=self._message_buffer[post_id]["end_timestamp"], | ||
params={ | ||
"post_id": post_id, | ||
"round_id": round_id, | ||
"extra": extra, | ||
}, | ||
returns=complete_message, | ||
agent_id=agent_id, | ||
) | ||
) | ||
|
||
self._message_buffer.pop(post_id, None) | ||
|
||
def cleanup_round(self): | ||
"""Cleanup agents and buffers for a completed round""" | ||
self._active_agents.clear() | ||
self._message_buffer.clear() | ||
self._attachment_buffer.clear() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.