Skip to content

Commit

Permalink
Merge branch 'main' into devin/1734589134-anthropic-streaming-context
Browse files Browse the repository at this point in the history
  • Loading branch information
the-praxs authored Dec 24, 2024
2 parents cc945a8 + 915cb59 commit 041d51f
Show file tree
Hide file tree
Showing 26 changed files with 650 additions and 535 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/python-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ on:

jobs:
build:

runs-on: ubuntu-latest
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

strategy:
matrix:
Expand All @@ -33,8 +34,11 @@ jobs:
- uses: actions/setup-python@v5
with:
cache: 'pip'
python-version: '3.11' # Use a default Python version for running tox
python-version: '3.11'
- name: Install tox
run: pip install tox
- name: Run tests with tox
run: tox
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AGENTOPS_API_KEY: ${{ secrets.AGENTOPS_API_KEY }}
146 changes: 146 additions & 0 deletions agentops/llms/providers/taskweaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import pprint
from typing import Optional
import json

from agentops.event import ErrorEvent, LLMEvent, ActionEvent
from agentops.session import Session
from agentops.log_config import logger
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id
from agentops.llms.providers.instrumented_provider import InstrumentedProvider
from agentops.singleton import singleton


@singleton
class TaskWeaverProvider(InstrumentedProvider):
original_chat_completion = None

def __init__(self, client):
super().__init__(client)
self._provider_name = "TaskWeaver"
self.client.add_default_tags(["taskweaver"])

def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
"""Handle responses for TaskWeaver"""
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
action_event = ActionEvent(init_timestamp=init_timestamp)

try:
response_dict = response.get("response", {})

action_event.params = kwargs.get("json_schema", None)
action_event.returns = response_dict
action_event.end_timestamp = get_ISO_time()
self._safe_record(session, action_event)
except Exception as e:
error_event = ErrorEvent(
trigger_event=action_event, exception=e, details={"response": str(response), "kwargs": str(kwargs)}
)
self._safe_record(session, error_event)
kwargs_str = pprint.pformat(kwargs)
response_str = pprint.pformat(response)
logger.error(
f"Unable to parse response for Action call. Skipping upload to AgentOps\n"
f"response:\n {response_str}\n"
f"kwargs:\n {kwargs_str}\n"
)

try:
llm_event.init_timestamp = init_timestamp
llm_event.params = kwargs
llm_event.returns = response_dict
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs.get("model", "unknown")
llm_event.prompt = kwargs.get("messages")
llm_event.completion = response_dict.get("message", "")
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

except Exception as e:
error_event = ErrorEvent(
trigger_event=llm_event, exception=e, details={"response": str(response), "kwargs": str(kwargs)}
)
self._safe_record(session, error_event)
kwargs_str = pprint.pformat(kwargs)
response_str = pprint.pformat(response)
logger.error(
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n"
f"response:\n {response_str}\n"
f"kwargs:\n {kwargs_str}\n"
)

return response

def override(self):
"""Override TaskWeaver's chat completion methods"""
try:
from taskweaver.llm import llm_completion_config_map

def create_patched_chat_completion(original_method):
"""Create a new patched chat_completion function with bound original method"""

def patched_chat_completion(service, *args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

result = original_method(service, *args, **kwargs)
kwargs.update(
{
"model": self._get_model_name(service),
"messages": args[0],
"stream": args[1],
"temperature": args[2],
"max_tokens": args[3],
"top_p": args[4],
"stop": args[5],
}
)

if kwargs["stream"]:
accumulated_content = ""
for chunk in result:
if isinstance(chunk, dict) and "content" in chunk:
accumulated_content += chunk["content"]
else:
accumulated_content += chunk
yield chunk
accumulated_content = json.loads(accumulated_content)
return self.handle_response(accumulated_content, kwargs, init_timestamp, session=session)
else:
return self.handle_response(result, kwargs, init_timestamp, session=session)

return patched_chat_completion

for service_name, service_class in llm_completion_config_map.items():
if not hasattr(service_class, "_original_chat_completion"):
service_class._original_chat_completion = service_class.chat_completion
service_class.chat_completion = create_patched_chat_completion(
service_class._original_chat_completion
)

except Exception as e:
logger.error(f"Failed to patch method: {str(e)}", exc_info=True)

def undo_override(self):
"""Restore original TaskWeaver chat completion methods"""
try:
from taskweaver.llm import llm_completion_config_map

for service_name, service_class in llm_completion_config_map.items():
service_class.chat_completion = service_class._original_chat_completion
delattr(service_class, "_original_chat_completion")

except Exception as e:
logger.error(f"Failed to restore original method: {str(e)}", exc_info=True)

def _get_model_name(self, service) -> str:
"""Extract model name from service instance"""
model_name = "unknown"
if hasattr(service, "config"):
config = service.config
if hasattr(config, "model"):
model_name = config.model or "unknown"
elif hasattr(config, "llm_module_config") and hasattr(config.llm_module_config, "model"):
model_name = config.llm_module_config.model or "unknown"
return model_name
14 changes: 14 additions & 0 deletions agentops/llms/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .providers.anthropic import AnthropicProvider
from .providers.mistral import MistralProvider
from .providers.ai21 import AI21Provider
from .providers.taskweaver import TaskWeaverProvider

original_func = {}
original_create = None
Expand Down Expand Up @@ -54,6 +55,9 @@ class LlmTracker:
"client.answer.create",
),
},
"taskweaver": {
"0.0.1": ("chat_completion", "chat_completion_stream"),
},
}

def __init__(self, client):
Expand Down Expand Up @@ -164,6 +168,15 @@ def override_api(self):
else:
logger.warning(f"Only LlamaStackClient>=0.0.53 supported. v{module_version} found.")

if api == "taskweaver":
module_version = version(api)

if Version(module_version) >= parse("0.0.1"):
provider = TaskWeaverProvider(self.client)
provider.override()
else:
logger.warning(f"Only TaskWeaver>=0.0.1 supported. v{module_version} found.")

def stop_instrumenting(self):
OpenAiProvider(self.client).undo_override()
GroqProvider(self.client).undo_override()
Expand All @@ -174,3 +187,4 @@ def stop_instrumenting(self):
MistralProvider(self.client).undo_override()
AI21Provider(self.client).undo_override()
LlamaStackClientProvider(self.client).undo_override()
TaskWeaverProvider(self.client).undo_override()
Loading

0 comments on commit 041d51f

Please sign in to comment.