Skip to content

Commit

Permalink
Merge pull request #5 from yoziru/fix-token-limit
Browse files Browse the repository at this point in the history
fix token limit
  • Loading branch information
yoziru committed Jul 28, 2024
2 parents 89f7754 + 0bf0193 commit b71e3ef
Show file tree
Hide file tree
Showing 10 changed files with 621 additions and 755 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ docker run --rm -d -p 3000:3000 -e VLLM_URL=http://host.docker.internal:8000 ghc

If you're using Ollama, you need to set the `VLLM_MODEL`:
```
docker run --rm -d -p 3000:3000 -e VLLM_URL=http://host.docker.internal:11434 -e NEXT_PUBLIC_TOKEN_LIMIT=8192 -e VLLM_MODEL=llama3 ghcr.io/yoziru/nextjs-vllm-ui:latest
docker run --rm -d -p 3000:3000 -e VLLM_URL=http://host.docker.internal:11434 -e VLLM_TOKEN_LIMIT=8192 -e VLLM_MODEL=llama3 ghcr.io/yoziru/nextjs-vllm-ui:latest
```

Then go to [localhost:3000](http://localhost:3000) and start chatting with your favourite model!
Expand Down
9 changes: 5 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@
"lint": "next lint"
},
"dependencies": {
"@ai-sdk/openai": "^0.0.40",
"@hookform/resolvers": "^3.3.4",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-icons": "^1.3.0",
"@radix-ui/react-slot": "^1.0.2",
"@radix-ui/react-tabs": "^1.0.4",
"@radix-ui/react-tooltip": "^1.0.7",
"ai": "^3.0.15",
"ai": "^3.2.0",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.0",
"mistral-tokenizer-js": "^1.0.0",
"llama3-tokenizer-js": "^1.1.3",
"next": "14.1.4",
"next-themes": "^0.3.0",
"openai": "^4.30.0",
"react": "^18",
"react-code-blocks": "^0.1.6",
"react-dom": "^18",
Expand All @@ -33,7 +33,8 @@
"tailwind-merge": "^2.2.2",
"use-debounce": "^10.0.0",
"use-local-storage-state": "^19.2.0",
"uuid": "^9.0.1"
"uuid": "^9.0.1",
"zod": "^3.23.8"
},
"devDependencies": {
"@types/node": "^20.11.30",
Expand Down
166 changes: 52 additions & 114 deletions src/app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import {
createParser,
ParsedEvent,
ReconnectInterval,
} from "eventsource-parser";
streamText,
CoreMessage,
CoreUserMessage,
CoreSystemMessage,
CoreAssistantMessage,
} from "ai";
import { createOpenAI } from "@ai-sdk/openai";

// Allow streaming responses up to 30 seconds
export const maxDuration = 30;

import { NextRequest, NextResponse } from "next/server";
import {
ChatCompletionAssistantMessageParam,
ChatCompletionCreateParamsStreaming,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
} from "openai/resources/index.mjs";

import { encodeChat, tokenLimit } from "@/lib/token-counter";

const addSystemMessage = (
messages: ChatCompletionMessageParam[],
systemPrompt?: string
) => {

import { encodeChat } from "@/lib/token-counter";

const addSystemMessage = (messages: CoreMessage[], systemPrompt?: string) => {
// early exit if system prompt is empty
if (!systemPrompt || systemPrompt === "") {
return messages;
Expand Down Expand Up @@ -56,37 +53,39 @@ const addSystemMessage = (
};

const formatMessages = (
messages: ChatCompletionMessageParam[]
): ChatCompletionMessageParam[] => {
let mappedMessages: ChatCompletionMessageParam[] = [];
messages: CoreMessage[],
tokenLimit: number = 4096
): CoreMessage[] => {
let mappedMessages: CoreMessage[] = [];
let messagesTokenCounts: number[] = [];
const responseTokens = 512;
const tokenLimitRemaining = tokenLimit - responseTokens;
const reservedResponseTokens = 512;

const tokenLimitRemaining = tokenLimit - reservedResponseTokens;
let tokenCount = 0;

messages.forEach((m) => {
if (m.role === "system") {
mappedMessages.push({
role: "system",
content: m.content,
} as ChatCompletionSystemMessageParam);
} as CoreSystemMessage);
} else if (m.role === "user") {
mappedMessages.push({
role: "user",
content: m.content,
} as ChatCompletionUserMessageParam);
} as CoreUserMessage);
} else if (m.role === "assistant") {
mappedMessages.push({
role: "assistant",
content: m.content,
} as ChatCompletionAssistantMessageParam);
} as CoreAssistantMessage);
} else {
return;
}

// ignore typing
// tslint:disable-next-line
const messageTokens = encodeChat([m]);
const messageTokens = encodeChat([m]);
messagesTokenCounts.push(messageTokens);
tokenCount += messageTokens;
});
Expand All @@ -106,7 +105,8 @@ const formatMessages = (
return mappedMessages;
};

export async function POST(req: NextRequest): Promise<NextResponse> {
export async function POST(req: Request) {
// export async function POST(req: NextRequest): Promise<NextResponse> {
try {
const { messages, chatOptions } = await req.json();
if (!chatOptions.selectedModel || chatOptions.selectedModel === "") {
Expand All @@ -119,20 +119,34 @@ export async function POST(req: NextRequest): Promise<NextResponse> {
}
const apiKey = process.env.VLLM_API_KEY;

const tokenLimit = process.env.VLLM_TOKEN_LIMIT
? parseInt(process.env.VLLM_TOKEN_LIMIT)
: 4096;

const formattedMessages = formatMessages(
addSystemMessage(messages, chatOptions.systemPrompt)
addSystemMessage(messages, chatOptions.systemPrompt),
tokenLimit
);

const stream = await getOpenAIStream(
baseUrl,
chatOptions.selectedModel,
formattedMessages,
chatOptions.temperature,
apiKey,
);
return new NextResponse(stream, {
headers: { "Content-Type": "text/event-stream" },
// Call the language model
const customOpenai = createOpenAI({
baseUrl: baseUrl + "/v1",
apiKey: apiKey ?? "",
});

const result = await streamText({
model: customOpenai(chatOptions.selectedModel),
messages: formattedMessages,
temperature: chatOptions.temperature,
// async onFinish({ text, toolCalls, toolResults, usage, finishReason }) {
// // implement your own logic here, e.g. for storing messages
// // or recording token usage
// },
});

// Respond with the stream
return result.toAIStreamResponse();

} catch (error) {
console.error(error);
return NextResponse.json(
Expand All @@ -144,79 +158,3 @@ export async function POST(req: NextRequest): Promise<NextResponse> {
);
}
}

const getOpenAIStream = async (
apiUrl: string,
model: string,
messages: ChatCompletionMessageParam[],
temperature?: number,
apiKey?: string
): Promise<ReadableStream<Uint8Array>> => {
const encoder = new TextEncoder();
const decoder = new TextDecoder();
const headers = new Headers();
headers.set("Content-Type", "application/json");
if (apiKey !== undefined) {
headers.set("Authorization", `Bearer ${apiKey}`);
headers.set("api-key", apiKey);
}
const chatOptions: ChatCompletionCreateParamsStreaming = {
model: model,
// frequency_penalty: 0,
// max_tokens: 2000,
messages: messages,
// presence_penalty: 0,
stream: true,
temperature: temperature ?? 0.5,
// response_format: {
// type: "json_object",
// }
// top_p: 0.95,
};
const res = await fetch(apiUrl + "/v1/chat/completions", {
headers: headers,
method: "POST",
body: JSON.stringify(chatOptions),
});

if (res.status !== 200) {
const statusText = res.statusText;
const responseBody = await res.text();
console.error(`vLLM API response error: ${responseBody}`);
throw new Error(
`The vLLM API has encountered an error with a status code of ${res.status} ${statusText}: ${responseBody}`
);
}

return new ReadableStream({
async start(controller) {
const onParse = (event: ParsedEvent | ReconnectInterval) => {
if (event.type === "event") {
const data = event.data;

if (data === "[DONE]") {
controller.close();
return;
}

try {
const json = JSON.parse(data);
const text = json.choices[0].delta.content;
const queue = encoder.encode(text);
controller.enqueue(queue);
} catch (e) {
controller.error(e);
}
}
};

const parser = createParser(onParse);

for await (const chunk of res.body as any) {
// An extra newline is required to make AzureOpenAI work.
const str = decoder.decode(chunk).replace("[DONE]\n", "[DONE]\n\n");
parser.feed(str);
}
},
});
};
14 changes: 14 additions & 0 deletions src/app/api/settings/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { NextRequest, NextResponse } from "next/server";

export async function GET(req: NextRequest): Promise<NextResponse> {
const tokenLimit = process.env.VLLM_TOKEN_LIMIT
? parseInt(process.env.VLLM_TOKEN_LIMIT)
: 4096;

return NextResponse.json(
{
tokenLimit: tokenLimit,
},
{ status: 200 }
);
}
13 changes: 10 additions & 3 deletions src/components/chat/chat-bottombar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import React from "react";

import { PaperPlaneIcon, StopIcon } from "@radix-ui/react-icons";
import { ChatRequestOptions } from "ai";
import mistralTokenizer from "mistral-tokenizer-js";
import llama3Tokenizer from "llama3-tokenizer-js";
import TextareaAutosize from "react-textarea-autosize";

import { tokenLimit } from "@/lib/token-counter";
import { basePath, useHasMounted } from "@/lib/utils";
import { getTokenLimit } from "@/lib/token-counter";
import { Button } from "../ui/button";

interface ChatBottombarProps {
Expand All @@ -30,6 +31,7 @@ export default function ChatBottombar({
isLoading,
stop,
}: ChatBottombarProps) {
const hasMounted = useHasMounted();
const inputRef = React.useRef<HTMLTextAreaElement>(null);
const hasSelectedModel = selectedModel && selectedModel !== "";

Expand All @@ -39,7 +41,12 @@ export default function ChatBottombar({
handleSubmit(e as unknown as React.FormEvent<HTMLFormElement>);
}
};
const tokenCount = input ? mistralTokenizer.encode(input).length - 1 : 0;
const tokenCount = input ? llama3Tokenizer.encode(input).length - 1 : 0;

const [tokenLimit, setTokenLimit] = React.useState<number>(4096);
React.useEffect(() => {
getTokenLimit(basePath).then((limit) => setTokenLimit(limit));
}, [hasMounted]);

return (
<div>
Expand Down
2 changes: 1 addition & 1 deletion src/components/chat/chat-page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export default function ChatPage({ chatId, setChatId }: ChatPageProps) {
setMessages,
} = useChat({
api: basePath + "/api/chat",
streamMode: "text",
streamMode: "stream-data",
onError: (error) => {
toast.error("Something went wrong: " + error);
},
Expand Down
4 changes: 3 additions & 1 deletion src/components/chat/chat-topbar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { encodeChat, tokenLimit } from "@/lib/token-counter";
import { encodeChat, getTokenLimit } from "@/lib/token-counter";
import { basePath, useHasMounted } from "@/lib/utils";
import { Sidebar } from "../sidebar";
import { ChatOptions } from "./chat-options";
Expand All @@ -44,6 +44,7 @@ export default function ChatTopbar({
const hasMounted = useHasMounted();

const currentModel = chatOptions && chatOptions.selectedModel;
const [tokenLimit, setTokenLimit] = React.useState<number>(4096);
const [error, setError] = React.useState<string | undefined>(undefined);

const fetchData = async () => {
Expand Down Expand Up @@ -72,6 +73,7 @@ export default function ChatTopbar({

useEffect(() => {
fetchData();
getTokenLimit(basePath).then((limit) => setTokenLimit(limit));
}, [hasMounted]);

if (!hasMounted) {
Expand Down
3 changes: 0 additions & 3 deletions src/lib/mistral-tokenizer-js.d.ts

This file was deleted.

25 changes: 16 additions & 9 deletions src/lib/token-counter.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import { Message } from "ai";
import mistralTokenizer from "mistral-tokenizer-js";
import { ChatCompletionMessageParam } from "openai/resources/index.mjs";
import { CoreMessage, Message } from "ai";
import llama3Tokenizer from "llama3-tokenizer-js";

export const tokenLimit = process.env.NEXT_PUBLIC_TOKEN_LIMIT ? parseInt(process.env.NEXT_PUBLIC_TOKEN_LIMIT) : 4096;
export const getTokenLimit = async (basePath: string) => {
const res = await fetch(basePath + "/api/settings");

export const encodeChat = (
messages: Message[] | ChatCompletionMessageParam[]
): number => {
if (!res.ok) {
const errorResponse = await res.json();
const errorMessage = `Connection to vLLM server failed: ${errorResponse.error} [${res.status} ${res.statusText}]`;
throw new Error(errorMessage);
}
const data = await res.json();
return data.tokenLimit;
};

export const encodeChat = (messages: Message[] | CoreMessage[]): number => {
const tokensPerMessage = 3;
let numTokens = 0;
for (const message of messages) {
numTokens += tokensPerMessage;
numTokens += mistralTokenizer.encode(message.role).length;
numTokens += llama3Tokenizer.encode(message.role).length;
if (typeof message.content === "string") {
numTokens += mistralTokenizer.encode(message.content).length;
numTokens += llama3Tokenizer.encode(message.content).length;
}
}
numTokens += 3;
Expand Down
Loading

0 comments on commit b71e3ef

Please sign in to comment.