Skip to content

Commit

Permalink
feature: additional cohere support and example script (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
HowieG authored May 17, 2024
1 parent ed9d7b3 commit 1ef9ef5
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 38 deletions.
56 changes: 48 additions & 8 deletions agentops/llm_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from importlib.metadata import version
from packaging.version import Version, parse
from .log_config import logger
from .event import LLMEvent, ErrorEvent
from .event import LLMEvent, ActionEvent, ToolEvent, ErrorEvent
from .helpers import get_ISO_time, check_call_stack_for_agent_id
import inspect
from typing import Optional
Expand Down Expand Up @@ -228,7 +228,7 @@ async def async_generator():
return response

def _handle_response_cohere(self, response, kwargs, init_timestamp):
# TODO: """Handle responses for Cohere versions >v5.4.0"""
"""Handle responses for Cohere versions >v5.4.0"""
from cohere.types.non_streamed_chat_response import NonStreamedChatResponse
from cohere.types.streamed_chat_response import (
StreamedChatResponse,
Expand All @@ -248,6 +248,8 @@ def _handle_response_cohere(self, response, kwargs, init_timestamp):
params=kwargs
)

self.action_events = {}

def handle_stream_chunk(chunk):

# We take the first chunk and accumulate the deltas from all subsequent chunks to build one full chat completion
Expand All @@ -261,21 +263,59 @@ def handle_stream_chunk(chunk):

try:
if isinstance(chunk, StreamedChatResponse_StreamEnd):
# Streaming is done. Record LLMEvent
# self.llm_event.returns.finish_reason = chunk.is_finished
# StreamedChatResponse_TextGeneration = LLMEvent
self.llm_event.completion = {
"role": "assistant", "content": self.llm_event.completion}
"role": "assistant", "content": chunk.response.text}
self.llm_event.end_timestamp = get_ISO_time()

self.client.record(self.llm_event)

# StreamedChatResponse_SearchResults = ActionEvent
search_results = chunk.response.search_results
for search_result in search_results:
query = search_result.search_query
if query.generation_id in self.action_events:
action_event = self.action_events[query.generation_id]
search_result_dict = search_result.dict()
del search_result_dict["search_query"]
action_event.returns = search_result_dict
action_event.end_timestamp = get_ISO_time()

# StreamedChatResponse_CitationGeneration = ActionEvent
documents = {doc['id']: doc for doc in chunk.response.documents}
citations = chunk.response.citations
for citation in citations:
citation_id = f"{citation.start}.{citation.end}"
if citation_id in self.action_events:
action_event = self.action_events[citation_id]
citation_dict = citation.dict()
# Replace document_ids with the actual documents
citation_dict['documents'] = [documents[doc_id]
for doc_id in citation_dict['document_ids'] if doc_id in documents]
del citation_dict['document_ids']

action_event.returns = citation_dict
action_event.end_timestamp = get_ISO_time()

for key, action_event in self.action_events.items():
self.client.record(action_event)

elif isinstance(chunk, StreamedChatResponse_TextGeneration):
self.llm_event.completion += chunk.text
elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
pass
elif isinstance(chunk, StreamedChatResponse_CitationGeneration):
pass
for citation in chunk.citations:
self.action_events[f"{citation.start}.{citation.end}"] = ActionEvent(
action_type="citation",
init_timestamp=get_ISO_time(),
params=citation.text)
elif isinstance(chunk, StreamedChatResponse_SearchQueriesGeneration):
for query in chunk.search_queries:
self.action_events[query.generation_id] = ActionEvent(
action_type="search_query",
init_timestamp=get_ISO_time(),
params=query.text)
elif isinstance(chunk, StreamedChatResponse_SearchResults):
pass

except Exception as e:
Expand Down Expand Up @@ -498,7 +538,7 @@ def override_api(self):
self._override_method(api, method_path, module)

if api == 'cohere':
# Patch cohere vx.x.x+ methods
# Patch cohere v5.4.0+ methods
module_version = version(api)
if module_version is None:
logger.warning(f'🖇 AgentOps: Cannot determine Cohere version. Only Cohere>=5.4.0 supported.')
Expand Down
30 changes: 0 additions & 30 deletions examples/cohere.py

This file was deleted.

46 changes: 46 additions & 0 deletions examples/cohere_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import cohere
import agentops # just
from dotenv import load_dotenv
load_dotenv()

agentops.init(tags=["cohere", "agentops-demo"]) # three
co = cohere.Client()

stream = co.chat_stream(
message="Tell me everything you can about AgentOps",
connectors=[{"id": "web-search"}]
)

response = ""
for event in stream:
if event.event_type == "text-generation":
response += event.text
print(event.text, end='')
elif event.event_type == "stream-end":
print("\n")
print(event)
print("\n")

stream = co.chat_stream(
chat_history=[
{"role": "SYSTEM", "message": "You are Adam Silverman: die-hard advocate of AgentOps, leader in AI Agent observability"},
{
"role": "CHATBOT",
"message": "How's your day going? I'd like to tell you about AgentOps: {response}",
},
],
message="Based on your newfound knowledge of AgentOps, is Cohere a suitable partner for them and how could they integrate?",
connectors=[{"id": "web-search"}]
)

response = ""
for event in stream:
if event.event_type == "text-generation":
response += event.text
print(event.text, end='')
elif event.event_type == "stream-end":
print("\n")
print(event)
print("\n")

agentops.end_session('Success') # lines

0 comments on commit 1ef9ef5

Please sign in to comment.