Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jczhong84 committed Sep 1, 2023
1 parent b8c32ce commit 5c85caf
Show file tree
Hide file tree
Showing 20 changed files with 249 additions and 182 deletions.
25 changes: 14 additions & 11 deletions containers/bundled_querybook_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ ELASTICSEARCH_HOST: http://elasticsearch:9200
# Uncomment for email
# EMAILER_CONN: dockerhostforward

AI_ASSISTANT_PROVIDER: openai
AI_ASSISTANT_CONFIG:
model_name: gpt-3.5-turbo-0613
temperature: 0
# Uncomment below to enable AI Assistant
# AI_ASSISTANT_PROVIDER: openai
# AI_ASSISTANT_CONFIG:
# model_name: gpt-3.5-turbo
# temperature: 0

EMBEDDINGS_PROVIDER: openai
EMBEDDINGS_CONFIG: ~
VECTOR_STORE_PROVIDER: opensearch
VECTOR_STORE_CONFIG:
embeddings_arg_name: 'embedding_function'
opensearch_url: http://elasticsearch:9200
index_name: 'vector_index_v1'
# Uncomment below to enable vector store to support embedding based table search.
# Please check langchain doc for the configs of each provider.
# EMBEDDINGS_PROVIDER: openai
# EMBEDDINGS_CONFIG: ~
# VECTOR_STORE_PROVIDER: opensearch
# VECTOR_STORE_CONFIG:
# embeddings_arg_name: 'embedding_function'
# opensearch_url: http://elasticsearch:9200
# index_name: 'vector_index_v1'
2 changes: 1 addition & 1 deletion querybook/server/datasources_socketio/ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def text_to_sql(payload={}):
original_query = payload["original_query"]
query_engine_id = payload["query_engine_id"]
tables = payload.get("tables")
tables = payload.get("tables", [])
question = payload["question"]
ai_assistant.generate_sql_query(
query_engine_id=query_engine_id,
Expand Down
2 changes: 1 addition & 1 deletion querybook/server/lib/ai_assistant/ai_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def send_delta_data(self, data: str):
def send_delta_end(self):
self._send("delta_end")

def send_tables(self, data: list[str]):
def send_tables_for_sql_gen(self, data: list[str]):
self._send("tables", data)

def send_error(self, error: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def _get_error_msg(self, error) -> str:
return super()._get_error_msg(error)

def _get_llm(self, callback_handler=None):
if not callback_handler:
# non-streaming
return ChatOpenAI(**self._config)

return ChatOpenAI(
**self._config,
streaming=True if callback_handler else False,
streaming=True,
callback_manager=CallbackManager([callback_handler])
if callback_handler
else None,
)
2 changes: 1 addition & 1 deletion querybook/server/lib/ai_assistant/base_ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def generate_sql_query(
if not tables:
tables = get_vector_store().search_tables(question)
if tables:
socket.send_tables(tables)
socket.send_tables_for_sql_gen(tables)

# not finding any relevant tables
if not tables:
Expand Down
28 changes: 18 additions & 10 deletions querybook/server/lib/ai_assistant/tools/table_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,26 @@
from models.metastore import DataTable, DataTableColumn


def _get_column_prompt(column: DataTableColumn) -> str:
prompt = ""

prompt += f"- Column Name: {column.name}\n"
prompt += f" Data Type: {column.type}\n"
if column.description:
prompt += f" Description: {column.description}\n"
elif column.data_elements:
# use data element's description when column's description is empty
# TODO: only handling the REF data element for now. Need to handle ARRAY, MAP and etc in the future.
prompt += f" Description: {column.data_elements[0].description}\n"
prompt += f" Data Element: {column.data_elements[0].name}\n"

return prompt


def _get_table_schema_prompt(
table: DataTable,
should_skip_column: Callable[[DataTableColumn], bool] = None,
):
) -> str:
"""Generate table schema prompt. The format will be like:
Table Name: [Name_of_table_1]
Expand Down Expand Up @@ -38,15 +54,7 @@ def _get_table_schema_prompt(
if should_skip_column and should_skip_column(column):
continue

prompt += f"- Column Name: {column.name}\n"
prompt += f" Data Type: {column.type}\n"
if column.description:
prompt += f" Description: {column.description}\n"
elif column.data_elements:
# use data element's description when column's description is empty
# TODO: only handling the REF data element for now. Need to handle ARRAY, MAP and etc in the future.
prompt += f" Description: {column.data_elements[0].description}\n"
prompt += f" Data Element: {column.data_elements[0].name}\n"
prompt += _get_column_prompt(column)

return prompt

Expand Down
2 changes: 1 addition & 1 deletion querybook/server/lib/vector_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def get_vector_store():
)

kwargs = {
embeddings_arg_name: get_embeddings(),
**vector_store_config,
embeddings_arg_name: get_embeddings(),
}
if "embeddings_arg_name" in kwargs:
del kwargs["embeddings_arg_name"]
Expand Down
10 changes: 2 additions & 8 deletions querybook/server/lib/vector_store/all_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
from typing import Type

from lib.utils.import_helper import import_module_with_default
from langchain.embeddings.base import Embeddings
from langchain.embeddings import OpenAIEmbeddings


PROVIDED_EMBEDDINGS = {"openai": OpenAIEmbeddings}
from lib.utils.import_helper import import_module_with_default

ALL_PLUGIN_EMBEDDINGS = import_module_with_default(
ALL_EMBEDDINGS = import_module_with_default(
"vector_store_plugin",
"ALL_PLUGIN_EMBEDDINGS",
default={},
)

ALL_EMBEDDINGS = {**PROVIDED_EMBEDDINGS, **ALL_PLUGIN_EMBEDDINGS}


def get_embeddings_class(name: str) -> Type[Embeddings]:
if name in ALL_EMBEDDINGS:
Expand Down
9 changes: 1 addition & 8 deletions querybook/server/lib/vector_store/all_vector_stores.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from typing import Type


from lib.utils.import_helper import import_module_with_default
from lib.vector_store.stores.opensearch import OpenSearchVectorStore

from lib.vector_store.base_vector_store import VectorStoreBase

PROVIDED_VECTOR_STORES = {"opensearch": OpenSearchVectorStore}

ALL_PLUGIN_VECTOR_STORES = import_module_with_default(
ALL_VECTOR_STORES = import_module_with_default(
"vector_store_plugin",
"ALL_PLUGIN_VECTOR_STORES",
default={},
)

ALL_VECTOR_STORES = {**PROVIDED_VECTOR_STORES, **ALL_PLUGIN_VECTOR_STORES}


def get_vector_store_class(name: str) -> Type[VectorStoreBase]:
if name in ALL_VECTOR_STORES:
Expand Down
5 changes: 3 additions & 2 deletions querybook/server/lib/vector_store/base_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def should_skip_query_execution(
Override this method to implement custom logic for your vector store."""
query = query_execution.query

# TODO: add more filters
# skip queries if it starts with "select * from"
pattern = r"^\s*select\s+\*\s+from"
if re.match(pattern, query, re.IGNORECASE):
Expand Down Expand Up @@ -55,8 +56,8 @@ def search_tables(self, text: str, threshold=0.6, k=3):
for table, score in tables:
table_score_dict[table] = max(score, table_score_dict.get(table, 0))

unique_tables = sorted(
sorted_tables = sorted(
table_score_dict.items(), key=lambda x: x[1], reverse=True
)

return [t for t, _ in unique_tables[:k]]
return [t for t, _ in sorted_tables[:k]]
4 changes: 2 additions & 2 deletions querybook/server/logic/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_query_execution_by_id,
get_successful_query_executions_by_data_cell_id,
)
from logic.vector_store import delete_table_doc, log_table
from logic.vector_store import delete_table_doc, record_table
from models.user import User
from models.datadoc import DataCellType
from models.board import Board
Expand Down Expand Up @@ -640,7 +640,7 @@ def update_table_by_id(table_id, session=None):
_update(index_name, table_id, updated_body)

# update it in vector store as well
log_table(table=table, session=session)
record_table(table=table, session=session)
except Exception:
# Otherwise insert as new
LOG.error("failed to upsert {}. Will pass.".format(table_id))
Expand Down
2 changes: 1 addition & 1 deletion querybook/server/logic/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def get_table_query_samples_count(table_id, session):


@with_session
def get_tables_by_query_execution_id(query_execution_id, session):
def get_tables_by_query_execution_id(query_execution_id, session=None):
return (
session.query(DataTable)
.join(DataTableQueryExecution)
Expand Down
110 changes: 88 additions & 22 deletions querybook/server/logic/vector_store.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import hashlib
from typing import Optional

from app.db import with_session
from const.query_execution import QueryExecutionStatus
from langchain.docstore.document import Document
from lib.ai_assistant import ai_assistant
from lib.logger import get_logger
from lib.vector_store import get_vector_store
from logic.metastore import get_table_by_id, get_tables_by_query_execution_id
from logic.metastore import (
get_all_table,
get_table_by_id,
get_tables_by_query_execution_id,
)
from logic.query_execution import get_query_execution_by_id
from models.metastore import DataTable
from models.query_execution import QueryExecution
Expand Down Expand Up @@ -40,25 +44,19 @@ def _get_query_doc_id(query: str) -> str:


@with_session
def log_query_execution(
query_execution_id: Optional[int] = None,
query_execution: Optional[QueryExecution] = None,
def record_query_execution(
query_execution: QueryExecution,
session=None,
):
"""Generate summary of the query execution and log it to the vector store."""
# vector store is not configured
if not get_vector_store():
return

try:
if query_execution is None:
query_execution = get_query_execution_by_id(
query_execution_id, session=session
)

if not query_execution:
return
if not query_execution:
return

try:
tables = get_tables_by_query_execution_id(query_execution.id, session=session)
if not tables:
return
Expand Down Expand Up @@ -87,23 +85,32 @@ def log_query_execution(


@with_session
def log_table(
table_id: Optional[int] = None,
table: Optional[DataTable] = None,
def record_query_execution_by_id(
query_execution_id: int,
session=None,
):
"""Generate summary of the table and log it to the vector store."""
# vector store is not configured
if not get_vector_store():
return

try:
if table is None:
table = get_table_by_id(table_id, session=session)
query_execution = get_query_execution_by_id(query_execution_id, session=session)
record_query_execution(query_execution, session=session)

if table is None:
return

@with_session
def record_table(
table: DataTable,
session=None,
):
"""Generate summary of the table and record it to the vector store."""
# vector store is not configured
if not get_vector_store():
return

if not table:
return

try:
if get_vector_store().should_skip_table(table):
return

Expand All @@ -122,6 +129,19 @@ def log_table(
LOG.error(f"Failed to log table to vector store: {e}")


@with_session
def record_table_by_id(
table_id: int,
session=None,
):
# vector store is not configured
if not get_vector_store():
return

table = get_table_by_id(table_id, session=session)
record_table(table=table, session=session)


def delete_table_doc(table_id: int):
"""Delete table summary doc from vector store by table id."""

Expand All @@ -130,3 +150,49 @@ def delete_table_doc(table_id: int):
return
doc_id = _get_table_doc_id(table_id)
get_vector_store().delete([doc_id])


@with_session
def ingest_tables(batch_size=100, session=None):
offset = 0

while True:
tables = get_all_table(
limit=batch_size,
offset=offset,
session=session,
)

for table in tables:
full_table_name = f"{table.data_schema.name}.{table.name}"
print(f"Ingesting table: {full_table_name}")
record_table(table=table, session=session)

if len(tables) < batch_size:
break

offset += batch_size


@with_session
def ingest_query_executions(batch_size=100, session=None):
offset = 0

while True:
# TODO: there may be many highly similar queries, we should not ingest all of them.
query_executions = (
session.query(QueryExecution)
.filter(QueryExecution.status == QueryExecutionStatus.DONE)
.offset(offset)
.limit(batch_size)
.all()
)

for qe in query_executions:
print(f"Ingesting query execution: {qe.id}")
record_query_execution(query_execution=qe, session=session)

if len(query_executions) < batch_size:
break

offset += batch_size
Loading

0 comments on commit 5c85caf

Please sign in to comment.