Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the search tool tavily to the agent #82

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lui/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "petercat-lui",
"version": "0.0.3",
"version": "0.0.4",
"description": "A react library developed with dumi",
"module": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
2 changes: 1 addition & 1 deletion lui/src/Assistant/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ atomId: Assistant

```tsx
import React from 'react';
import { Assistant } from 'lui';
import { Assistant } from 'petercat-lui';

export default () => (
<Assistant
Expand Down
2 changes: 1 addition & 1 deletion lui/src/Assistant/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const Assistant = (props: AssistantProps) => {
className="fixed right-0 top-0 h-full flex flex-row z-[999] overflow-hidden text-left text-black bg-gradient-to-r from-f2e9ed via-e9eefb to-f0eeea shadow-[0px_0px_1px_#919eab3d]"
style={{ width: drawerWidth, zIndex: 9999 }}
>
<Chat {...props} />
<Chat {...props} drawerWidth={drawerWidth} />
<div className="absolute top-0 right-0 m-1">
<ActionIcon
icon={<CloseCircleFilled />}
Expand Down
2 changes: 1 addition & 1 deletion lui/src/Chat/index.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

```jsx
import React from 'react';
import { Chat } from 'lui';
import { Chat } from 'petercat-lui';


export default () => (
Expand Down
36 changes: 23 additions & 13 deletions lui/src/Chat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import type {
} from '@ant-design/pro-chat';
import { ProChat } from '@ant-design/pro-chat';
import { Markdown } from '@ant-design/pro-editor';
import StopBtn from 'lui/StopBtn';
import { theme } from 'lui/Theme';
import ThoughtChain from 'lui/ThoughtChain';
import { Role } from 'lui/interface';
import { BOT_INFO } from 'lui/mock';
import { streamChat } from 'lui/services/ChatController';
import { handleStream } from 'lui/utils';
import React, { ReactNode, memo, useRef, useState, type FC } from 'react';
import StopBtn from '../StopBtn';
import { theme } from '../Theme';
import ThoughtChain from '../ThoughtChain';
import { Role } from '../interface';
import { BOT_INFO } from '../mock';
import { streamChat } from '../services/ChatController';
import { handleStream } from '../utils';
import Actions from './inputArea/actions';

const { getDesignToken } = theme;
Expand All @@ -23,15 +23,19 @@ export interface ChatProps {
assistantMeta?: MetaData;
helloMessage?: string;
host?: string;
drawerWidth?: number;
slot?: {
componentID: string;
renderFunc: (data: any) => React.ReactNode;
}[];
}

const Chat: FC<ChatProps> = memo(({ helloMessage, host }) => {
const Chat: FC<ChatProps> = memo(({ helloMessage, host, drawerWidth }) => {
const proChatRef = useRef<ProChatInstance>();
const [chats, setChats] = useState<ChatMessage<Record<string, any>>[]>();
const messageMinWidth = drawerWidth
? `calc(${drawerWidth}px - 90px)`
: '100%';
return (
<div
className="h-full w-full"
Expand Down Expand Up @@ -60,10 +64,16 @@ const Chat: FC<ChatProps> = memo(({ helloMessage, host }) => {
},
contentRender: (props: ChatItemProps, defaultDom: ReactNode) => {
const originData = props.originData || {};
if (originData?.role === Role.user) {
return defaultDom;
}
const message = originData.content;
const defaultMessageContent = (
<div style={{ minWidth: messageMinWidth }}>{defaultDom}</div>
);

if (!message || !message.startsWith('<TOOL>')) {
return defaultDom;
return defaultMessageContent;
}

const [toolStr, answerStr] = message.split('<ANSWER>');
Expand All @@ -75,23 +85,23 @@ const Chat: FC<ChatProps> = memo(({ helloMessage, host }) => {

if (!match) {
console.error('No valid JSON found in input');
return defaultDom;
return defaultMessageContent;
}

try {
const config = JSON.parse(match[1]);
const { type, extra } = config;

if (![Role.knowledge, Role.tool].includes(type)) {
return defaultDom;
return defaultMessageContent;
}

const { status, source } = extra;

return (
<div
className="p-2 bg-white rounded-md "
style={{ minWidth: 'calc(375px - 90px)' }}
style={{ minWidth: messageMinWidth }}
>
<div className="mb-1">
<ThoughtChain
Expand All @@ -107,7 +117,7 @@ const Chat: FC<ChatProps> = memo(({ helloMessage, host }) => {
);
} catch (error) {
console.error(`JSON parse error: ${error}`);
return defaultDom;
return defaultMessageContent;
}
},
}}
Expand Down
2 changes: 1 addition & 1 deletion lui/src/StopBtn/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ atomId: StopBtn
# StopBtn
``` tsx
import React from 'react';
import { StopBtn } from 'lui';
import { StopBtn } from 'petercat-lui';

export default () => <StopBtn visible={true} />;
```
2 changes: 1 addition & 1 deletion lui/src/ThoughtChain/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ atomId: ThoughtChain

```tsx
import React from 'react';
import { ThoughtChain } from 'lui';
import { ThoughtChain } from 'petercat-lui';

export default () => (
<ThoughtChain
Expand Down
37 changes: 6 additions & 31 deletions lui/src/ThoughtChain/index.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import {
ApiOutlined,
CheckCircleOutlined,
CloseCircleOutlined,
DownOutlined,
ExclamationCircleOutlined,
FileTextOutlined,
LoadingOutlined,
UnorderedListOutlined,
UpOutlined,
Expand Down Expand Up @@ -72,35 +70,12 @@ const ThoughtChain: React.FC<ThoughtChainProps> = (params) => {
<DownOutlined className={`${getColorClass(status!)}`} />
</span>
),
children: (
<Collapse
ghost
size="small"
expandIcon={(panelProps) => {
const {
status: itemStatus,
knowledgeName,
pluginName,
} = (panelProps as IExtraInfo) || {};

if (itemStatus === Status.loading) {
return <LoadingOutlined className="text-blue-600 text-xs" />;
} else if (knowledgeName) {
return <FileTextOutlined className="text-gray-900 text-xs" />;
} else if (pluginName) {
return <ApiOutlined className="text-gray-900 text-xs" />;
}
return <></>;
}}
>
{safeJsonParse(content?.data) ? (
<Highlight language="json" theme="light" type="block">
{JSON.stringify(safeJsonParse(content?.data), null, 2)}
</Highlight>
) : (
<>{content?.data}</>
)}
</Collapse>
children: safeJsonParse(content?.data) ? (
<Highlight language="json" theme="light" type="block">
{JSON.stringify(safeJsonParse(content?.data), null, 2)}
</Highlight>
) : (
<>{content?.data}</>
),
},
];
Expand Down
2 changes: 1 addition & 1 deletion lui/src/mock/inputArea.mock.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { IBot } from 'lui/interface';
import { IBot } from '../interface';

export const DEFAULT_HELLO_MESSAGE =
'我是你的私人助理Kate, 我有许多惊人的能力,比如你可以对我说我想创建一个机器人';
Expand Down
2 changes: 1 addition & 1 deletion lui/src/services/ChatController.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { IPrompt } from 'lui/interface';
import { IPrompt } from '../interface';

/**
* Chat api
Expand Down
2 changes: 1 addition & 1 deletion lui/src/utils/chatTranslator.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { map } from 'lodash';
import { Role } from 'lui/interface';
import { Role } from '../interface';

export const convertChunkToJson = (rawData: string) => {
const regex = /data:(.*)/;
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"axios": "^1.6.7",
"concurrently": "^8.2.2",
"dayjs": "^1.11.10",
"petercat-lui": "^0.0.3",
"petercat-lui": "^0.0.4",
"eslint": "8.46.0",
"eslint-config-next": "13.4.12",
"framer-motion": "^10.16.15",
Expand Down
5 changes: 3 additions & 2 deletions server/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ DOCKER_SOCKET_LOCATION=/var/run/docker.sock
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER

# GitHub Access Token
GITHUB_TOKEN=GITHUB_TOKEN

#TAVILY_API_KEY
TAVILY_API_KEY=TAVILY_API_KEY
18 changes: 12 additions & 6 deletions server/agent/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import json
import os
import uuid
from langchain.tools import tool
from typing import AsyncIterator
Expand All @@ -12,11 +13,15 @@
from langchain.prompts import MessagesPlaceholder
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.prompts import ChatPromptTemplate
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
from langchain.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI
from uilts.env import get_env_variable
from tools import issue
from tools import sourcecode
from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage

TAVILY_API_KEY = get_env_variable("TAVILY_API_KEY")

prompt = ChatPromptTemplate.from_messages(
[
Expand Down Expand Up @@ -56,10 +61,11 @@ def get_datetime() -> datetime:
TOOLS = ["get_datetime", "create_issue", "get_issues", "search_issues", "search_code"]


def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor:
openai_api_key=openai_api_key
llm = ChatOpenAI(model="gpt-4", temperature=0.2, streaming=True)
tools = []
def _create_agent_with_tools(open_api_key: str) -> AgentExecutor:
llm = ChatOpenAI(model="gpt-4-1106-preview", temperature=0.2, streaming=True, max_tokens=1500, openai_api_key=open_api_key)
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search)
tools = [tavily_tool]

for requested_tool in TOOLS:
if requested_tool not in TOOL_MAPPING:
Expand Down Expand Up @@ -104,10 +110,10 @@ def chat_history_transform(messages: list[Message]):
return transformed_messages


async def agent_chat(input_data: ChatData, openai_api_key) -> AsyncIterator[str]:
async def agent_chat(input_data: ChatData, open_api_key: str) -> AsyncIterator[str]:
try:
messages = input_data.messages
agent_executor = _create_agent_with_tools(openai_api_key)
agent_executor = _create_agent_with_tools(open_api_key)
print(chat_history_transform(messages))
async for event in agent_executor.astream_events(
{
Expand Down
6 changes: 3 additions & 3 deletions server/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from data_class import DalleData, ChatData
from openai_api import dalle
from langchain_api import chat
from agent import stream
from uilts.env import get_env_variable
import uvicorn

open_api_key = os.getenv("OPENAI_API_KEY")
open_api_key = get_env_variable("OPENAI_API_KEY")

app = FastAPI(
title="Bo-meta Server",
Expand Down Expand Up @@ -46,4 +46,4 @@ def run_agent_chat(input_data: ChatData):
return StreamingResponse(result, media_type="text/event-stream")

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
1 change: 1 addition & 0 deletions server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ langchain-openai
PyGithub
python-multipart
httpx[socks]
load_dotenv
6 changes: 2 additions & 4 deletions server/tools/issue.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import json
import os
from typing import Optional
from github import Github
from langchain.tools import tool

GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')
from uilts.env import get_env_variable

DEFAULT_REPO_NAME = "ant-design/ant-design"

g = Github(GITHUB_TOKEN)
g = Github()

@tool
def create_issue(repo_name, title, body):
Expand Down
6 changes: 2 additions & 4 deletions server/tools/sourcecode.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
from typing import List, Optional
from github import Github
from github.ContentFile import ContentFile
from langchain.tools import tool
from uilts.env import get_env_variable


GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')

DEFAULT_REPO_NAME = "ant-design/ant-design"

g = Github(GITHUB_TOKEN)
g = Github()

@tool
def search_code(
Expand Down
17 changes: 17 additions & 0 deletions server/uilts/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dotenv import load_dotenv
import os

# Define a method to load an environmental variable and return its value
def get_env_variable(key: str, default=None):
"""
Retrieve the specified environment variable. Return the specified default value if the variable does not exist.

:param key: The name of the environment variable to retrieve.
:param default: The default value to return if the environment variable does not exist.
:return: The value of the environment variable, or the default value if it does not exist.
"""
# Load the .env file
load_dotenv(verbose=True, override=True)

# Get the environment variable, returning the default value if it does not exist
return os.getenv(key, default)