From 785d3748e10c6c2fa5b21129aa8e35905876a171 Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Sat, 6 Jul 2024 13:05:09 +0800 Subject: [PATCH 01/27] feat: support baidu model --- .gitignore | 2 +- app/api/auth.ts | 3 + app/api/baidu/[...path]/route.ts | 176 +++++++++++++++++++++ app/client/api.ts | 5 + app/client/platforms/baidu.ts | 252 +++++++++++++++++++++++++++++++ app/components/exporter.tsx | 2 + app/components/home.tsx | 2 + app/components/settings.tsx | 62 ++++++++ app/config/server.ts | 17 ++- app/constant.ts | 32 ++++ app/locales/cn.ts | 16 ++ app/store/access.ts | 10 ++ app/store/chat.ts | 4 + app/utils/model.ts | 7 +- 14 files changed, 586 insertions(+), 4 deletions(-) create mode 100644 app/api/baidu/[...path]/route.ts create mode 100644 app/client/platforms/baidu.ts diff --git a/.gitignore b/.gitignore index b00b0e325a4..a24c6e047d5 100644 --- a/.gitignore +++ b/.gitignore @@ -43,4 +43,4 @@ dev .env *.key -*.key.pub \ No newline at end of file +*.key.pub diff --git a/app/api/auth.ts b/app/api/auth.ts index 2b4702aedc3..cce8847f4dd 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { case ModelProvider.Claude: systemApiKey = serverConfig.anthropicApiKey; break; + case ModelProvider.Ernie: + systemApiKey = serverConfig.baiduApiKey; + break; case ModelProvider.GPT: default: if (req.nextUrl.pathname.includes("azure/deployments")) { diff --git a/app/api/baidu/[...path]/route.ts b/app/api/baidu/[...path]/route.ts new file mode 100644 index 00000000000..27676d29df8 --- /dev/null +++ b/app/api/baidu/[...path]/route.ts @@ -0,0 +1,176 @@ +import { getServerSideConfig } from "@/app/config/server"; +import { + BAIDU_BASE_URL, + ApiPath, + ModelProvider, + BAIDU_OATUH_URL, + ServiceProvider, +} from "@/app/constant"; +import { prettyObject } from "@/app/utils/format"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/app/api/auth"; +import { isModelAvailableInServer } from "@/app/utils/model"; + +const serverConfig = getServerSideConfig(); + +async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[Baidu Route] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const authResult = auth(req, ModelProvider.Ernie); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + const response = await request(req); + return response; + } catch (e) { + console.error("[Baidu] ", e); + return NextResponse.json(prettyObject(e)); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + +async function request(req: NextRequest) { + const controller = new AbortController(); + + let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Baidu, ""); + + let baseUrl = serverConfig.baiduUrl || BAIDU_BASE_URL; + + if (!baseUrl.startsWith("http")) { + baseUrl = `https://${baseUrl}`; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, -1); + } + + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + const { access_token } = await getAccessToken(); + const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`; + + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + }, + method: req.method, + body: req.body, + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + // #1815 try to refuse some request to some models + if (serverConfig.customModels && req.body) { + try { + const clonedBody = await req.text(); + fetchOptions.body = clonedBody; + + const jsonBody = JSON.parse(clonedBody) as { model?: string }; + + // not undefined and is false + if ( + isModelAvailableInServer( + serverConfig.customModels, + jsonBody?.model as string, + ServiceProvider.Baidu as string, + ) + ) { + return NextResponse.json( + { + error: true, + message: `you are not allowed to use ${jsonBody?.model} model`, + }, + { + status: 403, + }, + ); + } + } catch (e) { + console.error(`[Baidu] filter`, e); + } + } + console.log("[Baidu request]", fetchOptions.headers, req.method); + try { + const res = await fetch(fetchUrl, fetchOptions); + + console.log("[Baidu response]", res.status, " ", res.headers, res.url); + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + return new Response(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } +} + +/** + * 使用 AK,SK 生成鉴权签名(Access Token) + * @return 鉴权签名信息 + */ +async function getAccessToken(): Promise<{ + access_token: string; + expires_in: number; + error?: number; +}> { + const AK = serverConfig.baiduApiKey; + const SK = serverConfig.baiduSecretKey; + const res = await fetch( + `${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`, + { + method: "POST", + }, + ); + const resJson = await res.json(); + return resJson; +} diff --git a/app/client/api.ts b/app/client/api.ts index 41ccbd8e1c0..74e0ef9a996 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -9,6 +9,8 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; import { ChatGPTApi } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; +import { ErnieApi } from "./platforms/baidu"; + export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -104,6 +106,9 @@ export class ClientApi { case ModelProvider.Claude: this.llm = new ClaudeApi(); break; + case ModelProvider.Ernie: + this.llm = new ErnieApi(); + break; default: this.llm = new ChatGPTApi(); } diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts new file mode 100644 index 00000000000..e2f6f12dd22 --- /dev/null +++ b/app/client/platforms/baidu.ts @@ -0,0 +1,252 @@ +"use client"; +import { + ApiPath, + Baidu, + DEFAULT_API_HOST, + REQUEST_TIMEOUT_MS, +} from "@/app/constant"; +import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; + +import { + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + MultimodalContent, +} from "../api"; +import Locale from "../../locales"; +import { + EventStreamContentType, + fetchEventSource, +} from "@fortaine/fetch-event-source"; +import { prettyObject } from "@/app/utils/format"; +import { getClientConfig } from "@/app/config/client"; +import { getMessageTextContent, isVisionModel } from "@/app/utils"; + +export interface OpenAIListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; +} + +interface RequestPayload { + messages: { + role: "system" | "user" | "assistant"; + content: string | MultimodalContent[]; + }[]; + stream?: boolean; + model: string; + temperature: number; + presence_penalty: number; + frequency_penalty: number; + top_p: number; + max_tokens?: number; +} + +export class ErnieApi implements LLMApi { + path(path: string): string { + const accessStore = useAccessStore.getState(); + + let baseUrl = ""; + + if (accessStore.useCustomConfig) { + baseUrl = accessStore.baiduUrl; + } + + if (baseUrl.length === 0) { + const isApp = !!getClientConfig()?.isApp; + baseUrl = isApp ? DEFAULT_API_HOST + "/api/proxy/baidu" : ApiPath.Baidu; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Baidu)) { + baseUrl = "https://" + baseUrl; + } + + console.log("[Proxy Endpoint] ", baseUrl, path); + + return [baseUrl, path].join("/"); + } + + extractMessage(res: any) { + return res.choices?.at(0)?.message?.content ?? ""; + } + + async chat(options: ChatOptions) { + const visionModel = isVisionModel(options.config.model); + const messages = options.messages.map((v) => ({ + role: v.role, + content: visionModel ? v.content : getMessageTextContent(v), + })); + + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.config.model, + }, + }; + + const requestPayload: RequestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + }; + + console.log("[Request] Baidu payload: ", requestPayload); + + const shouldStream = !!options.config.stream; + const controller = new AbortController(); + options.onController?.(controller); + + try { + const chatPath = this.path(Baidu.ChatPath(modelConfig.model)); + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), + }; + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + if (shouldStream) { + let responseText = ""; + let remainText = ""; + let finished = false; + + // animate response to make it looks smooth + function animateResponseText() { + if (finished || controller.signal.aborted) { + responseText += remainText; + console.log("[Response Animation] finished"); + if (responseText?.length === 0) { + options.onError?.(new Error("empty response from server")); + } + return; + } + + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + options.onUpdate?.(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + } + + // start animaion + animateResponseText(); + + const finish = () => { + if (!finished) { + finished = true; + options.onFinish(responseText + remainText); + } + }; + + controller.signal.onabort = finish; + + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + clearTimeout(requestTimeoutId); + const contentType = res.headers.get("content-type"); + console.log("[Baidu] request response content type: ", contentType); + + if (contentType?.startsWith("text/plain")) { + responseText = await res.clone().text(); + return finish(); + } + + if ( + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 + ) { + const responseTexts = [responseText]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + + return finish(); + } + }, + onmessage(msg) { + if (msg.data === "[DONE]" || finished) { + return finish(); + } + const text = msg.data; + try { + const json = JSON.parse(text); + const delta = json?.result; + if (delta) { + remainText += delta; + } + } catch (e) { + console.error("[Request] parse error", text, msg); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options.onError?.(e); + throw e; + }, + openWhenHidden: true, + }); + } else { + const res = await fetch(chatPath, chatPayload); + clearTimeout(requestTimeoutId); + + const resJson = await res.json(); + const message = this.extractMessage(resJson); + options.onFinish(message); + } + } catch (e) { + console.log("[Request] failed to make a chat request", e); + options.onError?.(e as Error); + } + } + async usage() { + return { + used: 0, + total: 0, + }; + } + + async models(): Promise { + return []; + } +} +export { Baidu }; diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 7281fc2f12d..ec0060c7245 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -321,6 +321,8 @@ export function PreviewActions(props: { api = new ClientApi(ModelProvider.GeminiPro); } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); + } else if (config.modelConfig.providerName == ServiceProvider.Baidu) { + api = new ClientApi(ModelProvider.Ernie); } else { api = new ClientApi(ModelProvider.GPT); } diff --git a/app/components/home.tsx b/app/components/home.tsx index addb5e80373..00af1f4ba77 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -175,6 +175,8 @@ export function useLoadData() { api = new ClientApi(ModelProvider.GeminiPro); } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); + } else if (config.modelConfig.providerName == ServiceProvider.Baidu) { + api = new ClientApi(ModelProvider.Ernie); } else { api = new ClientApi(ModelProvider.GPT); } diff --git a/app/components/settings.tsx b/app/components/settings.tsx index db08b48a9ff..7db09940d4b 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -53,6 +53,7 @@ import Link from "next/link"; import { Anthropic, Azure, + Baidu, Google, OPENAI_BASE_URL, Path, @@ -1187,6 +1188,67 @@ export function Settings() { )} + {accessStore.provider === ServiceProvider.Baidu && ( + <> + + + accessStore.update( + (access) => + (access.baiduUrl = e.currentTarget.value), + ) + } + > + + + { + accessStore.update( + (access) => + (access.baiduApiKey = e.currentTarget.value), + ); + }} + /> + + + { + accessStore.update( + (access) => + (access.baiduSecretKey = e.currentTarget.value), + ); + }} + /> + + + )} )} diff --git a/app/config/server.ts b/app/config/server.ts index b7c85ce6a5f..2d09c547961 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -35,6 +35,16 @@ declare global { // google tag manager GTM_ID?: string; + // anthropic only + ANTHROPIC_URL?: string; + ANTHROPIC_API_KEY?: string; + ANTHROPIC_API_VERSION?: string; + + // baidu only + BAIDU_URL?: string; + BAIDU_API_KEY?: string; + BAIDU_SECRET_KEY?: string; + // custom template for preprocessing user input DEFAULT_INPUT_TEMPLATE?: string; } @@ -92,7 +102,7 @@ export const getServerSideConfig = () => { const isAzure = !!process.env.AZURE_URL; const isGoogle = !!process.env.GOOGLE_API_KEY; const isAnthropic = !!process.env.ANTHROPIC_API_KEY; - + const isBaidu = !!process.env.BAIDU_API_KEY; // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? ""; // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); // const randomIndex = Math.floor(Math.random() * apiKeys.length); @@ -124,6 +134,11 @@ export const getServerSideConfig = () => { anthropicApiVersion: process.env.ANTHROPIC_API_VERSION, anthropicUrl: process.env.ANTHROPIC_URL, + isBaidu, + baiduUrl: process.env.BAIDU_URL, + baiduApiKey: getApiKey(process.env.BAIDU_API_KEY), + baiduSecretKey: process.env.BAIDU_SECRET_KEY, + gtmId: process.env.GTM_ID, needCode: ACCESS_CODES.size > 0, diff --git a/app/constant.ts b/app/constant.ts index d44b5b8173b..6ffc0e0b3f8 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -14,6 +14,10 @@ export const ANTHROPIC_BASE_URL = "https://api.anthropic.com"; export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; +export const BAIDU_BASE_URL = "https://aip.baidubce.com"; + +export const BAIDU_OATUH_URL = `${BAIDU_BASE_URL}/oauth/2.0/token`; + export enum Path { Home = "/", Chat = "/chat", @@ -28,6 +32,7 @@ export enum ApiPath { Azure = "/api/azure", OpenAI = "/api/openai", Anthropic = "/api/anthropic", + Baidu = "/api/baidu", } export enum SlotID { @@ -71,12 +76,14 @@ export enum ServiceProvider { Azure = "Azure", Google = "Google", Anthropic = "Anthropic", + Baidu = "Baidu", } export enum ModelProvider { GPT = "GPT", GeminiPro = "GeminiPro", Claude = "Claude", + Ernie = "Ernie", } export const Anthropic = { @@ -104,6 +111,12 @@ export const Google = { ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, }; +export const Baidu = { + ExampleEndpoint: "https://aip.baidubce.com", + ChatPath: (modelName: string) => + `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${modelName}`, +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang // export const DEFAULT_SYSTEM_TEMPLATE = ` // You are ChatGPT, a large language model trained by {{ServiceProvider}}. @@ -173,6 +186,16 @@ const anthropicModels = [ "claude-3-5-sonnet-20240620", ]; +const baiduModels = [ + "ernie-4.0-turbo-8k", + "completions_pro=ernie-4.0-8k", + "ernie-4.0-8k-preview", + "completions_adv_pro=ernie-4.0-8k-preview-0518", + "ernie-4.0-8k-latest", + "completions=ernie-3.5-8k", + "ernie-3.5-8k-0205", +]; + export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ name, @@ -210,6 +233,15 @@ export const DEFAULT_MODELS = [ providerType: "anthropic", }, })), + ...baiduModels.map((name) => ({ + name, + available: true, + provider: { + id: "baidu", + providerName: "Baidu", + providerType: "baidu", + }, + })), ] as const; export const CHAT_PAGE_SIZE = 15; diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 2ff94e32d43..a872ee75ada 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -347,6 +347,22 @@ const cn = { SubTitle: "选择一个特定的 API 版本", }, }, + Baidu: { + ApiKey: { + Title: "接口密钥", + SubTitle: "使用自定义 Baidu API Key 绕过密码访问限制", + Placeholder: "Baidu API Key", + }, + SecretKey: { + Title: "接口密钥", + SubTitle: "使用自定义 Baidu Secret Key 绕过密码访问限制", + Placeholder: "Baidu Secret Key", + }, + Endpoint: { + Title: "接口地址", + SubTitle: "样例:", + }, + }, CustomModel: { Title: "自定义模型名", SubTitle: "增加自定义模型可选项,使用英文逗号隔开", diff --git a/app/store/access.ts b/app/store/access.ts index 03780779e72..7e6d01b3461 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -47,6 +47,11 @@ const DEFAULT_ACCESS_STATE = { anthropicApiVersion: "2023-06-01", anthropicUrl: "", + // baidu + baiduUrl: "", + baiduApiKey: "", + baiduSecretKey: "", + // server config needCode: true, hideUserApiKey: false, @@ -83,6 +88,10 @@ export const useAccessStore = createPersistStore( return ensure(get(), ["anthropicApiKey"]); }, + isValidBaidu() { + return ensure(get(), ["baiduApiKey", "baiduSecretKey"]); + }, + isAuthorized() { this.fetch(); @@ -92,6 +101,7 @@ export const useAccessStore = createPersistStore( this.isValidAzure() || this.isValidGoogle() || this.isValidAnthropic() || + this.isValidBaidu() || !this.enabledAccessControl() || (this.enabledAccessControl() && ensure(get(), ["accessCode"])) ); diff --git a/app/store/chat.ts b/app/store/chat.ts index 44d41830a63..45ab479d9d0 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -368,6 +368,8 @@ export const useChatStore = createPersistStore( api = new ClientApi(ModelProvider.GeminiPro); } else if (modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); + } else if (modelConfig.providerName == ServiceProvider.Baidu) { + api = new ClientApi(ModelProvider.Ernie); } else { api = new ClientApi(ModelProvider.GPT); } @@ -552,6 +554,8 @@ export const useChatStore = createPersistStore( api = new ClientApi(ModelProvider.GeminiPro); } else if (modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); + } else if (modelConfig.providerName == ServiceProvider.Baidu) { + api = new ClientApi(ModelProvider.Ernie); } else { api = new ClientApi(ModelProvider.GPT); } diff --git a/app/utils/model.ts b/app/utils/model.ts index 249987726ad..6a02ed7eb4b 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -24,10 +24,13 @@ export function collectModelTable( // default models models.forEach((m) => { + // supoort name=displayName eg:completions_pro=ernie-4.0-8k + const [name, displayName] = m.name?.split("="); // using @ as fullName - modelTable[`${m.name}@${m?.provider?.id}`] = { + modelTable[`${name}@${m?.provider?.id}`] = { ...m, - displayName: m.name, // 'provider' is copied over if it exists + name, + displayName: displayName || name, // 'provider' is copied over if it exists }; }); From 9b3b4494ba6ff6a517ca17376d2550b1aa651c00 Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Sat, 6 Jul 2024 14:59:37 +0800 Subject: [PATCH 02/27] wip: doubao --- app/api/auth.ts | 3 + app/api/bytedance/[...path]/route.ts | 160 +++++++++++++++++ app/client/api.ts | 5 + app/client/platforms/bytedance.ts | 260 +++++++++++++++++++++++++++ app/components/exporter.tsx | 2 + app/components/home.tsx | 2 + app/config/server.ts | 9 + app/constant.ts | 21 +++ app/store/access.ts | 9 + app/store/chat.ts | 4 + 10 files changed, 475 insertions(+) create mode 100644 app/api/bytedance/[...path]/route.ts create mode 100644 app/client/platforms/bytedance.ts diff --git a/app/api/auth.ts b/app/api/auth.ts index 2b4702aedc3..9c334f2fed9 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { case ModelProvider.Claude: systemApiKey = serverConfig.anthropicApiKey; break; + case ModelProvider.Doubao: + systemApiKey = serverConfig.bytedanceApiKey; + break; case ModelProvider.GPT: default: if (req.nextUrl.pathname.includes("azure/deployments")) { diff --git a/app/api/bytedance/[...path]/route.ts b/app/api/bytedance/[...path]/route.ts new file mode 100644 index 00000000000..bffb60f6c3a --- /dev/null +++ b/app/api/bytedance/[...path]/route.ts @@ -0,0 +1,160 @@ +import { getServerSideConfig } from "@/app/config/server"; +import { + BYTEDANCE_BASE_URL, + ApiPath, + ModelProvider, + ServiceProvider, +} from "@/app/constant"; +import { prettyObject } from "@/app/utils/format"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/app/api/auth"; +import { isModelAvailableInServer } from "@/app/utils/model"; + +const serverConfig = getServerSideConfig(); + +async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[ByteDance Route] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const authResult = auth(req, ModelProvider.Doubao); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + const response = await request(req); + return response; + } catch (e) { + console.error("[ByteDance] ", e); + return NextResponse.json(prettyObject(e)); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + +async function request(req: NextRequest) { + const controller = new AbortController(); + + let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.ByteDance, ""); + + let baseUrl = serverConfig.bytedanceUrl || BYTEDANCE_BASE_URL; + + if (!baseUrl.startsWith("http")) { + baseUrl = `https://${baseUrl}`; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, -1); + } + + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + const fetchUrl = `${baseUrl}${path}`; + + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + Authorization: req.headers.get("Authorization") ?? "", + }, + method: req.method, + body: req.body, + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + // #1815 try to refuse some request to some models + if (serverConfig.customModels && req.body) { + try { + const clonedBody = await req.text(); + fetchOptions.body = clonedBody; + + const jsonBody = JSON.parse(clonedBody) as { model?: string }; + + // not undefined and is false + if ( + isModelAvailableInServer( + serverConfig.customModels, + jsonBody?.model as string, + ServiceProvider.ByteDance as string, + ) + ) { + return NextResponse.json( + { + error: true, + message: `you are not allowed to use ${jsonBody?.model} model`, + }, + { + status: 403, + }, + ); + } + } catch (e) { + console.error(`[ByteDance] filter`, e); + } + } + console.log("[ByteDance request]", fetchOptions.headers, req.method); + try { + const res = await fetch(fetchUrl, fetchOptions); + + console.log( + "[ByteDance response]", + res.status, + " ", + res.headers, + res.url, + ); + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + return new Response(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } +} diff --git a/app/client/api.ts b/app/client/api.ts index 41ccbd8e1c0..ee43fc7cc12 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -9,6 +9,8 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; import { ChatGPTApi } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; +import { DoubaoApi } from "./platforms/bytedance"; + export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -104,6 +106,9 @@ export class ClientApi { case ModelProvider.Claude: this.llm = new ClaudeApi(); break; + case ModelProvider.Doubao: + this.llm = new DoubaoApi(); + break; default: this.llm = new ChatGPTApi(); } diff --git a/app/client/platforms/bytedance.ts b/app/client/platforms/bytedance.ts new file mode 100644 index 00000000000..92c1fd55826 --- /dev/null +++ b/app/client/platforms/bytedance.ts @@ -0,0 +1,260 @@ +"use client"; +import { + ApiPath, + ByteDance, + DEFAULT_API_HOST, + REQUEST_TIMEOUT_MS, +} from "@/app/constant"; +import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; + +import { + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + MultimodalContent, +} from "../api"; +import Locale from "../../locales"; +import { + EventStreamContentType, + fetchEventSource, +} from "@fortaine/fetch-event-source"; +import { prettyObject } from "@/app/utils/format"; +import { getClientConfig } from "@/app/config/client"; +import { getMessageTextContent, isVisionModel } from "@/app/utils"; + +export interface OpenAIListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; +} + +interface RequestPayload { + messages: { + role: "system" | "user" | "assistant"; + content: string | MultimodalContent[]; + }[]; + stream?: boolean; + model: string; + temperature: number; + presence_penalty: number; + frequency_penalty: number; + top_p: number; + max_tokens?: number; +} + +export class DoubaoApi implements LLMApi { + path(path: string): string { + const accessStore = useAccessStore.getState(); + + let baseUrl = ""; + + if (accessStore.useCustomConfig) { + baseUrl = accessStore.bytedanceUrl; + } + + if (baseUrl.length === 0) { + const isApp = !!getClientConfig()?.isApp; + baseUrl = isApp + ? DEFAULT_API_HOST + "/api/proxy/bytedance" + : ApiPath.ByteDance; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.ByteDance)) { + baseUrl = "https://" + baseUrl; + } + + console.log("[Proxy Endpoint] ", baseUrl, path); + + return [baseUrl, path].join("/"); + } + + extractMessage(res: any) { + return res.choices?.at(0)?.message?.content ?? ""; + } + + async chat(options: ChatOptions) { + const visionModel = isVisionModel(options.config.model); + const messages = options.messages.map((v) => ({ + role: v.role, + content: visionModel ? v.content : getMessageTextContent(v), + })); + + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.config.model, + }, + }; + + const requestPayload: RequestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + }; + + console.log("[Request] ByteDance payload: ", requestPayload); + + const shouldStream = !!options.config.stream; + const controller = new AbortController(); + options.onController?.(controller); + + try { + const chatPath = this.path(ByteDance.ChatPath); + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), + }; + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + if (shouldStream) { + let responseText = ""; + let remainText = ""; + let finished = false; + + // animate response to make it looks smooth + function animateResponseText() { + if (finished || controller.signal.aborted) { + responseText += remainText; + console.log("[Response Animation] finished"); + if (responseText?.length === 0) { + options.onError?.(new Error("empty response from server")); + } + return; + } + + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + options.onUpdate?.(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + } + + // start animaion + animateResponseText(); + + const finish = () => { + if (!finished) { + finished = true; + options.onFinish(responseText + remainText); + } + }; + + controller.signal.onabort = finish; + + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + clearTimeout(requestTimeoutId); + const contentType = res.headers.get("content-type"); + console.log( + "[ByteDance] request response content type: ", + contentType, + ); + + if (contentType?.startsWith("text/plain")) { + responseText = await res.clone().text(); + return finish(); + } + + if ( + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 + ) { + const responseTexts = [responseText]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + + return finish(); + } + }, + onmessage(msg) { + if (msg.data === "[DONE]" || finished) { + return finish(); + } + const text = msg.data; + try { + const json = JSON.parse(text); + const choices = json.choices as Array<{ + delta: { content: string }; + }>; + const delta = choices[0]?.delta?.content; + if (delta) { + remainText += delta; + } + } catch (e) { + console.error("[Request] parse error", text, msg); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options.onError?.(e); + throw e; + }, + openWhenHidden: true, + }); + } else { + const res = await fetch(chatPath, chatPayload); + clearTimeout(requestTimeoutId); + + const resJson = await res.json(); + const message = this.extractMessage(resJson); + options.onFinish(message); + } + } catch (e) { + console.log("[Request] failed to make a chat request", e); + options.onError?.(e as Error); + } + } + async usage() { + return { + used: 0, + total: 0, + }; + } + + async models(): Promise { + return []; + } +} +export { ByteDance }; diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 7281fc2f12d..1cc531eb88b 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -321,6 +321,8 @@ export function PreviewActions(props: { api = new ClientApi(ModelProvider.GeminiPro); } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); + } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) { + api = new ClientApi(ModelProvider.Doubao); } else { api = new ClientApi(ModelProvider.GPT); } diff --git a/app/components/home.tsx b/app/components/home.tsx index addb5e80373..7da20df2256 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -175,6 +175,8 @@ export function useLoadData() { api = new ClientApi(ModelProvider.GeminiPro); } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); + } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) { + api = new ClientApi(ModelProvider.Doubao); } else { api = new ClientApi(ModelProvider.GPT); } diff --git a/app/config/server.ts b/app/config/server.ts index b7c85ce6a5f..d50dbf1a1e4 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -32,6 +32,10 @@ declare global { GOOGLE_API_KEY?: string; GOOGLE_URL?: string; + // bytedance only + BYTEDANCE_URL?: string; + BYTEDANCE_API_KEY?: string; + // google tag manager GTM_ID?: string; @@ -92,6 +96,7 @@ export const getServerSideConfig = () => { const isAzure = !!process.env.AZURE_URL; const isGoogle = !!process.env.GOOGLE_API_KEY; const isAnthropic = !!process.env.ANTHROPIC_API_KEY; + const isBytedance = !!process.env.BYTEDANCE_API_KEY; // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? ""; // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); @@ -126,6 +131,10 @@ export const getServerSideConfig = () => { gtmId: process.env.GTM_ID, + isBytedance, + bytedanceApiKey: getApiKey(process.env.BYTEDANCE_API_KEY), + bytedanceUrl: process.env.BYTEDANCE_URL, + needCode: ACCESS_CODES.size > 0, code: process.env.CODE, codes: ACCESS_CODES, diff --git a/app/constant.ts b/app/constant.ts index d44b5b8173b..1ed292d219a 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -14,6 +14,8 @@ export const ANTHROPIC_BASE_URL = "https://api.anthropic.com"; export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; +export const BYTEDANCE_BASE_URL = "https://ark.cn-beijing.volces.com"; + export enum Path { Home = "/", Chat = "/chat", @@ -28,6 +30,7 @@ export enum ApiPath { Azure = "/api/azure", OpenAI = "/api/openai", Anthropic = "/api/anthropic", + ByteDance = "/api/bytedance", } export enum SlotID { @@ -71,12 +74,14 @@ export enum ServiceProvider { Azure = "Azure", Google = "Google", Anthropic = "Anthropic", + ByteDance = "ByteDance", } export enum ModelProvider { GPT = "GPT", GeminiPro = "GeminiPro", Claude = "Claude", + Doubao = "Doubao", } export const Anthropic = { @@ -104,6 +109,11 @@ export const Google = { ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, }; +export const ByteDance = { + ExampleEndpoint: "https://ark.cn-beijing.volces.com/api/v3/chat/completions", + ChatPath: "/api/v3/chat/completions", +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang // export const DEFAULT_SYSTEM_TEMPLATE = ` // You are ChatGPT, a large language model trained by {{ServiceProvider}}. @@ -173,6 +183,8 @@ const anthropicModels = [ "claude-3-5-sonnet-20240620", ]; +const bytedanceModels = ["ep-20240520082937-424bw=Doubao-lite-4k"]; + export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ name, @@ -210,6 +222,15 @@ export const DEFAULT_MODELS = [ providerType: "anthropic", }, })), + ...bytedanceModels.map((name) => ({ + name, + available: true, + provider: { + id: "bytedance", + providerName: "ByteDance", + providerType: "bytedance", + }, + })), ] as const; export const CHAT_PAGE_SIZE = 15; diff --git a/app/store/access.ts b/app/store/access.ts index 03780779e72..b04748b8cef 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -47,6 +47,10 @@ const DEFAULT_ACCESS_STATE = { anthropicApiVersion: "2023-06-01", anthropicUrl: "", + // bytedance + bytedanceApiKey: "", + bytedanceUrl: "", + // server config needCode: true, hideUserApiKey: false, @@ -83,6 +87,10 @@ export const useAccessStore = createPersistStore( return ensure(get(), ["anthropicApiKey"]); }, + isValidByteDance() { + return ensure(get(), ["bytedanceApiKey"]); + }, + isAuthorized() { this.fetch(); @@ -92,6 +100,7 @@ export const useAccessStore = createPersistStore( this.isValidAzure() || this.isValidGoogle() || this.isValidAnthropic() || + this.isValidByteDance() || !this.enabledAccessControl() || (this.enabledAccessControl() && ensure(get(), ["accessCode"])) ); diff --git a/app/store/chat.ts b/app/store/chat.ts index 44d41830a63..475d436d972 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -368,6 +368,8 @@ export const useChatStore = createPersistStore( api = new ClientApi(ModelProvider.GeminiPro); } else if (modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); + } else if (modelConfig.providerName == ServiceProvider.ByteDance) { + api = new ClientApi(ModelProvider.Doubao); } else { api = new ClientApi(ModelProvider.GPT); } @@ -552,6 +554,8 @@ export const useChatStore = createPersistStore( api = new ClientApi(ModelProvider.GeminiPro); } else if (modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); + } else if (modelConfig.providerName == ServiceProvider.ByteDance) { + api = new ClientApi(ModelProvider.Doubao); } else { api = new ClientApi(ModelProvider.GPT); } From f3e3f083774ab01db558a213a0b180fe995ad2c4 Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Sat, 6 Jul 2024 21:25:00 +0800 Subject: [PATCH 03/27] fix: apiClient --- app/client/api.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/client/api.ts b/app/client/api.ts index a3d5a36e0c5..f650139f906 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -225,6 +225,8 @@ export function getClientApi(provider: ServiceProvider): ClientApi { return new ClientApi(ModelProvider.GeminiPro); case ServiceProvider.Anthropic: return new ClientApi(ModelProvider.Claude); + case ServiceProvider.Baidu: + return new ClientApi(ModelProvider.Ernie); default: return new ClientApi(ModelProvider.GPT); } From 1caa61f4c0e8d35bfff2dd670925f8c1ceb8267a Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Sat, 6 Jul 2024 22:59:20 +0800 Subject: [PATCH 04/27] feat: swap name and displayName for bytedance in custom models --- app/client/api.ts | 2 ++ app/config/server.ts | 6 +++--- app/constant.ts | 9 ++++++++- app/utils/model.ts | 12 ++++++++++-- 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index d2eeca46a32..f2e83c391ce 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -225,6 +225,8 @@ export function getClientApi(provider: ServiceProvider): ClientApi { return new ClientApi(ModelProvider.GeminiPro); case ServiceProvider.Anthropic: return new ClientApi(ModelProvider.Claude); + case ServiceProvider.ByteDance: + return new ClientApi(ModelProvider.Doubao); default: return new ClientApi(ModelProvider.GPT); } diff --git a/app/config/server.ts b/app/config/server.ts index d50dbf1a1e4..0f57d2d6d44 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -32,13 +32,13 @@ declare global { GOOGLE_API_KEY?: string; GOOGLE_URL?: string; + // google tag manager + GTM_ID?: string; + // bytedance only BYTEDANCE_URL?: string; BYTEDANCE_API_KEY?: string; - // google tag manager - GTM_ID?: string; - // custom template for preprocessing user input DEFAULT_INPUT_TEMPLATE?: string; } diff --git a/app/constant.ts b/app/constant.ts index 1ed292d219a..5b52073bb5a 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -183,7 +183,14 @@ const anthropicModels = [ "claude-3-5-sonnet-20240620", ]; -const bytedanceModels = ["ep-20240520082937-424bw=Doubao-lite-4k"]; +const bytedanceModels = [ + "Doubao-lite-4k", + "Doubao-lite-32k", + "Doubao-lite-128k", + "Doubao-pro-4k", + "Doubao-pro-32k", + "Doubao-pro-128k", +]; export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ diff --git a/app/utils/model.ts b/app/utils/model.ts index 249987726ad..62ecc55b3be 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -39,7 +39,7 @@ export function collectModelTable( const available = !m.startsWith("-"); const nameConfig = m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; - const [name, displayName] = nameConfig.split("="); + let [name, displayName] = nameConfig.split("="); // enable or disable all models if (name === "all") { @@ -50,9 +50,17 @@ export function collectModelTable( // 1. find model by name(), and set available value let count = 0; for (const fullName in modelTable) { - if (fullName.split("@").shift() == name) { + const [modelName, providerName] = fullName.split("@"); + if (modelName === name) { count += 1; modelTable[fullName]["available"] = available; + // swap name and displayName for bytedance + if (providerName === "bytedance") { + const tempName = name; + name = displayName; + displayName = tempName; + modelTable[fullName]["name"] = name; + } if (displayName) { modelTable[fullName]["displayName"] = displayName; } From 9bdd37bb631198f8c75b995b47ba87a1e6639c14 Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Sun, 7 Jul 2024 21:59:56 +0800 Subject: [PATCH 05/27] feat: qwen --- app/api/alibaba/[...path]/route.ts | 175 +++++++++++++++++++ app/api/auth.ts | 3 + app/client/api.ts | 7 + app/client/platforms/alibaba.ts | 260 +++++++++++++++++++++++++++++ app/client/platforms/openai.ts | 2 +- app/config/server.ts | 9 + app/constant.ts | 29 ++++ app/store/access.ts | 9 + 8 files changed, 493 insertions(+), 1 deletion(-) create mode 100644 app/api/alibaba/[...path]/route.ts create mode 100644 app/client/platforms/alibaba.ts diff --git a/app/api/alibaba/[...path]/route.ts b/app/api/alibaba/[...path]/route.ts new file mode 100644 index 00000000000..e30eacbdb7c --- /dev/null +++ b/app/api/alibaba/[...path]/route.ts @@ -0,0 +1,175 @@ +import { getServerSideConfig } from "@/app/config/server"; +import { + Alibaba, + ALIBABA_BASE_URL, + ApiPath, + ModelProvider, + ServiceProvider, +} from "@/app/constant"; +import { prettyObject } from "@/app/utils/format"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/app/api/auth"; +import { isModelAvailableInServer } from "@/app/utils/model"; +import type { RequestPayload } from "@/app/client/platforms/openai"; + +const serverConfig = getServerSideConfig(); + +async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[Alibaba Route] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const authResult = auth(req, ModelProvider.Qwen); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + const response = await request(req); + return response; + } catch (e) { + console.error("[Alibaba] ", e); + return NextResponse.json(prettyObject(e)); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + +async function request(req: NextRequest) { + const controller = new AbortController(); + + // alibaba use base url or just remove the path + let path = `${req.nextUrl.pathname}`.replaceAll( + ApiPath.Alibaba + "/" + Alibaba.ChatPath, + "", + ); + + let baseUrl = serverConfig.alibabaUrl || ALIBABA_BASE_URL; + + if (!baseUrl.startsWith("http")) { + baseUrl = `https://${baseUrl}`; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, -1); + } + + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + const fetchUrl = `${baseUrl}${path}`; + + const clonedBody = await req.text(); + + const { messages, model, stream, top_p, ...rest } = JSON.parse( + clonedBody, + ) as RequestPayload; + + const requestBody = { + model, + input: { + messages, + }, + parameters: { + ...rest, + top_p: top_p === 1 ? 0.99 : top_p, // qwen top_p is should be < 1 + result_format: "message", + incremental_output: true, + }, + }; + + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + Authorization: req.headers.get("Authorization") ?? "", + "X-DashScope-SSE": stream ? "enable" : "disable", + }, + method: req.method, + body: JSON.stringify(requestBody), + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + // #1815 try to refuse some request to some models + if (serverConfig.customModels && req.body) { + try { + // not undefined and is false + if ( + isModelAvailableInServer( + serverConfig.customModels, + model as string, + ServiceProvider.Alibaba as string, + ) + ) { + return NextResponse.json( + { + error: true, + message: `you are not allowed to use ${model} model`, + }, + { + status: 403, + }, + ); + } + } catch (e) { + console.error(`[Alibaba] filter`, e); + } + } + console.log("[Alibaba request]", fetchOptions.headers, req.method); + try { + const res = await fetch(fetchUrl, fetchOptions); + + console.log("[Alibaba response]", res.status, " ", res.headers, res.url); + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + return new Response(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } +} diff --git a/app/api/auth.ts b/app/api/auth.ts index 2b4702aedc3..b9f23d4c4ce 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { case ModelProvider.Claude: systemApiKey = serverConfig.anthropicApiKey; break; + case ModelProvider.Qwen: + systemApiKey = serverConfig.alibabaApiKey; + break; case ModelProvider.GPT: default: if (req.nextUrl.pathname.includes("azure/deployments")) { diff --git a/app/client/api.ts b/app/client/api.ts index 528a5598aec..3677415ce03 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -9,6 +9,8 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; import { ChatGPTApi } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; +import { QwenApi } from "./platforms/alibaba"; + export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -104,6 +106,9 @@ export class ClientApi { case ModelProvider.Claude: this.llm = new ClaudeApi(); break; + case ModelProvider.Qwen: + this.llm = new QwenApi(); + break; default: this.llm = new ChatGPTApi(); } @@ -220,6 +225,8 @@ export function getClientApi(provider: ServiceProvider): ClientApi { return new ClientApi(ModelProvider.GeminiPro); case ServiceProvider.Anthropic: return new ClientApi(ModelProvider.Claude); + case ServiceProvider.Alibaba: + return new ClientApi(ModelProvider.Qwen); default: return new ClientApi(ModelProvider.GPT); } diff --git a/app/client/platforms/alibaba.ts b/app/client/platforms/alibaba.ts new file mode 100644 index 00000000000..eefdb017fd0 --- /dev/null +++ b/app/client/platforms/alibaba.ts @@ -0,0 +1,260 @@ +"use client"; +import { + ApiPath, + Alibaba, + DEFAULT_API_HOST, + REQUEST_TIMEOUT_MS, +} from "@/app/constant"; +import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; + +import { + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + MultimodalContent, +} from "../api"; +import Locale from "../../locales"; +import { + EventStreamContentType, + fetchEventSource, +} from "@fortaine/fetch-event-source"; +import { prettyObject } from "@/app/utils/format"; +import { getClientConfig } from "@/app/config/client"; +import { getMessageTextContent, isVisionModel } from "@/app/utils"; + +export interface OpenAIListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; +} + +interface RequestPayload { + messages: { + role: "system" | "user" | "assistant"; + content: string | MultimodalContent[]; + }[]; + stream?: boolean; + model: string; + temperature: number; + presence_penalty: number; + frequency_penalty: number; + top_p: number; + max_tokens?: number; +} + +export class QwenApi implements LLMApi { + path(path: string): string { + const accessStore = useAccessStore.getState(); + + let baseUrl = ""; + + if (accessStore.useCustomConfig) { + baseUrl = accessStore.alibabaUrl; + } + + if (baseUrl.length === 0) { + const isApp = !!getClientConfig()?.isApp; + baseUrl = isApp + ? DEFAULT_API_HOST + "/api/proxy/alibaba" + : ApiPath.Alibaba; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Alibaba)) { + baseUrl = "https://" + baseUrl; + } + + console.log("[Proxy Endpoint] ", baseUrl, path); + + return [baseUrl, path].join("/"); + } + + extractMessage(res: any) { + return res.choices?.at(0)?.message?.content ?? ""; + } + + async chat(options: ChatOptions) { + const visionModel = isVisionModel(options.config.model); + const messages = options.messages.map((v) => ({ + role: v.role, + content: visionModel ? v.content : getMessageTextContent(v), + })); + + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.config.model, + }, + }; + + const requestPayload: RequestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + }; + + console.log("[Request] Alibaba payload: ", requestPayload); + + const shouldStream = !!options.config.stream; + const controller = new AbortController(); + options.onController?.(controller); + + try { + const chatPath = this.path(Alibaba.ChatPath); + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), + }; + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + if (shouldStream) { + let responseText = ""; + let remainText = ""; + let finished = false; + + // animate response to make it looks smooth + function animateResponseText() { + if (finished || controller.signal.aborted) { + responseText += remainText; + console.log("[Response Animation] finished"); + if (responseText?.length === 0) { + options.onError?.(new Error("empty response from server")); + } + return; + } + + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + options.onUpdate?.(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + } + + // start animaion + animateResponseText(); + + const finish = () => { + if (!finished) { + finished = true; + options.onFinish(responseText + remainText); + } + }; + + controller.signal.onabort = finish; + + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + clearTimeout(requestTimeoutId); + const contentType = res.headers.get("content-type"); + console.log( + "[Alibaba] request response content type: ", + contentType, + ); + + if (contentType?.startsWith("text/plain")) { + responseText = await res.clone().text(); + return finish(); + } + + if ( + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 + ) { + const responseTexts = [responseText]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + + return finish(); + } + }, + onmessage(msg) { + if (msg.data === "[DONE]" || finished) { + return finish(); + } + const text = msg.data; + try { + const json = JSON.parse(text); + const choices = json.output.choices as Array<{ + message: { content: string }; + }>; + const delta = choices[0]?.message?.content; + if (delta) { + remainText += delta; + } + } catch (e) { + console.error("[Request] parse error", text, msg); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options.onError?.(e); + throw e; + }, + openWhenHidden: true, + }); + } else { + const res = await fetch(chatPath, chatPayload); + clearTimeout(requestTimeoutId); + + const resJson = await res.json(); + const message = this.extractMessage(resJson); + options.onFinish(message); + } + } catch (e) { + console.log("[Request] failed to make a chat request", e); + options.onError?.(e as Error); + } + } + async usage() { + return { + used: 0, + total: 0, + }; + } + + async models(): Promise { + return []; + } +} +export { Alibaba }; diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 8615172a311..bba359429fc 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -42,7 +42,7 @@ export interface OpenAIListModelResponse { }>; } -interface RequestPayload { +export interface RequestPayload { messages: { role: "system" | "user" | "assistant"; content: string | MultimodalContent[]; diff --git a/app/config/server.ts b/app/config/server.ts index b7c85ce6a5f..62624a8e4cd 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -35,6 +35,10 @@ declare global { // google tag manager GTM_ID?: string; + // alibaba only + ALIBABA_URL?: string; + ALIBABA_API_KEY?: string; + // custom template for preprocessing user input DEFAULT_INPUT_TEMPLATE?: string; } @@ -92,6 +96,7 @@ export const getServerSideConfig = () => { const isAzure = !!process.env.AZURE_URL; const isGoogle = !!process.env.GOOGLE_API_KEY; const isAnthropic = !!process.env.ANTHROPIC_API_KEY; + const isAlibaba = !!process.env.ALIBABA_API_KEY; // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? ""; // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); @@ -124,6 +129,10 @@ export const getServerSideConfig = () => { anthropicApiVersion: process.env.ANTHROPIC_API_VERSION, anthropicUrl: process.env.ANTHROPIC_URL, + isAlibaba, + alibabaUrl: process.env.ALIBABA_URL, + alibabaApiKey: getApiKey(process.env.ALIBABA_API_KEY), + gtmId: process.env.GTM_ID, needCode: ACCESS_CODES.size > 0, diff --git a/app/constant.ts b/app/constant.ts index d44b5b8173b..01c212b242e 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -14,6 +14,9 @@ export const ANTHROPIC_BASE_URL = "https://api.anthropic.com"; export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; +export const ALIBABA_BASE_URL = + "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"; + export enum Path { Home = "/", Chat = "/chat", @@ -28,6 +31,7 @@ export enum ApiPath { Azure = "/api/azure", OpenAI = "/api/openai", Anthropic = "/api/anthropic", + Alibaba = "/api/alibaba", } export enum SlotID { @@ -71,12 +75,14 @@ export enum ServiceProvider { Azure = "Azure", Google = "Google", Anthropic = "Anthropic", + Alibaba = "Alibaba", } export enum ModelProvider { GPT = "GPT", GeminiPro = "GeminiPro", Claude = "Claude", + Qwen = "Qwen", } export const Anthropic = { @@ -104,6 +110,10 @@ export const Google = { ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, }; +export const Alibaba = { + ChatPath: "chat/completions", +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang // export const DEFAULT_SYSTEM_TEMPLATE = ` // You are ChatGPT, a large language model trained by {{ServiceProvider}}. @@ -173,6 +183,16 @@ const anthropicModels = [ "claude-3-5-sonnet-20240620", ]; +const alibabaModes = [ + "qwen-turbo", + "qwen-plus", + "qwen-max", + "qwen-max-0428", + "qwen-max-0403", + "qwen-max-0107", + "qwen-max-longcontext", +]; + export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ name, @@ -210,6 +230,15 @@ export const DEFAULT_MODELS = [ providerType: "anthropic", }, })), + ...alibabaModes.map((name) => ({ + name, + available: true, + provider: { + id: "alibaba", + providerName: "Alibaba", + providerType: "alibaba", + }, + })), ] as const; export const CHAT_PAGE_SIZE = 15; diff --git a/app/store/access.ts b/app/store/access.ts index 03780779e72..5ea45904952 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -47,6 +47,10 @@ const DEFAULT_ACCESS_STATE = { anthropicApiVersion: "2023-06-01", anthropicUrl: "", + // alibaba + alibabaUrl: "", + alibabaApiKey: "", + // server config needCode: true, hideUserApiKey: false, @@ -83,6 +87,10 @@ export const useAccessStore = createPersistStore( return ensure(get(), ["anthropicApiKey"]); }, + isValidAlibaba() { + return ensure(get(), ["alibabaApiKey"]); + }, + isAuthorized() { this.fetch(); @@ -92,6 +100,7 @@ export const useAccessStore = createPersistStore( this.isValidAzure() || this.isValidGoogle() || this.isValidAnthropic() || + this.isValidAlibaba() || !this.enabledAccessControl() || (this.enabledAccessControl() && ensure(get(), ["accessCode"])) ); From 71af2628eb8d791070fc2b4818f6f46c9068c962 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 00:32:18 +0800 Subject: [PATCH 06/27] hotfix: old AZURE_URL config error: "DeploymentNotFound". #4945 #4930 --- app/api/common.ts | 25 +++++++++++++++++++++++++ app/utils/model.ts | 10 ++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/app/api/common.ts b/app/api/common.ts index b2fae6df24f..5223646d264 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -66,6 +66,31 @@ export async function requestOpenai(req: NextRequest) { "/api/azure/", "", )}?api-version=${azureApiVersion}`; + + // Forward compatibility: + // if display_name(deployment_name) not set, and '{deploy-id}' in AZURE_URL + // then using default '{deploy-id}' + if (serverConfig.customModels) { + const modelName = path.split("/")[1]; + let realDeployName = ""; + serverConfig.customModels + .split(",") + .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName)) + .forEach((m) => { + const [fullName, displayName] = m.split("="); + const [_, providerName] = fullName.split("@"); + if (providerName === "azure" && !displayName) { + const [_, deployId] = serverConfig.azureUrl.split("deployments/"); + if (deployId) { + realDeployName = deployId; + } + } + }); + if (realDeployName) { + console.log("[Replace with DeployId", realDeployName); + path = path.replaceAll(modelName, realDeployName); + } + } } const fetchUrl = `${baseUrl}/${path}`; diff --git a/app/utils/model.ts b/app/utils/model.ts index 249987726ad..0b160f1013b 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -47,10 +47,16 @@ export function collectModelTable( (model) => (model.available = available), ); } else { - // 1. find model by name(), and set available value + // 1. find model by name, and set available value + const [customModelName, customProviderName] = name.split("@"); let count = 0; for (const fullName in modelTable) { - if (fullName.split("@").shift() == name) { + const [modelName, providerName] = fullName.split("@"); + if ( + customModelName == modelName && + (customProviderName === undefined || + customProviderName === providerName) + ) { count += 1; modelTable[fullName]["available"] = available; if (displayName) { From 34ab37f31e1fe968c86a4ddc8421a1bfe6a20a27 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 00:47:35 +0800 Subject: [PATCH 07/27] update CUSTOM_MODELS config for Azure mode. --- README.md | 4 ++++ README_CN.md | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/README.md b/README.md index c77d2023c97..2cac1088a6b 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,7 @@ Specify OpenAI organization ID. ### `AZURE_URL` (optional) > Example: https://{azure-resource-url}/openai/deployments/{deploy-name} +> if you config deployment name in `CUSTOM_MODELS`, you can remove `{deploy-name}` in `AZURE_URL` Azure deploy url. @@ -245,6 +246,9 @@ To control custom models, use `+` to add a custom model, use `-` to hide a model User `-all` to disable all default models, `+all` to enable all default models. +For Azure: use `modelName@azure=deploymentName` to customize model name and deployment name. +> Example: `+gpt-3.5-turbo@azure=gpt35` will show option `gpt35(Azure)` in model list. + ### `DEFAULT_MODEL` (optional) Change default model diff --git a/README_CN.md b/README_CN.md index 970ecdef2f9..c6cbf653978 100644 --- a/README_CN.md +++ b/README_CN.md @@ -95,6 +95,7 @@ OpenAI 接口代理 URL,如果你手动配置了 openai 接口代理,请填 ### `AZURE_URL` (可选) > 形如:https://{azure-resource-url}/openai/deployments/{deploy-name} +> 如果你已经在`CUSTOM_MODELS`中参考`displayName`的方式配置了{deploy-name},那么可以从`AZURE_URL`中移除`{deploy-name}` Azure 部署地址。 @@ -156,6 +157,10 @@ anthropic claude Api Url. 用来控制模型列表,使用 `+` 增加一个模型,使用 `-` 来隐藏一个模型,使用 `模型名=展示名` 来自定义模型的展示名,用英文逗号隔开。 +在Azure的模式下,支持使用`modelName@azure=deploymentName`的方式配置模型名称和部署名称(deploy-name) +> 示例:`+gpt-3.5-turbo@azure=gpt35`这个配置会在模型列表显示一个`gpt35(Azure)`的选项 + + ### `DEFAULT_MODEL` (可选) 更改默认模型 From 6ac9789a1c4065c19cdd1bab7a808fbc54c0b1a2 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 12:16:37 +0800 Subject: [PATCH 08/27] hotfix --- app/store/config.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/store/config.ts b/app/store/config.ts index 4b0a34f4f08..1eaafe12b1d 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -49,7 +49,7 @@ export const DEFAULT_CONFIG = { modelConfig: { model: "gpt-3.5-turbo" as ModelType, - providerName: "Openai" as ServiceProvider, + providerName: "OpenAI" as ServiceProvider, temperature: 0.5, top_p: 1, max_tokens: 4000, From f68cd2c5c04a33dda4187ee7db4bbfb4026b9e40 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 12:27:44 +0800 Subject: [PATCH 09/27] review code --- app/client/platforms/baidu.ts | 10 +++++----- app/constant.ts | 23 +++++++++++++++++------ app/utils/model.ts | 6 ++---- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts index e2f6f12dd22..4fc3d2f6462 100644 --- a/app/client/platforms/baidu.ts +++ b/app/client/platforms/baidu.ts @@ -2,7 +2,7 @@ import { ApiPath, Baidu, - DEFAULT_API_HOST, + BAIDU_BASE_URL, REQUEST_TIMEOUT_MS, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; @@ -21,7 +21,7 @@ import { } from "@fortaine/fetch-event-source"; import { prettyObject } from "@/app/utils/format"; import { getClientConfig } from "@/app/config/client"; -import { getMessageTextContent, isVisionModel } from "@/app/utils"; +import { getMessageTextContent } from "@/app/utils"; export interface OpenAIListModelResponse { object: string; @@ -58,7 +58,8 @@ export class ErnieApi implements LLMApi { if (baseUrl.length === 0) { const isApp = !!getClientConfig()?.isApp; - baseUrl = isApp ? DEFAULT_API_HOST + "/api/proxy/baidu" : ApiPath.Baidu; + // do not use proxy for baidubce api + baseUrl = isApp ? BAIDU_BASE_URL : ApiPath.Baidu; } if (baseUrl.endsWith("/")) { @@ -78,10 +79,9 @@ export class ErnieApi implements LLMApi { } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); const messages = options.messages.map((v) => ({ role: v.role, - content: visionModel ? v.content : getMessageTextContent(v), + content: getMessageTextContent(v), })); const modelConfig = { diff --git a/app/constant.ts b/app/constant.ts index 6ffc0e0b3f8..0fd4d1c2492 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -112,9 +112,20 @@ export const Google = { }; export const Baidu = { - ExampleEndpoint: "https://aip.baidubce.com", - ChatPath: (modelName: string) => - `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${modelName}`, + ExampleEndpoint: BAIDU_BASE_URL, + ChatPath: (modelName: string) => { + let endpoint = modelName; + if (modelName === "ernie-4.0-8k") { + endpoint = "completions_pro"; + } + if (modelName === "ernie-4.0-8k-preview-0518") { + endpoint = "completions_adv_pro"; + } + if (modelName === "ernie-3.5-8k") { + endpoint = "completions"; + } + return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`; + }, }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang @@ -188,11 +199,11 @@ const anthropicModels = [ const baiduModels = [ "ernie-4.0-turbo-8k", - "completions_pro=ernie-4.0-8k", + "ernie-4.0-8k", "ernie-4.0-8k-preview", - "completions_adv_pro=ernie-4.0-8k-preview-0518", + "ernie-4.0-8k-preview-0518", "ernie-4.0-8k-latest", - "completions=ernie-3.5-8k", + "ernie-3.5-8k", "ernie-3.5-8k-0205", ]; diff --git a/app/utils/model.ts b/app/utils/model.ts index 6a02ed7eb4b..7c778888e15 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -24,13 +24,11 @@ export function collectModelTable( // default models models.forEach((m) => { - // supoort name=displayName eg:completions_pro=ernie-4.0-8k - const [name, displayName] = m.name?.split("="); // using @ as fullName - modelTable[`${name}@${m?.provider?.id}`] = { + modelTable[`${m.name}@${m?.provider?.id}`] = { ...m, name, - displayName: displayName || name, // 'provider' is copied over if it exists + displayName: m.name, // 'provider' is copied over if it exists }; }); From 011b76e4e720be49db847a12ba02a78961a0159e Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 13:39:39 +0800 Subject: [PATCH 10/27] review code --- app/utils/model.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index 7c778888e15..249987726ad 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -27,7 +27,6 @@ export function collectModelTable( // using @ as fullName modelTable[`${m.name}@${m?.provider?.id}`] = { ...m, - name, displayName: m.name, // 'provider' is copied over if it exists }; }); From fadd7f6eb4cb9d70fb9758ee52c85aac768dc1be Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 14:50:40 +0800 Subject: [PATCH 11/27] try getAccessToken in app, fixbug to fetch in none stream mode --- app/api/baidu/[...path]/route.ts | 41 +++++++++++++------------------- app/client/platforms/baidu.ts | 37 +++++++++++++++++++++------- app/constant.ts | 2 +- 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/app/api/baidu/[...path]/route.ts b/app/api/baidu/[...path]/route.ts index 27676d29df8..5444ba4fe85 100644 --- a/app/api/baidu/[...path]/route.ts +++ b/app/api/baidu/[...path]/route.ts @@ -10,6 +10,7 @@ import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; import { isModelAvailableInServer } from "@/app/utils/model"; +import { getAccessToken } from "@/app/utils/baidu"; const serverConfig = getServerSideConfig(); @@ -30,6 +31,18 @@ async function handle( }); } + if (!serverConfig.baiduApiKey || !serverConfig.baiduSecretKey) { + return NextResponse.json( + { + error: true, + message: `missing BAIDU_API_KEY or BAIDU_SECRET_KEY in server env vars`, + }, + { + status: 401, + }, + ); + } + try { const response = await request(req); return response; @@ -88,7 +101,10 @@ async function request(req: NextRequest) { 10 * 60 * 1000, ); - const { access_token } = await getAccessToken(); + const { access_token } = await getAccessToken( + serverConfig.baiduApiKey, + serverConfig.baiduSecretKey, + ); const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`; const fetchOptions: RequestInit = { @@ -133,11 +149,9 @@ async function request(req: NextRequest) { console.error(`[Baidu] filter`, e); } } - console.log("[Baidu request]", fetchOptions.headers, req.method); try { const res = await fetch(fetchUrl, fetchOptions); - console.log("[Baidu response]", res.status, " ", res.headers, res.url); // to prevent browser prompt for credentials const newHeaders = new Headers(res.headers); newHeaders.delete("www-authenticate"); @@ -153,24 +167,3 @@ async function request(req: NextRequest) { clearTimeout(timeoutId); } } - -/** - * 使用 AK,SK 生成鉴权签名(Access Token) - * @return 鉴权签名信息 - */ -async function getAccessToken(): Promise<{ - access_token: string; - expires_in: number; - error?: number; -}> { - const AK = serverConfig.baiduApiKey; - const SK = serverConfig.baiduSecretKey; - const res = await fetch( - `${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`, - { - method: "POST", - }, - ); - const resJson = await res.json(); - return resJson; -} diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts index 4fc3d2f6462..188b78bf963 100644 --- a/app/client/platforms/baidu.ts +++ b/app/client/platforms/baidu.ts @@ -6,6 +6,7 @@ import { REQUEST_TIMEOUT_MS, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import { getAccessToken } from "@/app/utils/baidu"; import { ChatOptions, @@ -74,16 +75,20 @@ export class ErnieApi implements LLMApi { return [baseUrl, path].join("/"); } - extractMessage(res: any) { - return res.choices?.at(0)?.message?.content ?? ""; - } - async chat(options: ChatOptions) { const messages = options.messages.map((v) => ({ role: v.role, content: getMessageTextContent(v), })); + // "error_code": 336006, "error_msg": "the length of messages must be an odd number", + if (messages.length % 2 === 0) { + messages.unshift({ + role: "user", + content: " ", + }); + } + const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, @@ -92,9 +97,10 @@ export class ErnieApi implements LLMApi { }, }; + const shouldStream = !!options.config.stream; const requestPayload: RequestPayload = { messages, - stream: options.config.stream, + stream: shouldStream, model: modelConfig.model, temperature: modelConfig.temperature, presence_penalty: modelConfig.presence_penalty, @@ -104,12 +110,27 @@ export class ErnieApi implements LLMApi { console.log("[Request] Baidu payload: ", requestPayload); - const shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); try { - const chatPath = this.path(Baidu.ChatPath(modelConfig.model)); + let chatPath = this.path(Baidu.ChatPath(modelConfig.model)); + + // getAccessToken can not run in browser, because cors error + if (!!getClientConfig()?.isApp) { + const accessStore = useAccessStore.getState(); + if (accessStore.useCustomConfig) { + if (accessStore.isValidBaidu()) { + const { access_token } = await getAccessToken( + accessStore.baiduApiKey, + accessStore.baiduSecretKey, + ); + chatPath = `${chatPath}${ + chatPath.includes("?") ? "&" : "?" + }access_token=${access_token}`; + } + } + } const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), @@ -230,7 +251,7 @@ export class ErnieApi implements LLMApi { clearTimeout(requestTimeoutId); const resJson = await res.json(); - const message = this.extractMessage(resJson); + const message = resJson?.result; options.onFinish(message); } } catch (e) { diff --git a/app/constant.ts b/app/constant.ts index 0fd4d1c2492..3d48dbb6235 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -124,7 +124,7 @@ export const Baidu = { if (modelName === "ernie-3.5-8k") { endpoint = "completions"; } - return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`; + return `rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`; }, }; From b14a0f24ae2b5d3dee298f6f573016b2356d7fac Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 14:57:19 +0800 Subject: [PATCH 12/27] update locales --- app/locales/cn.ts | 4 ++-- app/locales/en.ts | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/app/locales/cn.ts b/app/locales/cn.ts index a872ee75ada..d7268807cb3 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -350,12 +350,12 @@ const cn = { Baidu: { ApiKey: { Title: "接口密钥", - SubTitle: "使用自定义 Baidu API Key 绕过密码访问限制", + SubTitle: "使用自定义 Baidu API Key", Placeholder: "Baidu API Key", }, SecretKey: { Title: "接口密钥", - SubTitle: "使用自定义 Baidu Secret Key 绕过密码访问限制", + SubTitle: "使用自定义 Baidu Secret Key", Placeholder: "Baidu Secret Key", }, Endpoint: { diff --git a/app/locales/en.ts b/app/locales/en.ts index aa153f52369..3c0d8851fa4 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -334,6 +334,22 @@ const en: LocaleType = { SubTitle: "Select and input a specific API version", }, }, + Baidu: { + ApiKey: { + Title: "Baidu API Key", + SubTitle: "Use a custom Baidu API Key", + Placeholder: "Baidu API Key", + }, + SecretKey: { + Title: "Baidu Secret Key", + SubTitle: "Use a custom Baidu Secret Key", + Placeholder: "Baidu Secret Key", + }, + Endpoint: { + Title: "Endpoint Address", + SubTitle: "Example:", + }, + }, CustomModel: { Title: "Custom Models", SubTitle: "Custom model options, seperated by comma", From 230e3823a90df67800f29be43d40e87ab42c1a76 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 15:02:44 +0800 Subject: [PATCH 13/27] update readme --- README.md | 12 ++++++++++++ README_CN.md | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/README.md b/README.md index c77d2023c97..feaf197c4d1 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,18 @@ anthropic claude Api version. anthropic claude Api Url. +### `BAIDU_API_KEY` (optional) + +Baidu Api Key. + +### `BAIDU_SECRET_KEY` (optional) + +Baidu Secret Key. + +### `BAIDU_URL` (optional) + +Baidu Api Url. + ### `HIDE_USER_API_KEY` (optional) > Default: Empty diff --git a/README_CN.md b/README_CN.md index 970ecdef2f9..827d4850fe4 100644 --- a/README_CN.md +++ b/README_CN.md @@ -126,6 +126,18 @@ anthropic claude Api version. anthropic claude Api Url. +### `BAIDU_API_KEY` (可选) + +Baidu Api Key. + +### `BAIDU_SECRET_KEY` (可选) + +Baidu Secret Key. + +### `BAIDU_URL` (可选) + +Baidu Api Url. + ### `HIDE_USER_API_KEY` (可选) 如果你不想让用户自行填入 API Key,将此环境变量设置为 1 即可。 From 147fc9a35a39187babb2b5aae156d47949547423 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 15:10:23 +0800 Subject: [PATCH 14/27] fix ts type error --- app/api/baidu/[...path]/route.ts | 4 ++-- app/api/common.ts | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/app/api/baidu/[...path]/route.ts b/app/api/baidu/[...path]/route.ts index 5444ba4fe85..94c9963c7e9 100644 --- a/app/api/baidu/[...path]/route.ts +++ b/app/api/baidu/[...path]/route.ts @@ -102,8 +102,8 @@ async function request(req: NextRequest) { ); const { access_token } = await getAccessToken( - serverConfig.baiduApiKey, - serverConfig.baiduSecretKey, + serverConfig.baiduApiKey as string, + serverConfig.baiduSecretKey as string, ); const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`; diff --git a/app/api/common.ts b/app/api/common.ts index 5223646d264..1ffac7fce15 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -70,7 +70,7 @@ export async function requestOpenai(req: NextRequest) { // Forward compatibility: // if display_name(deployment_name) not set, and '{deploy-id}' in AZURE_URL // then using default '{deploy-id}' - if (serverConfig.customModels) { + if (serverConfig.customModels && serverConfig.azureUrl) { const modelName = path.split("/")[1]; let realDeployName = ""; serverConfig.customModels @@ -80,7 +80,9 @@ export async function requestOpenai(req: NextRequest) { const [fullName, displayName] = m.split("="); const [_, providerName] = fullName.split("@"); if (providerName === "azure" && !displayName) { - const [_, deployId] = serverConfig.azureUrl.split("deployments/"); + const [_, deployId] = (serverConfig?.azureUrl ?? "").split( + "deployments/", + ); if (deployId) { realDeployName = deployId; } From f2a35f11140b4ee41828ad9024fee88ceebb24b0 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 16:38:22 +0800 Subject: [PATCH 15/27] add missing file --- app/utils/baidu.ts | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 app/utils/baidu.ts diff --git a/app/utils/baidu.ts b/app/utils/baidu.ts new file mode 100644 index 00000000000..ddeb17bd59d --- /dev/null +++ b/app/utils/baidu.ts @@ -0,0 +1,23 @@ +import { BAIDU_OATUH_URL } from "../constant"; +/** + * 使用 AK,SK 生成鉴权签名(Access Token) + * @return 鉴权签名信息 + */ +export async function getAccessToken( + clientId: string, + clientSecret: string, +): Promise<{ + access_token: string; + expires_in: number; + error?: number; +}> { + const res = await fetch( + `${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${clientId}&client_secret=${clientSecret}`, + { + method: "POST", + mode: "cors", + }, + ); + const resJson = await res.json(); + return resJson; +} From b3023543d67589c30f1c1ffd8f68fd712bc6c1aa Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 16:55:33 +0800 Subject: [PATCH 16/27] update --- app/utils/model.ts | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index adfbe287b1f..a3a014877a9 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -61,9 +61,7 @@ export function collectModelTable( modelTable[fullName]["available"] = available; // swap name and displayName for bytedance if (providerName === "bytedance") { - const tempName = name; - name = displayName; - displayName = tempName; + [name, displayName] = [displayName, name]; modelTable[fullName]["name"] = name; } if (displayName) { From 9d7e19cebf762ac7cd58e579040bd41c4d2cc15e Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 18:05:23 +0800 Subject: [PATCH 17/27] display doubao model name when select model --- app/api/bytedance/[...path]/route.ts | 9 +-------- app/client/platforms/bytedance.ts | 12 ++++-------- app/components/chat.tsx | 26 +++++++++++++++++++++++--- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/app/api/bytedance/[...path]/route.ts b/app/api/bytedance/[...path]/route.ts index bffb60f6c3a..336c837f037 100644 --- a/app/api/bytedance/[...path]/route.ts +++ b/app/api/bytedance/[...path]/route.ts @@ -132,17 +132,10 @@ async function request(req: NextRequest) { console.error(`[ByteDance] filter`, e); } } - console.log("[ByteDance request]", fetchOptions.headers, req.method); + try { const res = await fetch(fetchUrl, fetchOptions); - console.log( - "[ByteDance response]", - res.status, - " ", - res.headers, - res.url, - ); // to prevent browser prompt for credentials const newHeaders = new Headers(res.headers); newHeaders.delete("www-authenticate"); diff --git a/app/client/platforms/bytedance.ts b/app/client/platforms/bytedance.ts index 92c1fd55826..ce401e68dd5 100644 --- a/app/client/platforms/bytedance.ts +++ b/app/client/platforms/bytedance.ts @@ -2,7 +2,7 @@ import { ApiPath, ByteDance, - DEFAULT_API_HOST, + BYTEDANCE_BASE_URL, REQUEST_TIMEOUT_MS, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; @@ -58,9 +58,7 @@ export class DoubaoApi implements LLMApi { if (baseUrl.length === 0) { const isApp = !!getClientConfig()?.isApp; - baseUrl = isApp - ? DEFAULT_API_HOST + "/api/proxy/bytedance" - : ApiPath.ByteDance; + baseUrl = isApp ? BYTEDANCE_BASE_URL : ApiPath.ByteDance; } if (baseUrl.endsWith("/")) { @@ -94,9 +92,10 @@ export class DoubaoApi implements LLMApi { }, }; + const shouldStream = !!options.config.stream; const requestPayload: RequestPayload = { messages, - stream: options.config.stream, + stream: shouldStream, model: modelConfig.model, temperature: modelConfig.temperature, presence_penalty: modelConfig.presence_penalty, @@ -104,9 +103,6 @@ export class DoubaoApi implements LLMApi { top_p: modelConfig.top_p, }; - console.log("[Request] ByteDance payload: ", requestPayload); - - const shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); diff --git a/app/components/chat.tsx b/app/components/chat.tsx index b1bdf757f44..ace404c1082 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -467,6 +467,14 @@ export function ChatActions(props: { return filteredModels; } }, [allModels]); + const currentModelName = useMemo(() => { + const model = models.find( + (m) => + m.name == currentModel && + m?.provider?.providerName == currentProviderName, + ); + return model?.displayName ?? ""; + }, [models, currentModel, currentProviderName]); const [showModelSelector, setShowModelSelector] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false); @@ -489,7 +497,11 @@ export function ChatActions(props: { session.mask.modelConfig.providerName = nextModel?.provider ?.providerName as ServiceProvider; }); - showToast(nextModel.name); + showToast( + nextModel?.provider?.providerName == "ByteDance" + ? nextModel.displayName + : nextModel.name, + ); } }, [chatStore, currentModel, models]); @@ -571,7 +583,7 @@ export function ChatActions(props: { setShowModelSelector(true)} - text={currentModel} + text={currentModelName} icon={} /> @@ -596,7 +608,15 @@ export function ChatActions(props: { providerName as ServiceProvider; session.mask.syncGlobalConfig = false; }); - showToast(model); + if (providerName == "ByteDance") { + const selectedModel = models.find( + (m) => + m.name == model && m?.provider.providerName == providerName, + ); + showToast(selectedModel?.displayName ?? ""); + } else { + showToast(model); + } }} /> )} From 1149d455890bfb73df98026d9fad11ecbfa88e52 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 18:06:59 +0800 Subject: [PATCH 18/27] remove check vision model --- app/client/platforms/bytedance.ts | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/app/client/platforms/bytedance.ts b/app/client/platforms/bytedance.ts index ce401e68dd5..7677cafe12b 100644 --- a/app/client/platforms/bytedance.ts +++ b/app/client/platforms/bytedance.ts @@ -21,7 +21,7 @@ import { } from "@fortaine/fetch-event-source"; import { prettyObject } from "@/app/utils/format"; import { getClientConfig } from "@/app/config/client"; -import { getMessageTextContent, isVisionModel } from "@/app/utils"; +import { getMessageTextContent } from "@/app/utils"; export interface OpenAIListModelResponse { object: string; @@ -78,10 +78,9 @@ export class DoubaoApi implements LLMApi { } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); const messages = options.messages.map((v) => ({ role: v.role, - content: visionModel ? v.content : getMessageTextContent(v), + content: getMessageTextContent(v), })); const modelConfig = { From 9d2a633f5e900c67343797a92de41635cdcbe25d Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 18:15:43 +0800 Subject: [PATCH 19/27] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=85=8D=E7=BD=AE=E8=B1=86=E5=8C=85=E7=9A=84?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 11 +++++++++++ README_CN.md | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/README.md b/README.md index 467bfbbe0f6..0815b723f62 100644 --- a/README.md +++ b/README.md @@ -225,6 +225,14 @@ Baidu Secret Key. Baidu Api Url. +### `BYTEDANCE_API_KEY` (optional) + +ByteDance Api Key. + +### `BYTEDANCE_URL` (optional) + +ByteDance Api Url. + ### `HIDE_USER_API_KEY` (optional) > Default: Empty @@ -261,6 +269,9 @@ User `-all` to disable all default models, `+all` to enable all default models. For Azure: use `modelName@azure=deploymentName` to customize model name and deployment name. > Example: `+gpt-3.5-turbo@azure=gpt35` will show option `gpt35(Azure)` in model list. +For ByteDance: use `modelName@bytedance=deploymentName` to customize model name and deployment name. +> Example: `+Doubao-lite-4k@bytedance=ep-xxxxx-xxx` will show option `Doubao-lite-4k(ByteDance)` in model list. + ### `DEFAULT_MODEL` (optional) Change default model diff --git a/README_CN.md b/README_CN.md index e6c4d2011d8..321efe441dd 100644 --- a/README_CN.md +++ b/README_CN.md @@ -139,6 +139,14 @@ Baidu Secret Key. Baidu Api Url. +### `BYTEDANCE_API_KEY` (可选) + +ByteDance Api Key. + +### `BYTEDANCE_URL` (可选) + +ByteDance Api Url. + ### `HIDE_USER_API_KEY` (可选) 如果你不想让用户自行填入 API Key,将此环境变量设置为 1 即可。 @@ -172,6 +180,9 @@ Baidu Api Url. 在Azure的模式下,支持使用`modelName@azure=deploymentName`的方式配置模型名称和部署名称(deploy-name) > 示例:`+gpt-3.5-turbo@azure=gpt35`这个配置会在模型列表显示一个`gpt35(Azure)`的选项 +在ByteDance的模式下,支持使用`modelName@bytedance=deploymentName`的方式配置模型名称和部署名称(deploy-name) +> 示例: `+Doubao-lite-4k@bytedance=ep-xxxxx-xxx`这个配置会在模型列表显示一个`Doubao-lite-4k(ByteDance)`的选项 + ### `DEFAULT_MODEL` (可选) From 82be426f78449840158adab56a88aa94dfcfc2c7 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 18:19:34 +0800 Subject: [PATCH 20/27] fix eslint error --- app/components/chat.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/components/chat.tsx b/app/components/chat.tsx index ace404c1082..40e02cb57ac 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -611,7 +611,7 @@ export function ChatActions(props: { if (providerName == "ByteDance") { const selectedModel = models.find( (m) => - m.name == model && m?.provider.providerName == providerName, + m.name == model && m?.provider?.providerName == providerName, ); showToast(selectedModel?.displayName ?? ""); } else { From bb349a03dac8e006c4d125779c506efa98283286 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 19:21:27 +0800 Subject: [PATCH 21/27] fix get headers for bytedance --- app/client/api.ts | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index ff81f53721d..147b11ad211 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -179,6 +179,8 @@ export function getHeaders() { const isGoogle = modelConfig.providerName == ServiceProvider.Google; const isAzure = modelConfig.providerName === ServiceProvider.Azure; const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic; + const isBaidu = modelConfig.providerName == ServiceProvider.Baidu; + const isByteDance = modelConfig.providerName === ServiceProvider.ByteDance; const isEnabledAccessControl = accessStore.enabledAccessControl(); const apiKey = isGoogle ? accessStore.googleApiKey @@ -186,8 +188,18 @@ export function getHeaders() { ? accessStore.azureApiKey : isAnthropic ? accessStore.anthropicApiKey + : isByteDance + ? accessStore.bytedanceApiKey : accessStore.openaiApiKey; - return { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl }; + return { + isGoogle, + isAzure, + isAnthropic, + isBaidu, + isByteDance, + apiKey, + isEnabledAccessControl, + }; } function getAuthHeader(): string { @@ -203,10 +215,18 @@ export function getHeaders() { function validString(x: string): boolean { return x?.length > 0; } - const { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl } = - getConfig(); + const { + isGoogle, + isAzure, + isAnthropic, + isBaidu, + apiKey, + isEnabledAccessControl, + } = getConfig(); // when using google api in app, not set auth header if (isGoogle && clientConfig?.isApp) return headers; + // when using baidu api in app, not set auth header + if (isBaidu && clientConfig?.isApp) return headers; const authHeader = getAuthHeader(); From 3628d68d9a7eaf6fa9e9f210f382cc88b6769bea Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 19:56:52 +0800 Subject: [PATCH 22/27] update --- app/api/alibaba/[...path]/route.ts | 7 +------ app/client/api.ts | 4 ++++ app/client/platforms/alibaba.ts | 13 ++++--------- app/constant.ts | 6 +++--- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/app/api/alibaba/[...path]/route.ts b/app/api/alibaba/[...path]/route.ts index e30eacbdb7c..b2c42ac7807 100644 --- a/app/api/alibaba/[...path]/route.ts +++ b/app/api/alibaba/[...path]/route.ts @@ -68,10 +68,7 @@ async function request(req: NextRequest) { const controller = new AbortController(); // alibaba use base url or just remove the path - let path = `${req.nextUrl.pathname}`.replaceAll( - ApiPath.Alibaba + "/" + Alibaba.ChatPath, - "", - ); + let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Alibaba, ""); let baseUrl = serverConfig.alibabaUrl || ALIBABA_BASE_URL; @@ -153,11 +150,9 @@ async function request(req: NextRequest) { console.error(`[Alibaba] filter`, e); } } - console.log("[Alibaba request]", fetchOptions.headers, req.method); try { const res = await fetch(fetchUrl, fetchOptions); - console.log("[Alibaba response]", res.status, " ", res.headers, res.url); // to prevent browser prompt for credentials const newHeaders = new Headers(res.headers); newHeaders.delete("www-authenticate"); diff --git a/app/client/api.ts b/app/client/api.ts index 6f6ff622248..c0c71480cd0 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -185,6 +185,7 @@ export function getHeaders() { const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic; const isBaidu = modelConfig.providerName == ServiceProvider.Baidu; const isByteDance = modelConfig.providerName === ServiceProvider.ByteDance; + const isAlibaba = modelConfig.providerName === ServiceProvider.Alibaba; const isEnabledAccessControl = accessStore.enabledAccessControl(); const apiKey = isGoogle ? accessStore.googleApiKey @@ -194,6 +195,8 @@ export function getHeaders() { ? accessStore.anthropicApiKey : isByteDance ? accessStore.bytedanceApiKey + : isAlibaba + ? accessStore.alibabaApiKey : accessStore.openaiApiKey; return { isGoogle, @@ -201,6 +204,7 @@ export function getHeaders() { isAnthropic, isBaidu, isByteDance, + isAlibaba, apiKey, isEnabledAccessControl, }; diff --git a/app/client/platforms/alibaba.ts b/app/client/platforms/alibaba.ts index eefdb017fd0..72126d7287a 100644 --- a/app/client/platforms/alibaba.ts +++ b/app/client/platforms/alibaba.ts @@ -2,7 +2,7 @@ import { ApiPath, Alibaba, - DEFAULT_API_HOST, + ALIBABA_BASE_URL, REQUEST_TIMEOUT_MS, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; @@ -58,9 +58,7 @@ export class QwenApi implements LLMApi { if (baseUrl.length === 0) { const isApp = !!getClientConfig()?.isApp; - baseUrl = isApp - ? DEFAULT_API_HOST + "/api/proxy/alibaba" - : ApiPath.Alibaba; + baseUrl = isApp ? ALIBABA_BASE_URL : ApiPath.Alibaba; } if (baseUrl.endsWith("/")) { @@ -76,14 +74,13 @@ export class QwenApi implements LLMApi { } extractMessage(res: any) { - return res.choices?.at(0)?.message?.content ?? ""; + return res?.output?.choices?.at(0)?.message?.content ?? ""; } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); const messages = options.messages.map((v) => ({ role: v.role, - content: visionModel ? v.content : getMessageTextContent(v), + content: getMessageTextContent(v), })); const modelConfig = { @@ -104,8 +101,6 @@ export class QwenApi implements LLMApi { top_p: modelConfig.top_p, }; - console.log("[Request] Alibaba payload: ", requestPayload); - const shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); diff --git a/app/constant.ts b/app/constant.ts index 8b5bd23061a..c07adad25be 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -19,8 +19,7 @@ export const BAIDU_OATUH_URL = `${BAIDU_BASE_URL}/oauth/2.0/token`; export const BYTEDANCE_BASE_URL = "https://ark.cn-beijing.volces.com"; -export const ALIBABA_BASE_URL = - "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"; +export const ALIBABA_BASE_URL = "https://dashscope.aliyuncs.com/api/"; export enum Path { Home = "/", @@ -144,7 +143,8 @@ export const ByteDance = { }; export const Alibaba = { - ChatPath: "chat/completions", + ExampleEndpoint: ALIBABA_BASE_URL, + ChatPath: "v1/services/aigc/text-generation/generation", }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang From 7573a19dc91749ee1246e1df33950be87ef74c58 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 20:01:58 +0800 Subject: [PATCH 23/27] add custom settings --- app/components/settings.tsx | 46 +++++++++++++++++++++++++++++++++++++ app/locales/cn.ts | 11 +++++++++ app/locales/en.ts | 11 +++++++++ 3 files changed, 68 insertions(+) diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 7db09940d4b..3d77a26317b 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -54,6 +54,7 @@ import { Anthropic, Azure, Baidu, + ByteDance, Google, OPENAI_BASE_URL, Path, @@ -1249,6 +1250,51 @@ export function Settings() { )} + + {accessStore.provider === ServiceProvider.ByteDance && ( + <> + + + accessStore.update( + (access) => + (access.bytedanceUrl = e.currentTarget.value), + ) + } + > + + + { + accessStore.update( + (access) => + (access.bytedanceApiKey = + e.currentTarget.value), + ); + }} + /> + + + )} )} diff --git a/app/locales/cn.ts b/app/locales/cn.ts index d7268807cb3..d605268703a 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -363,6 +363,17 @@ const cn = { SubTitle: "样例:", }, }, + ByteDance: { + ApiKey: { + Title: "接口密钥", + SubTitle: "使用自定义 ByteDance API Key", + Placeholder: "ByteDance API Key", + }, + Endpoint: { + Title: "接口地址", + SubTitle: "样例:", + }, + }, CustomModel: { Title: "自定义模型名", SubTitle: "增加自定义模型可选项,使用英文逗号隔开", diff --git a/app/locales/en.ts b/app/locales/en.ts index 3c0d8851fa4..136a5bbaccb 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -350,6 +350,17 @@ const en: LocaleType = { SubTitle: "Example:", }, }, + ByteDance: { + ApiKey: { + Title: "ByteDance API Key", + SubTitle: "Use a custom ByteDance API Key", + Placeholder: "ByteDance API Key", + }, + Endpoint: { + Title: "Endpoint Address", + SubTitle: "Example:", + }, + }, CustomModel: { Title: "Custom Models", SubTitle: "Custom model options, seperated by comma", From e3b3a4fefa64efe3c0d49faa403709900729dc23 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 20:09:03 +0800 Subject: [PATCH 24/27] add custom settings --- app/components/settings.tsx | 45 +++++++++++++++++++++++++++++++++++++ app/locales/cn.ts | 11 +++++++++ app/locales/en.ts | 11 +++++++++ 3 files changed, 67 insertions(+) diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 3d77a26317b..ba119d1a0f0 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -55,6 +55,7 @@ import { Azure, Baidu, ByteDance, + Alibaba, Google, OPENAI_BASE_URL, Path, @@ -1295,6 +1296,50 @@ export function Settings() { )} + + {accessStore.provider === ServiceProvider.Alibaba && ( + <> + + + accessStore.update( + (access) => + (access.alibabaUrl = e.currentTarget.value), + ) + } + > + + + { + accessStore.update( + (access) => + (access.alibabaApiKey = e.currentTarget.value), + ); + }} + /> + + + )} )} diff --git a/app/locales/cn.ts b/app/locales/cn.ts index d605268703a..728bdbc5968 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -374,6 +374,17 @@ const cn = { SubTitle: "样例:", }, }, + Alibaba: { + ApiKey: { + Title: "接口密钥", + SubTitle: "使用自定义阿里云API Key", + Placeholder: "Alibaba Cloud API Key", + }, + Endpoint: { + Title: "接口地址", + SubTitle: "样例:", + }, + }, CustomModel: { Title: "自定义模型名", SubTitle: "增加自定义模型可选项,使用英文逗号隔开", diff --git a/app/locales/en.ts b/app/locales/en.ts index 136a5bbaccb..f18f5a19e97 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -361,6 +361,17 @@ const en: LocaleType = { SubTitle: "Example:", }, }, + Alibaba: { + ApiKey: { + Title: "Alibaba API Key", + SubTitle: "Use a custom Alibaba Cloud API Key", + Placeholder: "Alibaba Cloud API Key", + }, + Endpoint: { + Title: "Endpoint Address", + SubTitle: "Example:", + }, + }, CustomModel: { Title: "Custom Models", SubTitle: "Custom model options, seperated by comma", From 814aaa4a69daec51f05f7308714b417598d69c45 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 20:15:20 +0800 Subject: [PATCH 25/27] update config for alibaba(qwen) --- README.md | 8 ++++++++ README_CN.md | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/README.md b/README.md index 0815b723f62..24967c16403 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,14 @@ ByteDance Api Key. ByteDance Api Url. +### `ALIBABA_API_KEY` (optional) + +Alibaba Cloud Api Key. + +### `ALIBABA_URL` (optional) + +Alibaba Cloud Api Url. + ### `HIDE_USER_API_KEY` (optional) > Default: Empty diff --git a/README_CN.md b/README_CN.md index 321efe441dd..5400bb276fa 100644 --- a/README_CN.md +++ b/README_CN.md @@ -147,6 +147,14 @@ ByteDance Api Key. ByteDance Api Url. +### `ALIBABA_API_KEY` (可选) + +阿里云(千问)Api Key. + +### `ALIBABA_URL` (可选) + +阿里云(千问)Api Url. + ### `HIDE_USER_API_KEY` (可选) 如果你不想让用户自行填入 API Key,将此环境变量设置为 1 即可。 From cd4784c54a213fd38f5b4d8c3093814f29b9e7fa Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 21:14:38 +0800 Subject: [PATCH 26/27] update --- app/components/settings.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 3d77a26317b..4d19fa76ed6 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -1256,7 +1256,7 @@ export function Settings() { From 044c16da4ccb00c28f6de71f68adcb20bde5f3ea Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 9 Jul 2024 21:17:32 +0800 Subject: [PATCH 27/27] update --- app/components/settings.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/components/settings.tsx b/app/components/settings.tsx index ba119d1a0f0..1467f706bc6 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -1302,7 +1302,7 @@ export function Settings() {