From 9bdf776eac914a0dcad932d91d7740b580a4f6e4 Mon Sep 17 00:00:00 2001 From: Matthew Harris Date: Wed, 17 Jul 2024 17:10:43 -0400 Subject: [PATCH] Added an LLM recipe validation function to recipe manager as well as a script that will automatically generate recipes which are then LLM judged. This is rough work to provide pointers for doing in an agentic flow, but don't merge this PR as-is --- docker-compose.yml | 1 + management/code_gen.py | 58 +++++++++++++++++ management/recipe_sync.py | 85 +++++++++++++++++++++---- templates/validate_recipe_prompt.jinja2 | 23 +++++++ 4 files changed, 156 insertions(+), 11 deletions(-) create mode 100644 management/code_gen.py create mode 100644 templates/validate_recipe_prompt.jinja2 diff --git a/docker-compose.yml b/docker-compose.yml index ad7569ed..b389860d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -139,6 +139,7 @@ services: - ./utils:/app/utils - ./templates:/app/templates - ./db/recipedb:/app/db + - ./tests:/app/tests volumes: pgdata2: shared-data: \ No newline at end of file diff --git a/management/code_gen.py b/management/code_gen.py new file mode 100644 index 00000000..64d6fec0 --- /dev/null +++ b/management/code_gen.py @@ -0,0 +1,58 @@ +import json +import os +import readline +import shutil +import sys + +import pandas as pd +from dotenv import load_dotenv +from recipe_sync import create_new_recipe, llm_validate_recipe + +load_dotenv() + +input_data = "./tests/humanitarian_user_inputs_short.csv" +work_dir = "./work/checked_out" + +env_cmd = " python " +author = "matt" + +data = pd.read_csv(input_data) + +user_inputs = data["user_input"] + +# +# This code will read an input file of user questions, +# automatically generate recipes and have an LLM review the output +# +# + + +results = [] + +for input in user_inputs[0:3]: + print(input) + + input = input + " /nochecks" + + create_new_recipe(input, author) + print("\n\n") + + # Find most recent directory by timestamp in ./management/work + dirs = os.listdir(work_dir) + dirs = sorted(dirs, key=lambda x: os.path.getmtime(f"{work_dir}/{x}"), reverse=True) + recent_dir = work_dir + "/" + dirs[0] + "/recipe.py" + + validation_result = llm_validate_recipe(input, recent_dir) + + r = { + "input": input, + "validation_result": validation_result["answer"], + "validation_reason": validation_result["reason"], + } + + results.append(r) + + print("\n\n") + +results = pd.DataFrame(results) +results.to_csv("results.csv") diff --git a/management/recipe_sync.py b/management/recipe_sync.py index e67b7836..97d3c0f5 100644 --- a/management/recipe_sync.py +++ b/management/recipe_sync.py @@ -294,7 +294,6 @@ def extract_code_sections(recipe_path): raise ValueError( f"Code separator '{code_separator}' not found in the recipe file '{recipe_path}'." ) - sys.exit() content = content.split("\n") @@ -319,7 +318,6 @@ def extract_code_sections(recipe_path): raise ValueError( f"Function code or calling code not found in the recipe file '{recipe_path}'." ) - sys.exit() return { "function_code": function_code, @@ -1020,7 +1018,6 @@ def create_new_recipe(recipe_intent, recipe_author): print("Running recipe to capture errors for LLM ...") result = run_recipe(recipe_path) - print(result.stderr) # If there was an error, call edit recipe to try and fix it one round if result.returncode != 0: @@ -1101,6 +1098,41 @@ def llm_edit_recipe(recipe_path, llm_prompt, recipe_author): print("\n\nRecipe editing done") +def llm_validate_recipe(user_input, recipe_path): + + recipe_folder = os.path.dirname(recipe_path) + + with open(recipe_path, "r") as file: + recipe_code = file.read() + + metadata_path = os.path.join(recipe_folder, "metadata.json") + with open(metadata_path, "r") as file: + metadata = json.load(file) + + result_type = metadata["sample_result_type"] + result = metadata["sample_result"] + + validation_prompt = environment.get_template("validate_recipe_prompt.jinja2") + prompt = validation_prompt.render( + user_input=user_input, recipe_code=recipe_code, recipe_result=result + ) + + if len(prompt.split(" ")) > 8000: + return { + "answer": "error", + "user_input": user_input, + "reason": "Prompt too long, please shorten recipe code or result", + } + + if result_type == "image": + llm_result = call_llm("", prompt, image=result) + else: + llm_result = call_llm("", prompt) + + print(llm_result) + return llm_result + + def update_metadata_file_results(recipe_folder, output): """ Update the metadata file for a given recipe folder with the provided result. @@ -1118,6 +1150,8 @@ def update_metadata_file_results(recipe_folder, output): with open(metadata_path, "r") as file: metadata = json.load(file) + print(output) + if output["result"]["type"] == "image": png_file = output["result"]["file"] @@ -1304,18 +1338,26 @@ def validate_output(output): # Remove any lines with DEBUG in them output = re.sub(r"DEBUG.*\n", "", output) + error = None + try: output = json.loads(output) print("JSON output parsed successfully") + # Now check for required fields + for f in required_output_json_fields: + if f not in output: + error = f"Output of recipe must contain field {f}" + print(error) + if "type" not in output["result"]: + error = 'Output of recipe must contain field "type" in output["result"]' + print(error) except json.JSONDecodeError: print("Output: \n\n") print(output) - raise ValueError("Output of recipe must be JSON") + error = "Output of recipe must be JSON" + print(error) - # Now check for required fields - for f in required_output_json_fields: - if f not in output: - raise ValueError(f"Output of recipe must contain field {f}") + return error def run_recipe(recipe_path): @@ -1339,8 +1381,23 @@ def run_recipe(recipe_path): if output_start_string in result.stdout: output = result.stdout.split(output_start_string)[1] # output is JSON - validate_output(output) - output = json.loads(output) + error = validate_output(output) + if error is None: + output = json.loads(output) + + # Check for required fields + required_output_json_fields = ["result"] + for f in required_output_json_fields: + if f not in output: + error = f"Output of recipe must contain field {f}" + print(error) + result.stderr += f"{error}" + result.returncode = 1 + break + + else: + result.stderr += f"{error}" + result.returncode = 1 else: error_str = "ERROR: Output of recipe must contain 'OUTPUT:'" print(error_str) @@ -1636,6 +1693,9 @@ def main(): group.add_argument( "--edit_recipe", action="store_true", help="Create a new blank recipe" ) + group.add_argument( + "--validate_recipe", action="store_true", help="Validate a recipe using LLM" + ) group.add_argument( "--info", action="store_true", help="Get information about the data available" ) @@ -1673,7 +1733,7 @@ def main(): elif args.check_in: check_in(args.recipe_author) elif args.create_recipe: - recipe_intent = args.recipe_intent.lower().replace(" ", "_") + recipe_intent = args.recipe_intent.replace(" ", "_").lower() create_new_recipe(recipe_intent, args.recipe_author) elif args.delete_recipe: delete_recipe(args.recipe_custom_id) @@ -1683,6 +1743,9 @@ def main(): save_as_memory(args.recipe_path) elif args.edit_recipe: llm_edit_recipe(args.recipe_path, args.llm_prompt, args.recipe_author) + elif args.validate_recipe: + recipe_intent = args.recipe_intent + llm_validate_recipe(recipe_intent, args.recipe_path) elif args.rebuild: rebuild(args.recipe_author) elif args.dump_db: diff --git a/templates/validate_recipe_prompt.jinja2 b/templates/validate_recipe_prompt.jinja2 new file mode 100644 index 00000000..e4aaebca --- /dev/null +++ b/templates/validate_recipe_prompt.jinja2 @@ -0,0 +1,23 @@ +{# templates/validate_recipe_prompt.jinja2 #} + +The user requested this: + +{{ user_input }} + +The recipe code is: + +{{ recipe_code }} + +The recipe output is: + +{{ recipe_result }} + +Did the recipe output match the user request? + +Provide your answer as a valid JSON string in the following format: + +{ + "answer": "", + "reason": "" + "user_input": "", +} \ No newline at end of file