Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add embedding based table search support #1314

Merged
merged 8 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion containers/bundled_querybook_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@ FLASK_SECRET_KEY: SOME_RANDOM_SECRET_KEY

DATABASE_CONN: mysql+pymysql://test:passw0rd@mysql:3306/querybook2?charset=utf8mb4
REDIS_URL: redis://redis:6379/0
ELASTICSEARCH_HOST: elasticsearch:9200
ELASTICSEARCH_HOST: http://elasticsearch:9200
# ELASTICSEARCH_CONNECTION_TYPE: aws
# Uncomment for email
# EMAILER_CONN: dockerhostforward

# Uncomment below to enable AI Assistant
# AI_ASSISTANT_PROVIDER: openai
# AI_ASSISTANT_CONFIG:
# model_name: gpt-3.5-turbo
# temperature: 0

# Uncomment below to enable vector store to support embedding based table search.
# Please check langchain doc for the configs of each provider.
# EMBEDDINGS_PROVIDER: openai
# EMBEDDINGS_CONFIG: ~
# VECTOR_STORE_PROVIDER: opensearch
# VECTOR_STORE_CONFIG:
# embeddings_arg_name: 'embedding_function'
# opensearch_url: http://elasticsearch:9200
# index_name: 'vector_index_v1'
9 changes: 5 additions & 4 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,13 @@ services:
command: ['--character-set-server=utf8mb4', '--collation-server=utf8mb4_unicode_ci']
elasticsearch:
container_name: querybook_elasticsearch
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
image: opensearchproject/opensearch:2.9.0
environment:
cluster.name: docker-cluster
bootstrap.memory_lock: 'true'
discovery.type: single-node
ES_JAVA_OPTS: -Xms750m -Xmx750m
plugins.security.disabled: true
OPENSEARCH_JAVA_OPTS: -Xms750m -Xmx750m
ulimits:
memlock:
soft: -1
Expand All @@ -113,7 +114,7 @@ services:
soft: 65536
hard: 65536
volumes:
- esdata1:/usr/share/elasticsearch/data
- osdata1:/usr/share/opensearch/data
ports:
- 9200:9200
healthcheck:
Expand Down Expand Up @@ -161,7 +162,7 @@ services:

volumes:
my-db:
esdata1:
osdata1:
driver: local
# file:
# driver: local
Expand Down
2 changes: 2 additions & 0 deletions plugins/vector_store_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALL_PLUGIN_VECTOR_STORES = {}
ALL_PLUGIN_EMBEDDINGS = {}
12 changes: 7 additions & 5 deletions querybook/config/querybook_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ EVENT_LOGGER_NAME: ~
STATS_LOGGER_NAME: ~

# --------------- AI Assistant ---------------
# Example config for OpenAI
# AI_ASSISTANT_PROVIDER: openai
# AI_ASSISTANT_CONFIG:
# model_name: gpt-3.5-turbo
# temperature: 0
AI_ASSISTANT_PROVIDER: ~
AI_ASSISTANT_CONFIG:
model_name: ~

EMBEDDINGS_PROVIDER: ~
EMBEDDINGS_CONFIG: ~
VECTOR_STORE_PROVIDER: ~
VECTOR_STORE_CONFIG:
embeddings_arg_name: 'embedding_function'
index_name: 'vector_index_v1'
9 changes: 3 additions & 6 deletions querybook/server/const/ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

# KEEP IT CONSISTENT AS webapp/const/aiAssistant.ts
class AICommandType(Enum):
SQL_FIX = "SQL_FIX"
SQL_TITLE = "SQL_TITLE"
TEXT_TO_SQL = "TEXT_TO_SQL"
RESET_MEMORY = "RESET_MEMORY"
SQL_FIX = "sql_fix"
SQL_TITLE = "sql_title"
TEXT_TO_SQL = "text_to_sql"


AI_ASSISTANT_NAMESPACE = "/ai_assistant"
AI_ASSISTANT_REQUEST_EVENT = "ai_assistant_request"
AI_ASSISTANT_RESPONSE_EVENT = "ai_assistant_response"
35 changes: 27 additions & 8 deletions querybook/server/datasources_socketio/ai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
from const.ai_assistant import (
AI_ASSISTANT_NAMESPACE,
AI_ASSISTANT_REQUEST_EVENT,
)
from const.ai_assistant import AI_ASSISTANT_NAMESPACE, AICommandType
from lib.ai_assistant import ai_assistant

from .helper import register_socket


@register_socket(AI_ASSISTANT_REQUEST_EVENT, namespace=AI_ASSISTANT_NAMESPACE)
def ai_assistant_request(command_type: str, payload={}):
from lib.ai_assistant import ai_assistant
@register_socket(AICommandType.TEXT_TO_SQL.value, namespace=AI_ASSISTANT_NAMESPACE)
def text_to_sql(payload={}):
original_query = payload["original_query"]
query_engine_id = payload["query_engine_id"]
tables = payload.get("tables", [])
question = payload["question"]
ai_assistant.generate_sql_query(
query_engine_id=query_engine_id,
tables=tables,
question=question,
original_query=original_query,
)

ai_assistant.handle_ai_command(command_type, payload)

@register_socket(AICommandType.SQL_TITLE.value, namespace=AI_ASSISTANT_NAMESPACE)
def sql_title(payload={}):
query = payload["query"]
ai_assistant.generate_title_from_query(query=query)


@register_socket(AICommandType.SQL_FIX.value, namespace=AI_ASSISTANT_NAMESPACE)
def sql_fix(payload={}):
query_execution_id = payload["query_execution_id"]
ai_assistant.query_auto_fix(
query_execution_id=query_execution_id,
)
5 changes: 5 additions & 0 deletions querybook/server/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,8 @@ class QuerybookSettings(object):
# AI Assistant
AI_ASSISTANT_PROVIDER = get_env_config("AI_ASSISTANT_PROVIDER")
AI_ASSISTANT_CONFIG = get_env_config("AI_ASSISTANT_CONFIG") or {}

VECTOR_STORE_PROVIDER = get_env_config("VECTOR_STORE_PROVIDER")
VECTOR_STORE_CONFIG = get_env_config("VECTOR_STORE_CONFIG") or {}
EMBEDDINGS_PROVIDER = get_env_config("EMBEDDINGS_PROVIDER")
EMBEDDINGS_CONFIG = get_env_config("EMBEDDINGS_CONFIG") or {}
59 changes: 59 additions & 0 deletions querybook/server/lib/ai_assistant/ai_socket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import functools

from flask import request

from app.flask_app import socketio
from const.ai_assistant import AI_ASSISTANT_NAMESPACE, AICommandType


class AIWebSocket:
def __init__(self, socketio, command_type: AICommandType):
self.socketio = socketio
self.command_type = command_type
self.room = request.sid

def _send(self, event_type, payload: dict = None):
self.socketio.emit(
self.command_type.value,
(
event_type,
payload,
),
namespace=AI_ASSISTANT_NAMESPACE,
room=self.room,
)

def send_data(self, data: dict):
self._send("data", data)

def send_delta_data(self, data: str):
self._send("delta_data", data)

def send_delta_end(self):
self._send("delta_end")

def send_tables_for_sql_gen(self, data: list[str]):
self._send("tables", data)

def send_error(self, error: str):
self._send("error", error)
self.close()

def close(self):
self._send("close")


def with_ai_socket(command_type: AICommandType):
def decorator_fn(fn):
@functools.wraps(fn)
def func(*args, **kwargs):
if not kwargs.get("socket"):
kwargs["socket"] = AIWebSocket(socketio, command_type)

result = fn(*args, **kwargs)

return result

return func

return decorator_fn
17 changes: 10 additions & 7 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from lib.ai_assistant.base_ai_assistant import BaseAIAssistant
from lib.logger import get_logger

from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import CallbackManager
import openai
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models import ChatOpenAI

from lib.ai_assistant.base_ai_assistant import BaseAIAssistant
from lib.logger import get_logger

LOG = get_logger(__file__)

Expand All @@ -24,9 +23,13 @@ def _get_error_msg(self, error) -> str:

return super()._get_error_msg(error)

def _get_llm(self, callback_handler):
def _get_llm(self, callback_handler=None):
if not callback_handler:
# non-streaming
return ChatOpenAI(**self._config)

return ChatOpenAI(
**self._config,
streaming=True,
callback_manager=CallbackManager([callback_handler]),
callback_manager=CallbackManager([callback_handler])
)
jczhong84 marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading