From 7fb081a929c32d170cd61b79f30b4a68c3ca31f3 Mon Sep 17 00:00:00 2001 From: jp Date: Thu, 11 Jan 2024 14:22:16 +0800 Subject: [PATCH] Fix feedback logging - query.py: we exclude connection errors from logging - cli.py: refactor feedback solicitation after query generation to use a separate function `util.get_feedback` - remove an unused function - add tests --- defog/__init__.py | 12 ----- defog/cli.py | 123 ++++++++++++-------------------------------- defog/query.py | 27 +++++++--- defog/util.py | 48 +++++++++++++++++ tests/test_query.py | 33 +++++++++++- tests/test_util.py | 40 +++++++++++++- 6 files changed, 170 insertions(+), 113 deletions(-) diff --git a/defog/__init__.py b/defog/__init__.py index fb68267..35a7bc0 100644 --- a/defog/__init__.py +++ b/defog/__init__.py @@ -1005,18 +1005,6 @@ def get_predefined_queries(self): else: return [] - def execute_predefined_query(self, query): - """ - Executes a predefined query - """ - resp = execute_query( - query["query"], - self.api_key, - self.db_type, - self.db_creds, - ) - return resp - def update_db_schema_csv(self, path_to_csv): """ Update the DB schema via a CSV, rather than by via a Google Sheet diff --git a/defog/cli.py b/defog/cli.py index 5c42aea..40cdba4 100644 --- a/defog/cli.py +++ b/defog/cli.py @@ -10,8 +10,7 @@ import requests import defog -from defog import Defog -from defog.util import parse_update +from defog.util import get_feedback, parse_update from prompt_toolkit import prompt USAGE_STRING = """ @@ -327,8 +326,6 @@ def query(): else: query = sys.argv[2] - feedback_mode = False - feedback = "" user_question = "" sql_generated = "" while True: @@ -338,100 +335,44 @@ def query(): elif query == "": print("Your query cannot be empty.") query = prompt("Enter a query, or type 'e' to exit: ") - if feedback_mode: - if feedback not in ["y", "n"]: - pass - elif feedback == "y": - # send data to /feedback endpoint - try: - requests.post( - "https://api.defog.ai/feedback", - json={ - "api_key": df.api_key, - "feedback": "good", - "db_type": df.db_type, - "question": user_question, - "query": sql_generated, - }, - timeout=1, - ) - print("Thank you for the feedback!") - feedback_mode = False - except: - pass - - elif feedback == "n": - # send data to /feedback endpoint - feedback_text = prompt( - "Could you tell us why this was a bad query? This will help us improve the model for you. Just hit enter if you want to leave this blank.\n" - ) - try: - requests.post( - "https://api.defog.ai/feedback", - json={ - "api_key": df.api_key, - "feedback": "bad", - "text": feedback_text, - "db_type": df.db_type, - "question": user_question, - "query": sql_generated, - }, - timeout=1, - ) - except: - pass + user_question = query + resp = df.run_query(query, retries=3) + if not resp["ran_successfully"]: + if "query_generated" in resp: + print("Defog generated the following query to answer your question:\n") + print(f"\033[1m{resp['query_generated']}\033[0m\n") + print( - "Thank you for the feedback! We retrain our models every week, and you should see much better performance on these kinds of queries in another week.\n" + f"However, your query did not run successfully. The error message generated while running the query on your database was\n\n\033[1m{resp['error_message']}\033[0m\n." ) + print( f"If you continue to get these errors, please consider updating the metadata in your schema by editing the google sheet generated and running `defog update `, or by updating your glossary.\n" ) - feedback_mode = False - query = prompt("Enter another query, or type 'e' to exit: ") - else: - user_question = query - resp = df.run_query(query, retries=3) - if not resp["ran_successfully"]: - if "query_generated" in resp: - print( - "Defog generated the following query to answer your question:\n" - ) - print(f"\033[1m{resp['query_generated']}\033[0m\n") - - print( - f"However, your query did not run successfully. The error message generated while running the query on your database was\n\n\033[1m{resp['error_message']}\033[0m\n." - ) - - print( - f"If you continue to get these errors, please consider updating the metadata in your schema by editing the google sheet generated and running `defog update `, or by updating your glossary.\n" - ) - else: - print( - f"Defog was unable to generate a query for your question. The error message generated while running the query on your database was\n\n\033[1m{resp.get('error_message')}\033[0m\n." - ) - query = prompt("Enter another query, or type 'e' to exit: ") else: - sql_generated = resp.get("query_generated") - print("Defog generated the following query to answer your question:\n") - print(f"\033[1m{resp['query_generated']}\033[0m\n") - reason_for_query = resp.get("reason_for_query", "") - reason_for_query = reason_for_query.replace(". ", "\n\n") - - print("This was its reasoning for generating this query:\n") - print(f"\033[1m{reason_for_query}\033[0m\n") - - print("Results:\n") - # print results in tabular format using 'columns' and 'data' keys - try: - print_table(resp["columns"], resp["data"]) - except: - print(resp) - - print() - feedback_mode = True - feedback = prompt( - "Did Defog answer your question well? Just hit enter to skip (y/n):\n" + print( + f"Defog was unable to generate a query for your question. The error message generated while running the query on your database was\n\n\033[1m{resp.get('error_message')}\033[0m\n." ) + else: + sql_generated = resp.get("query_generated") + print("Defog generated the following query to answer your question:\n") + print(f"\033[1m{resp['query_generated']}\033[0m\n") + reason_for_query = resp.get("reason_for_query", "") + reason_for_query = reason_for_query.replace(". ", "\n\n") + + print("This was its reasoning for generating this query:\n") + print(f"\033[1m{reason_for_query}\033[0m\n") + + print("Results:\n") + # print results in tabular format using 'columns' and 'data' keys + try: + print_table(resp["columns"], resp["data"]) + except: + print(resp) + + print() + get_feedback(df.api_key, df.db_type, user_question, sql_generated) + query = prompt("Please enter another query, or type 'e' to exit: ") def deploy(): diff --git a/defog/query.py b/defog/query.py index 725f338..6043ede 100644 --- a/defog/query.py +++ b/defog/query.py @@ -1,4 +1,5 @@ import json +import re import requests from defog.util import write_logs @@ -139,18 +140,22 @@ def execute_query( retries: int = 3, schema: dict = None, ): + """ + Execute the query and retry with adaptive learning if there is an error. + Raises an Exception if there are no retries left, or if the error is a connection error. + """ err_msg = None try: return execute_query_once(db_type, db_creds, query) + (query,) except Exception as e: err_msg = str(e) - print( - "There was an error when running the previous query. Retrying with adaptive learning..." - ) - - # log this error to our feedback system + if is_connection_error(err_msg): + raise Exception( + f"There was a connection issue to your database:\n{err_msg}\n\nPlease check your database credentials and try again." + ) + # log this error to our feedback system first (this is a 1-way side-effect) try: - r = requests.post( + requests.post( "https://api.defog.ai/feedback", json={ "api_key": api_key, @@ -164,8 +169,9 @@ def execute_query( ) except: pass - + # log locally write_logs(str(e)) + # retry with adaptive learning while retries > 0: write_logs(f"Retries left: {retries}") try: @@ -196,3 +202,10 @@ def execute_query( write_logs(str(e)) retries -= 1 raise Exception(err_msg) + + +def is_connection_error(err_msg: str) -> bool: + return ( + isinstance(err_msg, str) + and re.search(r"connection.*failed", err_msg) is not None + ) diff --git a/defog/util.py b/defog/util.py index 087e138..7e2f474 100644 --- a/defog/util.py +++ b/defog/util.py @@ -1,6 +1,9 @@ import os from typing import List +from prompt_toolkit import prompt +import requests + def parse_update( args_list: List[str], attributes_list: List[str], config_dict: dict @@ -94,3 +97,48 @@ def identify_categorical_columns( top_values = [i[0] for i in top_values if i[0] is not None] rows[idx]["top_values"] = top_values return rows + + +def get_feedback(api_key: str, db_type: str, user_question: str, sql_generated: str): + """ + Get feedback from the user on whether the query was good or bad, and why. + """ + feedback = prompt( + "Did Defog answer your question well? Just hit enter to skip (y/n):\n" + ) + while feedback not in ["y", "n", ""]: + feedback = prompt("Please enter y or n:\n") + # get explanation for negative feedback + if feedback == "n": + feedback_text = prompt( + "Could you tell us why this was a bad query? This will help us improve the model for you. Just hit enter if you want to leave this blank.\n" + ) + else: + feedback_text = "" + try: + data = { + "api_key": api_key, + "feedback": "good" if feedback == "y" else "bad", + "db_type": db_type, + "question": user_question, + "query": sql_generated, + } + if feedback_text != "": + data["feedback_text"] = feedback_text + requests.post( + "https://api.defog.ai/feedback", + json=data, + timeout=1, + ) + if feedback == "y": + print("Thank you for the feedback!") + else: + print( + "Thank you for the feedback! We retrain our models every week, and you should see much better performance on these kinds of queries in another week.\n" + ) + print( + f"If you continue to get these errors, please consider updating the metadata in your schema by editing the google sheet generated and running `defog update `, or by updating your glossary.\n" + ) + except Exception as e: + write_logs(f"Error in get_feedback:\n{e}") + pass diff --git a/tests/test_query.py b/tests/test_query.py index c02ea38..80a0c32 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -4,7 +4,7 @@ import unittest from unittest import mock -from defog.query import execute_query_once, execute_query +from defog.query import is_connection_error, execute_query_once, execute_query class ExecuteQueryOnceTestCase(unittest.TestCase): @@ -166,5 +166,36 @@ def side_effect(db_type, db_creds, query): self.assertIn(json.dumps(json_req), lines[2]) +class TestConnectionError(unittest.TestCase): + def test_connection_failed(self): + self.assertTrue( + is_connection_error( + """connection to server on socket "/tmp/.s.PGSQL.5432" failed: No such file or directory + Is the server running locally and accepting connections on that socket?""" + ) + ) + + def test_not_connection_failed(self): + self.assertFalse( + is_connection_error( + 'psycopg2.errors.UndefinedTable: relation "nonexistent_table" does not exist' + ) + ) + self.assertFalse( + is_connection_error( + 'psycopg2.errors.SyntaxError: syntax error at or near "nonexistent_table"' + ) + ) + self.assertFalse( + is_connection_error( + 'psycopg2.errors.UndefinedColumn: column "nonexistent_column" does not exist' + ) + ) + + def test_empty_string(self): + self.assertFalse(is_connection_error("")) + self.assertFalse(is_connection_error(None)) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_util.py b/tests/test_util.py index 7171ef0..a1ecf78 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,8 +1,10 @@ import unittest -from defog.util import parse_update +from unittest.mock import patch +from defog.util import parse_update, get_feedback -class TestDefogUtil(unittest.TestCase): + +class TestParseUpdate(unittest.TestCase): def test_parse_update_1_key_edit(self): update_str = ["--app_name", "AWS"] attributes_list = ["app_name"] @@ -41,5 +43,39 @@ def test_parse_update_1_key_not_exists(self): ) +class TestGetFeedback(unittest.TestCase): + @patch("defog.util.prompt", return_value="y") + @patch("requests.post") + def test_positive_feedback(self, mock_post, mock_prompt): + get_feedback("api_key", "db_type", "user_question", "sql_generated") + mock_post.assert_called_once() + self.assertIn("good", mock_post.call_args.kwargs["json"]["feedback"]) + self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"]) + + @patch("defog.util.prompt", side_effect=["n", "bad query"]) + @patch("requests.post") + def test_negative_feedback_with_text(self, mock_post, mock_prompt): + get_feedback("api_key", "db_type", "user_question", "sql_generated") + mock_post.assert_called_once() + self.assertIn("bad", mock_post.call_args.kwargs["json"]["feedback"]) + self.assertIn("bad query", mock_post.call_args.kwargs["json"]["feedback_text"]) + + @patch("defog.util.prompt", side_effect=["n", ""]) + @patch("requests.post") + def test_negative_feedback_without_text(self, mock_post, mock_prompt): + get_feedback("api_key", "db_type", "user_question", "sql_generated") + mock_post.assert_called_once() + self.assertIn("bad", mock_post.call_args.kwargs["json"]["feedback"]) + self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"]) + + @patch("defog.util.prompt", side_effect=["invalid", "y"]) + @patch("requests.post") + def test_invalid_then_valid_input(self, mock_post, mock_prompt): + get_feedback("api_key", "db_type", "user_question", "sql_generated") + mock_post.assert_called_once() + self.assertIn("good", mock_post.call_args.kwargs["json"]["feedback"]) + self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"]) + + if __name__ == "__main__": unittest.main()