Skip to content

Commit

Permalink
243 max chat history length (#259)
Browse files Browse the repository at this point in the history
* remove old messages from chat history when queue limit reached

* filter chat history based on max tokens

* add max token sizes for each model

* fix selecting gpt model not updating

* fix the button

* rename max chat history variable
  • Loading branch information
heatherlogan-scottlogic authored and chriswilty committed Apr 8, 2024
1 parent e062425 commit 00d1cf8
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 10 deletions.
6 changes: 6 additions & 0 deletions backend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"dependencies": {
"@dqbd/tiktoken": "^1.0.7",
"cors": "^2.8.5",
"d3-dsv": "^2.0.0",
"dotenv": "^16.3.1",
Expand Down
1 change: 1 addition & 0 deletions backend/src/models/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ interface ChatHttpResponse {
interface ChatHistoryMessage {
completion: ChatCompletionRequestMessage | null;
chatMessageType: CHAT_MESSAGE_TYPE;
numTokens?: number | null;
infoMessage?: string | null;
}

Expand Down
99 changes: 92 additions & 7 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
FunctionAskQuestionParams,
FunctionSendEmailParams,
} from "./models/openai";
import { get_encoding } from "@dqbd/tiktoken";

// OpenAI config
let config: Configuration | null = null;
Expand All @@ -52,6 +53,7 @@ const chatGptFunctions = [
},
confirmed: {
type: "boolean",
default: "false",
description:
"whether the user has confirmed the email is correct before sending",
},
Expand Down Expand Up @@ -84,6 +86,18 @@ const chatGptFunctions = [
},
];

// max tokens each model can use
const chatModelMaxTokens = {
[CHAT_MODELS.GPT_4]: 8192,
[CHAT_MODELS.GPT_4_0613]: 8192,
[CHAT_MODELS.GPT_4_32K]: 32768,
[CHAT_MODELS.GPT_4_32K_0613]: 32768,
[CHAT_MODELS.GPT_3_5_TURBO]: 4097,
[CHAT_MODELS.GPT_3_5_TURBO_0613]: 4097,
[CHAT_MODELS.GPT_3_5_TURBO_16K]: 16385,
[CHAT_MODELS.GPT_3_5_TURBO_16K_0613]: 16385,
};

// test the api key works with the model
async function validateApiKey(openAiApiKey: string, gptModel: string) {
try {
Expand Down Expand Up @@ -287,27 +301,75 @@ async function chatGptChatCompletion(
chatHistory.shift();
}
}

const chat_completion = await openai.createChatCompletion({
model: gptModel,
messages: getChatCompletionsFromHistory(chatHistory),
messages: getChatCompletionsFromHistory(chatHistory, gptModel),
functions: chatGptFunctions,
});

// get the reply
return chat_completion.data.choices[0].message ?? null;
}

// take only the chat history to send to GPT that is within the max tokens
function filterChatHistoryByMaxTokens(
list: ChatHistoryMessage[],
maxNumTokens: number
): ChatHistoryMessage[] {
let sumTokens = 0;
const filteredList: ChatHistoryMessage[] = [];

// reverse list to add from most recent
const reverseList = list.slice().reverse();

// always add the most recent message to start of list
filteredList.push(reverseList[0]);
sumTokens += reverseList[0].numTokens ?? 0;

// if the first message is a system role add it to list
if (list[0].completion?.role === "system") {
sumTokens += list[0].numTokens ?? 0;
filteredList.push(list[0]);
}

// add elements after first message until max tokens reached
for (let i = 1; i < reverseList.length; i++) {
const element = reverseList[i];
if (element.completion && element.numTokens) {
// if we reach end and system role is there skip as it's already been added
if (element.completion.role === "system") {
continue;
}
if (sumTokens + element.numTokens <= maxNumTokens) {
filteredList.splice(i, 0, element);
sumTokens += element.numTokens;
} else {
console.debug("max tokens reached on element = ", element);
break;
}
}
}
return filteredList.reverse();
}

// take only the completions to send to GPT
function getChatCompletionsFromHistory(
chatHistory: ChatHistoryMessage[]
chatHistory: ChatHistoryMessage[],
gptModel: CHAT_MODELS
): ChatCompletionRequestMessage[] {
// limit the number of tokens sent to GPT
const maxTokens = chatModelMaxTokens[gptModel];
console.log("gpt model = ", gptModel, "max tokens = ", maxTokens);

const reducedChatHistory: ChatHistoryMessage[] = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);
const completions: ChatCompletionRequestMessage[] =
chatHistory.length > 0
? (chatHistory
reducedChatHistory.length > 0
? (reducedChatHistory
.filter((message) => message.completion !== null)
.map(
// we know the completion is not null here
(message) => message.completion
) as ChatCompletionRequestMessage[])
: [];
Expand All @@ -319,10 +381,27 @@ function pushCompletionToHistory(
completion: ChatCompletionRequestMessage,
messageType: CHAT_MESSAGE_TYPE
) {
// limit the length of the chat history
const maxChatHistoryLength = 1000;

// gpt-4 and 3.5 models use cl100k_base encoding
const encoding = get_encoding("cl100k_base");

if (messageType !== CHAT_MESSAGE_TYPE.BOT_BLOCKED) {
// remove the oldest message, not including system role message
if (chatHistory.length >= maxChatHistoryLength) {
if (chatHistory[0].completion?.role !== "system") {
chatHistory.shift();
} else {
chatHistory.splice(1, 1);
}
}
chatHistory.push({
completion: completion,
chatMessageType: messageType,
numTokens: completion.content
? encoding.encode(completion.content).length
: null,
});
} else {
// do not add the bots reply which was subsequently blocked
Expand Down Expand Up @@ -459,4 +538,10 @@ async function chatGptSendMessage(
}
}

export { chatGptSendMessage, setOpenAiApiKey, validateApiKey, setGptModel };
export {
chatGptSendMessage,
filterChatHistoryByMaxTokens,
setOpenAiApiKey,
validateApiKey,
setGptModel,
};
1 change: 1 addition & 0 deletions backend/src/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ router.post("/openai/model", async (req: OpenAiSetModelRequest, res) => {
} else if (model === req.session.gptModel) {
res.status(200).send();
} else if (await setGptModel(req.session.openAiApiKey, model)) {
req.session.gptModel = model;
res.status(200).send();
} else {
res.status(401).send();
Expand Down
187 changes: 185 additions & 2 deletions backend/test/unit/openai.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import { OpenAIApi } from "openai";
import { validateApiKey, setOpenAiApiKey } from "../../src/openai";
import {
validateApiKey,
setOpenAiApiKey,
filterChatHistoryByMaxTokens,
} from "../../src/openai";
import { initQAModel } from "../../src/langchain";
import { CHAT_MODELS } from "../../src/models/chat";
import {
CHAT_MESSAGE_TYPE,
CHAT_MODELS,
ChatHistoryMessage,
} from "../../src/models/chat";

// Define a mock implementation for the createChatCompletion method
const mockCreateChatCompletion = jest.fn();
Expand Down Expand Up @@ -72,6 +80,181 @@ test("GIVEN an invalid API key WHEN calling setOpenAiApiKey THEN it should set t
expect(initQAModel).not.toHaveBeenCalled();
});

test("GIVEN chat history exceeds max token number WHEN applying filter THEN it should return the filtered chat history", () => {
const maxTokens = 50;
const chatHistory: ChatHistoryMessage[] = [
{
completion: {
role: "user",
content: "Hello, my name is Bob.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
{
completion: {
role: "assistant",
content: "Hello, how are you?",
},
numTokens: 17,
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];
// expect that the first message is discounted
const expectedFilteredChatHistory = [
{
completion: {
role: "assistant",
content: "Hello, how are you?",
},
numTokens: 17,
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];

const filteredChatHistory = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);
expect(filteredChatHistory).toEqual(expectedFilteredChatHistory);
});

test("GIVEN chat history does not exceed max token number WHEN applying filter THEN it should return the original chat history", () => {
const maxTokens = 1000;
const chatHistory: ChatHistoryMessage[] = [
{
completion: {
role: "user",
content: "Hello, my name is Bob.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
{
completion: {
role: "assistant",
content: "Hello, how are you?",
},
numTokens: 17,
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];

const filteredChatHistory = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);
expect(filteredChatHistory).toEqual(chatHistory);
});

test("GIVEN chat history exceeds max token number WHEN applying filter AND there is a system role in chat history THEN it should return the filtered chat history", () => {
const maxTokens = 50;
const chatHistory: ChatHistoryMessage[] = [
{
completion: {
role: "system",
content: "You are a helpful chatbot.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM,
},
{
completion: {
role: "user",
content: "Hello, my name is Bob.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
{
completion: {
role: "assistant",
content: "Hello, how are you?",
},
numTokens: 17,
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];

const expectedFilteredChatHistory = [
{
completion: {
role: "system",
content: "You are a helpful chatbot.",
},
numTokens: 15,
chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM,
},
{
completion: {
role: "user",
content: "Send an email to my boss to tell him I quit.",
},
numTokens: 30,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];
const filteredChatHistory = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);
expect(filteredChatHistory.length).toEqual(2);
expect(filteredChatHistory).toEqual(expectedFilteredChatHistory);
});

test("GIVEN chat history most recent message exceeds max tokens alone WHEN applying filter THEN it should return this message", () => {
const maxTokens = 30;
const chatHistory: ChatHistoryMessage[] = [
{
completion: {
role: "user",
content:
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. ",
},
numTokens: 50,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
];
const filteredChatHistory = filterChatHistoryByMaxTokens(
chatHistory,
maxTokens
);

expect(filteredChatHistory).toEqual(chatHistory);
});

afterEach(() => {
jest.clearAllMocks();
});
Loading

0 comments on commit 00d1cf8

Please sign in to comment.