Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
HowieG committed Aug 13, 2024
1 parent 25fa488 commit a411e45
Showing 1 changed file with 50 additions and 3 deletions.
53 changes: 50 additions & 3 deletions agentops/time_travel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,56 @@ def fetch_prompt_override_from_time_travel_cache(kwargs):
return

if TimeTravel()._prompt_override_map:
search_prompt = str({"messages": kwargs["messages"]})
result_from_cache = TimeTravel()._prompt_override_map.get(search_prompt)
return json.loads(result_from_cache)
from openai.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)

messages = kwargs["messages"]

from typing import List

def validate_and_parse_messages(
messages: List[dict],
) -> List[ChatCompletionUserMessageParam]:
parsed_messages = []
for message in messages:
try:
parsed_message = ChatCompletionUserMessageParam(
content=message["content"],
role=message["role"],
name=message.get("name", ""),
)
parsed_messages.append(parsed_message)
except KeyError as e:
raise ValueError(f"Missing required field in message: {e}")
return parsed_messages

# Validate and parse the messages
messages = validate_and_parse_messages(messages)

parsed_prompts = []
for key in TimeTravel()._prompt_override_map.keys():
try:
prompt_messages = json.loads(key).get("messages", [])
parsed_messages = validate_and_parse_messages(prompt_messages)
parsed_prompts.append(parsed_messages)
except (json.JSONDecodeError, ValueError) as e:
print(f"Error parsing messages for key {key}: {e}")

def compare_messages(messages, parsed_messages):
if len(messages) != len(parsed_messages):
return False
for i in range(len(messages)):
if messages[i]["content"] != parsed_messages[i]["content"]:
return False
return True

for key, parsed_prompt in zip(
TimeTravel()._prompt_override_map.keys(), parsed_prompts
):
if compare_messages(messages, parsed_prompt):
result_from_cache = TimeTravel()._prompt_override_map[key]
return json.loads(result_from_cache)


def check_time_travel_active():
Expand Down

0 comments on commit a411e45

Please sign in to comment.