diff --git a/defog/admin_methods.py b/defog/admin_methods.py index d7fbe39..b2a4f98 100644 --- a/defog/admin_methods.py +++ b/defog/admin_methods.py @@ -1,5 +1,5 @@ import json -from typing import Dict, Optional +from typing import Dict, List, Optional import requests import pandas as pd @@ -142,7 +142,7 @@ def get_quota(self) -> Optional[Dict]: def update_golden_queries( self, - golden_queries: dict = None, + golden_queries: List[Dict] = None, golden_queries_path: str = None, scrub: bool = True, ): diff --git a/defog/cli.py b/defog/cli.py index 3a7b76e..d3520a8 100644 --- a/defog/cli.py +++ b/defog/cli.py @@ -534,7 +534,9 @@ def query(): print(resp) print() - get_feedback(df.api_key, df.db_type, user_question, sql_generated) + get_feedback( + df.api_key, df.db_type, user_question, sql_generated, df.base_url + ) query = prompt("Please enter another query, or type 'e' to exit: ") diff --git a/defog/util.py b/defog/util.py index d29e5bd..0eba04e 100644 --- a/defog/util.py +++ b/defog/util.py @@ -132,7 +132,9 @@ def identify_categorical_columns( return rows -def get_feedback(api_key: str, db_type: str, user_question: str, sql_generated: str): +def get_feedback( + api_key: str, db_type: str, user_question: str, sql_generated: str, base_url: str +): """ Get feedback from the user on whether the query was good or bad, and why. """ @@ -162,19 +164,172 @@ def get_feedback(api_key: str, db_type: str, user_question: str, sql_generated: if feedback_text != "": data["feedback_text"] = feedback_text requests.post( - "https://api.defog.ai/feedback", + f"{base_url}/feedback", json=data, timeout=1, ) if feedback == "y": print("Thank you for the feedback!") - else: + elif feedback == "n": + data = { + "api_key": api_key, + "question": user_question, + "sql_generated": sql_generated, + "error": feedback_text, + } print( - "Thank you for the feedback! We retrain our models every week, and you should see much better performance on these kinds of queries in another week.\n" + "Thank you for the feedback, let us see how can we improve this for you...\n" ) - print( - f"If you continue to get these errors, please consider updating the metadata in your schema by editing the CSV generated and running `defog update `, or by updating your glossary.\n" + response = requests.post( + f"{base_url}/reflect_on_error", + json=data, ) + + if response.status_code == 200: + response_dict = response.json() + feedback = response_dict.get("feedback") + if feedback: + print(f"Here is our automated assessment:\n{feedback}\n") + # 1) validate and update glossary + instruction_set = response_dict.get("instruction_set") + if instruction_set: + print( + f"We came up with the following additions for improving your glossary:\n{instruction_set}" + ) + add_to_glossary = prompt( + "If you would like to add these suggestions to your glossary, please enter 'y'. If you would like to amend it, just type in your edits and hit enter. Otherwise, enter 'n'.\n" + ) + if add_to_glossary == "y": + md_resp = requests.post( + f"{base_url}/get_metadata", + json={"api_key": api_key}, + ) + md_resp_dict = md_resp.json() + glossary_current = md_resp_dict.get("glossary", "") + glossary_updated = glossary_current + "\n" + instruction_set + requests.post( + f"{base_url}/update_glossary", + json={ + "api_key": api_key, + "glossary": glossary_updated, + }, + ) + print("Glossary updated successfully.\n") + elif add_to_glossary != "n": + md_resp = requests.post( + f"{base_url}/get_metadata", + json={"api_key": api_key}, + ) + md_resp_dict = md_resp.json() + glossary_current = md_resp_dict.get("glossary", "") + glossary_updated = glossary_current + "\n" + add_to_glossary + requests.post( + f"{base_url}/update_glossary", + json={ + "api_key": api_key, + "glossary": glossary_updated, + }, + ) + print("Glossary updated successfully.\n") + else: + print("Glossary not updated.\n") + + # 2) validate and update column descriptions in metadata + new_column_descriptions = response_dict.get("column_descriptions") + if new_column_descriptions: + print( + f"We came up with the following suggestions for improving your column descriptions:\n{new_column_descriptions}" + ) + # get original metadata + r = requests.post( + f"{base_url}/get_metadata", + json={"api_key": api_key}, + ) + resp = r.json() + md = resp.get("table_metadata", {}) + # we will be editing md in place + column_changed = False + for new_column_description in new_column_descriptions: + table_name = new_column_description.get("table_name") + column_name = new_column_description.get("column_name") + description = new_column_description.get("description") + if table_name in md: + for column in md[table_name]: + if column.get("column_name") == column_name: + print( + f"\nCurrent description for {column_name}: {column.get('column_description')}" + ) + print( + f"Suggested description for {column_name}: {description}" + ) + replace = prompt( + "Would you like to replace this description with our suggestion? Please enter 'y' to replace, or your own description to amend. Otherwise, enter 'n' to skip.\n" + ) + if replace == "y": + column["column_description"] = description + print("Updated description.") + column_changed = True + break + elif replace != "n": + column["column_description"] = replace + print("Updated description.") + column_changed = True + break + else: + print("Description not updated.") + break + if column_changed: + requests.post( + f"{base_url}/update_metadata", + json={ + "api_key": api_key, + "table_metadata": new_column_description, + "db_type": db_type, + }, + ) + print("Metadata updated successfully.\n") + else: + print("No metadata changes to update.\n") + # 3) validate and update reference_queries + new_reference_queries = response_dict.get("reference_queries") + if ( + isinstance(new_reference_queries, list) + and len(new_reference_queries) > 0 + ): + reference_queries_to_add = [] + print( + f"We came up with the following suggestions for adding as your reference queries:" + ) + for new_reference_query in new_reference_queries: + question = new_reference_query.get("question") + sql = new_reference_query.get("sql") + print(f"Question: {question}\nSQL: {sql}") + update_reference_queries = prompt( + "Would you like to add this as one of your reference queries? Please hit 'y' to add, or anything else to skip to the next suggestion.\n" + ) + if update_reference_queries == "y": + reference_queries_to_add.append(new_reference_query) + if len(reference_queries_to_add) > 0: + r = requests.post( + f"{base_url}/update_golden_queries", + json={ + "api_key": api_key, + "golden_queries": reference_queries_to_add, + "scrub": True, + }, + ) + if r.status_code == 200: + print( + f"{len(reference_queries_to_add)} reference queries added successfully." + ) + else: + print("Reference queries not updated.") + else: + print("No reference queries to update.") + print() + else: + print("There was an error in getting suggestions. Our apologies!") + except Exception as e: write_logs(f"Error in get_feedback:\n{e}") pass diff --git a/tests/test_util.py b/tests/test_util.py index 2fde072..69de0c1 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -109,39 +109,49 @@ class TestGetFeedback(unittest.TestCase): @patch("defog.util.prompt", return_value="y") @patch("requests.post") def test_positive_feedback(self, mock_post, mock_prompt): - get_feedback("api_key", "db_type", "user_question", "sql_generated") - mock_post.assert_called_once() + get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") + assert mock_post.call_count == 1 self.assertIn("good", mock_post.call_args.kwargs["json"]["feedback"]) self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"]) - @patch("defog.util.prompt", side_effect=["n", "bad query"]) + @patch("defog.util.prompt", side_effect=["n", "bad query", "", "", ""]) @patch("requests.post") def test_negative_feedback_with_text(self, mock_post, mock_prompt): - get_feedback("api_key", "db_type", "user_question", "sql_generated") - mock_post.assert_called_once() - self.assertIn("bad", mock_post.call_args.kwargs["json"]["feedback"]) - self.assertIn("bad query", mock_post.call_args.kwargs["json"]["feedback_text"]) + get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") + # 2 calls: 1 to /feedback, 1 to /reflect_on_error + assert mock_post.call_count == 2 + self.assertIn("api_key", mock_post.call_args.kwargs["json"]["api_key"]) + self.assertIn("user_question", mock_post.call_args.kwargs["json"]["question"]) + self.assertIn( + "sql_generated", mock_post.call_args.kwargs["json"]["sql_generated"] + ) + self.assertIn("bad query", mock_post.call_args.kwargs["json"]["error"]) - @patch("defog.util.prompt", side_effect=["n", ""]) + @patch("defog.util.prompt", side_effect=["n", "", "", "", ""]) @patch("requests.post") def test_negative_feedback_without_text(self, mock_post, mock_prompt): - get_feedback("api_key", "db_type", "user_question", "sql_generated") - mock_post.assert_called_once() - self.assertIn("bad", mock_post.call_args.kwargs["json"]["feedback"]) - self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"]) + get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") + # 2 calls: 1 to /feedback, 1 to /reflect_on_error + assert mock_post.call_count == 2 + self.assertIn("api_key", mock_post.call_args.kwargs["json"]["api_key"]) + self.assertIn("user_question", mock_post.call_args.kwargs["json"]["question"]) + self.assertIn( + "sql_generated", mock_post.call_args.kwargs["json"]["sql_generated"] + ) + self.assertIn("", mock_post.call_args.kwargs["json"]["error"]) @patch("defog.util.prompt", side_effect=["invalid", "y"]) @patch("requests.post") def test_invalid_then_valid_input(self, mock_post, mock_prompt): - get_feedback("api_key", "db_type", "user_question", "sql_generated") - mock_post.assert_called_once() + get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") + assert mock_post.call_count == 1 self.assertIn("good", mock_post.call_args.kwargs["json"]["feedback"]) self.assertNotIn("feedback_text", mock_post.call_args.kwargs["json"]) @patch("defog.util.prompt", side_effect=[""]) @patch("requests.post") def test_skip_input(self, mock_post, mock_prompt): - get_feedback("api_key", "db_type", "user_question", "sql_generated") + get_feedback("api_key", "db_type", "user_question", "sql_generated", "base_url") mock_post.assert_not_called()