-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Call Analyst to perform validation (#187)
We currently maintain copies of the validation logic in both the internal Analyst codepaths as well as this OSS app. Often, the OSS app can become out of date. Instead of performing validation locally, we will simply call Analyst with the current YAML string, as it performs validation at inference time. Any error returned is shown to the user. The diff for this PR seems big but it's mostly deleting unnecessary code + tests.
- Loading branch information
1 parent
7a3afa7
commit 2f1f675
Showing
4 changed files
with
142 additions
and
392 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import json | ||
import re | ||
from typing import Dict, Any | ||
|
||
import requests | ||
import streamlit as st | ||
from snowflake.connector import SnowflakeConnection | ||
|
||
API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message" | ||
|
||
|
||
@st.cache_data(ttl=60, show_spinner=False) | ||
def send_message( | ||
_conn: SnowflakeConnection, semantic_model: str, messages: list[dict[str, str]] | ||
) -> Dict[str, Any]: | ||
""" | ||
Calls the REST API with a list of messages and returns the response. | ||
Args: | ||
_conn: SnowflakeConnection, used to grab the token for auth. | ||
messages: list of chat messages to pass to the Analyst API. | ||
semantic_model: stringified YAML of the semantic model. | ||
Returns: The raw ChatMessage response from Analyst. | ||
""" | ||
request_body = { | ||
"messages": messages, | ||
"semantic_model": semantic_model, | ||
} | ||
|
||
if st.session_state["sis"]: | ||
import _snowflake | ||
|
||
resp = _snowflake.send_snow_api_request( # type: ignore | ||
"POST", | ||
f"/api/v2/cortex/analyst/message", | ||
{}, | ||
{}, | ||
request_body, | ||
{}, | ||
30000, | ||
) | ||
if resp["status"] < 400: | ||
json_resp: Dict[str, Any] = json.loads(resp["content"]) | ||
return json_resp | ||
else: | ||
err_body = json.loads(resp["content"]) | ||
if "message" in err_body: | ||
# Certain errors have a message payload with a link to the github repo, which we should remove. | ||
error_msg = re.sub( | ||
r"\s*Please use https://github\.com/Snowflake-Labs/semantic-model-generator.*", | ||
"", | ||
err_body["message"], | ||
) | ||
raise ValueError(error_msg) | ||
raise ValueError(err_body) | ||
|
||
else: | ||
host = st.session_state.host_name | ||
resp = requests.post( | ||
API_ENDPOINT.format( | ||
HOST=host, | ||
), | ||
json=request_body, | ||
headers={ | ||
"Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr] | ||
"Content-Type": "application/json", | ||
}, | ||
) | ||
if resp.status_code < 400: | ||
json_resp: Dict[str, Any] = resp.json() | ||
return json_resp | ||
else: | ||
err_body = json.loads(resp.text) | ||
if "message" in err_body: | ||
# Certain errors have a message payload with a link to the github repo, which we should remove. | ||
error_msg = re.sub( | ||
r"\s*Please use https://github\.com/Snowflake-Labs/semantic-model-generator.*", | ||
"", | ||
err_body["message"], | ||
) | ||
raise ValueError(error_msg) | ||
raise ValueError(err_body) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.