Skip to content

Commit

Permalink
Merge pull request ChatGPTNextWeb#5173 from ConnectAI-E/feature/dalle
Browse files Browse the repository at this point in the history
add dalle3 model
  • Loading branch information
Dogtiti authored Aug 5, 2024
2 parents a6b7432 + 4a8e85c commit fec80c6
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 33 deletions.
3 changes: 2 additions & 1 deletion app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
ServiceProvider,
} from "../constant";
import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
import { ChatGPTApi } from "./platforms/openai";
import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai";
import { GeminiProApi } from "./platforms/google";
import { ClaudeApi } from "./platforms/anthropic";
import { ErnieApi } from "./platforms/baidu";
Expand Down Expand Up @@ -42,6 +42,7 @@ export interface LLMConfig {
stream?: boolean;
presence_penalty?: number;
frequency_penalty?: number;
size?: DalleRequestPayload["size"];
}

export interface ChatOptions {
Expand Down
115 changes: 84 additions & 31 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ import {
} from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { collectModelsWithDefaultModel } from "@/app/utils/model";
import { preProcessImageContent } from "@/app/utils/chat";
import {
preProcessImageContent,
uploadImage,
base64Image2Blob,
} from "@/app/utils/chat";
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
import { DalleSize } from "@/app/typing";

import {
ChatOptions,
Expand All @@ -33,6 +38,7 @@ import {
getMessageTextContent,
getMessageImages,
isVisionModel,
isDalle3 as _isDalle3,
} from "@/app/utils";

export interface OpenAIListModelResponse {
Expand All @@ -58,6 +64,14 @@ export interface RequestPayload {
max_tokens?: number;
}

export interface DalleRequestPayload {
model: string;
prompt: string;
response_format: "url" | "b64_json";
n: number;
size: DalleSize;
}

export class ChatGPTApi implements LLMApi {
private disableListModels = true;

Expand Down Expand Up @@ -100,20 +114,31 @@ export class ChatGPTApi implements LLMApi {
return cloudflareAIGatewayUrl([baseUrl, path].join("/"));
}

extractMessage(res: any) {
return res.choices?.at(0)?.message?.content ?? "";
async extractMessage(res: any) {
if (res.error) {
return "```\n" + JSON.stringify(res, null, 4) + "\n```";
}
// dalle3 model return url, using url create image message
if (res.data) {
let url = res.data?.at(0)?.url ?? "";
const b64_json = res.data?.at(0)?.b64_json ?? "";
if (!url && b64_json) {
// uploadImage
url = await uploadImage(base64Image2Blob(b64_json, "image/png"));
}
return [
{
type: "image_url",
image_url: {
url,
},
},
];
}
return res.choices?.at(0)?.message?.content ?? res;
}

async chat(options: ChatOptions) {
const visionModel = isVisionModel(options.config.model);
const messages: ChatOptions["messages"] = [];
for (const v of options.messages) {
const content = visionModel
? await preProcessImageContent(v.content)
: getMessageTextContent(v);
messages.push({ role: v.role, content });
}

const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
Expand All @@ -123,26 +148,52 @@ export class ChatGPTApi implements LLMApi {
},
};

const requestPayload: RequestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};
let requestPayload: RequestPayload | DalleRequestPayload;

const isDalle3 = _isDalle3(options.config.model);
if (isDalle3) {
const prompt = getMessageTextContent(
options.messages.slice(-1)?.pop() as any,
);
requestPayload = {
model: options.config.model,
prompt,
// URLs are only valid for 60 minutes after the image has been generated.
response_format: "b64_json", // using b64_json, and save image in CacheStorage
n: 1,
size: options.config?.size ?? "1024x1024",
};
} else {
const visionModel = isVisionModel(options.config.model);
const messages: ChatOptions["messages"] = [];
for (const v of options.messages) {
const content = visionModel
? await preProcessImageContent(v.content)
: getMessageTextContent(v);
messages.push({ role: v.role, content });
}

// add max_tokens to vision model
if (visionModel && modelConfig.model.includes("preview")) {
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
requestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};

// add max_tokens to vision model
if (visionModel && modelConfig.model.includes("preview")) {
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
}
}

console.log("[Request] openai payload: ", requestPayload);

const shouldStream = !!options.config.stream;
const shouldStream = !isDalle3 && !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);

Expand All @@ -168,13 +219,15 @@ export class ChatGPTApi implements LLMApi {
model?.provider?.providerName === ServiceProvider.Azure,
);
chatPath = this.path(
Azure.ChatPath(
(isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
(model?.displayName ?? model?.name) as string,
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
),
);
} else {
chatPath = this.path(OpenaiPath.ChatPath);
chatPath = this.path(
isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
);
}
const chatPayload = {
method: "POST",
Expand All @@ -186,7 +239,7 @@ export class ChatGPTApi implements LLMApi {
// make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
);

if (shouldStream) {
Expand Down Expand Up @@ -317,7 +370,7 @@ export class ChatGPTApi implements LLMApi {
clearTimeout(requestTimeoutId);

const resJson = await res.json();
const message = this.extractMessage(resJson);
const message = await this.extractMessage(resJson);
options.onFinish(message);
}
} catch (e) {
Expand Down
35 changes: 35 additions & 0 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg";
import BottomIcon from "../icons/bottom.svg";
import StopIcon from "../icons/pause.svg";
import RobotIcon from "../icons/robot.svg";
import SizeIcon from "../icons/size.svg";
import PluginIcon from "../icons/plugin.svg";

import {
Expand All @@ -60,13 +61,15 @@ import {
getMessageTextContent,
getMessageImages,
isVisionModel,
isDalle3,
} from "../utils";

import { uploadImage as uploadImageRemote } from "@/app/utils/chat";

import dynamic from "next/dynamic";

import { ChatControllerPool } from "../client/controller";
import { DalleSize } from "../typing";
import { Prompt, usePromptStore } from "../store/prompt";
import Locale from "../locales";

Expand Down Expand Up @@ -481,6 +484,11 @@ export function ChatActions(props: {
const [showPluginSelector, setShowPluginSelector] = useState(false);
const [showUploadImage, setShowUploadImage] = useState(false);

const [showSizeSelector, setShowSizeSelector] = useState(false);
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
const currentSize =
chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024";

useEffect(() => {
const show = isVisionModel(currentModel);
setShowUploadImage(show);
Expand Down Expand Up @@ -624,6 +632,33 @@ export function ChatActions(props: {
/>
)}

{isDalle3(currentModel) && (
<ChatAction
onClick={() => setShowSizeSelector(true)}
text={currentSize}
icon={<SizeIcon />}
/>
)}

{showSizeSelector && (
<Selector
defaultSelectedValue={currentSize}
items={dalle3Sizes.map((m) => ({
title: m,
value: m,
}))}
onClose={() => setShowSizeSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
const size = s[0];
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.size = size;
});
showToast(size);
}}
/>
)}

<ChatAction
onClick={() => setShowPluginSelector(true)}
text={Locale.Plugin.Name}
Expand Down
7 changes: 6 additions & 1 deletion app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ export const Anthropic = {

export const OpenaiPath = {
ChatPath: "v1/chat/completions",
ImagePath: "v1/images/generations",
UsagePath: "dashboard/billing/usage",
SubsPath: "dashboard/billing/subscription",
ListModelPath: "v1/models",
Expand All @@ -154,7 +155,10 @@ export const OpenaiPath = {
export const Azure = {
ChatPath: (deployName: string, apiVersion: string) =>
`deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
// https://<your_resource_name>.openai.azure.com/openai/deployments/<your_deployment_name>/images/generations?api-version=<api_version>
ImagePath: (deployName: string, apiVersion: string) =>
`deployments/${deployName}/images/generations?api-version=${apiVersion}`,
ExampleEndpoint: "https://{resource-url}/openai",
};

export const Google = {
Expand Down Expand Up @@ -256,6 +260,7 @@ const openaiModels = [
"gpt-4-vision-preview",
"gpt-4-turbo-2024-04-09",
"gpt-4-1106-preview",
"dall-e-3",
];

const googleModels = [
Expand Down
1 change: 1 addition & 0 deletions app/icons/size.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions app/store/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access";
import { isDalle3 } from "../utils";

export type ChatMessage = RequestMessage & {
date: string;
Expand Down Expand Up @@ -541,6 +542,10 @@ export const useChatStore = createPersistStore(
const config = useAppConfig.getState();
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
// skip summarize when using dalle3?
if (isDalle3(modelConfig.model)) {
return;
}

const api: ClientApi = getClientApi(modelConfig.providerName);

Expand Down
2 changes: 2 additions & 0 deletions app/store/config.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { LLMModel } from "../client/api";
import { DalleSize } from "../typing";
import { getClientConfig } from "../config/client";
import {
DEFAULT_INPUT_TEMPLATE,
Expand Down Expand Up @@ -61,6 +62,7 @@ export const DEFAULT_CONFIG = {
compressMessageLengthThreshold: 1000,
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
size: "1024x1024" as DalleSize,
},
};

Expand Down
2 changes: 2 additions & 0 deletions app/typing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ export interface RequestMessage {
role: MessageRole;
content: string;
}

export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
4 changes: 4 additions & 0 deletions app/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,7 @@ export function isVisionModel(model: string) {
visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo
);
}

export function isDalle3(model: string) {
return "dall-e-3" === model;
}

0 comments on commit fec80c6

Please sign in to comment.