From a411e45a8879007655946a0c3b022dbab70a5cd5 Mon Sep 17 00:00:00 2001 From: Howard Gil Date: Tue, 13 Aug 2024 01:02:16 -0700 Subject: [PATCH] WIP --- agentops/time_travel.py | 53 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/agentops/time_travel.py b/agentops/time_travel.py index e466198d..8f364656 100644 --- a/agentops/time_travel.py +++ b/agentops/time_travel.py @@ -72,9 +72,56 @@ def fetch_prompt_override_from_time_travel_cache(kwargs): return if TimeTravel()._prompt_override_map: - search_prompt = str({"messages": kwargs["messages"]}) - result_from_cache = TimeTravel()._prompt_override_map.get(search_prompt) - return json.loads(result_from_cache) + from openai.types.chat.chat_completion_user_message_param import ( + ChatCompletionUserMessageParam, + ) + + messages = kwargs["messages"] + + from typing import List + + def validate_and_parse_messages( + messages: List[dict], + ) -> List[ChatCompletionUserMessageParam]: + parsed_messages = [] + for message in messages: + try: + parsed_message = ChatCompletionUserMessageParam( + content=message["content"], + role=message["role"], + name=message.get("name", ""), + ) + parsed_messages.append(parsed_message) + except KeyError as e: + raise ValueError(f"Missing required field in message: {e}") + return parsed_messages + + # Validate and parse the messages + messages = validate_and_parse_messages(messages) + + parsed_prompts = [] + for key in TimeTravel()._prompt_override_map.keys(): + try: + prompt_messages = json.loads(key).get("messages", []) + parsed_messages = validate_and_parse_messages(prompt_messages) + parsed_prompts.append(parsed_messages) + except (json.JSONDecodeError, ValueError) as e: + print(f"Error parsing messages for key {key}: {e}") + + def compare_messages(messages, parsed_messages): + if len(messages) != len(parsed_messages): + return False + for i in range(len(messages)): + if messages[i]["content"] != parsed_messages[i]["content"]: + return False + return True + + for key, parsed_prompt in zip( + TimeTravel()._prompt_override_map.keys(), parsed_prompts + ): + if compare_messages(messages, parsed_prompt): + result_from_cache = TimeTravel()._prompt_override_map[key] + return json.loads(result_from_cache) def check_time_travel_active():