diff --git a/ads/llm/autogen/__init__.py b/ads/llm/autogen/__init__.py index e69de29bb..72e03c615 100644 --- a/ads/llm/autogen/__init__.py +++ b/ads/llm/autogen/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/ads/llm/autogen/constants.py b/ads/llm/autogen/constants.py new file mode 100644 index 000000000..75d3bcd32 --- /dev/null +++ b/ads/llm/autogen/constants.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + + +class Events: + KEY = "event_name" + + EXCEPTION = "exception" + LLM_CALL = "llm_call" + TOOL_CALL = "tool_call" + NEW_AGENT = "new_agent" + NEW_CLIENT = "new_client" + RECEIVED_MESSAGE = "received_message" + SESSION_START = "logging_session_start" + SESSION_STOP = "logging_session_stop" diff --git a/ads/llm/autogen/reports/__init__.py b/ads/llm/autogen/reports/__init__.py new file mode 100644 index 000000000..72e03c615 --- /dev/null +++ b/ads/llm/autogen/reports/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/ads/llm/autogen/reports/base.py b/ads/llm/autogen/reports/base.py new file mode 100644 index 000000000..4f081eb76 --- /dev/null +++ b/ads/llm/autogen/reports/base.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import json +import logging +import os + +from jinja2 import Environment, FileSystemLoader + +logger = logging.getLogger(__name__) + + +class BaseReport: + """Base class containing utilities for generating reports.""" + + @staticmethod + def format_json_string(s) -> str: + """Formats the JSON string in markdown.""" + return f"```json\n{json.dumps(json.loads(s), indent=2)}\n```" + + @staticmethod + def _parse_date_time(datetime_string: str): + """Parses a datetime string in the logs into date and time. + Keeps only the seconds in the time. + """ + date_str, time_str = datetime_string.split(" ", 1) + time_str = time_str.split(".", 1)[0] + return date_str, time_str + + @staticmethod + def _preview_message(message: str, max_length=30) -> str: + """Shows the beginning part of a string message.""" + # Return the entire string if it is less than the max_length + if len(message) <= max_length: + return message + # Go backward until we find the first whitespace + idx = 30 + while not message[idx].isspace() and idx > 0: + idx -= 1 + # If we found a whitespace + if idx > 0: + return message[:idx] + "..." + # If we didn't find a whitespace + return message[:30] + "..." + + @classmethod + def _render_template(cls, template_path, **kwargs) -> str: + """Render Jinja template with kwargs.""" + template_dir = os.path.join(os.path.dirname(__file__), "templates") + environment = Environment( + loader=FileSystemLoader(template_dir), autoescape=True + ) + template = environment.get_template(template_path) + try: + html = template.render(**kwargs) + except Exception: + logger.error( + "Unable to render template %s with data:\n%s", + template_path, + str(kwargs), + ) + return cls._render_template( + template_path=template_path, + sender=kwargs.get("sender", "N/A"), + content="TEMPLATE RENDER ERROR", + timestamp=kwargs.get("timestamp", ""), + ) + return html diff --git a/ads/llm/autogen/reports/data.py b/ads/llm/autogen/reports/data.py new file mode 100644 index 000000000..9e70cfa7a --- /dev/null +++ b/ads/llm/autogen/reports/data.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +"""Contains the data structure for logging and reporting.""" +import copy +import json +from dataclasses import asdict, dataclass, field +from typing import Optional, Union + +from ads.llm.autogen.constants import Events + + +@dataclass +class LogData: + """Base class for the data field of LogRecord.""" + + def to_dict(self): + """Convert the log data to dictionary.""" + return asdict(self) + + +@dataclass +class LogRecord: + """Represents a log record. + + The `data` field is for pre-defined structured data, which should be an instance of LogData. + The `kwargs` field is for freeform key value pairs. + """ + + session_id: str + thread_id: int + timestamp: str + event_name: str + source_id: Optional[int] = None + source_name: Optional[str] = None + # Structured data for specific type of logs + data: Optional[LogData] = None + # Freeform data + kwargs: dict = field(default_factory=dict) + + def to_dict(self): + """Convert the log record to dictionary.""" + return asdict(self) + + def to_string(self): + """Serialize the log record to JSON string.""" + return json.dumps(self.to_dict(), default=str) + + @classmethod + def from_dict(cls, data: dict) -> "LogRecord": + """Initializes a LogRecord object from dictionary.""" + event_mapping = { + Events.NEW_AGENT: AgentData, + Events.TOOL_CALL: ToolCallData, + Events.LLM_CALL: LLMCompletionData, + } + if Events.KEY not in data: + raise KeyError("event_name not found in data.") + + data = copy.deepcopy(data) + + event_name = data["event_name"] + if event_name in event_mapping and data.get("data"): + data["data"] = event_mapping[event_name](**data.pop("data")) + + return cls(**data) + + +@dataclass +class AgentData(LogData): + """Represents agent log Data.""" + + agent_name: str + agent_class: str + agent_module: Optional[str] = None + is_manager: Optional[bool] = None + + +@dataclass +class LLMCompletionData(LogData): + """Represents LLM completion log data.""" + + invocation_id: str + request: dict + response: dict + start_time: str + end_time: str + cost: Optional[float] = None + is_cached: Optional[bool] = None + + +@dataclass +class ToolCallData(LogData): + """Represents tool call log data.""" + + tool_name: str + start_time: str + end_time: str + agent_name: str + agent_class: str + agent_module: Optional[str] = None + input_args: dict = field(default_factory=dict) + returns: Optional[Union[str, list, dict, tuple]] = None diff --git a/ads/llm/autogen/reports/session.py b/ads/llm/autogen/reports/session.py new file mode 100644 index 000000000..8992f7c0b --- /dev/null +++ b/ads/llm/autogen/reports/session.py @@ -0,0 +1,526 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +"""Module for building session report.""" +import copy +import json +import logging +from dataclasses import dataclass +from typing import List, Optional + +import fsspec +import pandas as pd +import plotly.express as px +import report_creator as rc + +from ads.common.auth import default_signer +from ads.llm.autogen.constants import Events +from ads.llm.autogen.reports.base import BaseReport +from ads.llm.autogen.reports.data import ( + AgentData, + LLMCompletionData, + LogRecord, + ToolCallData, +) +from ads.llm.autogen.reports.utils import escape_html, get_duration, is_json_string + +logger = logging.getLogger(__name__) + + +@dataclass +class AgentInvocation: + """Represents an agent invocation.""" + + log: LogRecord + header: str = "" + description: str = "" + duration: Optional[float] = None + + +class SessionReport(BaseReport): + """Class for building session report from session log file.""" + + def __init__(self, log_file: str, auth: Optional[dict] = None) -> None: + """Initialize the session report with log file. + It is assumed that the file contains logs for a single session. + + Parameters + ---------- + log_file : str + Path or URI of the log file. + auth : dict, optional + Authentication signer/config for OCI, by default None + """ + self.log_file: str = log_file + if self.log_file.startswith("oci://"): + auth = auth or default_signer() + with fsspec.open(self.log_file, mode="r", **auth) as f: + self.log_lines = f.readlines() + else: + with open(self.log_file, encoding="utf-8") as f: + self.log_lines = f.readlines() + self.logs: List[LogRecord] = self._parse_logs() + + # Parse logs to get entities for building the report + # Agents + self.agents: List[AgentData] = self._parse_agents() + self.managers: List[AgentData] = self._parse_managers() + # Events + self.start_event: LogRecord = self._parse_start_event() + self.session_id: str = self.start_event.session_id + self.llm_calls: List[AgentInvocation] = self._parse_llm_calls() + self.tool_calls: List[AgentInvocation] = self._parse_tool_calls() + self.invocations: List[AgentInvocation] = self._parse_invocations() + + self.received_message_logs = self._parse_received_messages() + + def _parse_logs(self) -> List[LogRecord]: + """Parses the logs form strings into LogRecord objects.""" + logs = [] + for i, log in enumerate(self.log_lines): + try: + logs.append(LogRecord.from_dict(json.loads(log))) + except Exception as e: + logger.error( + "Error when parsing log record at line %s:\n%s", str(i + 1), str(e) + ) + continue + # Sort the logs by timestamp + logs = sorted(logs, key=lambda x: x.timestamp) + return logs + + def _parse_agents(self) -> List[AgentData]: + """Parses the logs to identify unique agents. + AutoGen may have new_agent multiple times. + Here we identify the agents by the unique tuple of (name, module, class). + """ + new_agent_logs = self.filter_by_event(Events.NEW_AGENT) + agents = {} + for log in new_agent_logs: + agent: AgentData = log.data + agents[(agent.agent_name, agent.agent_module, agent.agent_class)] = agent + return list(agents.values()) + + def _parse_managers(self) -> List[AgentData]: + """Parses the logs to get chat managers.""" + managers = [] + for agent in self.agents: + if agent.is_manager: + managers.append(agent) + return managers + + def _parse_start_event(self) -> LogRecord: + """Parses the logs to get the first logging_session_start event log.""" + records = self.filter_by_event(event_name=Events.SESSION_START) + if not records: + raise ValueError("logging_session_start event is not found in the logs.") + records = sorted(records, key=lambda x: x.timestamp) + return records[0] + + def _parse_llm_calls(self) -> List[AgentInvocation]: + """Parses the logs to get the LLM calls.""" + records = self.filter_by_event(Events.LLM_CALL) + invocations = [] + for record in records: + log_data: LLMCompletionData = record.data + source_name = record.source_name + request = log_data.request + # If there is no request, the log is invalid. + if not request: + continue + + header = f"{source_name} invoking {request.get('model')}" + if log_data.is_cached: + header += " (Cached)" + invocations.append( + AgentInvocation( + header=header, + log=record, + duration=get_duration(log_data.start_time, log_data.end_time), + ) + ) + return invocations + + def _parse_tool_calls(self) -> List[AgentInvocation]: + """Parses the logs to get the tool calls.""" + records = self.filter_by_event(Events.TOOL_CALL) + invocations = [] + for record in records: + log_data: ToolCallData = record.data + source_name = record.source_name + invocations.append( + AgentInvocation( + log=record, + header=f"{source_name} invoking {log_data.tool_name}", + duration=get_duration(log_data.start_time, log_data.end_time), + ) + ) + return invocations + + def _parse_invocations(self) -> List[AgentInvocation]: + """Add numbering to the combined list of LLM and tool calls.""" + invocations = self.llm_calls + self.tool_calls + invocations = sorted(invocations, key=lambda x: x.log.data.start_time) + for i, invocation in enumerate(invocations): + invocation.header = f"{str(i + 1)} {invocation.header}" + return invocations + + def _parse_received_messages(self) -> List[LogRecord]: + """Parses the logs to get the received_message events.""" + managers = [manager.agent_name for manager in self.managers] + logs = self.filter_by_event(Events.RECEIVED_MESSAGE) + if not logs: + return [] + logs = sorted(logs, key=lambda x: x.timestamp) + logs = [log for log in logs if log.kwargs.get("sender") not in managers] + return logs + + def filter_by_event(self, event_name: str) -> List[LogRecord]: + """Filters the logs by event name. + + Parameters + ---------- + event_name : str + Name of the event. + + Returns + ------- + List[LogRecord] + A list of LogRecord objects for the event. + """ + filtered_logs = [] + for log in self.logs: + if log.event_name == event_name: + filtered_logs.append(log) + return filtered_logs + + def _build_flowchart(self): + """Builds the flowchart of agent chats.""" + senders = [] + for log in self.received_message_logs: + sender = log.kwargs.get("sender") + senders.append(sender) + + diagram_src = "graph LR\n" + prev_sender = None + links = [] + # Conversation Flow + for sender in senders: + if prev_sender is None: + link = f"START([START]) --> {sender}" + else: + link = f"{prev_sender} --> {sender}" + if link not in links: + links.append(link) + prev_sender = sender + links.append(f"{prev_sender} --> END([END])") + # Tool Calls + for invocation in self.tool_calls: + tool = invocation.log.data.tool_name + agent = invocation.log.data.agent_name + if tool and agent: + link = f"{agent} <--> {tool}[[{tool}]]" + if link not in links: + links.append(link) + + diagram_src += "\n".join(links) + return rc.Diagram(src=diagram_src, label="Flowchart") + + def _build_timeline_tab(self): + """Builds the plotly timeline chart.""" + if not self.invocations: + return rc.Text("No LLM or Tool Calls.", label="Timeline") + invocations = [] + for invocation in self.invocations: + invocations.append( + { + "start_time": invocation.log.data.start_time, + "end_time": invocation.log.data.end_time, + "header": invocation.header, + "duration": invocation.duration, + } + ) + df = pd.DataFrame(invocations) + fig = px.timeline( + df, + x_start="start_time", + x_end="end_time", + y="header", + labels={"header": "Invocation"}, + color="duration", + color_continuous_scale="rdylgn_r", + height=max(len(df.index) * 50, 500), + ) + fig.update_layout(showlegend=False) + fig.update_yaxes(autorange="reversed") + return rc.Block( + rc.Widget(fig, label="Timeline"), self._build_flowchart(), label="Timeline" + ) + + def _format_messages(self, messages: List[dict]): + """Formats the LLM call messages to be displayed in the report.""" + text = "" + for message in messages: + text += f"**{message.get('role')}**:\n{message.get('content')}\n\n" + return text + + def _build_llm_call(self, invocation: AgentInvocation): + """Builds the LLM call details.""" + log_data: LLMCompletionData = invocation.log.data + request = log_data.request + response = log_data.response + + start_date, start_time = self._parse_date_time(log_data.start_time) + + request_value = f"{str(len(request.get('messages')))} messages" + tools = request.get("tools", []) + if tools: + request_value += f", {str(len(tools))} tools" + + response_message = response.get("choices")[0].get("message") + response_text = response_message.get("content") or "" + tool_calls = response_message.get("tool_calls") + if tool_calls: + response_text += "\n\n**Tool Calls**:" + for tool_call in tool_calls: + func = tool_call.get("function") + response_text += f"\n\n`{func.get('name')}(**{func.get('arguments')})`" + + metrics = [ + rc.Metric(heading="Time", value=start_time, label=start_date), + rc.Metric( + heading="Messages", + value=len(request.get("messages", [])), + ), + rc.Metric(heading="Tools", value=len(tools)), + rc.Metric(heading="Duration", value=invocation.duration, unit="s"), + rc.Metric( + heading="Cached", + value="Yes" if log_data.is_cached else "No", + ), + rc.Metric(heading="Cost", value=log_data.cost), + ] + + usage = response.get("usage") + if isinstance(usage, dict): + for k, v in usage.items(): + if not v: + continue + metrics.append( + rc.Metric(heading=str(k).replace("_", " ").title(), value=v) + ) + + return rc.Block( + rc.Block(rc.Group(*metrics, label=invocation.header)), + rc.Group( + rc.Block( + rc.Markdown( + self._format_messages(request.get("messages")), label="Request" + ), + rc.Collapse( + rc.Json(request), + label="JSON", + ), + ), + rc.Block( + rc.Markdown(response_text, label="Response"), + rc.Collapse( + rc.Json(response), + label="JSON", + ), + ), + ), + ) + + def _build_tool_call(self, invocation: AgentInvocation): + """Builds the tool call details.""" + log_data: ToolCallData = invocation.log.data + request = log_data.to_dict() + response = request.pop("returns", {}) + + start_date, start_time = self._parse_date_time(log_data.start_time) + tool_call_args = log_data.input_args + if is_json_string(tool_call_args): + tool_call_args = self.format_json_string(tool_call_args) + + if is_json_string(response): + response = self.format_json_string(response) + + metrics = [ + rc.Metric(heading="Time", value=start_time, label=start_date), + rc.Metric(heading="Duration", value=invocation.duration, unit="s"), + ] + + return rc.Block( + rc.Block(rc.Group(*metrics, label=invocation.header)), + rc.Group( + rc.Block( + rc.Markdown( + (log_data.tool_name or "") + "\n\n" + tool_call_args, + label="Request", + ), + rc.Collapse( + rc.Json(request), + label="JSON", + ), + ), + rc.Block(rc.Text("", label="Response"), rc.Markdown(response)), + ), + ) + + def _build_invocations_tab(self) -> rc.Block: + """Builds the invocations tab.""" + blocks = [] + for invocation in self.invocations: + event_name = invocation.log.event_name + if event_name == Events.LLM_CALL: + blocks.append(self._build_llm_call(invocation)) + elif event_name == Events.TOOL_CALL: + blocks.append(self._build_tool_call(invocation)) + return rc.Block( + *blocks, + label="Invocations", + ) + + def _build_chat_tab(self) -> rc.Block: + """Builds the chat tab.""" + if not self.received_message_logs: + return rc.Text("No messages received in this session.", label="Chats") + # The agent sending the first message will be placed on the right. + # All other agents will be placed on the left + host = self.received_message_logs[0].kwargs.get("sender") + blocks = [] + + for log in self.received_message_logs: + context = copy.deepcopy(log.kwargs) + context.update(log.to_dict()) + sender = context.get("sender") + message = context.get("message", "") + # Content + if isinstance(message, dict) and "content" in message: + content = message.get("content", "") + if is_json_string(content): + context["json_content"] = json.dumps(json.loads(content), indent=2) + context["content"] = content + else: + context["content"] = message + if context["content"] is None: + context["content"] = "" + # Tool call + if isinstance(message, dict) and "tool_calls" in message: + tool_calls = message.get("tool_calls") + if tool_calls: + tool_call_signatures = [] + for tool_call in tool_calls: + func = tool_call.get("function") + if not func: + continue + tool_call_signatures.append( + f'{func.get("name")}(**{func.get("arguments", "{}")})' + ) + context["tool_calls"] = tool_call_signatures + if sender == host: + html = self._render_template("chat_box_rt.html", **context) + else: + html = self._render_template("chat_box_lt.html", **context) + blocks.append(rc.Html(html)) + + return rc.Block( + *blocks, + label="Chats", + ) + + def _build_logs_tab(self) -> rc.Block: + """Builds the logs tab.""" + blocks = [] + for log_line in self.log_lines: + if is_json_string(log_line): + log = json.loads(log_line) + label = log.get( + "event_name", self._preview_message(log.get("message", "")) + ) + blocks.append(rc.Collapse(rc.Json(escape_html(log)), label=label)) + else: + log = log_line + blocks.append( + rc.Collapse(rc.Text(log), label=self._preview_message(log_line)) + ) + + return rc.Block( + *blocks, + label="Logs", + ) + + def _build_errors_tab(self) -> Optional[rc.Block]: + """Builds the error tab to show exception.""" + errors = self.filter_by_event(Events.EXCEPTION) + if not errors: + return None + blocks = [] + for error in errors: + label = f'{error.kwargs.get("exc_type", "")} - {error.kwargs.get("exc_value", "")}' + variables: dict = error.kwargs.get("locals", {}) + table = "| Variable | Value |\n|---|---|\n" + table += "\n".join([f"| {k} | {v} |" for k, v in variables.items()]) + blocks += [ + rc.Unformatted(text=error.kwargs.get("traceback", ""), label=label), + rc.Markdown(table), + ] + return rc.Block(*blocks, label="Error") + + def build(self, output_file: str): + """Builds the session report. + + Parameters + ---------- + output_file : str + Local path or OCI object storage URI to save the report HTML file. + """ + + if not self.managers: + agent_label = "" + elif len(self.managers) == 1: + agent_label = "+1 chat manager" + else: + agent_label = f"+{str(len(self.managers))} chat managers" + + blocks = [ + self._build_timeline_tab(), + self._build_invocations_tab(), + self._build_chat_tab(), + self._build_logs_tab(), + ] + + error_block = self._build_errors_tab() + if error_block: + blocks.append(error_block) + + with rc.ReportCreator( + title=f"AutoGen Session: {self.session_id}", + description=f"Started at {self.start_event.timestamp}", + footer="Created with ❤️ by Oracle ADS", + ) as report: + + view = rc.Block( + rc.Group( + rc.Metric( + heading="Agents", + value=len(self.agents) - len(self.managers), + label=agent_label, + ), + rc.Metric( + heading="Events", + value=len(self.logs), + ), + rc.Metric( + heading="LLM Calls", + value=len(self.llm_calls), + ), + rc.Metric( + heading="Tool Calls", + value=len(self.tool_calls), + ), + ), + rc.Select(blocks=blocks), + ) + + report.save(view, output_file) diff --git a/ads/llm/autogen/reports/templates/chat_box.html b/ads/llm/autogen/reports/templates/chat_box.html new file mode 100644 index 000000000..62d792888 --- /dev/null +++ b/ads/llm/autogen/reports/templates/chat_box.html @@ -0,0 +1,13 @@ +

{{ sender }}
to {{ source_name }}

+

{{ timestamp }}

+
+{% if json_content %} +
{{ json_content }}
+{% else%} +

{{ content }}

+{% endif %} +{% if tool_calls %} +{% for tool_call in tool_calls %} +
{{ tool_call }}
+{% endfor %} +{% endif %} \ No newline at end of file diff --git a/ads/llm/autogen/reports/templates/chat_box_lt.html b/ads/llm/autogen/reports/templates/chat_box_lt.html new file mode 100644 index 000000000..da766bb1a --- /dev/null +++ b/ads/llm/autogen/reports/templates/chat_box_lt.html @@ -0,0 +1,5 @@ +
+
+ {% include "chat_box.html" %} +
+
\ No newline at end of file diff --git a/ads/llm/autogen/reports/templates/chat_box_rt.html b/ads/llm/autogen/reports/templates/chat_box_rt.html new file mode 100644 index 000000000..126c903a0 --- /dev/null +++ b/ads/llm/autogen/reports/templates/chat_box_rt.html @@ -0,0 +1,6 @@ +
+
+ {% include "chat_box.html" %} +
+
\ No newline at end of file diff --git a/ads/llm/autogen/reports/utils.py b/ads/llm/autogen/reports/utils.py new file mode 100644 index 000000000..baaacc315 --- /dev/null +++ b/ads/llm/autogen/reports/utils.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import html +import json +from datetime import datetime + + +def parse_datetime(s): + return datetime.strptime(s, "%Y-%m-%d %H:%M:%S.%f") + + +def get_duration(start_time: str, end_time: str) -> float: + """Gets the duration in seconds between `start_time` and `end_time`. + Each of the value should be a time in string format of + `%Y-%m-%d %H:%M:%S.%f` + + The duration is calculated by parsing the two strings, + then subtracting the `end_time` from `start_time`. + + If either `start_time` or `end_time` is not presented, + 0 will be returned. + + Parameters + ---------- + start_time : str + The start time. + end_time : str + The end time. + + Returns + ------- + float + Duration in seconds. + """ + if not start_time or not end_time: + return 0 + return (parse_datetime(end_time) - parse_datetime(start_time)).total_seconds() + + +def is_json_string(s): + """Checks if a string contains valid JSON.""" + try: + json.loads(s) + except Exception: + return False + return True + + +def escape_html(obj): + if isinstance(obj, dict): + return {k: escape_html(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [escape_html(v) for v in obj] + elif isinstance(obj, str): + return html.escape(obj) + return html.escape(str(obj)) diff --git a/ads/llm/autogen/v02/__init__.py b/ads/llm/autogen/v02/__init__.py new file mode 100644 index 000000000..83f271279 --- /dev/null +++ b/ads/llm/autogen/v02/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from ads.llm.autogen.v02.client import LangChainModelClient, register_custom_client diff --git a/ads/llm/autogen/client_v02.py b/ads/llm/autogen/v02/client.py similarity index 92% rename from ads/llm/autogen/client_v02.py rename to ads/llm/autogen/v02/client.py index 8dd9b6c9e..10e7b02ab 100644 --- a/ads/llm/autogen/client_v02.py +++ b/ads/llm/autogen/v02/client.py @@ -1,6 +1,5 @@ -# coding: utf-8 -# Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved. -# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license. +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ """This module contains the custom LLM client for AutoGen v0.2 to use LangChain chat models. https://microsoft.github.io/autogen/0.2/blog/2024/01/26/Custom-Models/ @@ -72,14 +71,14 @@ import importlib import json import logging -from typing import Any, Dict, List, Union +from dataclasses import asdict, dataclass from types import SimpleNamespace +from typing import Any, Dict, List, Union from autogen import ModelClient from autogen.oai.client import OpenAIWrapper, PlaceHolderClient from langchain_core.messages import AIMessage - logger = logging.getLogger(__name__) # custom_clients is a dictionary mapping the name of the class to the actual class @@ -177,6 +176,13 @@ def function_call(self): return self.tool_calls +@dataclass +class Usage: + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + class LangChainModelClient(ModelClient): """Represents a model client wrapping a LangChain chat model.""" @@ -202,8 +208,8 @@ def __init__(self, config: dict, **kwargs) -> None: # Import the LangChain class if "langchain_cls" not in config: raise ValueError("Missing langchain_cls in LangChain Model Client config.") - module_cls = config.pop("langchain_cls") - module_name, cls_name = str(module_cls).rsplit(".", 1) + self.langchain_cls = config.pop("langchain_cls") + module_name, cls_name = str(self.langchain_cls).rsplit(".", 1) langchain_module = importlib.import_module(module_name) langchain_cls = getattr(langchain_module, cls_name) @@ -232,7 +238,14 @@ def create(self, params) -> ModelClient.ModelClientResponseProtocol: streaming = params.get("stream", False) # TODO: num_of_responses num_of_responses = params.get("n", 1) - messages = params.pop("messages", []) + + messages = copy.deepcopy(params.get("messages", [])) + + # OCI Gen AI does not allow empty message. + if str(self.langchain_cls).endswith("oci_generative_ai.ChatOCIGenAI"): + for message in messages: + if len(message.get("content", "")) == 0: + message["content"] = " " invoke_params = copy.deepcopy(self.invoke_params) @@ -241,7 +254,6 @@ def create(self, params) -> ModelClient.ModelClientResponseProtocol: model = self.model.bind_tools( [_convert_to_langchain_tool(tool) for tool in tools] ) - # invoke_params["tools"] = tools invoke_params.update(self.function_call_params) else: model = self.model @@ -249,6 +261,7 @@ def create(self, params) -> ModelClient.ModelClientResponseProtocol: response = SimpleNamespace() response.choices = [] response.model = self.model_name + response.usage = Usage() if streaming and messages: # If streaming is enabled and has messages, then iterate over the chunks of the response. @@ -279,4 +292,4 @@ def cost(self, response: ModelClient.ModelClientResponseProtocol) -> float: @staticmethod def get_usage(response: ModelClient.ModelClientResponseProtocol) -> Dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" - return {} + return asdict(response.usage) diff --git a/ads/llm/autogen/v02/log_handlers/__init__.py b/ads/llm/autogen/v02/log_handlers/__init__.py new file mode 100644 index 000000000..72e03c615 --- /dev/null +++ b/ads/llm/autogen/v02/log_handlers/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/ads/llm/autogen/v02/log_handlers/oci_file_handler.py b/ads/llm/autogen/v02/log_handlers/oci_file_handler.py new file mode 100644 index 000000000..0a1713749 --- /dev/null +++ b/ads/llm/autogen/v02/log_handlers/oci_file_handler.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import io +import json +import logging +import os +import threading + +import fsspec + +from ads.common.auth import default_signer + +logger = logging.getLogger(__name__) + + +class OCIFileHandler(logging.FileHandler): + """Log handler for saving log file to OCI object storage.""" + + def __init__( + self, + filename: str, + session_id: str, + mode: str = "a", + encoding: str | None = None, + delay: bool = False, + errors: str | None = None, + auth: dict | None = None, + ) -> None: + self.session_id = session_id + self.auth = auth + + if filename.startswith("oci://"): + self.baseFilename = filename + else: + self.baseFilename = os.path.abspath(os.path.expanduser(filename)) + os.makedirs(os.path.dirname(self.baseFilename), exist_ok=True) + + # The following code are from the `FileHandler.__init__()` + self.mode = mode + self.encoding = encoding + if "b" not in mode: + self.encoding = io.text_encoding(encoding) + self.errors = errors + self.delay = delay + + if delay: + # We don't open the stream, but we still need to call the + # Handler constructor to set level, formatter, lock etc. + logging.Handler.__init__(self) + self.stream = None + else: + logging.StreamHandler.__init__(self, self._open()) + + def _open(self): + """ + Open the current base file with the (original) mode and encoding. + Return the resulting stream. + """ + auth = self.auth or default_signer() + return fsspec.open( + self.baseFilename, + self.mode, + encoding=self.encoding, + errors=self.errors, + **auth, + ).open() + + def format(self, record: logging.LogRecord): + """Formats the log record as JSON payload and add session_id.""" + msg = record.getMessage() + try: + data = json.loads(msg) + except Exception as e: + data = {"message": msg} + + if "session_id" not in data: + data["session_id"] = self.session_id + if "thread_id" not in data: + data["thread_id"] = threading.get_ident() + + record.msg = json.dumps(data) + return super().format(record) + diff --git a/ads/llm/autogen/v02/loggers/__init__.py b/ads/llm/autogen/v02/loggers/__init__.py new file mode 100644 index 000000000..15635dc09 --- /dev/null +++ b/ads/llm/autogen/v02/loggers/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from ads.llm.autogen.v02.loggers.metric_logger import MetricLogger +from ads.llm.autogen.v02.loggers.session_logger import SessionLogger diff --git a/ads/llm/autogen/v02/loggers/metric_logger.py b/ads/llm/autogen/v02/loggers/metric_logger.py new file mode 100644 index 000000000..886089568 --- /dev/null +++ b/ads/llm/autogen/v02/loggers/metric_logger.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +import oci +from autogen import Agent, ConversableAgent, OpenAIWrapper +from autogen.logger.base_logger import BaseLogger, LLMConfig +from autogen.logger.logger_utils import get_current_ts +from oci.monitoring import MonitoringClient +from pydantic import BaseModel, Field + +import ads +import ads.config +from ads.llm.autogen.v02.loggers.utils import serialize_response + +logger = logging.getLogger(__name__) + + +class MetricName: + """Constants for metric name.""" + + TOOL_CALL = "tool_call" + CHAT_COMPLETION = "chat_completion_count" + COST = "chat_completion_cost" + SESSION_START = "session_start" + SESSION_STOP = "session_stop" + + +class MetricDimension: + """Constants for metric dimension.""" + + AGENT_NAME = "agent_name" + APP_NAME = "app_name" + MODEL = "model" + SESSION_ID = "session_id" + TOOL_NAME = "tool_name" + + +class Metric(BaseModel): + """Represents the metric to be logged.""" + + name: str + value: float + timestamp: str + dimensions: dict = Field(default_factory=dict) + + +class MetricLogger(BaseLogger): + """AutoGen logger for agent metrics.""" + + def __init__( + self, + namespace: str, + app_name: Optional[str] = None, + compartment_id: Optional[str] = None, + session_id: Optional[str] = None, + region: Optional[str] = None, + resource_group: Optional[str] = None, + log_agent_name: bool = False, + log_tool_name: bool = False, + log_model_name: bool = False, + ): + """Initialize the metric logger. + + Parameters + ---------- + namespace : str + Namespace for posting the metric + app_name : str + Application name, which will be a metric dimension if specified. + compartment_id : str, optional + Compartment OCID for posting the metric. + If compartment_id is not specified, + ADS will try to fetch the compartment OCID from environment variable. + session_id : str, optional + Session ID to be saved as a metric dimension, by default None. + region : str, optional + OCI region for posting the metric, by default None. + If region is not specified, the region from the authentication signer will be used. + resource_group : str, optional + Resource group for the metric, by default None + log_agent_name : bool, optional + Whether to log agent name as a metric dimension, by default True. + log_tool_name : bool, optional + Whether to log tool name as a metric dimension, by default True. + log_model_name : bool, optional + Whether to log model name as a metric dimension, by default True. + + """ + self.app_name = app_name + self.session_id = session_id + self.compartment_id = compartment_id or ads.config.COMPARTMENT_OCID + if not self.compartment_id: + raise ValueError( + "Unable to determine compartment OCID for metric logger." + "Please specify the compartment_id." + ) + self.namespace = namespace + self.resource_group = resource_group + self.log_agent_name = log_agent_name + self.log_tool_name = log_tool_name + self.log_model_name = log_model_name + # Indicate if the logger has started. + self.started = False + + auth = ads.auth.default_signer() + + # Use the config/signer to determine the region if it not specified. + signer = auth.get("signer") + config = auth.get("config", {}) + if not region: + if hasattr(signer, "region") and signer.region: + region = signer.region + elif config.get("region"): + region = config.get("region") + else: + raise ValueError( + "Unable to determine the region for OCI monitoring service. " + "Please specify the region using the `region` argument." + ) + + self.monitoring_client = MonitoringClient( + config=config, + signer=signer, + # Metrics should be submitted with the "telemetry-ingestion" endpoint instead. + # See note here: https://docs.oracle.com/iaas/api/#/en/monitoring/20180401/MetricData/PostMetricData + service_endpoint=f"https://telemetry-ingestion.{region}.oraclecloud.com", + ) + + def _post_metric(self, metric: Metric): + """Posts metric to OCI monitoring.""" + # Add app_name and session_id to dimensions + dimensions = metric.dimensions + if self.app_name: + dimensions[MetricDimension.APP_NAME] = self.app_name + if self.session_id: + dimensions[MetricDimension.SESSION_ID] = self.session_id + + logger.debug("Posting metrics:\n%s", str(metric)) + self.monitoring_client.post_metric_data( + post_metric_data_details=oci.monitoring.models.PostMetricDataDetails( + metric_data=[ + oci.monitoring.models.MetricDataDetails( + namespace=self.namespace, + compartment_id=self.compartment_id, + name=metric.name, + dimensions=dimensions, + datapoints=[ + oci.monitoring.models.Datapoint( + timestamp=datetime.strptime( + metric.timestamp.replace(" ", "T") + "Z", + "%Y-%m-%dT%H:%M:%S.%fZ", + ), + value=metric.value, + count=1, + ) + ], + resource_group=self.resource_group, + ) + ], + batch_atomicity="ATOMIC", + ), + ) + + def start(self): + """Starts the logger.""" + if self.session_id: + logger.info(f"Starting metric logging for session_id: {self.session_id}") + else: + logger.info("Starting metric logging.") + self.started = True + try: + metric = Metric( + name=MetricName.SESSION_START, + value=1, + timestamp=get_current_ts(), + ) + self._post_metric(metric=metric) + except Exception as e: + logger.error(f"MetricLogger Failed to log session start: {str(e)}") + return self.session_id + + def log_new_agent( + self, agent: ConversableAgent, init_args: Dict[str, Any] = {} + ) -> None: + """Metric logger does not log new agent.""" + pass + + def log_function_use( + self, + source: Union[str, Agent], + function: Any, + args: Dict[str, Any], + returns: Any, + ) -> None: + """ + Log a registered function(can be a tool) use from an agent or a string source. + """ + if not self.started: + return + agent_name = str(source.name) if hasattr(source, "name") else source + dimensions = {} + if self.log_tool_name: + dimensions[MetricDimension.TOOL_NAME] = function.__name__ + if self.log_agent_name: + dimensions[MetricDimension.AGENT_NAME] = agent_name + try: + self._post_metric( + Metric( + name=MetricName.TOOL_CALL, + value=1, + timestamp=get_current_ts(), + dimensions=dimensions, + ) + ) + except Exception as e: + logger.error(f"MetricLogger Failed to log tool call: {str(e)}") + + def log_chat_completion( + self, + invocation_id: UUID, + client_id: int, + wrapper_id: int, + source: Union[str, Agent], + request: Dict[str, Union[float, str, List[Dict[str, str]]]], + response: Union[str, Any], + is_cached: int, + cost: float, + start_time: str, + ) -> None: + """ + Log a chat completion. + """ + if not self.started: + return + + try: + response: dict = serialize_response(response) + if "usage" not in response or not isinstance(response["usage"], dict): + return + # Post usage metric + agent_name = str(source.name) if hasattr(source, "name") else source + model = response.get("model", "N/A") + dimensions = {} + if self.log_model_name: + dimensions[MetricDimension.MODEL] = model + if self.log_agent_name: + dimensions[MetricDimension.AGENT_NAME] = agent_name + + # Chat completion count + self._post_metric( + Metric( + name=MetricName.CHAT_COMPLETION, + value=1, + timestamp=get_current_ts(), + dimensions=dimensions, + ) + ) + # Cost + if cost: + self._post_metric( + Metric( + name=MetricName.COST, + value=cost, + timestamp=get_current_ts(), + dimensions=dimensions, + ) + ) + # Usage + for key, val in response["usage"].items(): + self._post_metric( + Metric( + name=key, + value=val, + timestamp=get_current_ts(), + dimensions=dimensions, + ) + ) + + except Exception as e: + logger.error(f"MetricLogger Failed to log chat completion: {str(e)}") + + def log_new_wrapper( + self, + wrapper: OpenAIWrapper, + init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]] = {}, + ) -> None: + """Metric logger does not log new wrapper.""" + pass + + def log_new_client(self, client, wrapper, init_args): + """Metric logger does not log new client.""" + pass + + def log_event(self, source, name, **kwargs): + """Metric logger does not log events.""" + pass + + def get_connection(self): + pass + + def stop(self): + """Stops the metric logger.""" + if not self.started: + return + self.started = False + try: + metric = Metric( + name=MetricName.SESSION_STOP, + value=1, + timestamp=get_current_ts(), + ) + self._post_metric(metric=metric) + except Exception as e: + logger.error(f"MetricLogger Failed to log session stop: {str(e)}") + logger.info("Metric logger stopped.") diff --git a/ads/llm/autogen/v02/loggers/session_logger.py b/ads/llm/autogen/v02/loggers/session_logger.py new file mode 100644 index 000000000..9ed21982f --- /dev/null +++ b/ads/llm/autogen/v02/loggers/session_logger.py @@ -0,0 +1,580 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import importlib +import logging +import os +import tempfile +import threading +import traceback +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse + +import autogen +import fsspec +import oci +from autogen import Agent, ConversableAgent, GroupChatManager, OpenAIWrapper +from autogen.logger.file_logger import ( + ChatCompletion, + F, + FileLogger, + get_current_ts, + safe_serialize, +) +from oci.object_storage import ObjectStorageClient +from oci.object_storage.models import ( + CreatePreauthenticatedRequestDetails, + PreauthenticatedRequest, +) + +import ads +from ads.common.auth import default_signer +from ads.llm.autogen.constants import Events +from ads.llm.autogen.reports.data import ( + AgentData, + LLMCompletionData, + LogRecord, + ToolCallData, +) +from ads.llm.autogen.reports.session import SessionReport +from ads.llm.autogen.v02 import runtime_logging +from ads.llm.autogen.v02.log_handlers.oci_file_handler import OCIFileHandler +from ads.llm.autogen.v02.loggers.utils import ( + serialize, + serialize_response, +) + +logger = logging.getLogger(__name__) + + +CONST_REPLY_FUNC_NAME = "reply_func_name" + + +@dataclass +class LoggingSession: + """Represents a logging session for a specific thread.""" + + session_id: str + log_dir: str + log_file: str + thread_id: int + pid: int + logger: logging.Logger + auth: dict = field(default_factory=dict) + report_file: Optional[str] = None + par_uri: Optional[str] = None + + @property + def report(self) -> str: + """HTML report path of the logging session. + If the a pre-authenticated link is generated for the report, + the pre-authenticated link will be returned. + + If the report is saved to OCI object storage, the URI will be return. + If the report is saved locally, the local path will be return. + If there is no report generated, `None` will be returned. + """ + if self.par_uri: + return self.par_uri + elif self.report_file: + return self.report_file + return None + + def __repr__(self) -> str: + """Shows the link to report if it is available, otherwise shows the log file link.""" + if self.report: + return self.report + return self.log_file + + def create_par_uri(self, oci_file: str, **kwargs) -> str: + """Creates a pre-authenticated request URI for a file on OCI object storage. + + Parameters + ---------- + oci_file : str + OCI file URI in the format of oci://bucket@namespace/prefix/to/file + auth : dict, optional + Dictionary containing the OCI authentication config and signer. + Defaults to `ads.common.auth.default_signer()`. + + Returns + ------- + str + The pre-authenticated URI + """ + auth = self.auth or default_signer() + client = ObjectStorageClient(**auth) + parsed = urlparse(oci_file) + bucket = parsed.username + namespace = parsed.hostname + time_expires = kwargs.pop( + "time_expires", datetime.now(timezone.utc) + timedelta(weeks=1) + ) + access_type = kwargs.pop("access_type", "ObjectRead") + response: PreauthenticatedRequest = client.create_preauthenticated_request( + bucket_name=bucket, + namespace_name=namespace, + create_preauthenticated_request_details=CreatePreauthenticatedRequestDetails( + name=os.path.basename(oci_file), + object_name=str(parsed.path).lstrip("/"), + access_type=access_type, + time_expires=time_expires, + **kwargs, + ), + ).data + return response.full_path + + def create_report( + self, report_file: str, return_par_uri: bool = False, **kwargs + ) -> str: + """Creates a report in HTML format. + + Parameters + ---------- + report_file : str + The file path to save the report. + return_par_uri : bool, optional + If the report is saved in object storage, + whether to create a pre-authenticated link for the report, by default False. + This will be ignored if the report is not saved in object storage. + + Returns + ------- + str + The full path or pre-authenticated link of the report. + """ + auth = self.auth or default_signer() + report = SessionReport(log_file=self.log_file, auth=auth) + if report_file.startswith("oci://"): + with tempfile.TemporaryDirectory() as temp_dir: + # Save the report to local temp dir + temp_report = os.path.join(temp_dir, os.path.basename(report_file)) + report.build(temp_report) + # Upload to OCI object storage + fs = fsspec.filesystem("oci", **auth) + fs.put(temp_report, report_file) + if return_par_uri: + par_uri = self.create_par_uri(oci_file=report_file, **kwargs) + self.report_file = report_file + self.par_uri = par_uri + return par_uri + else: + report_file = os.path.abspath(os.path.expanduser(report_file)) + os.makedirs(os.path.dirname(report_file), exist_ok=True) + report.build(report_file) + self.report_file = report_file + return report_file + + +class SessionLogger(FileLogger): + """Logger for saving log file to OCI object storage.""" + + def __init__( + self, + log_dir: str, + report_dir: Optional[str] = None, + session_id: Optional[str] = None, + auth: Optional[dict] = None, + log_for_all_threads: str = False, + report_par_uri: bool = False, + par_kwargs: Optional[dict] = None, + ): + """Initialize a file logger for new session. + + Parameters + ---------- + log_dir : str + Directory for saving the log file. + session_id : str, optional + Session ID, by default None. + If the session ID is None, a new UUID4 will be generated. + The session ID will be used as the log filename. + auth: dict, optional + Dictionary containing the OCI authentication config and signer. + If auth is None, `ads.common.auth.default_signer()` will be used. + log_for_all_threads: + Indicate if the logger should handle logging for all threads. + Defaults to False, the logger will only log for the current thread. + """ + self.report_dir = report_dir + self.report_par_uri = report_par_uri + self.par_kwargs = par_kwargs + self.log_for_all_threads = log_for_all_threads + + self.session = self.new_session( + log_dir=log_dir, session_id=session_id, auth=auth + ) + # Log only if started is True + self.started = False + + # Keep track of last check_termination_and_human_reply for calculating tool call duration + # This will be a dictionary mapping the IDs of the agents to their last timestamp + # of check_termination_and_human_reply + self.last_agent_checks = {} + + @property + def logger(self) -> Optional[logging.Logger]: + """Logger for the thread. + + This property is used to determine whether the log should be saved. + No log will be saved if the logger is None. + """ + if not self.started: + return None + thread_id = threading.get_ident() + if not self.log_for_all_threads and thread_id != self.session.thread_id: + return None + return self.session.logger + + @property + def session_id(self) -> Optional[str]: + """Session ID for the current session.""" + return self.session.session_id + + @property + def log_file(self) -> Optional[str]: + """Log file path for the current session.""" + return self.session.log_file + + @property + def report(self) -> Optional[str]: + """Report path/link for the session.""" + return self.session.report + + @property + def name(self) -> str: + """Name of the logger.""" + return self.session_id or "oci_file_logger" + + def new_session( + self, + log_dir: str, + session_id: Optional[str] = None, + auth: Optional[dict] = None, + ) -> LoggingSession: + """Creates a new logging session. + + Parameters + ---------- + log_dir : str + Directory for saving the log file. + session_id : str, optional + Session ID, by default None. + If the session ID is None, a new UUID4 will be generated. + The session ID will be used as the log filename. + auth: dict, optional + Dictionary containing the OCI authentication config and signer. + If auth is None, `ads.common.auth.default_signer()` will be used. + + + Returns + ------- + LoggingSession + The new logging session + """ + thread_id = threading.get_ident() + + session_id = str(session_id or uuid.uuid4()) + log_file = os.path.join(log_dir, f"{session_id}.log") + + # Prepare the logger + session_logger = logging.getLogger(session_id) + session_logger.setLevel(logging.INFO) + file_handler = OCIFileHandler(log_file, session_id=session_id, auth=auth) + session_logger.addHandler(file_handler) + + # Create logging session + session = LoggingSession( + session_id=session_id, + log_dir=log_dir, + log_file=log_file, + thread_id=thread_id, + pid=os.getpid(), + logger=session_logger, + auth=auth, + ) + + logger.info("Start logging session %s to file %s", session_id, log_file) + return session + + def generate_report( + self, + report_dir: Optional[str] = None, + report_par_uri: Optional[bool] = None, + **kwargs, + ) -> str: + """Generates a report for the session. + + Parameters + ---------- + report_dir : str, optional + Directory for saving the report, by default None + report_par_uri : bool, optional + Whether to create a pre-authenticated link for the report, by default None. + If the `report_par_uri` is not set, the value of `self.report_par_uri` will be used. + + Returns + ------- + str + The link to the report. + If the `report_dir` is local, the local file path will be returned. + If a pre-authenticated link is created, the link will be returned. + """ + report_dir = report_dir or self.report_dir + report_par_uri = ( + report_par_uri if report_par_uri is not None else self.report_par_uri + ) + kwargs = kwargs or self.par_kwargs or {} + + report_file = os.path.join(self.report_dir, f"{self.session_id}.html") + report_link = self.session.create_report( + report_file=report_file, return_par_uri=self.report_par_uri, **kwargs + ) + print(f"ADS AutoGen Session Report: {report_link}") + return report_link + + def new_record(self, event_name: str, source: Any = None) -> LogRecord: + """Initialize a new log record. + + The record is not logged until `self.log()` is called. + """ + record = LogRecord( + session_id=self.session_id, + thread_id=threading.get_ident(), + timestamp=get_current_ts(), + event_name=event_name, + ) + if source: + record.source_id = id(source) + record.source_name = str(source.name) if hasattr(source, "name") else source + return record + + def log(self, record: LogRecord) -> None: + """Logs a record. + + Parameters + ---------- + data : dict + Data to be logged. + """ + # Do nothing if there is no logger for the thread. + if not self.logger: + return + + try: + self.logger.info(record.to_string()) + except Exception: + self.logger.info("Failed to log %s", record.event_name) + + def start(self) -> str: + """Start the logging session and return the session_id.""" + envs = { + "oracle-ads": ads.__version__, + "oci": oci.__version__, + "autogen": autogen.__version__, + } + libraries = [ + "langchain", + "langchain-core", + "langchain-community", + "langchain-openai", + "openai", + ] + for library in libraries: + try: + imported_library = importlib.import_module(library) + version = imported_library.__version__ + envs[library] = version + except Exception: + pass + self.started = True + self.log_event(source=self, name=Events.SESSION_START, environment=envs) + return self.session_id + + def stop(self) -> None: + """Stops the logging session.""" + self.log_event(source=self, name=Events.SESSION_STOP) + super().stop() + self.started = False + if self.report_dir: + try: + self.generate_report() + except Exception as e: + logger.error( + "Failed to create session report for AutoGen session %s\n%s", + self.session_id, + str(e), + ) + logger.debug(traceback.format_exc()) + + def log_chat_completion( + self, + invocation_id: uuid.UUID, + client_id: int, + wrapper_id: int, + source: Union[str, Agent], + request: Dict[str, Union[float, str, List[Dict[str, str]]]], + response: Union[str, ChatCompletion], + is_cached: int, + cost: float, + start_time: str, + ) -> None: + """ + Logs a chat completion. + """ + if not self.logger: + return + + record = self.new_record(event_name=Events.LLM_CALL, source=source) + record.data = LLMCompletionData( + invocation_id=str(invocation_id), + request=serialize(request), + response=serialize_response(response), + start_time=start_time, + end_time=get_current_ts(), + cost=cost, + is_cached=is_cached, + ) + record.kwargs = { + "client_id": client_id, + "wrapper_id": wrapper_id, + } + + self.log(record) + + def log_function_use( + self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any + ) -> None: + """ + Logs a registered function(can be a tool) use from an agent or a string source. + """ + if not self.logger: + return + + source_id = id(source) + if source_id in self.last_agent_checks: + start_time = self.last_agent_checks[source_id] + else: + start_time = get_current_ts() + + record = self.new_record(Events.TOOL_CALL, source=source) + record.data = ToolCallData( + tool_name=function.__name__, + start_time=start_time, + end_time=record.timestamp, + agent_name=str(source.name) if hasattr(source, "name") else source, + agent_module=source.__module__, + agent_class=source.__class__.__name__, + input_args=safe_serialize(args), + returns=safe_serialize(returns), + ) + + self.log(record) + + def log_new_agent( + self, agent: ConversableAgent, init_args: Dict[str, Any] = {} + ) -> None: + """ + Logs a new agent instance. + """ + if not self.logger: + return + + record = self.new_record(event_name=Events.NEW_AGENT, source=agent) + record.data = AgentData( + agent_name=( + agent.name + if hasattr(agent, "name") and agent.name is not None + else str(agent) + ), + agent_module=agent.__module__, + agent_class=agent.__class__.__name__, + is_manager=isinstance(agent, GroupChatManager), + ) + record.kwargs = { + "wrapper_id": serialize( + agent.client.wrapper_id + if hasattr(agent, "client") and agent.client is not None + else "" + ), + "args": serialize(init_args), + } + self.log(record) + + def log_event( + self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any] + ) -> None: + """ + Logs an event. + """ + record = self.new_record(event_name=name) + record.source_id = id(source) + record.source_name = str(source.name) if hasattr(source, "name") else source + record.kwargs = kwargs + if isinstance(source, Agent): + if ( + CONST_REPLY_FUNC_NAME in kwargs + and kwargs[CONST_REPLY_FUNC_NAME] == "check_termination_and_human_reply" + ): + self.last_agent_checks[record.source_id] = record.timestamp + record.data = AgentData( + agent_name=record.source_name, + agent_module=source.__module__, + agent_class=source.__class__.__name__, + is_manager=isinstance(source, GroupChatManager), + ) + self.log(record) + + def log_new_wrapper(self, *args, **kwargs) -> None: + # Do not log new wrapper. + # This is not used at the moment. + return + + def log_new_client( + self, + client, + wrapper: OpenAIWrapper, + init_args: Dict[str, Any], + ) -> None: + if not self.logger: + return + + record = self.new_record(event_name=Events.NEW_CLIENT) + # init_args may contain credentials like api_key + record.kwargs = { + "client_id": id(client), + "wrapper_id": id(wrapper), + "class": client.__class__.__name__, + "args": serialize(init_args), + } + + self.log(record) + + def __repr__(self) -> str: + return self.session.__repr__() + + def __enter__(self) -> "SessionLogger": + """Starts the session logger + + Returns + ------- + SessionLogger + The session logger + """ + runtime_logging.start(self) + return self + + def __exit__(self, exc_type, exc_value, tb): + """Stops the session logger.""" + if exc_type: + record = self.new_record(event_name=Events.EXCEPTION) + record.kwargs = { + "exc_type": exc_type.__name__, + "exc_value": str(exc_value), + "traceback": "".join(traceback.format_tb(tb)), + "locals": serialize(tb.tb_frame.f_locals), + } + self.log(record) + runtime_logging.stop(self) diff --git a/ads/llm/autogen/v02/loggers/utils.py b/ads/llm/autogen/v02/loggers/utils.py new file mode 100644 index 000000000..247e11e4b --- /dev/null +++ b/ads/llm/autogen/v02/loggers/utils.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import inspect +import json +from types import SimpleNamespace +from typing import Any, Dict, List, Tuple, Union + + +def is_json_serializable(obj: Any) -> bool: + """Checks if an object is JSON serializable. + + Parameters + ---------- + obj : Any + Any object. + + Returns + ------- + bool + True if the object is JSON serializable, otherwise False. + """ + try: + json.dumps(obj) + except Exception: + return False + return True + + +def serialize_response(response) -> dict: + """Serializes the LLM response to dictionary.""" + if isinstance(response, SimpleNamespace) or is_json_serializable(response): + # Convert simpleNamespace to dict + return json.loads(json.dumps(response, default=vars)) + elif hasattr(response, "dict") and callable(response.dict): + return json.loads(json.dumps(response.dict(), default=str)) + elif hasattr(response, "model") and hasattr(response, "choices"): + return { + "model": response.model, + "choices": [ + {"message": {"content": choice.message.content}} + for choice in response.choices + ], + "response": str(response), + } + return { + "model": "", + "choices": [{"message": {"content": response}}], + "response": str(response), + } + + +def serialize( + obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any], + exclude: Tuple[str, ...] = ("api_key", "__class__"), + no_recursive: Tuple[Any, ...] = (), +) -> Any: + """Serializes an object for logging purpose.""" + try: + if isinstance(obj, (int, float, str, bool)): + return obj + elif callable(obj): + return inspect.getsource(obj).strip() + elif isinstance(obj, dict): + return { + str(k): ( + serialize(str(v)) + if isinstance(v, no_recursive) + else serialize(v, exclude, no_recursive) + ) + for k, v in obj.items() + if k not in exclude + } + elif isinstance(obj, (list, tuple)): + return [ + ( + serialize(str(v)) + if isinstance(v, no_recursive) + else serialize(v, exclude, no_recursive) + ) + for v in obj + ] + else: + return str(obj) + except Exception: + return str(obj) diff --git a/ads/llm/autogen/v02/runtime_logging.py b/ads/llm/autogen/v02/runtime_logging.py new file mode 100644 index 000000000..7d65bdb12 --- /dev/null +++ b/ads/llm/autogen/v02/runtime_logging.py @@ -0,0 +1,163 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import logging +import traceback +from sqlite3 import Connection +from typing import Any, Dict, List, Optional + +import autogen.runtime_logging +from autogen.logger.base_logger import BaseLogger +from autogen.logger.logger_factory import LoggerFactory + +logger = logging.getLogger(__name__) + + +class LoggerManager(BaseLogger): + """Manages multiple AutoGen loggers.""" + + def __init__(self) -> None: + self.loggers: List[BaseLogger] = [] + super().__init__() + + def add_logger(self, logger: BaseLogger) -> None: + """Adds a new AutoGen logger.""" + self.loggers.append(logger) + + def _call_loggers(self, method: str, *args, **kwargs) -> None: + """Calls the specific method on each AutoGen logger in self.loggers.""" + for autogen_logger in self.loggers: + try: + getattr(autogen_logger, method)(*args, **kwargs) + except Exception as e: + # Catch the logging exception so that the program will not be interrupted. + logger.error( + "Failed to %s with %s: %s", + method, + autogen_logger.__class__.__name__, + str(e), + ) + logger.debug(traceback.format_exc()) + + def start(self) -> str: + """Starts all loggers.""" + return self._call_loggers("start") + + def stop(self) -> None: + self._call_loggers("stop") + # Remove the loggers once they are stopped. + self.loggers = [] + + def get_connection(self) -> None | Connection: + return self._call_loggers("get_connection") + + def log_chat_completion(self, *args, **kwargs) -> None: + return self._call_loggers("log_chat_completion", *args, **kwargs) + + def log_new_agent(self, *args, **kwargs) -> None: + return self._call_loggers("log_new_agent", *args, **kwargs) + + def log_event(self, *args, **kwargs) -> None: + return self._call_loggers("log_event", *args, **kwargs) + + def log_new_wrapper(self, *args, **kwargs) -> None: + return self._call_loggers("log_new_wrapper", *args, **kwargs) + + def log_new_client(self, *args, **kwargs) -> None: + return self._call_loggers("log_new_client", *args, **kwargs) + + def log_function_use(self, *args, **kwargs) -> None: + return self._call_loggers("log_function_use", *args, **kwargs) + + def __repr__(self) -> str: + return "\n\n".join( + [ + f"{str(logger.__class__)}:\n{logger.__repr__()}" + for logger in self.loggers + ] + ) + + +def start( + autogen_logger: Optional[BaseLogger] = None, + logger_type: str = None, + config: Optional[Dict[str, Any]] = None, +) -> str: + """Starts logging with AutoGen logger. + Specify your custom autogen_logger, or the logger_type and config to use a built-in logger. + + Parameters + ---------- + autogen_logger : BaseLogger, optional + An AutoGen logger, which should be a subclass of autogen.logger.base_logger.BaseLogger. + logger_type : str, optional + Logger type, which can be a built-in AutoGen logger type ("file", or "sqlite"), by default None. + config : dict, optional + Configurations for the built-in AutoGen logger, by default None + + Returns + ------- + str + A unique session ID returned from starting the logger. + + """ + if autogen_logger and logger_type: + raise ValueError( + "Please specify only autogen_logger(%s) or logger_type(%s).", + autogen_logger, + logger_type, + ) + + # Check if a logger is already configured + existing_logger = autogen.runtime_logging.autogen_logger + if not existing_logger: + # No logger is configured + logger_manager = LoggerManager() + elif isinstance(existing_logger, LoggerManager): + # Logger is already configured with ADS + logger_manager = existing_logger + else: + # Logger is configured but it is not via ADS + logger.warning("AutoGen is already configured with %s", str(existing_logger)) + logger_manager = LoggerManager() + logger_manager.add_logger(existing_logger) + + # Add AutoGen logger + if not autogen_logger: + autogen_logger = LoggerFactory.get_logger( + logger_type=logger_type, config=config + ) + logger_manager.add_logger(autogen_logger) + + try: + session_id = autogen_logger.start() + autogen.runtime_logging.is_logging = True + autogen.runtime_logging.autogen_logger = logger_manager + except Exception as e: + logger.error(f"Failed to start logging: {e}") + return session_id + + +def stop(*loggers) -> BaseLogger: + """Stops AutoGen logger. + If loggers are managed by LoggerManager, + you may specify one or more loggers to be stopped. + If no logger is specified, all loggers will be stopped. + Stopped loggers will be removed from the LoggerManager. + """ + autogen_logger = autogen.runtime_logging.autogen_logger + if isinstance(autogen_logger, LoggerManager) and loggers: + for logger in loggers: + logger.stop() + if logger in autogen_logger.loggers: + autogen_logger.loggers.remove(logger) + else: + autogen.runtime_logging.stop() + return autogen_logger + + +def get_loggers() -> List[BaseLogger]: + """Gets a list of existing AutoGen loggers.""" + autogen_logger = autogen.runtime_logging.autogen_logger + if isinstance(autogen_logger, LoggerManager): + return autogen_logger.loggers + return [autogen_logger] diff --git a/docs/source/user_guide/large_language_model/autogen_integration.rst b/docs/source/user_guide/large_language_model/autogen_integration.rst index e21a8bd3e..7c8c17055 100644 --- a/docs/source/user_guide/large_language_model/autogen_integration.rst +++ b/docs/source/user_guide/large_language_model/autogen_integration.rst @@ -104,3 +104,81 @@ Following is an example LLM config for the OCI Generative AI service: }, } +Logging And Reporting +===================== + +ADS offers enhanced utilities integrating with OCI to log data for debugging and analysis: +* The ``SessionLogger`` saves events to a log file and generates report to for you to profile and debug the application. +* The ``MetricLogger`` sends the metrics to OCI monitoring service, allowing you to build dashboards to gain more insights about the application usage. + +Session Logger and Report +------------------------- + +To use the session logger, you need to specify a local directory or an OCI object storage location for saving the log files. +A unique session ID will be generated for each session. Each session will be logged into one file. +Optionally, you can specify the ``report_dir`` to generate a report at the end of each session. +If you are using an object storage location as ``report_dir``, you can also have a pre-authenticated link generated automatically for viewing and sharing the report. + +.. code-block:: python3 + + from ads.llm.autogen.v02.loggers import SessionLogger + + session_logger = SessionLogger( + # log_dir can be local dir or OCI object storage location in the form of oci://bucket@namespace/prefix + log_dir="", + # Location for saving the report. Can be local path or object storage location. + report_dir="", + # Specify session ID if you would like to resume a previous session or use your own session ID. + session_id=session_id, + # Set report_par_uri to True when using object storage to auto-generate PAR link. + report_par_uri=True, + ) + + # You may get the auto-generated session id once the logger is initialized + print(session_logger.session_id) + + # It is recommended to run your application with the context manager. + with session_logger: + # Create and run your AutoGen application + ... + + # Access the log file path + print(session_logger.log_file) + + # Report file path or pre-authenticated link + print(session_logger.report) + +The session report provides a comprehensive overview of the timeline, invocations, chat interactions, and logs in HTML format. It effectively visualizes the application's flow, facilitating efficient debugging and analysis. + +.. figure:: figures/autogen_report.png + :width: 800 + +Metric Logger +------------- +The agent metric logger emits agent metrics to `OCI Monitoring `_, +allowing you to integrate AutoGen application with OCI monitoring service to `build queries `_ and `dashboards `_, as well as `managing alarms `_. + +.. code-block:: python3 + + from ads.llm.autogen.v02 import runtime_logging + from ads.llm.autogen.v02.loggers import MetricLogger + + monitoring_logger = MetricLogger( + # Metric namespace required by OCI monitoring. + namespace="", + # Optional application name, which will be a metric dimension if specified. + app_name="order_support", + # Compartment OCID for posting the metric + compartment_id="", + # Optional session ID to be saved as a metric dimension. + session_id="" + # Whether to log agent name as a metric dimension. + log_agent_name=False, + # Whether to log tool name as a metric dimension. + log_model_name=False, + # Whether to log model name as a metric dimension. + log_tool_name=False, + ) + # Start logging metrics + runtime_logging.start(monitoring_logger) + diff --git a/docs/source/user_guide/large_language_model/figures/autogen_report.png b/docs/source/user_guide/large_language_model/figures/autogen_report.png new file mode 100644 index 000000000..6100eee01 Binary files /dev/null and b/docs/source/user_guide/large_language_model/figures/autogen_report.png differ diff --git a/tests/unitary/with_extras/autogen/test_autogen_client.py b/tests/unitary/with_extras/autogen/test_autogen_client.py index c8cce9121..57700ac9a 100644 --- a/tests/unitary/with_extras/autogen/test_autogen_client.py +++ b/tests/unitary/with_extras/autogen/test_autogen_client.py @@ -11,7 +11,7 @@ import autogen from langchain_core.messages import AIMessage, ToolCall -from ads.llm.autogen.client_v02 import ( +from ads.llm.autogen.v02.client import ( LangChainModelClient, register_custom_client, custom_clients,