Skip to content

Commit

Permalink
feat: add embedding based table search support (#1314)
Browse files Browse the repository at this point in the history
* feat: add embedding based table search support

* update

* build fail

* linter

* test failure

* comments

* nodetest

* opensearch volumne path
  • Loading branch information
jczhong84 authored Sep 2, 2023
1 parent 15da8e2 commit 839a96b
Show file tree
Hide file tree
Showing 41 changed files with 1,249 additions and 556 deletions.
19 changes: 18 additions & 1 deletion containers/bundled_querybook_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@ FLASK_SECRET_KEY: SOME_RANDOM_SECRET_KEY

DATABASE_CONN: mysql+pymysql://test:passw0rd@mysql:3306/querybook2?charset=utf8mb4
REDIS_URL: redis://redis:6379/0
ELASTICSEARCH_HOST: elasticsearch:9200
ELASTICSEARCH_HOST: http://elasticsearch:9200
# ELASTICSEARCH_CONNECTION_TYPE: aws
# Uncomment for email
# EMAILER_CONN: dockerhostforward

# Uncomment below to enable AI Assistant
# AI_ASSISTANT_PROVIDER: openai
# AI_ASSISTANT_CONFIG:
# model_name: gpt-3.5-turbo
# temperature: 0

# Uncomment below to enable vector store to support embedding based table search.
# Please check langchain doc for the configs of each provider.
# EMBEDDINGS_PROVIDER: openai
# EMBEDDINGS_CONFIG: ~
# VECTOR_STORE_PROVIDER: opensearch
# VECTOR_STORE_CONFIG:
# embeddings_arg_name: 'embedding_function'
# opensearch_url: http://elasticsearch:9200
# index_name: 'vector_index_v1'
9 changes: 5 additions & 4 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,13 @@ services:
command: ['--character-set-server=utf8mb4', '--collation-server=utf8mb4_unicode_ci']
elasticsearch:
container_name: querybook_elasticsearch
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
image: opensearchproject/opensearch:2.9.0
environment:
cluster.name: docker-cluster
bootstrap.memory_lock: 'true'
discovery.type: single-node
ES_JAVA_OPTS: -Xms750m -Xmx750m
plugins.security.disabled: true
OPENSEARCH_JAVA_OPTS: -Xms750m -Xmx750m
ulimits:
memlock:
soft: -1
Expand All @@ -113,7 +114,7 @@ services:
soft: 65536
hard: 65536
volumes:
- esdata1:/usr/share/elasticsearch/data
- osdata1:/usr/share/opensearch/data
ports:
- 9200:9200
healthcheck:
Expand Down Expand Up @@ -161,7 +162,7 @@ services:

volumes:
my-db:
esdata1:
osdata1:
driver: local
# file:
# driver: local
Expand Down
2 changes: 2 additions & 0 deletions plugins/vector_store_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALL_PLUGIN_VECTOR_STORES = {}
ALL_PLUGIN_EMBEDDINGS = {}
12 changes: 7 additions & 5 deletions querybook/config/querybook_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ EVENT_LOGGER_NAME: ~
STATS_LOGGER_NAME: ~

# --------------- AI Assistant ---------------
# Example config for OpenAI
# AI_ASSISTANT_PROVIDER: openai
# AI_ASSISTANT_CONFIG:
# model_name: gpt-3.5-turbo
# temperature: 0
AI_ASSISTANT_PROVIDER: ~
AI_ASSISTANT_CONFIG:
model_name: ~

EMBEDDINGS_PROVIDER: ~
EMBEDDINGS_CONFIG: ~
VECTOR_STORE_PROVIDER: ~
VECTOR_STORE_CONFIG:
embeddings_arg_name: 'embedding_function'
index_name: 'vector_index_v1'
9 changes: 3 additions & 6 deletions querybook/server/const/ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

# KEEP IT CONSISTENT AS webapp/const/aiAssistant.ts
class AICommandType(Enum):
SQL_FIX = "SQL_FIX"
SQL_TITLE = "SQL_TITLE"
TEXT_TO_SQL = "TEXT_TO_SQL"
RESET_MEMORY = "RESET_MEMORY"
SQL_FIX = "sql_fix"
SQL_TITLE = "sql_title"
TEXT_TO_SQL = "text_to_sql"


AI_ASSISTANT_NAMESPACE = "/ai_assistant"
AI_ASSISTANT_REQUEST_EVENT = "ai_assistant_request"
AI_ASSISTANT_RESPONSE_EVENT = "ai_assistant_response"
35 changes: 27 additions & 8 deletions querybook/server/datasources_socketio/ai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
from const.ai_assistant import (
AI_ASSISTANT_NAMESPACE,
AI_ASSISTANT_REQUEST_EVENT,
)
from const.ai_assistant import AI_ASSISTANT_NAMESPACE, AICommandType
from lib.ai_assistant import ai_assistant

from .helper import register_socket


@register_socket(AI_ASSISTANT_REQUEST_EVENT, namespace=AI_ASSISTANT_NAMESPACE)
def ai_assistant_request(command_type: str, payload={}):
from lib.ai_assistant import ai_assistant
@register_socket(AICommandType.TEXT_TO_SQL.value, namespace=AI_ASSISTANT_NAMESPACE)
def text_to_sql(payload={}):
original_query = payload["original_query"]
query_engine_id = payload["query_engine_id"]
tables = payload.get("tables", [])
question = payload["question"]
ai_assistant.generate_sql_query(
query_engine_id=query_engine_id,
tables=tables,
question=question,
original_query=original_query,
)

ai_assistant.handle_ai_command(command_type, payload)

@register_socket(AICommandType.SQL_TITLE.value, namespace=AI_ASSISTANT_NAMESPACE)
def sql_title(payload={}):
query = payload["query"]
ai_assistant.generate_title_from_query(query=query)


@register_socket(AICommandType.SQL_FIX.value, namespace=AI_ASSISTANT_NAMESPACE)
def sql_fix(payload={}):
query_execution_id = payload["query_execution_id"]
ai_assistant.query_auto_fix(
query_execution_id=query_execution_id,
)
5 changes: 5 additions & 0 deletions querybook/server/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,8 @@ class QuerybookSettings(object):
# AI Assistant
AI_ASSISTANT_PROVIDER = get_env_config("AI_ASSISTANT_PROVIDER")
AI_ASSISTANT_CONFIG = get_env_config("AI_ASSISTANT_CONFIG") or {}

VECTOR_STORE_PROVIDER = get_env_config("VECTOR_STORE_PROVIDER")
VECTOR_STORE_CONFIG = get_env_config("VECTOR_STORE_CONFIG") or {}
EMBEDDINGS_PROVIDER = get_env_config("EMBEDDINGS_PROVIDER")
EMBEDDINGS_CONFIG = get_env_config("EMBEDDINGS_CONFIG") or {}
59 changes: 59 additions & 0 deletions querybook/server/lib/ai_assistant/ai_socket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import functools

from flask import request

from app.flask_app import socketio
from const.ai_assistant import AI_ASSISTANT_NAMESPACE, AICommandType


class AIWebSocket:
def __init__(self, socketio, command_type: AICommandType):
self.socketio = socketio
self.command_type = command_type
self.room = request.sid

def _send(self, event_type, payload: dict = None):
self.socketio.emit(
self.command_type.value,
(
event_type,
payload,
),
namespace=AI_ASSISTANT_NAMESPACE,
room=self.room,
)

def send_data(self, data: dict):
self._send("data", data)

def send_delta_data(self, data: str):
self._send("delta_data", data)

def send_delta_end(self):
self._send("delta_end")

def send_tables_for_sql_gen(self, data: list[str]):
self._send("tables", data)

def send_error(self, error: str):
self._send("error", error)
self.close()

def close(self):
self._send("close")


def with_ai_socket(command_type: AICommandType):
def decorator_fn(fn):
@functools.wraps(fn)
def func(*args, **kwargs):
if not kwargs.get("socket"):
kwargs["socket"] = AIWebSocket(socketio, command_type)

result = fn(*args, **kwargs)

return result

return func

return decorator_fn
17 changes: 10 additions & 7 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from lib.ai_assistant.base_ai_assistant import BaseAIAssistant
from lib.logger import get_logger

from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import CallbackManager
import openai
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models import ChatOpenAI

from lib.ai_assistant.base_ai_assistant import BaseAIAssistant
from lib.logger import get_logger

LOG = get_logger(__file__)

Expand All @@ -24,9 +23,13 @@ def _get_error_msg(self, error) -> str:

return super()._get_error_msg(error)

def _get_llm(self, callback_handler):
def _get_llm(self, callback_handler=None):
if not callback_handler:
# non-streaming
return ChatOpenAI(**self._config)

return ChatOpenAI(
**self._config,
streaming=True,
callback_manager=CallbackManager([callback_handler]),
callback_manager=CallbackManager([callback_handler])
)
Loading

0 comments on commit 839a96b

Please sign in to comment.