Skip to content

Commit

Permalink
clean and refactor code for taskweaver LLM tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
the-praxs committed Dec 19, 2024
1 parent d794b9f commit 5523131
Showing 1 changed file with 107 additions and 158 deletions.
265 changes: 107 additions & 158 deletions agentops/llms/providers/taskweaver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pprint
from typing import Optional, Generator
import inspect
from typing import Optional
import json

from agentops.event import ErrorEvent, LLMEvent
Expand All @@ -18,193 +17,143 @@ class TaskWeaverProvider(InstrumentedProvider):
def __init__(self, client):
super().__init__(client)
self._provider_name = "TaskWeaver"

Check warning on line 19 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L18-L19

Added lines #L18 - L19 were not covered by tests
logger.info(f"TaskWeaver provider initialized with client: {client}")

def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
"""Handle responses for TaskWeaver"""
logger.info(f"[HANDLE_RESPONSE] Start handling response type: {type(response)}")
logger.info(f"[HANDLE_RESPONSE] Session: {session}")
logger.info(f"[HANDLE_RESPONSE] Processing response: {response}")
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)

Check warning on line 23 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L23

Added line #L23 was not covered by tests

try:
messages = kwargs.get("messages", [])
conversations = []
current_conversation = []

# Group messages by conversation and role
for msg in messages:
if msg['role'] == 'user' and "Let's start the new conversation!" in msg.get('content', ''):
if current_conversation:
conversations.append(current_conversation)
current_conversation = []
current_conversation.append(msg)

# Record system messages immediately
if msg['role'] == 'system':
system_event = LLMEvent(
init_timestamp=init_timestamp,
params=kwargs,
prompt=[msg],
completion=None,
model=kwargs.get("model", "unknown"),
)
if session is not None:
system_event.session_id = session.session_id
system_event.agent_id = check_call_stack_for_agent_id()
system_event.end_timestamp = get_ISO_time()
self._safe_record(session, system_event)
logger.info("[HANDLE_RESPONSE] System message event recorded")

if current_conversation:
conversations.append(current_conversation)

# Process the current response
if isinstance(response, dict):
content = response.get("content", "")
try:
if content and isinstance(content, str) and content.startswith('{"response":'):
parsed = json.loads(content)
taskweaver_response = parsed.get("response", {})
role = taskweaver_response.get("send_to")

# Record LLM event for the current role
llm_event = LLMEvent(
init_timestamp=init_timestamp,
params=kwargs,
prompt=current_conversation,
completion={
"role": "assistant",
"content": taskweaver_response.get("message", ""),
"metadata": {
"plan": taskweaver_response.get("plan"),
"current_plan_step": taskweaver_response.get("current_plan_step"),
"send_to": role,
"init_plan": taskweaver_response.get("init_plan")
},
},
model=kwargs.get("model", "unknown"),
)
if session is not None:
llm_event.session_id = session.session_id
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)
logger.info(f"[HANDLE_RESPONSE] LLM event recorded for role: {role}")

except json.JSONDecodeError as e:
logger.error(f"[HANDLE_RESPONSE] JSON decode error: {str(e)}")
raise

logger.info(f"[HANDLE_RESPONSE] Completion: {llm_event.completion}")
response_dict = response.get("response", {})
llm_event.init_timestamp = init_timestamp
llm_event.params = kwargs
llm_event.returns = response_dict
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs.get("model", "unknown")
llm_event.prompt = kwargs.get("messages")
llm_event.completion = response_dict.get("message", "")
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

Check warning on line 35 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L25-L35

Added lines #L25 - L35 were not covered by tests

except Exception as e:
logger.error(f"[HANDLE_RESPONSE] Error processing response: {str(e)}", exc_info=True)
error_event = ErrorEvent(

Check warning on line 38 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L37-L38

Added lines #L37 - L38 were not covered by tests
trigger_event=llm_event if 'llm_event' in locals() else None,
trigger_event=llm_event,
exception=e,
details={"response": str(response), "kwargs": str(kwargs)}
)
self._safe_record(session, error_event)
kwargs_str = pprint.pformat(kwargs)
response_str = pprint.pformat(response)
logger.warning(
f"[HANDLE_RESPONSE] Failed to process response:\n"
logger.error(

Check warning on line 46 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L43-L46

Added lines #L43 - L46 were not covered by tests
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n"
f"response:\n {response_str}\n"
f"kwargs:\n {kwargs_str}\n"
)

return response

Check warning on line 52 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L52

Added line #L52 was not covered by tests

def override(self):
try:
from taskweaver.llm import llm_completion_config_map

service_names, service_classes = zip(*llm_completion_config_map.items())
"""Override TaskWeaver's chat completion methods"""
global original_chat_completion

logger.info("[OVERRIDE] Starting to patch LLM services")

def patched_chat_completion(service_self, messages, stream=True, temperature=None, max_tokens=None, top_p=None, stop=None, **kwargs) -> Generator:
logger.info(f"[PATCHED] Starting patched chat completion for {service_self.__class__.__name__}")
logger.info(f"[PATCHED] Stream mode: {stream}")

try:
from taskweaver.llm.openai import OpenAIService
from taskweaver.llm.anthropic import AnthropicService
from taskweaver.llm.azure_ml import AzureMLService
from taskweaver.llm.groq import GroqService
from taskweaver.llm.ollama import OllamaService
from taskweaver.llm.qwen import QWenService
from taskweaver.llm.zhipuai import ZhipuAIService

Check warning on line 65 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L58-L65

Added lines #L58 - L65 were not covered by tests

# Create our own mapping of services
service_mapping = {

Check warning on line 68 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L68

Added line #L68 was not covered by tests
"openai": OpenAIService,
"azure": OpenAIService,
"azure_ad": OpenAIService,
"anthropic": AnthropicService,
"azure_ml": AzureMLService,
"groq": GroqService,
"ollama": OllamaService,
"qwen": QWenService,
"zhipuai": ZhipuAIService
}

def patched_chat_completion(service, *args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.pop("session", None)
logger.info(f"[PATCHED] Session: {session}")

# Get model information from service instance
model_name = "unknown"
if hasattr(service_self, "config"):
config = service_self.config
if hasattr(config, "model"):
model_name = config.model or "unknown"
elif hasattr(config, "llm_module_config") and hasattr(config.llm_module_config, "model"):
model_name = config.llm_module_config.model or "unknown"
session = kwargs.get("session", None)

Check warning on line 82 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L80-L82

Added lines #L80 - L82 were not covered by tests
if "session" in kwargs.keys():
del kwargs["session"]

Check warning on line 84 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L84

Added line #L84 was not covered by tests

original_messages = messages.copy()
logger.info(f"[PATCHED] Calling original with messages: {messages}")

result = original(
service_self,
messages=messages,
stream=stream,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=stop,
**kwargs
result = original_chat_completion(service, *args, **kwargs)
kwargs.update(

Check warning on line 87 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L86-L87

Added lines #L86 - L87 were not covered by tests
{
"model": self._get_model_name(service),
"messages": args[0],
"stream": args[1],
"temperature": args[2],
"max_tokens": args[3],
"top_p": args[4],
"stop": args[5],
}
)
logger.info(f"[PATCHED] Got result type: {type(result)}")

if stream:
logger.info("[PATCHED] Handling streaming response")
accumulated_response = {"role": "assistant", "content": ""}
for response in result:
logger.info(f"[PATCHED] Stream chunk: {response}")
if isinstance(response, dict) and "content" in response:
accumulated_response["content"] += response["content"]

if kwargs["stream"]:
accumulated_content = ""

Check warning on line 100 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L100

Added line #L100 was not covered by tests
for chunk in result:
if isinstance(chunk, dict) and "content" in chunk:
accumulated_content += chunk["content"]

Check warning on line 103 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L103

Added line #L103 was not covered by tests
else:
accumulated_response["content"] += str(response)
yield response

logger.info(f"[PATCHED] Recording accumulated response: {accumulated_response}")
kwargs["messages"] = original_messages
kwargs["model"] = model_name
self.handle_response(accumulated_response, kwargs, init_timestamp, session=session)
accumulated_content += chunk
yield chunk
accumulated_content = json.loads(accumulated_content)
return self.handle_response(accumulated_content, kwargs, init_timestamp, session=session)

Check warning on line 108 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L105-L108

Added lines #L105 - L108 were not covered by tests
else:
logger.info("[PATCHED] Handling non-streaming response")
response = next(result) if hasattr(result, '__next__') else result
logger.info(f"[PATCHED] Non-stream response: {response}")
kwargs["messages"] = original_messages
kwargs["model"] = model_name
self.handle_response(response, kwargs, init_timestamp, session=session)
return response

# Patch all services
for service in service_classes:
try:
logger.info(f"[OVERRIDE] Attempting to patch {service.__name__}")
if not hasattr(service, '_original_chat_completion'):
original = service.chat_completion
service._original_chat_completion = original
service.chat_completion = patched_chat_completion
logger.info(f"[OVERRIDE] Successfully patched {service.__name__}")
except Exception as e:
logger.error(f"[OVERRIDE] Failed to patch {service.__name__}: {str(e)}", exc_info=True)
return self.handle_response(result, kwargs, init_timestamp, session=session)

Check warning on line 110 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L110

Added line #L110 was not covered by tests

for service_name, service_class in service_mapping.items():
original_chat_completion = service_class.chat_completion
service_class.chat_completion = patched_chat_completion

Check warning on line 114 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L113-L114

Added lines #L113 - L114 were not covered by tests

except Exception as e:
logger.error(f"[OVERRIDE] Failed to patch services: {str(e)}", exc_info=True)
logger.error(f"Failed to patch method: {str(e)}", exc_info=True)

Check warning on line 117 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L116-L117

Added lines #L116 - L117 were not covered by tests

def undo_override(self):
"""Restore original TaskWeaver chat completion methods"""
try:
from taskweaver.llm import llm_completion_config_map

# Check each service for patching and undo if found
for service_name, service_class in llm_completion_config_map.items():
if hasattr(service_class, '_original_chat_completion'):
service_class.chat_completion = service_class._original_chat_completion
delattr(service_class, '_original_chat_completion')
logger.info(f"[UNDO] Restored original methods for {service_class.__name__}")
# Break after finding the patched service since we only patch one
break
from taskweaver.llm.openai import OpenAIService
from taskweaver.llm.anthropic import AnthropicService
from taskweaver.llm.azure_ml import AzureMLService
from taskweaver.llm.groq import GroqService
from taskweaver.llm.ollama import OllamaService
from taskweaver.llm.qwen import QWenService
from taskweaver.llm.zhipuai import ZhipuAIService

Check warning on line 128 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L121-L128

Added lines #L121 - L128 were not covered by tests

# Create our own mapping of services
service_mapping = {

Check warning on line 131 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L131

Added line #L131 was not covered by tests
"openai": OpenAIService,
"azure": OpenAIService,
"azure_ad": OpenAIService,
"anthropic": AnthropicService,
"azure_ml": AzureMLService,
"groq": GroqService,
"ollama": OllamaService,
"qwen": QWenService,
"zhipuai": ZhipuAIService
}

if original_chat_completion is not None:
for service_name, service_class in service_mapping.items():
service_class.chat_completion = original_chat_completion

Check warning on line 145 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L145

Added line #L145 was not covered by tests

except Exception as e:
logger.error(f"[UNDO] Failed to restore original methods: {str(e)}", exc_info=True)
logger.error(f"Failed to restore original method: {str(e)}", exc_info=True)

Check warning on line 148 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L147-L148

Added lines #L147 - L148 were not covered by tests

def _get_model_name(self, service) -> str:
"""Extract model name from service instance"""
model_name = "unknown"

Check warning on line 152 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L152

Added line #L152 was not covered by tests
if hasattr(service, "config"):
config = service.config

Check warning on line 154 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L154

Added line #L154 was not covered by tests
if hasattr(config, "model"):
model_name = config.model or "unknown"

Check warning on line 156 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L156

Added line #L156 was not covered by tests
elif hasattr(config, "llm_module_config") and hasattr(config.llm_module_config, "model"):
model_name = config.llm_module_config.model or "unknown"
return model_name

Check warning on line 159 in agentops/llms/providers/taskweaver.py

View check run for this annotation

Codecov / codecov/patch

agentops/llms/providers/taskweaver.py#L158-L159

Added lines #L158 - L159 were not covered by tests

0 comments on commit 5523131

Please sign in to comment.