Skip to content

Commit

Permalink
Add reflection + validation steps in get_feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
wongjingping committed Mar 25, 2024
1 parent 2d72b3c commit ccabaad
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 24 deletions.
4 changes: 2 additions & 2 deletions defog/admin_methods.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict, Optional
from typing import Dict, List, Optional
import requests
import pandas as pd

Expand Down Expand Up @@ -142,7 +142,7 @@ def get_quota(self) -> Optional[Dict]:

def update_golden_queries(
self,
golden_queries: dict = None,
golden_queries: List[Dict] = None,
golden_queries_path: str = None,
scrub: bool = True,
):
Expand Down
4 changes: 3 additions & 1 deletion defog/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,9 @@ def query():
print(resp)

print()
get_feedback(df.api_key, df.db_type, user_question, sql_generated)
get_feedback(
df.api_key, df.db_type, user_question, sql_generated, df.base_url
)
query = prompt("Please enter another query, or type 'e' to exit: ")


Expand Down
167 changes: 161 additions & 6 deletions defog/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def identify_categorical_columns(
return rows


def get_feedback(api_key: str, db_type: str, user_question: str, sql_generated: str):
def get_feedback(
api_key: str, db_type: str, user_question: str, sql_generated: str, base_url: str
):
"""
Get feedback from the user on whether the query was good or bad, and why.
"""
Expand Down Expand Up @@ -162,19 +164,172 @@ def get_feedback(api_key: str, db_type: str, user_question: str, sql_generated:
if feedback_text != "":
data["feedback_text"] = feedback_text
requests.post(
"https://api.defog.ai/feedback",
f"{base_url}/feedback",
json=data,
timeout=1,
)
if feedback == "y":
print("Thank you for the feedback!")
else:
elif feedback == "n":
data = {
"api_key": api_key,
"question": user_question,
"sql_generated": sql_generated,
"error": feedback_text,
}
print(
"Thank you for the feedback! We retrain our models every week, and you should see much better performance on these kinds of queries in another week.\n"
"Thank you for the feedback, let us see how can we improve this for you...\n"
)
print(
f"If you continue to get these errors, please consider updating the metadata in your schema by editing the CSV generated and running `defog update <url>`, or by updating your glossary.\n"
response = requests.post(
f"{base_url}/reflect_on_error",
json=data,
)

if response.status_code == 200:
response_dict = response.json()
feedback = response_dict.get("feedback")
if feedback:
print(f"Here is our automated assessment:\n{feedback}\n")
# 1) validate and update glossary
instruction_set = response_dict.get("instruction_set")
if instruction_set:
print(
f"We came up with the following additions for improving your glossary:\n{instruction_set}"
)
add_to_glossary = prompt(
"If you would like to add these suggestions to your glossary, please enter 'y'. If you would like to amend it, just type in your edits and hit enter. Otherwise, enter 'n'.\n"
)
if add_to_glossary == "y":
md_resp = requests.post(
f"{base_url}/get_metadata",
json={"api_key": api_key},
)
md_resp_dict = md_resp.json()
glossary_current = md_resp_dict.get("glossary", "")
glossary_updated = glossary_current + "\n" + instruction_set
requests.post(
f"{base_url}/update_glossary",
json={
"api_key": api_key,
"glossary": glossary_updated,
},
)
print("Glossary updated successfully.\n")
elif add_to_glossary != "n":
md_resp = requests.post(
f"{base_url}/get_metadata",
json={"api_key": api_key},
)
md_resp_dict = md_resp.json()
glossary_current = md_resp_dict.get("glossary", "")
glossary_updated = glossary_current + "\n" + add_to_glossary
requests.post(
f"{base_url}/update_glossary",
json={
"api_key": api_key,
"glossary": glossary_updated,
},
)
print("Glossary updated successfully.\n")
else:
print("Glossary not updated.\n")

# 2) validate and update column descriptions in metadata
new_column_descriptions = response_dict.get("column_descriptions")
if new_column_descriptions:
print(
f"We came up with the following suggestions for improving your column descriptions:\n{new_column_descriptions}"
)
# get original metadata
r = requests.post(
f"{base_url}/get_metadata",
json={"api_key": api_key},
)
resp = r.json()
md = resp.get("table_metadata", {})
# we will be editing md in place
column_changed = False
for new_column_description in new_column_descriptions:
table_name = new_column_description.get("table_name")
column_name = new_column_description.get("column_name")
description = new_column_description.get("description")
if table_name in md:
for column in md[table_name]:
if column.get("column_name") == column_name:
print(
f"\nCurrent description for {column_name}: {column.get('column_description')}"
)
print(
f"Suggested description for {column_name}: {description}"
)
replace = prompt(
"Would you like to replace this description with our suggestion? Please enter 'y' to replace, or your own description to amend. Otherwise, enter 'n' to skip.\n"
)
if replace == "y":
column["column_description"] = description
print("Updated description.")
column_changed = True
break
elif replace != "n":
column["column_description"] = replace
print("Updated description.")
column_changed = True
break
else:
print("Description not updated.")
break
if column_changed:
requests.post(
f"{base_url}/update_metadata",
json={
"api_key": api_key,
"table_metadata": new_column_description,
"db_type": db_type,
},
)
print("Metadata updated successfully.\n")
else:
print("No metadata changes to update.\n")
# 3) validate and update reference_queries
new_reference_queries = response_dict.get("reference_queries")
if (
isinstance(new_reference_queries, list)
and len(new_reference_queries) > 0
):
reference_queries_to_add = []
print(
f"We came up with the following suggestions for adding as your reference queries:"
)
for new_reference_query in new_reference_queries:
question = new_reference_query.get("question")
sql = new_reference_query.get("sql")
print(f"Question: {question}\nSQL: {sql}")
update_reference_queries = prompt(
"Would you like to add this as one of your reference queries? Please hit 'y' to add, or anything else to skip to the next suggestion.\n"
)
if update_reference_queries == "y":
reference_queries_to_add.append(new_reference_query)
if len(reference_queries_to_add) > 0:
r = requests.post(
f"{base_url}/update_golden_queries",
json={
"api_key": api_key,
"golden_queries": reference_queries_to_add,
"scrub": True,
},
)
if r.status_code == 200:
print(
f"{len(reference_queries_to_add)} reference queries added successfully."
)
else:
print("Reference queries not updated.")
else:
print("No reference queries to update.")
print()
else:
print("There was an error in getting suggestions. Our apologies!")

except Exception as e:
write_logs(f"Error in get_feedback:\n{e}")
pass
40 changes: 25 additions & 15 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,39 +109,49 @@ class TestGetFeedback(unittest.TestCase):
@patch("defog.util.prompt", return_value="y")
@patch("requests.post")
def test_positive_feedback(self, mock_post, mock_prompt):
get_feedback("api_key", "db_type", "user_question", "sql_generated")
mock_post.assert_called_once()
get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url")
assert mock_post.call_count == 1
self.assertIn("good", mock_post.call_args.kwargs["json"]["feedback"])
self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"])

@patch("defog.util.prompt", side_effect=["n", "bad query"])
@patch("defog.util.prompt", side_effect=["n", "bad query", "", "", ""])
@patch("requests.post")
def test_negative_feedback_with_text(self, mock_post, mock_prompt):
get_feedback("api_key", "db_type", "user_question", "sql_generated")
mock_post.assert_called_once()
self.assertIn("bad", mock_post.call_args.kwargs["json"]["feedback"])
self.assertIn("bad query", mock_post.call_args.kwargs["json"]["feedback_text"])
get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url")
# 2 calls: 1 to /feedback, 1 to /reflect_on_error
assert mock_post.call_count == 2
self.assertIn("api_key", mock_post.call_args.kwargs["json"]["api_key"])
self.assertIn("user_question", mock_post.call_args.kwargs["json"]["question"])
self.assertIn(
"sql_generated", mock_post.call_args.kwargs["json"]["sql_generated"]
)
self.assertIn("bad query", mock_post.call_args.kwargs["json"]["error"])

@patch("defog.util.prompt", side_effect=["n", ""])
@patch("defog.util.prompt", side_effect=["n", "", "", "", ""])
@patch("requests.post")
def test_negative_feedback_without_text(self, mock_post, mock_prompt):
get_feedback("api_key", "db_type", "user_question", "sql_generated")
mock_post.assert_called_once()
self.assertIn("bad", mock_post.call_args.kwargs["json"]["feedback"])
self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"])
get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url")
# 2 calls: 1 to /feedback, 1 to /reflect_on_error
assert mock_post.call_count == 2
self.assertIn("api_key", mock_post.call_args.kwargs["json"]["api_key"])
self.assertIn("user_question", mock_post.call_args.kwargs["json"]["question"])
self.assertIn(
"sql_generated", mock_post.call_args.kwargs["json"]["sql_generated"]
)
self.assertIn("", mock_post.call_args.kwargs["json"]["error"])

@patch("defog.util.prompt", side_effect=["invalid", "y"])
@patch("requests.post")
def test_invalid_then_valid_input(self, mock_post, mock_prompt):
get_feedback("api_key", "db_type", "user_question", "sql_generated")
mock_post.assert_called_once()
get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url")
assert mock_post.call_count == 1
self.assertIn("good", mock_post.call_args.kwargs["json"]["feedback"])
self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"])

@patch("defog.util.prompt", side_effect=[""])
@patch("requests.post")
def test_skip_input(self, mock_post, mock_prompt):
get_feedback("api_key", "db_type", "user_question", "sql_generated")
get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url")
mock_post.assert_not_called()


Expand Down

0 comments on commit ccabaad

Please sign in to comment.