Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

243 max chat history length #259

Merged
merged 6 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"dependencies": {
"@dqbd/tiktoken": "^1.0.7",
"cors": "^2.8.5",
"d3-dsv": "^2.0.0",
"dotenv": "^16.3.1",
Expand Down
1 change: 1 addition & 0 deletions backend/src/models/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ interface ChatHttpResponse {
interface ChatHistoryMessage {
completion: ChatCompletionRequestMessage | null;
chatMessageType: CHAT_MESSAGE_TYPE;
numTokens?: number | null;
infoMessage?: string | null;
}

Expand Down
99 changes: 92 additions & 7 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
FunctionAskQuestionParams,
FunctionSendEmailParams,
} from "./models/openai";
import { get_encoding } from "@dqbd/tiktoken";

// OpenAI config
let config: Configuration | null = null;
Expand All @@ -52,6 +53,7 @@ const chatGptFunctions = [
},
confirmed: {
type: "boolean",
default: "false",
description:
"whether the user has confirmed the email is correct before sending",
},
Expand Down Expand Up @@ -84,6 +86,18 @@ const chatGptFunctions = [
},
];

// max tokens each model can use
const chatModelMaxTokens = {
[CHAT_MODELS.GPT_4]: 8192,
[CHAT_MODELS.GPT_4_0613]: 8192,
[CHAT_MODELS.GPT_4_32K]: 32768,
[CHAT_MODELS.GPT_4_32K_0613]: 32768,
[CHAT_MODELS.GPT_3_5_TURBO]: 4097,
[CHAT_MODELS.GPT_3_5_TURBO_0613]: 4097,
[CHAT_MODELS.GPT_3_5_TURBO_16K]: 16385,
[CHAT_MODELS.GPT_3_5_TURBO_16K_0613]: 16385,
};

// test the api key works with the model
async function validateApiKey(openAiApiKey: string, gptModel: string) {
try {
Expand Down Expand Up @@ -287,27 +301,75 @@ async function chatGptChatCompletion(
chatHistory.shift();
}
}

const chat_completion = await openai.createChatCompletion({
model: gptModel,
messages: getChatCompletionsFromHistory(chatHistory),
messages: getChatCompletionsFromHistory(chatHistory, gptModel),
functions: chatGptFunctions,
});

// get the reply
return chat_completion.data.choices[0].message ?? null;
}

// take only the chat history to send to GPT that is within the max tokens
function filterChatHistoryByMaxTokens(
list: ChatHistoryMessage[],
maxNumTokens: number
): ChatHistoryMessage[] {
let sumTokens = 0;
const filteredList: ChatHistoryMessage[] = [];

// reverse list to add from most recent
const reverseList = list.slice().reverse();

// always add the most recent message to start of list
filteredList.push(reverseList[0]);
sumTokens += reverseList[0].numTokens ?? 0;

// if the first message is a system role add it to list
if (list[0].completion?.role === "system") {
sumTokens += list[0].numTokens ?? 0;
filteredList.push(list[0]);
}

// add elements after first message until max tokens reached
for (let i = 1; i < reverseList.length; i++) {
const element = reverseList[i];
if (element.completion && element.numTokens) {
// if we reach end and system role is there skip as it's already been added
if (element.completion.role === "system") {
continue;
}
if (sumTokens + element.numTokens <= maxNumTokens) {
filteredList.splice(i, 0, element);
sumTokens += element.numTokens;
} else {
console.debug("max tokens reached on element = ", element);
break;
}
}
}
return filteredList.reverse();
}

// take only the completions to send to GPT
function getChatCompletionsFromHistory(
chatHistory: ChatHistoryMessage[]
chatHistory: ChatHistoryMessage[],
gptModel: CHAT_MODELS
): ChatCompletionRequestMessage[] {
// limit the number of tokens sent to GPT
const maxTokens = chatModelMaxTokens[gptModel];
console.log("gpt model = ", gptModel, "max tokens = ", maxTokens);

const reducedChatHistory: ChatHistoryMessage[] = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);
const completions: ChatCompletionRequestMessage[] =
chatHistory.length > 0
? (chatHistory
reducedChatHistory.length > 0
? (reducedChatHistory
.filter((message) => message.completion !== null)
.map(
// we know the completion is not null here
(message) => message.completion
) as ChatCompletionRequestMessage[])
: [];
Expand All @@ -319,10 +381,27 @@ function pushCompletionToHistory(
completion: ChatCompletionRequestMessage,
messageType: CHAT_MESSAGE_TYPE
) {
// limit the length of the chat history
const maxChatHistoryLength = 1000;

// gpt-4 and 3.5 models use cl100k_base encoding
const encoding = get_encoding("cl100k_base");

if (messageType !== CHAT_MESSAGE_TYPE.BOT_BLOCKED) {
// remove the oldest message, not including system role message
if (chatHistory.length >= maxChatHistoryLength) {
if (chatHistory[0].completion?.role !== "system") {
chatHistory.shift();
} else {
chatHistory.splice(1, 1);
}
}
chatHistory.push({
completion: completion,
chatMessageType: messageType,
numTokens: completion.content
? encoding.encode(completion.content).length
: null,
});
} else {
// do not add the bots reply which was subsequently blocked
Expand Down Expand Up @@ -459,4 +538,10 @@ async function chatGptSendMessage(
}
}

export { chatGptSendMessage, setOpenAiApiKey, validateApiKey, setGptModel };
export {
chatGptSendMessage,
filterChatHistoryByMaxTokens,
setOpenAiApiKey,
validateApiKey,
setGptModel,
};
1 change: 1 addition & 0 deletions backend/src/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ router.post("/openai/model", async (req: OpenAiSetModelRequest, res) => {
} else if (model === req.session.gptModel) {
res.status(200).send();
} else if (await setGptModel(req.session.openAiApiKey, model)) {
req.session.gptModel = model;
res.status(200).send();
} else {
res.status(401).send();
Expand Down
187 changes: 185 additions & 2 deletions backend/test/unit/openai.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import { OpenAIApi } from "openai";
import { validateApiKey, setOpenAiApiKey } from "../../src/openai";
import {
validateApiKey,
setOpenAiApiKey,
filterChatHistoryByMaxTokens,
} from "../../src/openai";
import { initQAModel } from "../../src/langchain";
import { CHAT_MODELS } from "../../src/models/chat";
import {
CHAT_MESSAGE_TYPE,
CHAT_MODELS,
ChatHistoryMessage,
} from "../../src/models/chat";

// Define a mock implementation for the createChatCompletion method
const mockCreateChatCompletion = jest.fn();
Expand Down Expand Up @@ -72,6 +80,181 @@ test("GIVEN an invalid API key WHEN calling setOpenAiApiKey THEN it should set t
expect(initQAModel).not.toHaveBeenCalled();
});

test("GIVEN chat history exceeds max token number WHEN applying filter THEN it should return the filtered chat history", () => {
const maxTokens = 50;
const chatHistory: ChatHistoryMessage[] = [
{
completion: {
role: "user",
content: "Hello, my name is Bob.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
{
completion: {
role: "assistant",
content: "Hello, how are you?",
},
numTokens: 17,
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];
// expect that the first message is discounted
const expectedFilteredChatHistory = [
{
completion: {
role: "assistant",
content: "Hello, how are you?",
},
numTokens: 17,
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];

const filteredChatHistory = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);
expect(filteredChatHistory).toEqual(expectedFilteredChatHistory);
});

test("GIVEN chat history does not exceed max token number WHEN applying filter THEN it should return the original chat history", () => {
const maxTokens = 1000;
const chatHistory: ChatHistoryMessage[] = [
{
completion: {
role: "user",
content: "Hello, my name is Bob.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
{
completion: {
role: "assistant",
content: "Hello, how are you?",
},
numTokens: 17,
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];

const filteredChatHistory = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);
expect(filteredChatHistory).toEqual(chatHistory);
});

test("GIVEN chat history exceeds max token number WHEN applying filter AND there is a system role in chat history THEN it should return the filtered chat history", () => {
const maxTokens = 50;
const chatHistory: ChatHistoryMessage[] = [
{
completion: {
role: "system",
content: "You are a helpful chatbot.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM,
},
{
completion: {
role: "user",
content: "Hello, my name is Bob.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
{
completion: {
role: "assistant",
content: "Hello, how are you?",
},
numTokens: 17,
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];

const expectedFilteredChatHistory = [
{
completion: {
role: "system",
content: "You are a helpful chatbot.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];
const filteredChatHistory = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);
expect(filteredChatHistory.length).toEqual(2);
expect(filteredChatHistory).toEqual(expectedFilteredChatHistory);
});

test("GIVEN chat history most recent message exceeds max tokens alone WHEN applying filter THEN it should return this message", () => {
const maxTokens = 30;
const chatHistory: ChatHistoryMessage[] = [
{
completion: {
role: "user",
content:
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. ",
},
numTokens: 50,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];
const filteredChatHistory = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);

expect(filteredChatHistory).toEqual(chatHistory);
});

afterEach(() => {
jest.clearAllMocks();
});
Loading