From 7459b1fc401c31d4fa79ab02a655c5c6f71c7615 Mon Sep 17 00:00:00 2001 From: "J.C. Zhong" Date: Fri, 1 Sep 2023 03:57:35 +0000 Subject: [PATCH] comments --- containers/bundled_querybook_config.yaml | 25 ++-- .../datasources_socketio/ai_assistant.py | 2 +- .../server/lib/ai_assistant/ai_socket.py | 2 +- .../assistants/openai_assistant.py | 8 +- .../lib/ai_assistant/base_ai_assistant.py | 2 +- .../lib/ai_assistant/tools/table_schema.py | 28 +++-- querybook/server/lib/vector_store/__init__.py | 2 +- .../server/lib/vector_store/all_embeddings.py | 10 +- .../lib/vector_store/all_vector_stores.py | 9 +- .../lib/vector_store/base_vector_store.py | 5 +- querybook/server/logic/elasticsearch.py | 4 +- querybook/server/logic/metastore.py | 2 +- querybook/server/logic/vector_store.py | 110 ++++++++++++++---- querybook/server/scripts/init_vector_store.py | 49 +------- querybook/server/tasks/log_query_per_table.py | 20 ++-- .../components/AIAssistant/AutoFixButton.tsx | 38 +++--- .../AIAssistant/QueryGenerationModal.tsx | 74 +++++++----- .../CodeMirrorTooltip/TableTooltip.tsx | 3 +- querybook/webapp/const/aiAssistant.ts | 9 ++ querybook/webapp/hooks/useAISocket.ts | 29 ++--- 20 files changed, 249 insertions(+), 182 deletions(-) diff --git a/containers/bundled_querybook_config.yaml b/containers/bundled_querybook_config.yaml index 49adbb689..a0d7d954e 100644 --- a/containers/bundled_querybook_config.yaml +++ b/containers/bundled_querybook_config.yaml @@ -8,15 +8,18 @@ ELASTICSEARCH_HOST: http://elasticsearch:9200 # Uncomment for email # EMAILER_CONN: dockerhostforward -AI_ASSISTANT_PROVIDER: openai -AI_ASSISTANT_CONFIG: - model_name: gpt-3.5-turbo-0613 - temperature: 0 +# Uncomment below to enable AI Assistant +# AI_ASSISTANT_PROVIDER: openai +# AI_ASSISTANT_CONFIG: +# model_name: gpt-3.5-turbo +# temperature: 0 -EMBEDDINGS_PROVIDER: openai -EMBEDDINGS_CONFIG: ~ -VECTOR_STORE_PROVIDER: opensearch -VECTOR_STORE_CONFIG: - embeddings_arg_name: 'embedding_function' - opensearch_url: http://elasticsearch:9200 - index_name: 'vector_index_v1' +# Uncomment below to enable vector store to support embedding based table search. +# Please check langchain doc for the configs of each provider. +# EMBEDDINGS_PROVIDER: openai +# EMBEDDINGS_CONFIG: ~ +# VECTOR_STORE_PROVIDER: opensearch +# VECTOR_STORE_CONFIG: +# embeddings_arg_name: 'embedding_function' +# opensearch_url: http://elasticsearch:9200 +# index_name: 'vector_index_v1' diff --git a/querybook/server/datasources_socketio/ai_assistant.py b/querybook/server/datasources_socketio/ai_assistant.py index 84d4c13d1..263e8cc91 100644 --- a/querybook/server/datasources_socketio/ai_assistant.py +++ b/querybook/server/datasources_socketio/ai_assistant.py @@ -8,7 +8,7 @@ def text_to_sql(payload={}): original_query = payload["original_query"] query_engine_id = payload["query_engine_id"] - tables = payload.get("tables") + tables = payload.get("tables", []) question = payload["question"] ai_assistant.generate_sql_query( query_engine_id=query_engine_id, diff --git a/querybook/server/lib/ai_assistant/ai_socket.py b/querybook/server/lib/ai_assistant/ai_socket.py index e587c442e..4fb916e96 100644 --- a/querybook/server/lib/ai_assistant/ai_socket.py +++ b/querybook/server/lib/ai_assistant/ai_socket.py @@ -32,7 +32,7 @@ def send_delta_data(self, data: str): def send_delta_end(self): self._send("delta_end") - def send_tables(self, data: list[str]): + def send_tables_for_sql_gen(self, data: list[str]): self._send("tables", data) def send_error(self, error: str): diff --git a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py index 460602c88..3107ee5dd 100644 --- a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py +++ b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py @@ -24,10 +24,12 @@ def _get_error_msg(self, error) -> str: return super()._get_error_msg(error) def _get_llm(self, callback_handler=None): + if not callback_handler: + # non-streaming + return ChatOpenAI(**self._config) + return ChatOpenAI( **self._config, - streaming=True if callback_handler else False, + streaming=True, callback_manager=CallbackManager([callback_handler]) - if callback_handler - else None, ) diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index 7ebb9630a..9bf586ab9 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -114,7 +114,7 @@ def generate_sql_query( if not tables: tables = get_vector_store().search_tables(question) if tables: - socket.send_tables(tables) + socket.send_tables_for_sql_gen(tables) # not finding any relevant tables if not tables: diff --git a/querybook/server/lib/ai_assistant/tools/table_schema.py b/querybook/server/lib/ai_assistant/tools/table_schema.py index 90c383294..59b5df06c 100644 --- a/querybook/server/lib/ai_assistant/tools/table_schema.py +++ b/querybook/server/lib/ai_assistant/tools/table_schema.py @@ -5,10 +5,26 @@ from models.metastore import DataTable, DataTableColumn +def _get_column_prompt(column: DataTableColumn) -> str: + prompt = "" + + prompt += f"- Column Name: {column.name}\n" + prompt += f" Data Type: {column.type}\n" + if column.description: + prompt += f" Description: {column.description}\n" + elif column.data_elements: + # use data element's description when column's description is empty + # TODO: only handling the REF data element for now. Need to handle ARRAY, MAP and etc in the future. + prompt += f" Description: {column.data_elements[0].description}\n" + prompt += f" Data Element: {column.data_elements[0].name}\n" + + return prompt + + def _get_table_schema_prompt( table: DataTable, should_skip_column: Callable[[DataTableColumn], bool] = None, -): +) -> str: """Generate table schema prompt. The format will be like: Table Name: [Name_of_table_1] @@ -38,15 +54,7 @@ def _get_table_schema_prompt( if should_skip_column and should_skip_column(column): continue - prompt += f"- Column Name: {column.name}\n" - prompt += f" Data Type: {column.type}\n" - if column.description: - prompt += f" Description: {column.description}\n" - elif column.data_elements: - # use data element's description when column's description is empty - # TODO: only handling the REF data element for now. Need to handle ARRAY, MAP and etc in the future. - prompt += f" Description: {column.data_elements[0].description}\n" - prompt += f" Data Element: {column.data_elements[0].name}\n" + prompt += _get_column_prompt(column) return prompt diff --git a/querybook/server/lib/vector_store/__init__.py b/querybook/server/lib/vector_store/__init__.py index 3a27d033d..032d760f3 100644 --- a/querybook/server/lib/vector_store/__init__.py +++ b/querybook/server/lib/vector_store/__init__.py @@ -46,8 +46,8 @@ def get_vector_store(): ) kwargs = { - embeddings_arg_name: get_embeddings(), **vector_store_config, + embeddings_arg_name: get_embeddings(), } if "embeddings_arg_name" in kwargs: del kwargs["embeddings_arg_name"] diff --git a/querybook/server/lib/vector_store/all_embeddings.py b/querybook/server/lib/vector_store/all_embeddings.py index f67ce7fa2..ba883f9b4 100644 --- a/querybook/server/lib/vector_store/all_embeddings.py +++ b/querybook/server/lib/vector_store/all_embeddings.py @@ -1,20 +1,14 @@ from typing import Type -from lib.utils.import_helper import import_module_with_default from langchain.embeddings.base import Embeddings -from langchain.embeddings import OpenAIEmbeddings - - -PROVIDED_EMBEDDINGS = {"openai": OpenAIEmbeddings} +from lib.utils.import_helper import import_module_with_default -ALL_PLUGIN_EMBEDDINGS = import_module_with_default( +ALL_EMBEDDINGS = import_module_with_default( "vector_store_plugin", "ALL_PLUGIN_EMBEDDINGS", default={}, ) -ALL_EMBEDDINGS = {**PROVIDED_EMBEDDINGS, **ALL_PLUGIN_EMBEDDINGS} - def get_embeddings_class(name: str) -> Type[Embeddings]: if name in ALL_EMBEDDINGS: diff --git a/querybook/server/lib/vector_store/all_vector_stores.py b/querybook/server/lib/vector_store/all_vector_stores.py index 643e2f6c4..e9b2e5f46 100644 --- a/querybook/server/lib/vector_store/all_vector_stores.py +++ b/querybook/server/lib/vector_store/all_vector_stores.py @@ -1,21 +1,14 @@ from typing import Type - from lib.utils.import_helper import import_module_with_default -from lib.vector_store.stores.opensearch import OpenSearchVectorStore - from lib.vector_store.base_vector_store import VectorStoreBase -PROVIDED_VECTOR_STORES = {"opensearch": OpenSearchVectorStore} - -ALL_PLUGIN_VECTOR_STORES = import_module_with_default( +ALL_VECTOR_STORES = import_module_with_default( "vector_store_plugin", "ALL_PLUGIN_VECTOR_STORES", default={}, ) -ALL_VECTOR_STORES = {**PROVIDED_VECTOR_STORES, **ALL_PLUGIN_VECTOR_STORES} - def get_vector_store_class(name: str) -> Type[VectorStoreBase]: if name in ALL_VECTOR_STORES: diff --git a/querybook/server/lib/vector_store/base_vector_store.py b/querybook/server/lib/vector_store/base_vector_store.py index 444664429..4e5371b35 100644 --- a/querybook/server/lib/vector_store/base_vector_store.py +++ b/querybook/server/lib/vector_store/base_vector_store.py @@ -26,6 +26,7 @@ def should_skip_query_execution( Override this method to implement custom logic for your vector store.""" query = query_execution.query + # TODO: add more filters # skip queries if it starts with "select * from" pattern = r"^\s*select\s+\*\s+from" if re.match(pattern, query, re.IGNORECASE): @@ -55,8 +56,8 @@ def search_tables(self, text: str, threshold=0.6, k=3): for table, score in tables: table_score_dict[table] = max(score, table_score_dict.get(table, 0)) - unique_tables = sorted( + sorted_tables = sorted( table_score_dict.items(), key=lambda x: x[1], reverse=True ) - return [t for t, _ in unique_tables[:k]] + return [t for t, _ in sorted_tables[:k]] diff --git a/querybook/server/logic/elasticsearch.py b/querybook/server/logic/elasticsearch.py index 0fd398375..acf053771 100644 --- a/querybook/server/logic/elasticsearch.py +++ b/querybook/server/logic/elasticsearch.py @@ -44,7 +44,7 @@ get_query_execution_by_id, get_successful_query_executions_by_data_cell_id, ) -from logic.vector_store import delete_table_doc, log_table +from logic.vector_store import delete_table_doc, record_table from models.user import User from models.datadoc import DataCellType from models.board import Board @@ -640,7 +640,7 @@ def update_table_by_id(table_id, session=None): _update(index_name, table_id, updated_body) # update it in vector store as well - log_table(table=table, session=session) + record_table(table=table, session=session) except Exception: # Otherwise insert as new LOG.error("failed to upsert {}. Will pass.".format(table_id)) diff --git a/querybook/server/logic/metastore.py b/querybook/server/logic/metastore.py index edcc6a509..8eb30e56a 100644 --- a/querybook/server/logic/metastore.py +++ b/querybook/server/logic/metastore.py @@ -768,7 +768,7 @@ def get_table_query_samples_count(table_id, session): @with_session -def get_tables_by_query_execution_id(query_execution_id, session): +def get_tables_by_query_execution_id(query_execution_id, session=None): return ( session.query(DataTable) .join(DataTableQueryExecution) diff --git a/querybook/server/logic/vector_store.py b/querybook/server/logic/vector_store.py index 89808f2f5..592adf3e6 100644 --- a/querybook/server/logic/vector_store.py +++ b/querybook/server/logic/vector_store.py @@ -1,12 +1,16 @@ import hashlib -from typing import Optional from app.db import with_session +from const.query_execution import QueryExecutionStatus from langchain.docstore.document import Document from lib.ai_assistant import ai_assistant from lib.logger import get_logger from lib.vector_store import get_vector_store -from logic.metastore import get_table_by_id, get_tables_by_query_execution_id +from logic.metastore import ( + get_all_table, + get_table_by_id, + get_tables_by_query_execution_id, +) from logic.query_execution import get_query_execution_by_id from models.metastore import DataTable from models.query_execution import QueryExecution @@ -40,9 +44,8 @@ def _get_query_doc_id(query: str) -> str: @with_session -def log_query_execution( - query_execution_id: Optional[int] = None, - query_execution: Optional[QueryExecution] = None, +def record_query_execution( + query_execution: QueryExecution, session=None, ): """Generate summary of the query execution and log it to the vector store.""" @@ -50,15 +53,10 @@ def log_query_execution( if not get_vector_store(): return - try: - if query_execution is None: - query_execution = get_query_execution_by_id( - query_execution_id, session=session - ) - - if not query_execution: - return + if not query_execution: + return + try: tables = get_tables_by_query_execution_id(query_execution.id, session=session) if not tables: return @@ -87,23 +85,32 @@ def log_query_execution( @with_session -def log_table( - table_id: Optional[int] = None, - table: Optional[DataTable] = None, +def record_query_execution_by_id( + query_execution_id: int, session=None, ): - """Generate summary of the table and log it to the vector store.""" # vector store is not configured if not get_vector_store(): return - try: - if table is None: - table = get_table_by_id(table_id, session=session) + query_execution = get_query_execution_by_id(query_execution_id, session=session) + record_query_execution(query_execution, session=session) - if table is None: - return +@with_session +def record_table( + table: DataTable, + session=None, +): + """Generate summary of the table and record it to the vector store.""" + # vector store is not configured + if not get_vector_store(): + return + + if not table: + return + + try: if get_vector_store().should_skip_table(table): return @@ -122,6 +129,19 @@ def log_table( LOG.error(f"Failed to log table to vector store: {e}") +@with_session +def record_table_by_id( + table_id: int, + session=None, +): + # vector store is not configured + if not get_vector_store(): + return + + table = get_table_by_id(table_id, session=session) + record_table(table=table, session=session) + + def delete_table_doc(table_id: int): """Delete table summary doc from vector store by table id.""" @@ -130,3 +150,49 @@ def delete_table_doc(table_id: int): return doc_id = _get_table_doc_id(table_id) get_vector_store().delete([doc_id]) + + +@with_session +def ingest_tables(batch_size=100, session=None): + offset = 0 + + while True: + tables = get_all_table( + limit=batch_size, + offset=offset, + session=session, + ) + + for table in tables: + full_table_name = f"{table.data_schema.name}.{table.name}" + print(f"Ingesting table: {full_table_name}") + record_table(table=table, session=session) + + if len(tables) < batch_size: + break + + offset += batch_size + + +@with_session +def ingest_query_executions(batch_size=100, session=None): + offset = 0 + + while True: + # TODO: there may be many highly similar queries, we should not ingest all of them. + query_executions = ( + session.query(QueryExecution) + .filter(QueryExecution.status == QueryExecutionStatus.DONE) + .offset(offset) + .limit(batch_size) + .all() + ) + + for qe in query_executions: + print(f"Ingesting query execution: {qe.id}") + record_query_execution(query_execution=qe, session=session) + + if len(query_executions) < batch_size: + break + + offset += batch_size diff --git a/querybook/server/scripts/init_vector_store.py b/querybook/server/scripts/init_vector_store.py index 28be9091f..101fc5675 100644 --- a/querybook/server/scripts/init_vector_store.py +++ b/querybook/server/scripts/init_vector_store.py @@ -1,47 +1,4 @@ -from app.db import with_session -from const.query_execution import QueryExecutionStatus -from logic.metastore import get_all_table -from logic.vector_store import log_query_execution, log_table -from models.query_execution import QueryExecution +from logic.vector_store import ingest_tables, ingest_query_executions - -@with_session -def ingest_tables(batch_size=100, session=None): - offset = 0 - - while True: - tables = get_all_table( - limit=batch_size, - offset=offset, - session=session, - ) - - for table in tables: - log_table(table=table, session=session) - - if len(tables) < batch_size: - break - - offset += batch_size - - -@with_session -def ingest_query_executions(batch_size=100, session=None): - offset = 0 - - while True: - query_executions = ( - session.query(QueryExecution) - .filter(QueryExecution.status == QueryExecutionStatus.DONE) - .offset(offset) - .limit(batch_size) - .all() - ) - - for qe in query_executions: - log_query_execution(query_execution=qe, session=session) - - if len(query_executions) < batch_size: - break - - offset += batch_size +ingest_tables() +ingest_query_executions() diff --git a/querybook/server/tasks/log_query_per_table.py b/querybook/server/tasks/log_query_per_table.py index fbcb27f0b..e1453a3f1 100644 --- a/querybook/server/tasks/log_query_per_table.py +++ b/querybook/server/tasks/log_query_per_table.py @@ -41,6 +41,17 @@ def log_query_per_table_task(self, query_execution_id, execution_type): create_lineage_from_query( query_execution, metastore_id, datadoc_cell, session=session ) + + # log the ad-hoc(not scheduled) query to vector store for table search + if ( + datadoc_cell is not None + and execution_type == QueryExecutionType.ADHOC.value + ): + vs_logic.record_query_execution_by_id( + query_execution_id=query_execution_id, + session=session, + ) + if datadoc_cell is None or not datadoc_cell.doc.public: return @@ -50,7 +61,6 @@ def log_query_per_table_task(self, query_execution_id, execution_type): query_execution_id, metastore_id, datadoc_cell.id, - execution_type, session=session, ) @@ -123,7 +133,6 @@ def log_table_per_statement( query_execution_id, metastore_id, cell_id, - execution_type, session=None, ): metastore_loader = get_metastore_loader(metastore_id, session=session) @@ -154,10 +163,3 @@ def log_table_per_statement( query_execution_id=query_execution_id, session=session, ) - - # log the ad-hoc query with tables to vector store for table search - if execution_type == QueryExecutionType.ADHOC.value: - vs_logic.log_query_execution( - query_execution_id=query_execution_id, - session=session, - ) diff --git a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx index f58e548ea..dd0f363e0 100644 --- a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx +++ b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx @@ -19,12 +19,7 @@ interface IProps { onUpdateQuery?: (query: string, run?: boolean) => any; } -export const AutoFixButton = ({ - query, - queryExecutionId, - onUpdateQuery, -}: IProps) => { - const [show, setShow] = useState(false); +const useSQLFix = () => { const [data, setData] = useState<{ [key: string]: string }>({}); const socket = useAISocket(AICommandType.SQL_FIX, ({ data }) => { @@ -39,6 +34,23 @@ export const AutoFixButton = ({ const fixedQuery = trimSQLQuery(rawFixedQuery); + return { + socket, + fixed: Object.keys(data).length > 0, // If has data, then it has been fixed + explanation, + suggestion, + fixedQuery, + }; +}; + +export const AutoFixButton = ({ + query, + queryExecutionId, + onUpdateQuery, +}: IProps) => { + const [showModal, setShowModal] = useState(false); + const { socket, fixed, explanation, suggestion, fixedQuery } = useSQLFix(); + const bottomDOM = socket.loading ? (
@@ -92,8 +104,8 @@ export const AutoFixButton = ({ icon="Bug" title="Auto fix" onClick={() => { - setShow(true); - if (Object.keys(data).length === 0) { + setShowModal(true); + if (!fixed) { socket.emit({ query_execution_id: queryExecutionId, }); @@ -107,11 +119,11 @@ export const AutoFixButton = ({ } }} /> - {show && ( + {showModal && ( { socket.cancel(); - setShow(false); + setShowModal(false); }} bottomDOM={bottomDOM} className="AutoFixModal" diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx index cb97484c6..696cb6ec3 100644 --- a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx @@ -39,26 +39,44 @@ const useTablesInQuery = (query, language) => { const [tables, setTables] = useState([]); useEffect(() => { - if (!!query) { - analyzeCode(query, 'autocomplete', language).then( - (codeAnalysis) => { - const tableReferences: TableToken[] = [].concat.apply( - [], - Object.values(codeAnalysis?.lineage.references ?? {}) - ); - setTables( - tableReferences.map( - ({ schema, name }) => `${schema}.${name}` - ) - ); - } - ); + if (!query) { + return; } + + analyzeCode(query, 'autocomplete', language).then((codeAnalysis) => { + const tableReferences: TableToken[] = [].concat.apply( + [], + Object.values(codeAnalysis?.lineage.references ?? {}) + ); + setTables( + tableReferences.map(({ schema, name }) => `${schema}.${name}`) + ); + }); }, [query, language]); return tables; }; +const useSQLGeneration = ( + onData: ({ type: string, data: object }) => void +): { + generating: boolean; + generateSQL: (data: { + query_engine_id: number; + tables: string[]; + question: string; + original_query: string; + }) => void; + cancelGeneration: () => void; +} => { + const socket = useAISocket(AICommandType.TEXT_TO_SQL, onData); + return { + generating: socket.loading, + generateSQL: socket.emit, + cancelGeneration: socket.cancel, + }; +}; + export const QueryGenerationModal = ({ query = '', engineId, @@ -88,13 +106,15 @@ export const QueryGenerationModal = ({ } }, []); - const socket = useAISocket(AICommandType.TEXT_TO_SQL, onData); + // const socket = useAISocket(AICommandType.TEXT_TO_SQL, onData); + const { generating, generateSQL, cancelGeneration } = + useSQLGeneration(onData); useEffect(() => { - if (!socket.loading) { + if (!generating) { setTables(uniq([...tablesInQuery, ...tables])); } - }, [tablesInQuery, socket.loading]); + }, [tablesInQuery, generating]); const { explanation, query: rawNewQuery, data } = streamData; @@ -105,11 +125,11 @@ export const QueryGenerationModal = ({ const onKeyDown = useCallback( (event: React.KeyboardEvent) => { if ( - !socket.loading && + !generating && matchKeyPress(event, 'Enter') && !event.shiftKey ) { - socket.emit({ + generateSQL({ query_engine_id: engineId, tables: tables, question: question, @@ -126,13 +146,13 @@ export const QueryGenerationModal = ({ }); } }, - [engineId, question, tables, query, socket.emit, socket.loading] + [engineId, question, tables, query, generateSQL, generating] ); const questionBarDOM = (
- +
- {socket.loading && ( + {generating && (
); - const bottomDOM = newQuery && !socket.loading && ( + const bottomDOM = newQuery && !generating && (
} - disableHighlight={socket.loading} + disableHighlight={generating} hideEmptyQuery={true} />
diff --git a/querybook/webapp/components/CodeMirrorTooltip/TableTooltip.tsx b/querybook/webapp/components/CodeMirrorTooltip/TableTooltip.tsx index a73caa3ad..67025e0e8 100644 --- a/querybook/webapp/components/CodeMirrorTooltip/TableTooltip.tsx +++ b/querybook/webapp/components/CodeMirrorTooltip/TableTooltip.tsx @@ -121,8 +121,7 @@ export const TableTooltipByName: React.FunctionComponent<{ useEffect(() => { const fetchTable = async () => { try { - const schemaName = tableFullName.split('.')[0]; - const tableName = tableFullName.split('.')[1]; + const [schemaName, tableName] = tableFullName.split('.'); const table: any = await dispatch( dataSourcesActions.fetchDataTableByNameIfNeeded( schemaName, diff --git a/querybook/webapp/const/aiAssistant.ts b/querybook/webapp/const/aiAssistant.ts index 2579f2196..d2b60b04d 100644 --- a/querybook/webapp/const/aiAssistant.ts +++ b/querybook/webapp/const/aiAssistant.ts @@ -5,4 +5,13 @@ export enum AICommandType { TEXT_TO_SQL = 'text_to_sql', } +export enum AISocketEvent { + DATA = 'data', + DELTA_DATA = 'delta_data', + DELTA_END = 'delta_end', + TABLES = 'tables', + CLOSE = 'close', + ERROR = 'error', +} + export const AI_ASSISTANT_NAMESPACE = '/ai_assistant'; diff --git a/querybook/webapp/hooks/useAISocket.ts b/querybook/webapp/hooks/useAISocket.ts index 35974ada5..410ed3caf 100644 --- a/querybook/webapp/hooks/useAISocket.ts +++ b/querybook/webapp/hooks/useAISocket.ts @@ -1,19 +1,20 @@ -import { List } from 'immutable'; -import { useCallback, useEffect, useRef, useState } from 'react'; +import { useCallback, useRef, useState } from 'react'; import toast from 'react-hot-toast'; -import { AICommandType } from 'const/aiAssistant'; +import { AICommandType, AISocketEvent } from 'const/aiAssistant'; import aiAssistantSocket from 'lib/ai-assistant/ai-assistant-socketio'; import { DeltaStreamParser } from 'lib/stream'; -export function useAISocket( - commandType: AICommandType, - onData: (data: { type?: string; data: { [key: string]: string } }) => void -): { +export interface AISocket { loading: boolean; emit: (payload: object) => void; cancel: () => void; -} { +} + +export function useAISocket( + commandType: AICommandType, + onData: (data: { type?: string; data: { [key: string]: string } }) => void +): AISocket { const [loading, setLoading] = useState(false); const deltaStreamParserRef = useRef( @@ -24,29 +25,29 @@ export function useAISocket( (event, payload) => { const parser = deltaStreamParserRef.current; switch (event) { - case 'data': + case AISocketEvent.DATA: onData({ data: { data: payload } }); break; - case 'delta_data': + case AISocketEvent.DELTA_DATA: parser.parse(payload); onData({ data: parser.result }); break; - case 'delta_end': + case AISocketEvent.DELTA_END: parser.close(); onData({ data: parser.result }); break; - case 'tables': + case AISocketEvent.TABLES: onData({ type: 'tables', data: payload }); break; - case 'close': + case AISocketEvent.CLOSE: close(); break; - case 'error': + case AISocketEvent.ERROR: toast.error(payload); close(); break;