Skip to content

Commit

Permalink
llm refactor progress
Browse files Browse the repository at this point in the history
  • Loading branch information
bboynton97 committed Aug 12, 2024
1 parent bc9ab5f commit 8313a69
Show file tree
Hide file tree
Showing 10 changed files with 440 additions and 283 deletions.
2 changes: 1 addition & 1 deletion agentops/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from termcolor import colored

from .event import Event, ErrorEvent
from .helpers import (
from .singleton import (
conditional_singleton,
)
from .session import Session, active_sessions
Expand Down
39 changes: 8 additions & 31 deletions agentops/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,12 @@
import json
from importlib.metadata import version, PackageNotFoundError

from . import Client
from .log_config import logger
from uuid import UUID
from importlib.metadata import version
import subprocess

ao_instances = {}


def singleton(class_):

def getinstance(*args, **kwargs):
if class_ not in ao_instances:
ao_instances[class_] = class_(*args, **kwargs)
return ao_instances[class_]

return getinstance


def conditional_singleton(class_):

def getinstance(*args, **kwargs):
use_singleton = kwargs.pop("use_singleton", True)
if use_singleton:
if class_ not in ao_instances:
ao_instances[class_] = class_(*args, **kwargs)
return ao_instances[class_]
else:
return class_(*args, **kwargs)

return getinstance


def clear_singletons():
global ao_instances
ao_instances = {}


def get_ISO_time():
"""
Expand Down Expand Up @@ -212,3 +182,10 @@ def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)

return wrapper


def safe_record(session, event):
if session is not None:
session.record(event)
else:
Client().record(event)
6 changes: 0 additions & 6 deletions agentops/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,3 @@ def stop_instrumenting(self):
undo_override_openai_v1_async_completion()
undo_override_openai_v1_completion()
undo_override_ollama(self)

def _safe_record(self, session, event):
if session is not None:
session.record(event)
else:
self.client.record(event)
14 changes: 6 additions & 8 deletions agentops/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None):
"content": chunk.response.text,
}
tracker.llm_event.end_timestamp = get_ISO_time()
tracker._safe_record(session, tracker.llm_event)
safe_record(session, tracker.llm_event)

# StreamedChatResponse_SearchResults = ActionEvent
search_results = chunk.response.search_results
Expand Down Expand Up @@ -80,7 +80,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None):
action_event.end_timestamp = get_ISO_time()

for key, action_event in tracker.action_events.items():
tracker._safe_record(session, action_event)
safe_record(session, action_event)

elif isinstance(chunk, StreamedChatResponse_TextGeneration):
tracker.llm_event.completion += chunk.text
Expand All @@ -106,7 +106,7 @@ def handle_stream_chunk(chunk, session: Optional[Session] = None):
pass

except Exception as e:
tracker._safe_record(
safe_record(
session, ErrorEvent(trigger_event=tracker.llm_event, exception=e)
)

Expand Down Expand Up @@ -166,11 +166,9 @@ def generator():
tracker.llm_event.completion_tokens = response.meta.tokens.output_tokens
tracker.llm_event.model = kwargs.get("model", "command-r-plus")

tracker._safe_record(session, tracker.llm_event)
safe_record(session, tracker.llm_event)
except Exception as e:
tracker._safe_record(
session, ErrorEvent(trigger_event=tracker.llm_event, exception=e)
)
safe_record(session, ErrorEvent(trigger_event=tracker.llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
logger.warning(
Expand All @@ -194,7 +192,7 @@ def patched_function(*args, **kwargs):
if "session" in kwargs.keys():
del kwargs["session"]
result = original_chat(*args, **kwargs)
return tracker._handle_response_cohere(
return tracker.handle_response_cohere(
result, kwargs, init_timestamp, session=session
)

Expand Down
10 changes: 4 additions & 6 deletions agentops/llms/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
}
tracker.llm_event.end_timestamp = get_ISO_time()

tracker._safe_record(session, tracker.llm_event)
safe_record(session, tracker.llm_event)
except Exception as e:
tracker._safe_record(
safe_record(
session, ErrorEvent(trigger_event=tracker.llm_event, exception=e)
)

Expand Down Expand Up @@ -111,11 +111,9 @@ async def async_generator():
tracker.llm_event.completion_tokens = response.usage.completion_tokens
tracker.llm_event.model = response.model

tracker._safe_record(session, tracker.llm_event)
safe_record(session, tracker.llm_event)
except Exception as e:
tracker._safe_record(
session, ErrorEvent(trigger_event=tracker.llm_event, exception=e)
)
safe_record(session, ErrorEvent(trigger_event=tracker.llm_event, exception=e))

kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
Expand Down
26 changes: 26 additions & 0 deletions agentops/llms/instrumented_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import Optional

from agentops import Session


class InstrumentedProvider(ABC):
_provider_name: str = "InstrumentedModel"

@abstractmethod
def handle_response(
self, response, kwargs, init_timestamp, session: Optional[Session] = None
) -> dict:
pass

@abstractmethod
def override(self):
pass

@abstractmethod
def undo_override(self):
pass

@property
def provider_name(self):
return self._provider_name
2 changes: 1 addition & 1 deletion agentops/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def generator():
tracker.llm_event.prompt = kwargs["messages"]
tracker.llm_event.completion = response["message"]

tracker._safe_record(session, tracker.llm_event)
safe_record(session, tracker.llm_event)
return response


Expand Down
Loading

0 comments on commit 8313a69

Please sign in to comment.