Skip to content

Commit

Permalink
add crud for golden queries in cli
Browse files Browse the repository at this point in the history
  • Loading branch information
wongjingping authored and rishsriv committed Mar 6, 2024
1 parent 56a5865 commit 5e7a621
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 8 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,8 @@ ipython_config.py
.python-version

*.DS_Store
local_test.py
local_test.py

defog_metadata.csv
golden_queries.csv
golden_queries.json
17 changes: 12 additions & 5 deletions defog/admin_methods.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import requests
import pandas as pd

Expand Down Expand Up @@ -177,6 +178,7 @@ def delete_golden_queries(
},
)
resp = r.json()
print("All golden queries have now been deleted.")
else:
if golden_queries is None:
golden_queries = (
Expand All @@ -191,7 +193,6 @@ def delete_golden_queries(
},
)
resp = r.json()
print("All golden queries have now been deleted.")
return resp


Expand All @@ -204,13 +205,19 @@ def get_golden_queries(self, format="csv", export_path=None):
json={"api_key": self.api_key},
)
resp = r.json()
golden_queries = resp["golden_queries"]
if format == "csv":
if export_path is None:
export_path = "golden_queries.csv"
pd.DataFrame(resp["golden_queries"]).to_csv(export_path, index=False)
print(f"Golden queries exported to {export_path}")
return True
pd.DataFrame(golden_queries).to_csv(export_path, index=False)
print(f"{len(golden_queries)} golden queries exported to {export_path}")
return golden_queries
elif format == "json":
return resp["golden_queries"]
if export_path is None:
export_path = "golden_queries.json"
with open(export_path, "w") as f:
json.dump(resp, f, indent=4)
print(f"{len(golden_queries)} golden queries exported to {export_path}")
return golden_queries
else:
raise ValueError("format must be either 'csv' or 'json'.")
99 changes: 97 additions & 2 deletions defog/cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime
import decimal
import pwinput
import json
import os
import pandas as pd
import pwinput
import re
import shutil
import subprocess
import sys
import requests
import time

import defog
Expand Down Expand Up @@ -43,6 +43,8 @@ def main():
update()
elif sys.argv[1] == "query":
query()
elif sys.argv[1] == "golden":
golden()
elif sys.argv[1] == "deploy":
deploy()
elif sys.argv[1] == "quota":
Expand Down Expand Up @@ -400,6 +402,99 @@ def query():
query = prompt("Please enter another query, or type 'e' to exit: ")


def golden():
"""
Allow the user to get, add, or delete golden queries.
These are used as references during query generation (akin to k-shot learning),
and can significantly improve the performance of the query generation model.
"""
# get action from sys.argv or prompt
if len(sys.argv) < 3:
print(
"defog golden requires an action. Please enter 'get', 'add', or 'delete':"
)
action = prompt().strip().lower()
else:
action = sys.argv[2].lower()
while action not in ["get", "add", "delete", "exit"]:
print("Please enter 'get', 'add', 'delete', or 'exit' to exit:")
action = prompt().strip().lower()

dfg = defog.Defog()
if action == "get":
# get format from sys.argv or prompt
if len(sys.argv) < 4:
print(
"defog golden get requires an export format. Please enter 'json' or 'csv':"
)
format = prompt().strip().lower()
else:
format = sys.argv[3].lower()
while format not in ["json", "csv"]:
print("Please enter 'json' or 'csv':")
format = prompt().strip().lower()
# make request to get golden queries
print("Getting golden queries...")
golden_queries = dfg.get_golden_queries(format)
if golden_queries:
print(f"\nYou have {len(golden_queries)} golden queries:\n")
for pair in golden_queries:
question = pair["question"]
sql = pair["sql"]
print(f"Question:\n{question}\nSQL:\n{sql}\n")
elif action == "add":
# get path from sys.argv or prompt
if len(sys.argv) < 4:
print(
"defog golden add requires a path to a JSON or CSV file containing golden queries:"
)
path = prompt().strip()
else:
path = sys.argv[3]
while (
not os.path.exists(path)
and not path.endswith(".json")
and not path.endswith(".csv")
):
print("File not found. Please enter a valid path:")
path = prompt().strip()
# if path ends with json, read in json and pass in golden_queries
if path.endswith(".json"):
with open(path, "r") as f:
golden_queries = json.load(f)
dfg.update_golden_queries(golden_queries=golden_queries)
# if path ends with csv, pass in csv file
elif path.endswith(".csv"):
dfg.update_golden_queries(golden_queries_path=path)
elif action == "delete":
# get path from sys.argv or prompt
if len(sys.argv) < 4:
print(
"defog golden delete requires either a path to a JSON or CSV file containing golden queries for deletion, or 'all' to delete all golden queries."
)
path = prompt().strip()
else:
path = sys.argv[3]
while (
not os.path.exists(path)
and not path.endswith(".json")
and not path.endswith(".csv")
and path != "all"
):
print(
"File not found. Please enter a valid path, or 'all' to delete all golden queries:"
)
path = prompt().strip()
if path == "all":
dfg.delete_golden_queries(all=True)
elif path.endswith(".json"):
with open(path, "r") as f:
golden_queries = json.load(f)
dfg.delete_golden_queries(golden_queries=golden_queries)
elif path.endswith(".csv"):
dfg.delete_golden_queries(golden_queries_path=path)


def deploy():
"""
Deploy a cloud function that can be used to run queries.
Expand Down

0 comments on commit 5e7a621

Please sign in to comment.