diff --git a/agentops/llms/providers/taskweaver.py b/agentops/llms/providers/taskweaver.py index 7c0e8d92..b9df2419 100644 --- a/agentops/llms/providers/taskweaver.py +++ b/agentops/llms/providers/taskweaver.py @@ -1,6 +1,5 @@ import pprint -from typing import Optional, Generator -import inspect +from typing import Optional import json from agentops.event import ErrorEvent, LLMEvent @@ -18,96 +17,34 @@ class TaskWeaverProvider(InstrumentedProvider): def __init__(self, client): super().__init__(client) self._provider_name = "TaskWeaver" - logger.info(f"TaskWeaver provider initialized with client: {client}") def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: """Handle responses for TaskWeaver""" - logger.info(f"[HANDLE_RESPONSE] Start handling response type: {type(response)}") - logger.info(f"[HANDLE_RESPONSE] Session: {session}") - logger.info(f"[HANDLE_RESPONSE] Processing response: {response}") + llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) try: - messages = kwargs.get("messages", []) - conversations = [] - current_conversation = [] - - # Group messages by conversation and role - for msg in messages: - if msg['role'] == 'user' and "Let's start the new conversation!" in msg.get('content', ''): - if current_conversation: - conversations.append(current_conversation) - current_conversation = [] - current_conversation.append(msg) - - # Record system messages immediately - if msg['role'] == 'system': - system_event = LLMEvent( - init_timestamp=init_timestamp, - params=kwargs, - prompt=[msg], - completion=None, - model=kwargs.get("model", "unknown"), - ) - if session is not None: - system_event.session_id = session.session_id - system_event.agent_id = check_call_stack_for_agent_id() - system_event.end_timestamp = get_ISO_time() - self._safe_record(session, system_event) - logger.info("[HANDLE_RESPONSE] System message event recorded") - - if current_conversation: - conversations.append(current_conversation) - - # Process the current response - if isinstance(response, dict): - content = response.get("content", "") - try: - if content and isinstance(content, str) and content.startswith('{"response":'): - parsed = json.loads(content) - taskweaver_response = parsed.get("response", {}) - role = taskweaver_response.get("send_to") - - # Record LLM event for the current role - llm_event = LLMEvent( - init_timestamp=init_timestamp, - params=kwargs, - prompt=current_conversation, - completion={ - "role": "assistant", - "content": taskweaver_response.get("message", ""), - "metadata": { - "plan": taskweaver_response.get("plan"), - "current_plan_step": taskweaver_response.get("current_plan_step"), - "send_to": role, - "init_plan": taskweaver_response.get("init_plan") - }, - }, - model=kwargs.get("model", "unknown"), - ) - if session is not None: - llm_event.session_id = session.session_id - llm_event.end_timestamp = get_ISO_time() - self._safe_record(session, llm_event) - logger.info(f"[HANDLE_RESPONSE] LLM event recorded for role: {role}") - - except json.JSONDecodeError as e: - logger.error(f"[HANDLE_RESPONSE] JSON decode error: {str(e)}") - raise - - logger.info(f"[HANDLE_RESPONSE] Completion: {llm_event.completion}") + response_dict = response.get("response", {}) + llm_event.init_timestamp = init_timestamp + llm_event.params = kwargs + llm_event.returns = response_dict + llm_event.agent_id = check_call_stack_for_agent_id() + llm_event.model = kwargs.get("model", "unknown") + llm_event.prompt = kwargs.get("messages") + llm_event.completion = response_dict.get("message", "") + llm_event.end_timestamp = get_ISO_time() + self._safe_record(session, llm_event) except Exception as e: - logger.error(f"[HANDLE_RESPONSE] Error processing response: {str(e)}", exc_info=True) error_event = ErrorEvent( - trigger_event=llm_event if 'llm_event' in locals() else None, + trigger_event=llm_event, exception=e, details={"response": str(response), "kwargs": str(kwargs)} ) self._safe_record(session, error_event) kwargs_str = pprint.pformat(kwargs) response_str = pprint.pformat(response) - logger.warning( - f"[HANDLE_RESPONSE] Failed to process response:\n" + logger.error( + f"Unable to parse response for LLM call. Skipping upload to AgentOps\n" f"response:\n {response_str}\n" f"kwargs:\n {kwargs_str}\n" ) @@ -115,96 +52,108 @@ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Se return response def override(self): - try: - from taskweaver.llm import llm_completion_config_map - - service_names, service_classes = zip(*llm_completion_config_map.items()) + """Override TaskWeaver's chat completion methods""" + global original_chat_completion - logger.info("[OVERRIDE] Starting to patch LLM services") - - def patched_chat_completion(service_self, messages, stream=True, temperature=None, max_tokens=None, top_p=None, stop=None, **kwargs) -> Generator: - logger.info(f"[PATCHED] Starting patched chat completion for {service_self.__class__.__name__}") - logger.info(f"[PATCHED] Stream mode: {stream}") - + try: + from taskweaver.llm.openai import OpenAIService + from taskweaver.llm.anthropic import AnthropicService + from taskweaver.llm.azure_ml import AzureMLService + from taskweaver.llm.groq import GroqService + from taskweaver.llm.ollama import OllamaService + from taskweaver.llm.qwen import QWenService + from taskweaver.llm.zhipuai import ZhipuAIService + + # Create our own mapping of services + service_mapping = { + "openai": OpenAIService, + "azure": OpenAIService, + "azure_ad": OpenAIService, + "anthropic": AnthropicService, + "azure_ml": AzureMLService, + "groq": GroqService, + "ollama": OllamaService, + "qwen": QWenService, + "zhipuai": ZhipuAIService + } + + def patched_chat_completion(service, *args, **kwargs): init_timestamp = get_ISO_time() - session = kwargs.pop("session", None) - logger.info(f"[PATCHED] Session: {session}") - - # Get model information from service instance - model_name = "unknown" - if hasattr(service_self, "config"): - config = service_self.config - if hasattr(config, "model"): - model_name = config.model or "unknown" - elif hasattr(config, "llm_module_config") and hasattr(config.llm_module_config, "model"): - model_name = config.llm_module_config.model or "unknown" + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] - original_messages = messages.copy() - logger.info(f"[PATCHED] Calling original with messages: {messages}") - - result = original( - service_self, - messages=messages, - stream=stream, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - stop=stop, - **kwargs + result = original_chat_completion(service, *args, **kwargs) + kwargs.update( + { + "model": self._get_model_name(service), + "messages": args[0], + "stream": args[1], + "temperature": args[2], + "max_tokens": args[3], + "top_p": args[4], + "stop": args[5], + } ) - logger.info(f"[PATCHED] Got result type: {type(result)}") - - if stream: - logger.info("[PATCHED] Handling streaming response") - accumulated_response = {"role": "assistant", "content": ""} - for response in result: - logger.info(f"[PATCHED] Stream chunk: {response}") - if isinstance(response, dict) and "content" in response: - accumulated_response["content"] += response["content"] + + if kwargs["stream"]: + accumulated_content = "" + for chunk in result: + if isinstance(chunk, dict) and "content" in chunk: + accumulated_content += chunk["content"] else: - accumulated_response["content"] += str(response) - yield response - - logger.info(f"[PATCHED] Recording accumulated response: {accumulated_response}") - kwargs["messages"] = original_messages - kwargs["model"] = model_name - self.handle_response(accumulated_response, kwargs, init_timestamp, session=session) + accumulated_content += chunk + yield chunk + accumulated_content = json.loads(accumulated_content) + return self.handle_response(accumulated_content, kwargs, init_timestamp, session=session) else: - logger.info("[PATCHED] Handling non-streaming response") - response = next(result) if hasattr(result, '__next__') else result - logger.info(f"[PATCHED] Non-stream response: {response}") - kwargs["messages"] = original_messages - kwargs["model"] = model_name - self.handle_response(response, kwargs, init_timestamp, session=session) - return response - - # Patch all services - for service in service_classes: - try: - logger.info(f"[OVERRIDE] Attempting to patch {service.__name__}") - if not hasattr(service, '_original_chat_completion'): - original = service.chat_completion - service._original_chat_completion = original - service.chat_completion = patched_chat_completion - logger.info(f"[OVERRIDE] Successfully patched {service.__name__}") - except Exception as e: - logger.error(f"[OVERRIDE] Failed to patch {service.__name__}: {str(e)}", exc_info=True) + return self.handle_response(result, kwargs, init_timestamp, session=session) + for service_name, service_class in service_mapping.items(): + original_chat_completion = service_class.chat_completion + service_class.chat_completion = patched_chat_completion + except Exception as e: - logger.error(f"[OVERRIDE] Failed to patch services: {str(e)}", exc_info=True) + logger.error(f"Failed to patch method: {str(e)}", exc_info=True) def undo_override(self): + """Restore original TaskWeaver chat completion methods""" try: - from taskweaver.llm import llm_completion_config_map - - # Check each service for patching and undo if found - for service_name, service_class in llm_completion_config_map.items(): - if hasattr(service_class, '_original_chat_completion'): - service_class.chat_completion = service_class._original_chat_completion - delattr(service_class, '_original_chat_completion') - logger.info(f"[UNDO] Restored original methods for {service_class.__name__}") - # Break after finding the patched service since we only patch one - break + from taskweaver.llm.openai import OpenAIService + from taskweaver.llm.anthropic import AnthropicService + from taskweaver.llm.azure_ml import AzureMLService + from taskweaver.llm.groq import GroqService + from taskweaver.llm.ollama import OllamaService + from taskweaver.llm.qwen import QWenService + from taskweaver.llm.zhipuai import ZhipuAIService + + # Create our own mapping of services + service_mapping = { + "openai": OpenAIService, + "azure": OpenAIService, + "azure_ad": OpenAIService, + "anthropic": AnthropicService, + "azure_ml": AzureMLService, + "groq": GroqService, + "ollama": OllamaService, + "qwen": QWenService, + "zhipuai": ZhipuAIService + } + + if original_chat_completion is not None: + for service_name, service_class in service_mapping.items(): + service_class.chat_completion = original_chat_completion except Exception as e: - logger.error(f"[UNDO] Failed to restore original methods: {str(e)}", exc_info=True) \ No newline at end of file + logger.error(f"Failed to restore original method: {str(e)}", exc_info=True) + + def _get_model_name(self, service) -> str: + """Extract model name from service instance""" + model_name = "unknown" + if hasattr(service, "config"): + config = service.config + if hasattr(config, "model"): + model_name = config.model or "unknown" + elif hasattr(config, "llm_module_config") and hasattr(config.llm_module_config, "model"): + model_name = config.llm_module_config.model or "unknown" + return model_name \ No newline at end of file