Skip to content

Commit

Permalink
fix sql not found error for chat data
Browse files Browse the repository at this point in the history
Signed-off-by: shanhaikang.shk <[email protected]>
  • Loading branch information
GITHUBear committed Nov 22, 2024
1 parent 780ce80 commit 4d4a3e5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
43 changes: 29 additions & 14 deletions dbgpt/app/scene/chat_db/auto_execute/out_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ class SqlAction(NamedTuple):
sql: str
thoughts: Dict
display: str
direct_response: str

def to_dict(self) -> Dict[str, Dict]:
return {
"sql": self.sql,
"thoughts": self.thoughts,
"display": self.display,
"direct_response": self.direct_response,
}


Expand All @@ -48,7 +50,7 @@ def parse_prompt_response(self, model_out_text):
logger.info(f"clean prompt response: {clean_str}")
# Compatible with community pure sql output model
if self.is_sql_statement(clean_str):
return SqlAction(clean_str, "", "")
return SqlAction(clean_str, "", "", "")
else:
try:
response = json.loads(clean_str, strict=False)
Expand All @@ -59,28 +61,38 @@ def parse_prompt_response(self, model_out_text):
thoughts = response[key]
if key.strip() == "display_type":
display = response[key]
return SqlAction(sql, thoughts, display)
if key.strip() == "direct_response":
resp = response[key]
return SqlAction(sql, thoughts, display, resp)
except Exception as e:
logger.error(f"json load failed:{clean_str}")
return SqlAction("", clean_str, "")
return SqlAction("", clean_str, "", "")

def parse_view_response(self, speak, data, prompt_response) -> str:
param = {}
api_call_element = ET.Element("chart-view")
err_msg = None
success = False
try:
if not prompt_response.sql or len(prompt_response.sql) <= 0:
if (
not prompt_response.direct_response
or len(prompt_response.direct_response) <= 0
) and (not prompt_response.sql or len(prompt_response.sql) <= 0):
raise AppActionException("Can not find sql in response", speak)

df = data(prompt_response.sql)
param["type"] = prompt_response.display
param["sql"] = prompt_response.sql
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
success = True
if prompt_response.sql:
df = data(prompt_response.sql)
param["type"] = prompt_response.display
param["sql"] = prompt_response.sql
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
success = True
elif prompt_response.direct_response:
speak = prompt_response.direct_response
view_json_str = ""
success = True
except Exception as e:
logger.error("parse_view_response error!" + str(e))
err_param = {
Expand All @@ -93,8 +105,11 @@ def parse_view_response(self, speak, data, prompt_response) -> str:
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)

# api_call_element.text = view_json_str
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8")
if len(view_json_str) != 0:
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8")
else:
result = b""
if not success:
view_content = (
f'{speak} \\n <span style="color:red">ERROR!</span>'
Expand Down
1 change: 1 addition & 0 deletions dbgpt/app/scene/chat_db/auto_execute/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@

RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"direct_response": "If the context is sufficient to answer user, reply directly without sql",
"sql": "SQL Query to run",
"display_type": "Data display method",
}
Expand Down

0 comments on commit 4d4a3e5

Please sign in to comment.