From 37e0a6fff7c03105a912bf10b25f8b0a9527bbfa Mon Sep 17 00:00:00 2001 From: ch-liuzhide Date: Mon, 5 Aug 2024 18:20:29 +0800 Subject: [PATCH 1/4] feat:search doc add param bot_id --- server/README.md | 2 +- server/README.zh-CN.md | 88 +++++++++++++++++++- server/bot/builder.py | 12 +-- server/data_class.py | 16 ++-- server/rag_helper/github_file_loader.py | 14 ++-- server/rag_helper/retrieval.py | 105 ++++++++++++++---------- server/rag_helper/task.py | 17 ++-- server/routers/rag.py | 38 ++++----- server/sql/rag_docs.sql | 18 ++-- server/tools/knowledge.py | 5 +- 10 files changed, 215 insertions(+), 100 deletions(-) diff --git a/server/README.md b/server/README.md index 8d39d15d..693a3490 100644 --- a/server/README.md +++ b/server/README.md @@ -1,2 +1,2 @@ -English | [简体中文](./README.zh-CN.md.md) +English | [简体中文](./README.zh-CN.md) diff --git a/server/README.zh-CN.md b/server/README.zh-CN.md index ccae84be..1c471f34 100644 --- a/server/README.zh-CN.md +++ b/server/README.zh-CN.md @@ -4,6 +4,15 @@ PeterCat 服务端,采用 FastAPI 框架开发。 # 功能模块 +## 存储 +采用 [supabase](https://supabase.com) 作为数据库进行存储。 +作为开发者你需要熟悉该平台以下功能 +- project 管理平台:https://supabase.com/dashboard/project/{projectId}, 请开发者联系管理员赋予相关权限。 +- 进入 Project 管理平台后,左边菜单栏中的 Table Editor、SQL Editor、Database 是你的好帮手。 + - Table Editor 支持直接修改数据; + - SQL Editor 是一个可以在线编写 SQL 并执行的可视化客户端;你可以在其中创建表、删除表、创建函数、删除函数等操作。 + - Database 中提供了数据库的的综合管理; + ## github ### webhook 代码目录 @@ -36,4 +45,81 @@ Webhook URL \> 填入smee channel url, eg: https://smee.io/Q2VVS0casGnhZV 6. 在 demo repository 发起 issue 或者 pull-request,在 smee 、本地将能同步看到请求。 -7. 在测试完毕后记得将 Webhook URL 改回去, eg:http://pertercat.chat/api/github/app/webhook \ No newline at end of file +7. 在测试完毕后记得将 Webhook URL 改回去, eg:http://pertercat.chat/api/github/app/webhook + +## RAG +### API +> server/routers/rag.py +#### rag/add_knowledge_by_doc +新增知识库。执行将 github 上指定的仓库中的文档进行 Embedding 化后,存储在 supabase 中,对应的 table 为 `rag_docs`。 + +#### rag/search_knowledge +搜索知识。将输入的 query 进行 Embedding 化后,与 supabase 中存储的知识进行匹配,返回匹配结果。 + +### 数据库 +建议将 DB 相关操作备份在 /server/sql/rag_docs.sql 中,方便追踪。 +#### 创建知识库 +```sql +create extension +if not exists vector; + +-- Create a table to store your rag_docs +create table rag_docs +( + id uuid primary key, + content text, + -- corresponds to Document.pageContent + metadata jsonb, + -- corresponds to Document.metadata + embedding vector (1536), + -- 1536 works for OpenAI embeddings, change if needed + -- per request info + repo_name varchar, + commit_id varchar, + bot_id varchar, + file_sha varchar, + file_path varchar +); +``` +### 创建 Function +为了实现知识库的 Embedding 查询,需要创建一个 Function。 +[supabase 文档教程](https://supabase.com/docs/guides/ai/vector-columns#querying-a-vector--embedding) + +> 如果 Function 的入参发生了变化,需要将该function 进行删除后重新创建。事实上建议在项目上线后创建新版本的函数,保留历史函数。 +```sql +-- 删除函数 +drop function if exists match_rag_docs_v1; +-- 新建函数 +create or replace function match_rag_docs_v1 + ( + query_embedding vector (1536), + query_bot_id text, + filter jsonb default '{}', + query_limit integer default 4 +) returns table +( + id uuid, + content text, + metadata jsonb, + embedding vector, + similarity float +) language plpgsql as $$ +#variable_conflict use_column +begin + return query + select + id, + content, + metadata, + embedding, + 1 - (rag_docs.embedding <=> query_embedding + ) as similarity + from rag_docs + where metadata @> filter + and bot_id = query_bot_id + -- <=> 为 embedding 比较函数 + order by rag_docs.embedding <=> query_embedding + limit query_limit; +end; +$$; +``` diff --git a/server/bot/builder.py b/server/bot/builder.py index cf94f528..ab45c0bf 100644 --- a/server/bot/builder.py +++ b/server/bot/builder.py @@ -3,7 +3,7 @@ from db.supabase.client import get_client from prompts.bot_template import generate_prompt_by_repo_name from rag_helper.task import add_task -from data_class import GitDocConfig +from data_class import RAGGitDOCConfig g = Github() @@ -38,22 +38,22 @@ async def bot_info_generator( except Exception as e: print(f"An error occurred: {e}") return None - + def trigger_rag_task (repo_name: str, bot_id: str): try: repo = g.get_repo(repo_name) default_branch = repo.default_branch - config = GitDocConfig( + config = RAGGitDOCConfig( repo_name=repo_name, branch=default_branch, bot_id=bot_id, - file_path='', - commit_id='' + file_path="", + commit_id="", ) add_task(config ) except Exception as e: print(f"trigger_rag_task error: {e}") - + async def bot_builder( uid: str, diff --git a/server/data_class.py b/server/data_class.py index 611cbb3f..6f1f983e 100644 --- a/server/data_class.py +++ b/server/data_class.py @@ -16,15 +16,18 @@ class ChatData(BaseModel): prompt: Optional[str] = None bot_id: Optional[str] = None + class ExecuteMessage(BaseModel): type: str repo: str path: str + class S3Config(BaseModel): s3_bucket: str file_path: Optional[str] = None + class GitIssueConfig(BaseModel): repo_name: str page: Optional[int] = None @@ -33,14 +36,17 @@ class GitIssueConfig(BaseModel): per_page: Optional[int] = 30 """Number of items per page. Defaults to 30 in the GitHub API.""" - state: Optional[Literal["open", "closed", "all"]] = 'all' + state: Optional[Literal["open", "closed", "all"]] = "all" """Filter on issue state. Can be one of: 'open', 'closed', 'all'.""" class GitDocConfig(BaseModel): repo_name: str """File path of the documentation file. eg:'docs/blog/build-ghost.zh-CN.md'""" - file_path: Optional[str] = '', - branch: Optional[str] = 'main' - commit_id: Optional[str] = '', - bot_id: Optional[str] = '', + file_path: Optional[str] = "" + branch: Optional[str] = "main" + commit_id: Optional[str] = "" + + +class RAGGitDOCConfig(GitDocConfig): + bot_id: str diff --git a/server/rag_helper/github_file_loader.py b/server/rag_helper/github_file_loader.py index 89bf302b..e246d58f 100644 --- a/server/rag_helper/github_file_loader.py +++ b/server/rag_helper/github_file_loader.py @@ -7,6 +7,8 @@ from typing import Callable, Dict, List, Optional from github import Github from langchain_core.documents import Document + + class GithubFileLoader: repo: str github: Github @@ -19,13 +21,13 @@ class GithubFileLoader: github_api_url: str = "https://api.github.com" def __init__(self, **data: Dict): - self.repo = data['repo'] - self.file_path = data['file_path'] - self.branch = data['branch'] - self.file_filter = data['file_filter'] + self.repo = data["repo"] + self.file_path = data["file_path"] + self.branch = data["branch"] + self.file_filter = data["file_filter"] self.github = Github() - if 'commit_id' in data: - self.commit_id = data['commit_id'] + if "commit_id" in data and data["commit_id"]: + self.commit_id = data["commit_id"] else: self.commit_id = self.get_commit_id_by_branch(self.branch) diff --git a/server/rag_helper/retrieval.py b/server/rag_helper/retrieval.py index 874a86ac..d63d7473 100644 --- a/server/rag_helper/retrieval.py +++ b/server/rag_helper/retrieval.py @@ -1,16 +1,16 @@ import json -from typing import Any +from typing import Any, Dict, Optional from langchain_community.vectorstores import SupabaseVectorStore from langchain_openai import OpenAIEmbeddings +from langchain_core.documents import Document +import numpy as np -from data_class import GitDocConfig, GitIssueConfig, S3Config +from data_class import GitDocConfig, GitIssueConfig, RAGGitDOCConfig, S3Config from db.supabase.client import get_client from rag_helper.github_file_loader import GithubFileLoader from utils.env import get_env_variable - -supabase_url = get_env_variable("SUPABASE_URL") -supabase_key = get_env_variable("SUPABASE_SERVICE_KEY") +from urllib.parse import quote TABLE_NAME = "rag_docs" QUERY_NAME = "match_rag_docs" @@ -18,28 +18,13 @@ CHUNK_OVERLAP = 200 -def convert_document_to_dict(document): - return document.page_content, - - -def init_retriever(): - embeddings = OpenAIEmbeddings() - vector_store = SupabaseVectorStore( - embedding=embeddings, - client=get_client(), - table_name=TABLE_NAME, - query_name=QUERY_NAME, - chunk_size=CHUNK_SIZE, - ) - - return vector_store.as_retriever() - - def init_s3_Loader(config: S3Config): from langchain_community.document_loaders import S3DirectoryLoader + loader = S3DirectoryLoader(config.s3_bucket, prefix=config.file_path) return loader + # TODO init_github_issue_loader # def init_github_issue_loader(config: GitIssueConfig): # from langchain_community.document_loaders import GitHubIssuesLoader @@ -60,7 +45,7 @@ def init_github_file_loader(config: GitDocConfig): branch=config.branch, file_path=config.file_path, file_filter=lambda file_path: file_path.endswith(".md"), - commit_id=config.commit_id + commit_id=config.commit_id, ) return loader @@ -69,7 +54,9 @@ def supabase_embedding(documents, **kwargs: Any): from langchain_text_splitters import CharacterTextSplitter try: - text_splitter = CharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) + text_splitter = CharacterTextSplitter( + chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP + ) docs = text_splitter.split_documents(documents) embeddings = OpenAIEmbeddings() vector_store = SupabaseVectorStore.from_documents( @@ -79,13 +66,14 @@ def supabase_embedding(documents, **kwargs: Any): table_name=TABLE_NAME, query_name=QUERY_NAME, chunk_size=CHUNK_SIZE, - **kwargs + **kwargs, ) return vector_store except Exception as e: print(e) return None + # TODO this feature is not implemented yet # def add_knowledge_by_issues(config: GitIssueConfig): # try: @@ -109,32 +97,32 @@ def supabase_embedding(documents, **kwargs: Any): # }) -def add_knowledge_by_doc(config: GitDocConfig): +def add_knowledge_by_doc(config: RAGGitDOCConfig): loader = init_github_file_loader(config) documents = loader.load() supabase = get_client() is_added_query = ( supabase.table(TABLE_NAME) - .select("id, repo_name, commit_id, file_path") - .eq('repo_name', config.repo_name) - .eq('commit_id', loader.commit_id) - .eq('file_path', config.file_path) + .select("id, repo_name, commit_id, file_path, bot_id") + .eq("repo_name", config.repo_name) + .eq("commit_id", loader.commit_id) + .eq("file_path", config.file_path) + .eq("bot_id", config.bot_id) .execute() ) if not is_added_query.data: is_equal_query = ( - supabase.table(TABLE_NAME) - .select("*") - .eq('file_sha', loader.file_sha) + supabase.table(TABLE_NAME).select("*").eq("file_sha", loader.file_sha) ).execute() if not is_equal_query.data: + # If there is no file with the same file_sha, perform embedding. store = supabase_embedding( documents, repo_name=config.repo_name, commit_id=loader.commit_id, file_sha=loader.file_sha, file_path=config.file_path, - bot_id=config.bot_id + bot_id=config.bot_id, ) return store else: @@ -143,23 +131,52 @@ def add_knowledge_by_doc(config: GitDocConfig): **{k: v for k, v in item.items() if k != "id"}, "repo_name": config.repo_name, "commit_id": loader.commit_id, - "file_path": config.file_path + "file_path": config.file_path, + "bot_id": config.bot_id, } for item in is_equal_query.data ] - insert_result = ( - supabase.table(TABLE_NAME) - .insert(new_commit_list) - .execute() - ) + insert_result = supabase.table(TABLE_NAME).insert(new_commit_list).execute() return insert_result else: return True -def search_knowledge(query: str): - retriever = init_retriever() - docs = retriever.invoke(query) - documents_as_dicts = [convert_document_to_dict(doc) for doc in docs] +def search_knowledge( + query: str, + bot_id: str, + meta_filter={}, +): + """ + use supabase vector store to search knowledge + https://supabase.com/docs/guides/ai/vector-columns#querying-a-vector--embedding + """ + embeddings = OpenAIEmbeddings().embed_query(query) + client = get_client() + query_builder = client.rpc( + QUERY_NAME, + { + "query_embedding": embeddings, + "filter": meta_filter, + "query_bot_id": bot_id, + "query_limit": 10, + }, + ) + res = query_builder.execute() + docs = [ + ( + Document( + metadata=search.get("metadata", {}), # type: ignore + page_content=search.get("content", ""), + ), + search.get("similarity", 0.0), + # Supabase returns a vector type as its string represation (!). + # This is a hack to convert the string to numpy array. + np.fromstring(search.get("embedding", "").strip("[]"), np.float32, sep=","), + ) + for search in res.data + if search.get("content") + ] + documents_as_dicts = [doc[0].page_content for doc in docs] json_output = json.dumps(documents_as_dicts, ensure_ascii=False) return json_output diff --git a/server/rag_helper/task.py b/server/rag_helper/task.py index e0164e64..fb2d9dcb 100644 --- a/server/rag_helper/task.py +++ b/server/rag_helper/task.py @@ -4,7 +4,7 @@ from github import Github from github import Repository -from data_class import GitDocConfig +from data_class import GitDocConfig, RAGGitDOCConfig from db.supabase.client import get_client from rag_helper import retrieval @@ -22,8 +22,13 @@ class TaskStatus(Enum): ERROR = auto() -def add_task(config: GitDocConfig, - extra: Optional[Dict[str, Optional[str]]] = {"node_type": None, "from_task_id": None}): +def add_task( + config: RAGGitDOCConfig, + extra: Optional[Dict[str, Optional[str]]] = { + "node_type": None, + "from_task_id": None, + }, +): repo = g.get_repo(config.repo_name) commit_id = config.commit_id if config.commit_id else repo.get_branch(config.branch).commit.sha @@ -131,20 +136,18 @@ def handle_blob_task(task): ) retrieval.add_knowledge_by_doc( - GitDocConfig( + RAGGitDOCConfig( repo_name=task["repo_name"], file_path=task["path"], commit_id=task["commit_id"], - bot_id=task["bot_id"] + bot_id=task["bot_id"], ) ) - return (supabase.table(TABLE_NAME).update( {"status": TaskStatus.COMPLETED.name}) .eq("id", task["id"]) .execute()) - def trigger_task(task_id: Optional[str]): task = get_task_by_id(task_id) if task_id else get_oldest_task() if task is None: diff --git a/server/routers/rag.py b/server/routers/rag.py index 1e6c99ed..c0460c44 100644 --- a/server/routers/rag.py +++ b/server/routers/rag.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends -from data_class import GitDocConfig, GitIssueConfig +from data_class import GitIssueConfig, RAGGitDOCConfig from rag_helper import retrieval, task from verify.rate_limit import verify_rate_limit @@ -15,24 +15,21 @@ @router.post("/rag/add_knowledge_by_doc", dependencies=[Depends(verify_rate_limit)]) -def add_knowledge_by_doc(config: GitDocConfig): +def add_knowledge_by_doc(config: RAGGitDOCConfig): try: result = retrieval.add_knowledge_by_doc(config) - if (result): - return json.dumps({ - "success": True, - "message": "Knowledge added successfully!", - }) + if result: + return json.dumps( + { + "success": True, + "message": "Knowledge added successfully!", + } + ) else: - return json.dumps({ - "success": False, - "message": "Knowledge not added!" - }) + return json.dumps({"success": False, "message": "Knowledge not added!"}) except Exception as e: - return json.dumps({ - "success": False, - "message": str(e) - }) + return json.dumps({"success": False, "message": str(e)}) + # TODO this feature is not implemented yet # @router.post("/rag/add_knowledge_by_issues", dependencies=[Depends(verify_rate_limit)]) @@ -42,21 +39,18 @@ def add_knowledge_by_doc(config: GitDocConfig): @router.post("/rag/search_knowledge", dependencies=[Depends(verify_rate_limit)]) -def search_knowledge(query: str): - data = retrieval.search_knowledge(query) +def search_knowledge(query: str, bot_id: str, filter: dict = {}): + data = retrieval.search_knowledge(query, bot_id, filter) return data @router.post("/rag/add_task", dependencies=[Depends(verify_rate_limit)]) -def add_task(config: GitDocConfig): +def add_task(config: RAGGitDOCConfig): try: data = task.add_task(config) return data except Exception as e: - return json.dumps({ - "success": False, - "message": str(e) - }) + return json.dumps({"success": False, "message": str(e)}) @router.post("/rag/trigger_task", dependencies=[Depends(verify_rate_limit)]) diff --git a/server/sql/rag_docs.sql b/server/sql/rag_docs.sql index cf0d3a0e..1c83beba 100644 --- a/server/sql/rag_docs.sql +++ b/server/sql/rag_docs.sql @@ -24,18 +24,21 @@ create table rag_docs ); -- Drop the existing function if it already exists -drop function if exists match_rag_docs -(vector, jsonb); +drop function if exists match_rag_docs; -- Create a function to search for rag_docs -create function match_rag_docs ( +create function match_rag_docs + ( query_embedding vector (1536), - filter jsonb default '{}' + query_bot_id text, + filter jsonb default '{}', + query_limit integer default 4 ) returns table ( id uuid, content text, metadata jsonb, + embedding vector, similarity float ) language plpgsql as $$ #variable_conflict use_column @@ -45,10 +48,13 @@ begin id, content, metadata, + embedding, 1 - (rag_docs.embedding <=> query_embedding ) as similarity from rag_docs where metadata @> filter - order by rag_docs.embedding <=> query_embedding; + and bot_id = query_bot_id + order by rag_docs.embedding <=> query_embedding + limit query_limit; end; -$$; +$$; \ No newline at end of file diff --git a/server/tools/knowledge.py b/server/tools/knowledge.py index 08899e14..e4f26faa 100644 --- a/server/tools/knowledge.py +++ b/server/tools/knowledge.py @@ -5,15 +5,16 @@ @tool def search_knowledge( query: str, + bot_id: str ): """ Search for information based on the query. When use this tool, do not translate the search query. Use the original query language to search. eg: When user's question is 'Ant Design 有哪些新特性?', the query should be 'Ant Design 有哪些新特性?'. :param query: The user's question. + :param bot_id: The bot's unique id. """ try: - return retrieval.search_knowledge(query) + return retrieval.search_knowledge(query, bot_id) except Exception as e: print(f"An error occurred: {e}") return None - \ No newline at end of file From c9276d01f7e20c283ad83d53be93134f9021f904 Mon Sep 17 00:00:00 2001 From: ch-liuzhide Date: Tue, 6 Aug 2024 10:42:42 +0800 Subject: [PATCH 2/4] refactor(rag): still use SupabaseVectorStore retiver --- server/README.zh-CN.md | 22 +++++++------- server/rag_helper/retrieval.py | 54 ++++++++++++++-------------------- server/sql/rag_docs.sql | 12 +++----- 3 files changed, 37 insertions(+), 51 deletions(-) diff --git a/server/README.zh-CN.md b/server/README.zh-CN.md index 1c471f34..2f88ffa7 100644 --- a/server/README.zh-CN.md +++ b/server/README.zh-CN.md @@ -1,7 +1,7 @@ [English](./README.md) | 简体中文 # 介绍 -PeterCat 服务端,采用 FastAPI 框架开发。 +PeterCat 服务端,采用 FastAPI 框架开发。使用了 supabase 作为数据存储方案。 # 功能模块 ## 存储 @@ -85,17 +85,19 @@ create table rag_docs 为了实现知识库的 Embedding 查询,需要创建一个 Function。 [supabase 文档教程](https://supabase.com/docs/guides/ai/vector-columns#querying-a-vector--embedding) -> 如果 Function 的入参发生了变化,需要将该function 进行删除后重新创建。事实上建议在项目上线后创建新版本的函数,保留历史函数。 +> 建议: +> 1. 如果 Function 的入参发生了变化,需要将该function 进行删除后重新创建。事实上建议在项目上线后创建新版本的函数,保留历史函数。 +> 2. 将函数备份在本项目中 server/sql/rag_docs.sql +#### 示例 +这些 sql 可以在 SQL Editor 中执行。 ```sql -- 删除函数 drop function if exists match_rag_docs_v1; -- 新建函数 -create or replace function match_rag_docs_v1 +create function match_rag_docs_v1 ( query_embedding vector (1536), - query_bot_id text, - filter jsonb default '{}', - query_limit integer default 4 + filter jsonb default '{}' ) returns table ( id uuid, @@ -115,11 +117,9 @@ begin 1 - (rag_docs.embedding <=> query_embedding ) as similarity from rag_docs - where metadata @> filter - and bot_id = query_bot_id - -- <=> 为 embedding 比较函数 - order by rag_docs.embedding <=> query_embedding - limit query_limit; + where metadata @> jsonb_extract_path(filter, 'metadata') + and bot_id = jsonb_extract_path_text(filter, 'bot_id') + order by rag_docs.embedding <=> query_embedding; end; $$; ``` diff --git a/server/rag_helper/retrieval.py b/server/rag_helper/retrieval.py index d63d7473..242b069a 100644 --- a/server/rag_helper/retrieval.py +++ b/server/rag_helper/retrieval.py @@ -12,6 +12,24 @@ from utils.env import get_env_variable from urllib.parse import quote + +def convert_document_to_dict(document): + return (document.page_content,) + + +def init_retriever(search_kwargs): + embeddings = OpenAIEmbeddings() + vector_store = SupabaseVectorStore( + embedding=embeddings, + client=get_client(), + table_name=TABLE_NAME, + query_name=QUERY_NAME, + chunk_size=CHUNK_SIZE, + ) + + return vector_store.as_retriever(search_kwargs=search_kwargs) + + TABLE_NAME = "rag_docs" QUERY_NAME = "match_rag_docs" CHUNK_SIZE = 2000 @@ -145,38 +163,10 @@ def add_knowledge_by_doc(config: RAGGitDOCConfig): def search_knowledge( query: str, bot_id: str, - meta_filter={}, + meta_filter: Dict[str, Any] = {}, ): - """ - use supabase vector store to search knowledge - https://supabase.com/docs/guides/ai/vector-columns#querying-a-vector--embedding - """ - embeddings = OpenAIEmbeddings().embed_query(query) - client = get_client() - query_builder = client.rpc( - QUERY_NAME, - { - "query_embedding": embeddings, - "filter": meta_filter, - "query_bot_id": bot_id, - "query_limit": 10, - }, - ) - res = query_builder.execute() - docs = [ - ( - Document( - metadata=search.get("metadata", {}), # type: ignore - page_content=search.get("content", ""), - ), - search.get("similarity", 0.0), - # Supabase returns a vector type as its string represation (!). - # This is a hack to convert the string to numpy array. - np.fromstring(search.get("embedding", "").strip("[]"), np.float32, sep=","), - ) - for search in res.data - if search.get("content") - ] - documents_as_dicts = [doc[0].page_content for doc in docs] + retriever = init_retriever({"filter": {"metadata": meta_filter, "bot_id": bot_id}}) + docs = retriever.invoke(query) + documents_as_dicts = [convert_document_to_dict(doc) for doc in docs] json_output = json.dumps(documents_as_dicts, ensure_ascii=False) return json_output diff --git a/server/sql/rag_docs.sql b/server/sql/rag_docs.sql index 1c83beba..d3f45d32 100644 --- a/server/sql/rag_docs.sql +++ b/server/sql/rag_docs.sql @@ -26,13 +26,10 @@ create table rag_docs -- Drop the existing function if it already exists drop function if exists match_rag_docs; --- Create a function to search for rag_docs create function match_rag_docs ( query_embedding vector (1536), - query_bot_id text, - filter jsonb default '{}', - query_limit integer default 4 + filter jsonb default '{}' ) returns table ( id uuid, @@ -52,9 +49,8 @@ begin 1 - (rag_docs.embedding <=> query_embedding ) as similarity from rag_docs - where metadata @> filter - and bot_id = query_bot_id - order by rag_docs.embedding <=> query_embedding - limit query_limit; + where metadata @> jsonb_extract_path(filter, 'metadata') + and bot_id = jsonb_extract_path_text(filter, 'bot_id') + order by rag_docs.embedding <=> query_embedding; end; $$; \ No newline at end of file From e985c7b4b8da63929de8ab843071f70399026ceb Mon Sep 17 00:00:00 2001 From: ch-liuzhide Date: Tue, 6 Aug 2024 10:46:56 +0800 Subject: [PATCH 3/4] move code line --- server/rag_helper/retrieval.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/server/rag_helper/retrieval.py b/server/rag_helper/retrieval.py index 242b069a..78e4e1b2 100644 --- a/server/rag_helper/retrieval.py +++ b/server/rag_helper/retrieval.py @@ -13,6 +13,12 @@ from urllib.parse import quote +TABLE_NAME = "rag_docs" +QUERY_NAME = "match_rag_docs" +CHUNK_SIZE = 2000 +CHUNK_OVERLAP = 200 + + def convert_document_to_dict(document): return (document.page_content,) @@ -30,12 +36,6 @@ def init_retriever(search_kwargs): return vector_store.as_retriever(search_kwargs=search_kwargs) -TABLE_NAME = "rag_docs" -QUERY_NAME = "match_rag_docs" -CHUNK_SIZE = 2000 -CHUNK_OVERLAP = 200 - - def init_s3_Loader(config: S3Config): from langchain_community.document_loaders import S3DirectoryLoader From a89be4dfa99bf7827cdb70ac786f939f50230f45 Mon Sep 17 00:00:00 2001 From: ch-liuzhide Date: Tue, 6 Aug 2024 17:23:18 +0800 Subject: [PATCH 4/4] feat(rag): support passing the bot_id parameter as a default parameter in the qa --- client/app/factory/edit/[id]/page.tsx | 15 +++ client/public/icons/BookIcon.tsx | 19 ++++ server/agent/base.py | 131 ++++++++++++++------------ server/agent/bot_builder.py | 1 - server/agent/qa_chat.py | 43 ++++++--- server/bot/builder.py | 4 +- server/data_class.py | 2 +- server/prompts/bot_template.py | 4 +- server/rag_helper/retrieval.py | 6 +- server/rag_helper/task.py | 6 +- server/routers/chat.py | 2 +- server/routers/rag.py | 6 +- server/tools/knowledge.py | 36 +++---- 13 files changed, 169 insertions(+), 106 deletions(-) create mode 100644 client/public/icons/BookIcon.tsx diff --git a/client/app/factory/edit/[id]/page.tsx b/client/app/factory/edit/[id]/page.tsx index 7674f6a1..bbadf779 100644 --- a/client/app/factory/edit/[id]/page.tsx +++ b/client/app/factory/edit/[id]/page.tsx @@ -18,6 +18,7 @@ import AIBtnIcon from '@/public/icons/AIBtnIcon'; import ChatIcon from '@/public/icons/ChatIcon'; import ConfigIcon from '@/public/icons/ConfigIcon'; import SaveIcon from '@/public/icons/SaveIcon'; +import BookIcon from '@/public/icons/BookIcon'; import { useBot } from '@/app/contexts/BotContext'; import 'react-toastify/dist/ReactToastify.css'; @@ -247,6 +248,19 @@ export default function Edit({ params }: { params: { id: string } }) { 重新生成配置 )} + {isEdit ? ( + + ) : null} {isEdit && } @@ -377,6 +391,7 @@ export default function Edit({ params }: { params: { id: string } }) { style={{ backgroundColor: '#FCFCFC', }} + token={params.id} apiDomain={API_HOST} apiUrl="/api/chat/stream_qa" prompt={botProfile?.prompt} diff --git a/client/public/icons/BookIcon.tsx b/client/public/icons/BookIcon.tsx new file mode 100644 index 00000000..f43d03a4 --- /dev/null +++ b/client/public/icons/BookIcon.tsx @@ -0,0 +1,19 @@ +const BookIcon = () => ( + + + + +); +export default BookIcon; diff --git a/server/agent/base.py b/server/agent/base.py index 0b7c8b8a..1c28f409 100644 --- a/server/agent/base.py +++ b/server/agent/base.py @@ -1,5 +1,6 @@ import json from typing import AsyncIterator, Dict, Callable, Optional + # import uuid from langchain.agents import AgentExecutor from data_class import ChatData, Message @@ -17,17 +18,19 @@ from utils.env import get_env_variable OPEN_API_KEY = get_env_variable("OPENAI_API_KEY") -TAVILY_API_KEY = get_env_variable("TAVILY_API_KEY") +TAVILY_API_KEY = get_env_variable("TAVILY_API_KEY") + class AgentBuilder: - + def __init__( - self, - prompt: str, - tools: Dict[str, Callable], - enable_tavily: Optional[bool] = True, + self, + prompt: str, + tools: Dict[str, Callable], + enable_tavily: Optional[bool] = True, temperature: Optional[int] = 0.2, - max_tokens: Optional[int] = 1500 + max_tokens: Optional[int] = 1500, + runtime_invoke_context: Optional[Dict] = {}, ): """ @class `Builde AgentExecutor based on tools and prompt` @@ -45,25 +48,27 @@ def __init__( self.agent_executor = self._create_agent_with_tools() def init_tavily_tools(self): - # init Tavily + # init Tavily search = TavilySearchAPIWrapper() tavily_tool = TavilySearchResults(api_wrapper=search) return [tavily_tool] - + def _create_agent_with_tools(self) -> AgentExecutor: - llm = ChatOpenAI(model="gpt-4o", temperature=self.temperature, streaming=True, max_tokens=self.max_tokens, openai_api_key=OPEN_API_KEY) + llm = ChatOpenAI( + model="gpt-4o", + temperature=self.temperature, + streaming=True, + max_tokens=self.max_tokens, + openai_api_key=OPEN_API_KEY, + ) + + tools = self.init_tavily_tools() if self.enable_tavily else [] - tools = self.init_tavily_tools() if self.enable_tavily else [] - for tool in self.tools.values(): tools.append(tool) if tools: - llm_with_tools = llm.bind( - tools=[convert_to_openai_tool(tool) for tool in tools] - ) - else: - llm_with_tools = llm + llm = llm.bind_tools([convert_to_openai_tool(tool) for tool in tools]) self.prompt = self.get_prompt() agent = ( @@ -75,12 +80,18 @@ def _create_agent_with_tools(self) -> AgentExecutor: "chat_history": lambda x: x["chat_history"], } | self.prompt - | llm_with_tools + | llm | OpenAIToolsAgentOutputParser() ) - return AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True, max_iterations=5) - + return AgentExecutor( + agent=agent, + tools=tools, + verbose=True, + handle_parsing_errors=True, + max_iterations=5, + ) + def get_prompt(self): return ChatPromptTemplate.from_messages( [ @@ -95,7 +106,7 @@ def get_prompt(self): def chat_history_transform(messages: list[Message]): transformed_messages = [] for message in messages: - print('message', message) + print("message", message) if message.role == "user": transformed_messages.append(HumanMessage(content=message.content)) elif message.role == "assistant": @@ -108,7 +119,6 @@ async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]: try: messages = input_data.messages print(self.chat_history_transform(messages)) - async for event in self.agent_executor.astream_events( { "input": messages[len(messages) - 1].content, @@ -118,18 +128,14 @@ async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]: ): kind = event["event"] if kind == "on_chain_start": - if ( - event["name"] == "agent" - ): + if event["name"] == "agent": print( - f"Starting agent: {event['name']} " - f"with input: {event['data'].get('input')}" + f"Starting agent: {event['name']} " + f"with input: {event['data'].get('input')}" ) elif kind == "on_chain_end": - if ( - event["name"] == "agent" - ): - print ( + if event["name"] == "agent": + print( f"Done agent: {event['name']} " f"with output: {event['data'].get('output')['output']}" ) @@ -137,50 +143,59 @@ async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]: # id = str(uuid.uuid4()) content = event["data"]["chunk"].content if content: - json_output = json.dumps({ - "type": "message", - "content": content, - }, ensure_ascii=False) + json_output = json.dumps( + { + "type": "message", + "content": content, + }, + ensure_ascii=False, + ) yield f"{json_output}\n\n" elif kind == "on_tool_start": children_value = event["data"].get("input", {}) - json_output = json.dumps({ - "type": "tool", - "extra": { - "source": f"已调用工具: {event['name']}", - "pluginName": "GitHub", - "data": json.dumps(children_value, ensure_ascii=False), - "status": "loading" - } - }, ensure_ascii=False) - + json_output = json.dumps( + { + "type": "tool", + "extra": { + "source": f"已调用工具: {event['name']}", + "pluginName": "GitHub", + "data": json.dumps(children_value, ensure_ascii=False), + "status": "loading", + }, + }, + ensure_ascii=False, + ) + yield f"{json_output}\n\n" elif kind == "on_tool_end": children_value = event["data"].get("output", {}) - json_output = json.dumps({ - "type": "tool", - "extra": { - "source": f"已调用工具: {event['name']}", - "pluginName": "GitHub", - "data": children_value, - "status": "success" + json_output = json.dumps( + { + "type": "tool", + "extra": { + "source": f"已调用工具: {event['name']}", + "pluginName": "GitHub", + "data": children_value, + "status": "success", + }, }, - }, ensure_ascii=False) + ensure_ascii=False, + ) yield f"{json_output}\n\n" except Exception as e: yield f"error: {str(e)}\n\n" - + async def run_chat(self, input_data: ChatData) -> str: try: messages = input_data.messages - print('history', self.chat_history_transform(messages)) - + print("history", self.chat_history_transform(messages)) + return self.agent_executor.invoke( { "input": messages[len(messages) - 1].content, "chat_history": self.chat_history_transform(messages), }, - return_only_outputs=True, - ) + return_only_outputs=True, + ) except Exception as e: return f"error: {str(e)}\n" diff --git a/server/agent/bot_builder.py b/server/agent/bot_builder.py index f3adb59c..f1145e3d 100644 --- a/server/agent/bot_builder.py +++ b/server/agent/bot_builder.py @@ -5,7 +5,6 @@ from tools import bot_builder - TOOL_MAPPING = { "create_bot": bot_builder.create_bot, "edit_bot": bot_builder.edit_bot, diff --git a/server/agent/qa_chat.py b/server/agent/qa_chat.py index 0f9152d3..34e4b63c 100644 --- a/server/agent/qa_chat.py +++ b/server/agent/qa_chat.py @@ -6,34 +6,47 @@ from tools import issue, sourcecode, knowledge -TOOL_MAPPING = { - "search_knowledge": knowledge.search_knowledge, - "create_issue": issue.create_issue, - "get_issues": issue.get_issues, - "search_issues": issue.search_issues, - "search_code": sourcecode.search_code, -} +def get_tools(bot_id): + return { + "search_knowledge": knowledge.factory(bot_id=bot_id), + "create_issue": issue.create_issue, + "get_issues": issue.get_issues, + "search_issues": issue.search_issues, + "search_code": sourcecode.search_code, + } + def init_prompt(input_data: ChatData): if input_data.prompt: - prompt = input_data.prompt + prompt = input_data.prompt elif input_data.bot_id: try: supabase = get_client() - res = supabase.table("bots").select('prompt').eq('id', input_data.bot_id).execute() - prompt = res.data[0]['prompt'] + res = ( + supabase.table("bots") + .select("prompt") + .eq("id", input_data.bot_id) + .execute() + ) + prompt = res.data[0]["prompt"] except Exception as e: print(e) - prompt = generate_prompt_by_repo_name("ant-design") + prompt = generate_prompt_by_repo_name("ant-design") else: - prompt = generate_prompt_by_repo_name("ant-design") - + prompt = generate_prompt_by_repo_name("ant-design") + return prompt + def agent_stream_chat(input_data: ChatData) -> AsyncIterator[str]: - agent = AgentBuilder(prompt=init_prompt(input_data), tools=TOOL_MAPPING) + agent = AgentBuilder( + prompt=init_prompt(input_data), tools=get_tools(bot_id=input_data.bot_id) + ) return agent.run_stream_chat(input_data) + def agent_chat(input_data: ChatData) -> AsyncIterator[str]: - agent = AgentBuilder(prompt=init_prompt(input_data), tools=TOOL_MAPPING) + agent = AgentBuilder( + prompt=init_prompt(input_data), tools=get_tools(input_data.bot_id) + ) return agent.run_chat(input_data) diff --git a/server/bot/builder.py b/server/bot/builder.py index ab45c0bf..4f08fabd 100644 --- a/server/bot/builder.py +++ b/server/bot/builder.py @@ -3,7 +3,7 @@ from db.supabase.client import get_client from prompts.bot_template import generate_prompt_by_repo_name from rag_helper.task import add_task -from data_class import RAGGitDOCConfig +from data_class import RAGGitDocConfig g = Github() @@ -43,7 +43,7 @@ def trigger_rag_task (repo_name: str, bot_id: str): try: repo = g.get_repo(repo_name) default_branch = repo.default_branch - config = RAGGitDOCConfig( + config = RAGGitDocConfig( repo_name=repo_name, branch=default_branch, bot_id=bot_id, diff --git a/server/data_class.py b/server/data_class.py index 6f1f983e..b1bd3f66 100644 --- a/server/data_class.py +++ b/server/data_class.py @@ -48,5 +48,5 @@ class GitDocConfig(BaseModel): commit_id: Optional[str] = "" -class RAGGitDOCConfig(GitDocConfig): +class RAGGitDocConfig(GitDocConfig): bot_id: str diff --git a/server/prompts/bot_template.py b/server/prompts/bot_template.py index e8637148..897b1a42 100644 --- a/server/prompts/bot_template.py +++ b/server/prompts/bot_template.py @@ -1,6 +1,3 @@ - - - PROMPT = """ # Character You are a skilled assistant dedicated to {repo_name}, capable of delivering comprehensive insights and solutions pertaining to {repo_name}. You excel in fixing code issues correlated with {repo_name}. @@ -26,5 +23,6 @@ - With your multilingual capability, always respond in the user's language. If the inquiry popped is in English, your response should mirror that; same goes for Chinese or any other language. """ + def generate_prompt_by_repo_name(repo_name: str): return PROMPT.format(repo_name=repo_name) diff --git a/server/rag_helper/retrieval.py b/server/rag_helper/retrieval.py index 78e4e1b2..ff60eed8 100644 --- a/server/rag_helper/retrieval.py +++ b/server/rag_helper/retrieval.py @@ -6,7 +6,7 @@ from langchain_core.documents import Document import numpy as np -from data_class import GitDocConfig, GitIssueConfig, RAGGitDOCConfig, S3Config +from data_class import GitDocConfig, GitIssueConfig, RAGGitDocConfig, S3Config from db.supabase.client import get_client from rag_helper.github_file_loader import GithubFileLoader from utils.env import get_env_variable @@ -20,7 +20,7 @@ def convert_document_to_dict(document): - return (document.page_content,) + return document.page_content def init_retriever(search_kwargs): @@ -115,7 +115,7 @@ def supabase_embedding(documents, **kwargs: Any): # }) -def add_knowledge_by_doc(config: RAGGitDOCConfig): +def add_knowledge_by_doc(config: RAGGitDocConfig): loader = init_github_file_loader(config) documents = loader.load() supabase = get_client() diff --git a/server/rag_helper/task.py b/server/rag_helper/task.py index fb2d9dcb..8757447c 100644 --- a/server/rag_helper/task.py +++ b/server/rag_helper/task.py @@ -4,7 +4,7 @@ from github import Github from github import Repository -from data_class import GitDocConfig, RAGGitDOCConfig +from data_class import GitDocConfig, RAGGitDocConfig from db.supabase.client import get_client from rag_helper import retrieval @@ -23,7 +23,7 @@ class TaskStatus(Enum): def add_task( - config: RAGGitDOCConfig, + config: RAGGitDocConfig, extra: Optional[Dict[str, Optional[str]]] = { "node_type": None, "from_task_id": None, @@ -136,7 +136,7 @@ def handle_blob_task(task): ) retrieval.add_knowledge_by_doc( - RAGGitDOCConfig( + RAGGitDocConfig( repo_name=task["repo_name"], file_path=task["path"], commit_id=task["commit_id"], diff --git a/server/routers/chat.py b/server/routers/chat.py index 2e56e198..6a52ff3e 100644 --- a/server/routers/chat.py +++ b/server/routers/chat.py @@ -22,6 +22,7 @@ def run_qa_chat(input_data: ChatData): result = qa_chat.agent_stream_chat(input_data) return StreamingResponse(result, media_type="text/event-stream") + @router.post("/qa", dependencies=[Depends(verify_rate_limit)]) async def run_issue_helper(input_data: ChatData): result = await qa_chat.agent_chat(input_data) @@ -34,4 +35,3 @@ def run_bot_builder(input_data: ChatData, user_id: str = Cookie(None), bot_id: O return StreamingResponse(generate_auth_failed_stream(), media_type="text/event-stream") result = bot_builder.agent_stream_chat(input_data, user_id, bot_id) return StreamingResponse(result, media_type="text/event-stream") - diff --git a/server/routers/rag.py b/server/routers/rag.py index c0460c44..eb834521 100644 --- a/server/routers/rag.py +++ b/server/routers/rag.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends -from data_class import GitIssueConfig, RAGGitDOCConfig +from data_class import GitIssueConfig, RAGGitDocConfig from rag_helper import retrieval, task from verify.rate_limit import verify_rate_limit @@ -15,7 +15,7 @@ @router.post("/rag/add_knowledge_by_doc", dependencies=[Depends(verify_rate_limit)]) -def add_knowledge_by_doc(config: RAGGitDOCConfig): +def add_knowledge_by_doc(config: RAGGitDocConfig): try: result = retrieval.add_knowledge_by_doc(config) if result: @@ -45,7 +45,7 @@ def search_knowledge(query: str, bot_id: str, filter: dict = {}): @router.post("/rag/add_task", dependencies=[Depends(verify_rate_limit)]) -def add_task(config: RAGGitDOCConfig): +def add_task(config: RAGGitDocConfig): try: data = task.add_task(config) return data diff --git a/server/tools/knowledge.py b/server/tools/knowledge.py index e4f26faa..e40f781f 100644 --- a/server/tools/knowledge.py +++ b/server/tools/knowledge.py @@ -1,20 +1,24 @@ -from langchain.tools import tool +from langchain_core.tools import InjectedToolArg, tool +from typing_extensions import Annotated from rag_helper import retrieval -@tool -def search_knowledge( - query: str, - bot_id: str -): - """ - Search for information based on the query. When use this tool, do not translate the search query. Use the original query language to search. eg: When user's question is 'Ant Design 有哪些新特性?', the query should be 'Ant Design 有哪些新特性?'. +def factory(bot_id: str): + bot_id = bot_id - :param query: The user's question. - :param bot_id: The bot's unique id. - """ - try: - return retrieval.search_knowledge(query, bot_id) - except Exception as e: - print(f"An error occurred: {e}") - return None + @tool(parse_docstring=True) + def search_knowledge( + query: str, + ) -> str: + """Search for information based on the query. When use this tool, do not translate the search query. Use the original query language to search. eg: When user's question is 'Ant Design 有哪些新特性?', the query should be 'Ant Design 有哪些新特性?'. + + Args: + query: The user's question. + """ + try: + return retrieval.search_knowledge(query, bot_id) + except Exception as e: + print(f"An error occurred: {e}") + return None + + return search_knowledge