diff --git a/backend/src/controller/chatController.ts b/backend/src/controller/chatController.ts index 1ea284f6b..00d9b9ae6 100644 --- a/backend/src/controller/chatController.ts +++ b/backend/src/controller/chatController.ts @@ -23,7 +23,10 @@ import { Defence } from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; import { LEVEL_NAMES } from '@src/models/level'; import { chatGptSendMessage } from '@src/openai'; -import { pushMessageToHistory } from '@src/utils/chat'; +import { + pushMessageToHistory, + setSystemRoleInChatHistory, +} from '@src/utils/chat'; import { handleChatError } from './handleError'; @@ -233,9 +236,12 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) { ? req.session.chatModel : defaultChatModel; - const currentChatHistory = [ - ...req.session.levelState[currentLevel].chatHistory, - ]; + const currentChatHistory = setSystemRoleInChatHistory( + currentLevel, + req.session.levelState[currentLevel].defences, + req.session.levelState[currentLevel].chatHistory + ); + const defences = [...req.session.levelState[currentLevel].defences]; let levelResult: LevelHandlerResponse; diff --git a/backend/src/openai.ts b/backend/src/openai.ts index cfccff91f..680e231f6 100644 --- a/backend/src/openai.ts +++ b/backend/src/openai.ts @@ -3,14 +3,9 @@ import { ChatCompletionMessageParam, ChatCompletionTool, ChatCompletionMessageToolCall, - ChatCompletionSystemMessageParam, } from 'openai/resources/chat/completions'; -import { - isDefenceActive, - getSystemRole, - getQAPromptFromConfig, -} from './defence'; +import { isDefenceActive, getQAPromptFromConfig } from './defence'; import { sendEmail } from './email'; import { queryDocuments } from './langchain'; import { @@ -250,48 +245,11 @@ async function chatGptCallFunction( async function chatGptChatCompletion( chatHistory: ChatHistoryMessage[], - defences: Defence[], chatModel: ChatModel, - openai: OpenAI, - // default to sandbox - currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX -): Promise { + openai: OpenAI +) { const updatedChatHistory = [...chatHistory]; - // check if we need to set a system role - // system role is always active on levels - if ( - currentLevel !== LEVEL_NAMES.SANDBOX || - isDefenceActive(DEFENCE_ID.SYSTEM_ROLE, defences) - ) { - const completionConfig: ChatCompletionSystemMessageParam = { - role: 'system', - content: getSystemRole(defences, currentLevel), - }; - - // check to see if there's already a system role - const systemRole = chatHistory.find( - (message) => message.completion?.role === 'system' - ); - if (!systemRole) { - // add the system role to the start of the chat history - updatedChatHistory.unshift({ - completion: completionConfig, - chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, - }); - } else { - // replace with the latest system role - systemRole.completion = completionConfig; - } - } else { - // remove the system role from the chat history - while ( - updatedChatHistory.length > 0 && - updatedChatHistory[0].completion?.role === 'system' - ) { - updatedChatHistory.shift(); - } - } console.debug('Talking to model: ', JSON.stringify(chatModel)); // get start time @@ -376,7 +334,6 @@ async function performToolCalls( ): Promise { for (const toolCall of toolCalls) { // only tool type supported by openai is function - // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition if (toolCall.type === 'function') { const functionCallReply = await chatGptCallFunction( @@ -385,6 +342,7 @@ async function performToolCalls( toolCall.function, currentLevel ); + // return after getting function reply. may change when we support other tool types. We assume only one function call in toolCalls return { functionCallReply, @@ -417,10 +375,8 @@ async function getFinalReplyAfterAllToolCalls( do { gptReply = await chatGptChatCompletion( updatedChatHistory, - defences, chatModel, - openai, - currentLevel + openai ); updatedChatHistory = gptReply.chatHistory; diff --git a/backend/src/utils/chat.ts b/backend/src/utils/chat.ts index 4dfab5eaa..ce6cf675e 100644 --- a/backend/src/utils/chat.ts +++ b/backend/src/utils/chat.ts @@ -1,4 +1,9 @@ -import { ChatHistoryMessage } from '@src/models/chat'; +import { ChatCompletionSystemMessageParam } from 'openai/resources/chat/completions'; + +import { getSystemRole, isDefenceActive } from '@src/defence'; +import { CHAT_MESSAGE_TYPE, ChatHistoryMessage } from '@src/models/chat'; +import { DEFENCE_ID, Defence } from '@src/models/defence'; +import { LEVEL_NAMES } from '@src/models/level'; function pushMessageToHistory( chatHistory: ChatHistoryMessage[], @@ -21,4 +26,46 @@ function pushMessageToHistory( return updatedChatHistory; } -export { pushMessageToHistory }; +function setSystemRoleInChatHistory( + currentLevel: LEVEL_NAMES, + defences: Defence[], + chatHistory: ChatHistoryMessage[] +) { + const systemRoleNeededInChatHistory = + currentLevel !== LEVEL_NAMES.SANDBOX || + isDefenceActive(DEFENCE_ID.SYSTEM_ROLE, defences); + + if (systemRoleNeededInChatHistory) { + const completionConfig: ChatCompletionSystemMessageParam = { + role: 'system', + content: getSystemRole(defences, currentLevel), + }; + + const existingSystemRole = chatHistory.find( + (message) => message.completion?.role === 'system' + ); + if (!existingSystemRole) { + return [ + { + completion: completionConfig, + chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, + }, + ...chatHistory, + ]; + } else { + return chatHistory.map((message) => { + if (message.completion?.role === 'system') { + return { ...existingSystemRole, completion: completionConfig }; + } else { + return message; + } + }); + } + } else { + return chatHistory.filter( + (message) => message.completion?.role !== 'system' + ); + } +} + +export { pushMessageToHistory, setSystemRoleInChatHistory }; diff --git a/backend/test/integration/openai.test.ts b/backend/test/integration/openai.test.ts index d851d9499..6719663f9 100644 --- a/backend/test/integration/openai.test.ts +++ b/backend/test/integration/openai.test.ts @@ -1,16 +1,9 @@ import { expect, jest, test, describe } from '@jest/globals'; import { defaultDefences } from '@src/defaultDefences'; -import { activateDefence, configureDefence } from '@src/defence'; -import { - CHAT_MESSAGE_TYPE, - CHAT_MODELS, - ChatHistoryMessage, - ChatModel, -} from '@src/models/chat'; -import { DEFENCE_ID, Defence } from '@src/models/defence'; +import { CHAT_MODELS, ChatHistoryMessage, ChatModel } from '@src/models/chat'; +import { Defence } from '@src/models/defence'; import { chatGptSendMessage } from '@src/openai'; -import { systemRoleDefault } from '@src/promptTemplates'; const mockCreateChatCompletion = jest.fn<() => Promise>>(); @@ -87,251 +80,4 @@ describe('OpenAI Integration Tests', () => { // restore the mock mockCreateChatCompletion.mockRestore(); }); - - test('GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role is added to chat history', async () => { - const message = 'Hello'; - const initChatHistory: ChatHistoryMessage[] = []; - const chatModel: ChatModel = { - id: CHAT_MODELS.GPT_4, - configuration: { - temperature: 1, - topP: 1, - frequencyPenalty: 0, - presencePenalty: 0, - }, - }; - - // set the system role prompt - const defences = activateDefence(DEFENCE_ID.SYSTEM_ROLE, defaultDefences); - - // Mock the createChatCompletion function - mockCreateChatCompletion.mockResolvedValueOnce(chatResponseAssistant('Hi')); - - // send the message - const reply = await chatGptSendMessage( - initChatHistory, - defences, - chatModel, - message - ); - - const { chatResponse, chatHistory } = reply; - - expect(reply).toBeDefined(); - expect(chatResponse.completion?.content).toBe('Hi'); - // check the chat history has been updated - expect(chatHistory.length).toBe(1); - // system role is added to the start of the chat history - expect(chatHistory[0].completion?.role).toBe('system'); - expect(chatHistory[0].completion?.content).toBe(systemRoleDefault); - - // restore the mock - mockCreateChatCompletion.mockRestore(); - }); - - test('GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role is added to the start of the chat history', async () => { - const message = 'Hello'; - const initChatHistory: ChatHistoryMessage[] = [ - { - completion: { - role: 'user', - content: "I'm a user", - }, - chatMessageType: CHAT_MESSAGE_TYPE.USER, - }, - { - completion: { - role: 'assistant', - content: "I'm an assistant", - }, - chatMessageType: CHAT_MESSAGE_TYPE.BOT, - }, - ]; - const chatModel: ChatModel = { - id: CHAT_MODELS.GPT_4, - configuration: { - temperature: 1, - topP: 1, - frequencyPenalty: 0, - presencePenalty: 0, - }, - }; - // activate the SYSTEM_ROLE defence - const defences = activateDefence(DEFENCE_ID.SYSTEM_ROLE, defaultDefences); - - // Mock the createChatCompletion function - mockCreateChatCompletion.mockResolvedValueOnce(chatResponseAssistant('Hi')); - - // send the message - const reply = await chatGptSendMessage( - initChatHistory, - defences, - chatModel, - message - ); - - const { chatResponse, chatHistory } = reply; - - expect(reply).toBeDefined(); - expect(chatResponse.completion?.content).toBe('Hi'); - // check the chat history has been updated - expect(chatHistory.length).toBe(3); - // system role is added to the start of the chat history - expect(chatHistory[0].completion?.role).toBe('system'); - expect(chatHistory[0].completion?.content).toBe(systemRoleDefault); - // rest of the chat history is in order - expect(chatHistory[1].completion?.role).toBe('user'); - expect(chatHistory[1].completion?.content).toBe("I'm a user"); - expect(chatHistory[2].completion?.role).toBe('assistant'); - expect(chatHistory[2].completion?.content).toBe("I'm an assistant"); - - // restore the mock - mockCreateChatCompletion.mockRestore(); - }); - - test('GIVEN SYSTEM_ROLE defence is inactive WHEN sending message THEN system role is removed from the chat history', async () => { - const message = 'Hello'; - const initChatHistory: ChatHistoryMessage[] = [ - { - completion: { - role: 'system', - content: 'You are a helpful assistant', - }, - chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, - }, - { - completion: { - role: 'user', - content: "I'm a user", - }, - chatMessageType: CHAT_MESSAGE_TYPE.USER, - }, - { - completion: { - role: 'assistant', - content: "I'm an assistant", - }, - chatMessageType: CHAT_MESSAGE_TYPE.BOT, - }, - ]; - const defences: Defence[] = defaultDefences; - const chatModel: ChatModel = { - id: CHAT_MODELS.GPT_4, - configuration: { - temperature: 1, - topP: 1, - frequencyPenalty: 0, - presencePenalty: 0, - }, - }; - - // Mock the createChatCompletion function - mockCreateChatCompletion.mockResolvedValueOnce(chatResponseAssistant('Hi')); - - // send the message - const reply = await chatGptSendMessage( - initChatHistory, - defences, - chatModel, - message - ); - - const { chatResponse, chatHistory } = reply; - - expect(reply).toBeDefined(); - expect(chatResponse.completion?.content).toBe('Hi'); - // check the chat history has been updated - expect(chatHistory.length).toBe(2); - // system role is removed from the start of the chat history - // rest of the chat history is in order - expect(chatHistory[0].completion?.role).toBe('user'); - expect(chatHistory[0].completion?.content).toBe("I'm a user"); - expect(chatHistory[1].completion?.role).toBe('assistant'); - expect(chatHistory[1].completion?.content).toBe("I'm an assistant"); - - // restore the mock - mockCreateChatCompletion.mockRestore(); - }); - - test( - 'GIVEN SYSTEM_ROLE defence is configured AND the system role is already in the chat history ' + - 'WHEN sending message THEN system role is replaced with default value in the chat history', - async () => { - const message = 'Hello'; - const initChatHistory: ChatHistoryMessage[] = [ - { - completion: { - role: 'system', - content: 'You are a helpful assistant', - }, - chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, - }, - { - completion: { - role: 'user', - content: "I'm a user", - }, - chatMessageType: CHAT_MESSAGE_TYPE.USER, - }, - { - completion: { - role: 'assistant', - content: "I'm an assistant", - }, - chatMessageType: CHAT_MESSAGE_TYPE.BOT, - }, - ]; - const chatModel: ChatModel = { - id: CHAT_MODELS.GPT_4, - configuration: { - temperature: 1, - topP: 1, - frequencyPenalty: 0, - presencePenalty: 0, - }, - }; - - const defences = configureDefence( - DEFENCE_ID.SYSTEM_ROLE, - activateDefence(DEFENCE_ID.SYSTEM_ROLE, defaultDefences), - [ - { - id: 'SYSTEM_ROLE', - value: 'You are not a helpful assistant', - }, - ] - ); - - // Mock the createChatCompletion function - mockCreateChatCompletion.mockResolvedValueOnce( - chatResponseAssistant('Hi') - ); - - // send the message - const reply = await chatGptSendMessage( - initChatHistory, - defences, - chatModel, - message - ); - - const { chatResponse, chatHistory } = reply; - - expect(reply).toBeDefined(); - expect(chatResponse.completion?.content).toBe('Hi'); - // system role is added to the start of the chat history - expect(chatHistory[0].completion?.role).toBe('system'); - expect(chatHistory[0].completion?.content).toBe( - 'You are not a helpful assistant' - ); - // rest of the chat history is in order - expect(chatHistory[1].completion?.role).toBe('user'); - expect(chatHistory[1].completion?.content).toBe("I'm a user"); - expect(chatHistory[2].completion?.role).toBe('assistant'); - expect(chatHistory[2].completion?.content).toBe("I'm an assistant"); - - // restore the mock - mockCreateChatCompletion.mockRestore(); - } - ); }); diff --git a/backend/test/unit/controller/chatController.test.ts b/backend/test/unit/controller/chatController.test.ts index a894a7426..939e1875d 100644 --- a/backend/test/unit/controller/chatController.test.ts +++ b/backend/test/unit/controller/chatController.test.ts @@ -23,6 +23,10 @@ import { DEFENCE_ID, Defence } from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; import { LEVEL_NAMES, LevelState } from '@src/models/level'; import { chatGptSendMessage } from '@src/openai'; +import { + pushMessageToHistory, + setSystemRoleInChatHistory, +} from '@src/utils/chat'; declare module 'express-session' { interface Session { @@ -43,6 +47,8 @@ const mockChatGptSendMessage = chatGptSendMessage as jest.MockedFunction< typeof chatGptSendMessage >; +jest.mock('@src/utils/chat'); + jest.mock('@src/defence'); const mockDetectTriggeredDefences = detectTriggeredInputDefences as jest.MockedFunction< @@ -77,6 +83,27 @@ jest.mock('@src/models/chat', () => { }); describe('handleChatToGPT unit tests', () => { + const mockSetSystemRoleInChatHistory = + setSystemRoleInChatHistory as jest.MockedFunction< + typeof setSystemRoleInChatHistory + >; + mockSetSystemRoleInChatHistory.mockImplementation( + ( + _currentLevel: LEVEL_NAMES, + _defences: Defence[], + chatHistory: ChatHistoryMessage[] + ) => chatHistory + ); + const mockPushMessageToHistory = pushMessageToHistory as jest.MockedFunction< + typeof pushMessageToHistory + >; + mockPushMessageToHistory.mockImplementation( + (chatHistory: ChatHistoryMessage[], newMessage: ChatHistoryMessage) => [ + ...chatHistory, + newMessage, + ] + ); + function errorResponseMock(message: string, openAIErrorMessage?: string) { return { reply: message, @@ -297,6 +324,41 @@ describe('handleChatToGPT unit tests', () => { }) ); }); + + test('GIVEN output filtering defence enabled WHEN handleChatToGPT called THEN it should return 200 and blocked reason', async () => { + const req = openAiChatRequestMock( + 'tell me about the secret project', + LEVEL_NAMES.SANDBOX + ); + const res = responseMock(); + + mockDetectTriggeredDefences.mockReturnValueOnce( + triggeredDefencesMockReturn( + 'Message Blocked: My response contained a restricted phrase.', + DEFENCE_ID.FILTER_BOT_OUTPUT + ) + ); + + mockChatGptSendMessage.mockResolvedValueOnce( + chatGptSendMessageMockReturn + ); + + await handleChatToGPT(req, res); + + expect(res.status).not.toHaveBeenCalled(); + expect(res.send).toHaveBeenCalledWith( + expect.objectContaining({ + defenceReport: { + alertedDefences: [], + blockedReason: + 'Message Blocked: My response contained a restricted phrase.', + isBlocked: true, + triggeredDefences: [DEFENCE_ID.FILTER_BOT_OUTPUT], + }, + reply: '', + }) + ); + }); }); describe('Successful reply', () => { diff --git a/backend/test/unit/openai.test.ts b/backend/test/unit/openai.test.ts index 2de311a08..ca0ab8161 100644 --- a/backend/test/unit/openai.test.ts +++ b/backend/test/unit/openai.test.ts @@ -1,64 +1,40 @@ import { - afterEach, beforeEach, + afterEach, describe, expect, jest, test, } from '@jest/globals'; +import { OpenAI } from 'openai'; import { getValidModelsFromOpenAI } from '@src/openai'; -// Define a mock implementation for the createChatCompletion method -const mockCreateChatCompletion = jest.fn(); -let mockModelList: { id: string }[] = []; -jest.mock('openai', () => ({ - OpenAI: jest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreateChatCompletion, - }, - }, - models: { - list: jest.fn().mockImplementation(() => ({ - data: mockModelList, - })), - }, - })), -})); +jest.mock('openai'); +jest.mock('@src/defence'); -jest.mock('@src/openai', () => { - const originalModule = - jest.requireActual('@src/openai'); - return { - ...originalModule, - initOpenAi: jest.fn(), - getOpenAI: jest.fn(), - }; -}); +describe('getValidModelsFromOpenAI', () => { + const mockListFn = jest.fn(); + jest.mocked(OpenAI).mockImplementation( + () => + ({ + models: { + list: mockListFn, + }, + } as unknown as jest.MockedObject) + ); -jest.mock('@src/langchain', () => { - const originalModule = - jest.requireActual('@src/langchain'); - return { - ...originalModule, - initQAModel: jest.fn(), - initDocumentVectors: jest.fn(), - }; -}); + beforeEach(() => { + process.env = {}; + }); -beforeEach(() => { - // clear environment variables - process.env = {}; -}); -afterEach(() => { - mockCreateChatCompletion.mockReset(); -}); + afterEach(() => { + mockListFn.mockReset(); + }); -describe('openAI unit tests', () => { - test('GIVEN the user has an openAI key WHEN getValidModelsFromOpenAI is called THEN it returns the models in CHAT_MODELS enum', async () => { + test('GIVEN the user has an openAI key WHEN getValidModelsFromOpenAI is called THEN it returns only the models that are also in the CHAT_MODELS enum', async () => { process.env.OPENAI_API_KEY = 'sk-12345'; - mockModelList = [ + const mockModelList = [ { id: 'gpt-3.5-turbo' }, { id: 'gpt-3.5-turbo-0613' }, { id: 'gpt-4' }, @@ -72,10 +48,13 @@ describe('openAI unit tests', () => { 'gpt-4', 'gpt-4-0613', ]; + + mockListFn.mockResolvedValueOnce({ + data: mockModelList, + } as OpenAI.ModelsPage); + const validModels = await getValidModelsFromOpenAI(); + expect(validModels).toEqual(expectedValidModels); }); }); -afterEach(() => { - jest.clearAllMocks(); -}); diff --git a/backend/test/unit/utils/chat.test.ts b/backend/test/unit/utils/chat.test.ts deleted file mode 100644 index 5eddf49e2..000000000 --- a/backend/test/unit/utils/chat.test.ts +++ /dev/null @@ -1,134 +0,0 @@ -import { expect, test, describe } from '@jest/globals'; - -import { CHAT_MESSAGE_TYPE, ChatHistoryMessage } from '@src/models/chat'; -import { pushMessageToHistory } from '@src/utils/chat'; - -describe('chat utils unit tests', () => { - const maxChatHistoryLength = 1000; - const systemRoleMessage: ChatHistoryMessage = { - completion: { - role: 'system', - content: 'You are an AI.', - }, - chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, - }; - const generalChatMessage: ChatHistoryMessage = { - completion: { - role: 'user', - content: 'hello world', - }, - chatMessageType: CHAT_MESSAGE_TYPE.USER, - }; - - test( - 'GIVEN no chat history ' + - 'WHEN adding a new chat message ' + - 'THEN new message is added', - () => { - const chatHistory: ChatHistoryMessage[] = []; - const updatedChatHistory = pushMessageToHistory( - chatHistory, - generalChatMessage - ); - expect(updatedChatHistory.length).toBe(1); - expect(updatedChatHistory[0]).toEqual(generalChatMessage); - } - ); - - test( - 'GIVEN chat history with length < maxChatHistoryLength ' + - 'WHEN adding a new chat message ' + - 'THEN new message is added', - () => { - const chatHistory: ChatHistoryMessage[] = [generalChatMessage]; - const updatedChatHistory = pushMessageToHistory( - chatHistory, - generalChatMessage - ); - expect(updatedChatHistory.length).toBe(2); - expect(updatedChatHistory[1]).toEqual(generalChatMessage); - } - ); - - test( - "GIVEN chat history with length === maxChatHistoryLength AND there's no system role" + - 'WHEN adding a new chat message ' + - 'THEN new message is added AND the oldest message is removed', - () => { - const chatHistory: ChatHistoryMessage[] = new Array( - maxChatHistoryLength - ).fill(generalChatMessage); - const updatedChatHistory = pushMessageToHistory( - chatHistory, - generalChatMessage - ); - expect(updatedChatHistory.length).toBe(maxChatHistoryLength); - expect(updatedChatHistory[0]).toEqual(generalChatMessage); - expect(updatedChatHistory[updatedChatHistory.length - 1]).toEqual( - generalChatMessage - ); - } - ); - - test( - 'GIVEN chat history with length === maxChatHistoryLength AND the oldest message is a system role message ' + - 'WHEN adding a new chat message ' + - 'THEN new message is added AND the oldest non-system-role message is removed', - () => { - const chatHistory: ChatHistoryMessage[] = new Array( - maxChatHistoryLength - ).fill(generalChatMessage); - chatHistory[0] = systemRoleMessage; - const updatedChatHistory = pushMessageToHistory( - chatHistory, - generalChatMessage - ); - expect(updatedChatHistory.length).toBe(maxChatHistoryLength); - expect(updatedChatHistory[0]).toEqual(systemRoleMessage); - expect(updatedChatHistory[updatedChatHistory.length - 1]).toEqual( - generalChatMessage - ); - } - ); - - test( - "GIVEN chat history with length > maxChatHistoryLength AND there's no system role" + - 'WHEN adding a new chat message ' + - 'THEN new message is added AND the oldest messages are removed until the length is maxChatHistoryLength', - () => { - const chatHistory: ChatHistoryMessage[] = new Array( - maxChatHistoryLength + 1 - ).fill(generalChatMessage); - const updatedChatHistory = pushMessageToHistory( - chatHistory, - generalChatMessage - ); - expect(updatedChatHistory.length).toBe(maxChatHistoryLength); - expect(updatedChatHistory[0]).toEqual(generalChatMessage); - expect(updatedChatHistory[updatedChatHistory.length - 1]).toEqual( - generalChatMessage - ); - } - ); - - test( - 'GIVEN chat history with length > maxChatHistoryLength AND the oldest message is a system role message ' + - 'WHEN adding a new chat message ' + - 'THEN new message is added AND the oldest non-system-role messages are removed until the length is maxChatHistoryLength', - () => { - const chatHistory: ChatHistoryMessage[] = new Array( - maxChatHistoryLength + 1 - ).fill(generalChatMessage); - chatHistory[0] = systemRoleMessage; - const updatedChatHistory = pushMessageToHistory( - chatHistory, - generalChatMessage - ); - expect(updatedChatHistory.length).toBe(maxChatHistoryLength); - expect(updatedChatHistory[0]).toEqual(systemRoleMessage); - expect(updatedChatHistory[updatedChatHistory.length - 1]).toEqual( - generalChatMessage - ); - } - ); -}); diff --git a/backend/test/unit/utils/chat.ts/pushMessageToHistory.test.ts b/backend/test/unit/utils/chat.ts/pushMessageToHistory.test.ts new file mode 100644 index 000000000..1af6a96cc --- /dev/null +++ b/backend/test/unit/utils/chat.ts/pushMessageToHistory.test.ts @@ -0,0 +1,132 @@ +import { expect, test } from '@jest/globals'; + +import { CHAT_MESSAGE_TYPE, ChatHistoryMessage } from '@src/models/chat'; +import { pushMessageToHistory } from '@src/utils/chat'; + +const maxChatHistoryLength = 1000; +const systemRoleMessage: ChatHistoryMessage = { + completion: { + role: 'system', + content: 'You are an AI.', + }, + chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, +}; +const generalChatMessage: ChatHistoryMessage = { + completion: { + role: 'user', + content: 'hello world', + }, + chatMessageType: CHAT_MESSAGE_TYPE.USER, +}; + +test( + 'GIVEN no chat history ' + + 'WHEN adding a new chat message ' + + 'THEN new message is added', + () => { + const chatHistory: ChatHistoryMessage[] = []; + const updatedChatHistory = pushMessageToHistory( + chatHistory, + generalChatMessage + ); + expect(updatedChatHistory.length).toBe(1); + expect(updatedChatHistory[0]).toEqual(generalChatMessage); + } +); + +test( + 'GIVEN chat history with length < maxChatHistoryLength ' + + 'WHEN adding a new chat message ' + + 'THEN new message is added', + () => { + const chatHistory: ChatHistoryMessage[] = [generalChatMessage]; + const updatedChatHistory = pushMessageToHistory( + chatHistory, + generalChatMessage + ); + expect(updatedChatHistory.length).toBe(2); + expect(updatedChatHistory[1]).toEqual(generalChatMessage); + } +); + +test( + "GIVEN chat history with length === maxChatHistoryLength AND there's no system role" + + 'WHEN adding a new chat message ' + + 'THEN new message is added AND the oldest message is removed', + () => { + const chatHistory: ChatHistoryMessage[] = new Array( + maxChatHistoryLength + ).fill(generalChatMessage); + const updatedChatHistory = pushMessageToHistory( + chatHistory, + generalChatMessage + ); + expect(updatedChatHistory.length).toBe(maxChatHistoryLength); + expect(updatedChatHistory[0]).toEqual(generalChatMessage); + expect(updatedChatHistory[updatedChatHistory.length - 1]).toEqual( + generalChatMessage + ); + } +); + +test( + 'GIVEN chat history with length === maxChatHistoryLength AND the oldest message is a system role message ' + + 'WHEN adding a new chat message ' + + 'THEN new message is added AND the oldest non-system-role message is removed', + () => { + const chatHistory: ChatHistoryMessage[] = new Array( + maxChatHistoryLength + ).fill(generalChatMessage); + chatHistory[0] = systemRoleMessage; + const updatedChatHistory = pushMessageToHistory( + chatHistory, + generalChatMessage + ); + expect(updatedChatHistory.length).toBe(maxChatHistoryLength); + expect(updatedChatHistory[0]).toEqual(systemRoleMessage); + expect(updatedChatHistory[updatedChatHistory.length - 1]).toEqual( + generalChatMessage + ); + } +); + +test( + "GIVEN chat history with length > maxChatHistoryLength AND there's no system role" + + 'WHEN adding a new chat message ' + + 'THEN new message is added AND the oldest messages are removed until the length is maxChatHistoryLength', + () => { + const chatHistory: ChatHistoryMessage[] = new Array( + maxChatHistoryLength + 1 + ).fill(generalChatMessage); + const updatedChatHistory = pushMessageToHistory( + chatHistory, + generalChatMessage + ); + expect(updatedChatHistory.length).toBe(maxChatHistoryLength); + expect(updatedChatHistory[0]).toEqual(generalChatMessage); + expect(updatedChatHistory[updatedChatHistory.length - 1]).toEqual( + generalChatMessage + ); + } +); + +test( + 'GIVEN chat history with length > maxChatHistoryLength AND the oldest message is a system role message ' + + 'WHEN adding a new chat message ' + + 'THEN new message is added AND the oldest non-system-role messages are removed until the length is maxChatHistoryLength', + () => { + const chatHistory: ChatHistoryMessage[] = new Array( + maxChatHistoryLength + 1 + ).fill(generalChatMessage); + chatHistory[0] = systemRoleMessage; + const updatedChatHistory = pushMessageToHistory( + chatHistory, + generalChatMessage + ); + expect(updatedChatHistory.length).toBe(maxChatHistoryLength); + expect(updatedChatHistory[0]).toEqual(systemRoleMessage); + expect(updatedChatHistory[updatedChatHistory.length - 1]).toEqual( + generalChatMessage + ); + } +); diff --git a/backend/test/unit/utils/chat.ts/setSystemRoleInChatHistory.test.ts b/backend/test/unit/utils/chat.ts/setSystemRoleInChatHistory.test.ts new file mode 100644 index 000000000..7bacf5bdb --- /dev/null +++ b/backend/test/unit/utils/chat.ts/setSystemRoleInChatHistory.test.ts @@ -0,0 +1,130 @@ +import { afterEach, expect, jest, test } from '@jest/globals'; + +import { isDefenceActive, getSystemRole } from '@src/defence'; +import { ChatHistoryMessage, CHAT_MESSAGE_TYPE } from '@src/models/chat'; +import { Defence, DEFENCE_ID } from '@src/models/defence'; +import { LEVEL_NAMES } from '@src/models/level'; +import { setSystemRoleInChatHistory } from '@src/utils/chat'; + +const systemRolePrompt = 'You are a helpful chatbot that answers questions.'; +const defencesSystemRoleInactive: Defence[] = [ + { + id: DEFENCE_ID.SYSTEM_ROLE, + config: [ + { + id: 'SYSTEM_ROLE', + value: systemRolePrompt, + }, + ], + isActive: false, + isTriggered: false, + }, +]; +const defencesSystemRoleActive = [ + { ...defencesSystemRoleInactive[0], isActive: true }, +]; +const chatHistoryWithoutSystemRole: ChatHistoryMessage[] = [ + { + completion: { role: 'user', content: 'What is two plus two?' }, + chatMessageType: CHAT_MESSAGE_TYPE.USER, + }, + { + completion: { role: 'assistant', content: 'Two plus two equals four.' }, + chatMessageType: CHAT_MESSAGE_TYPE.BOT, + }, +]; + +const chatHistoryWithSystemRole: ChatHistoryMessage[] = [ + { + completion: { role: 'system', content: systemRolePrompt }, + chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, + }, + ...chatHistoryWithoutSystemRole, +]; + +jest.mock('@src/defence'); +const mockIsDefenceActive = isDefenceActive as jest.MockedFunction< + typeof isDefenceActive +>; + +(getSystemRole as jest.MockedFunction).mockReturnValue( + systemRolePrompt +); + +afterEach(() => { + mockIsDefenceActive.mockReset(); + jest.clearAllMocks(); +}); + +test('GIVEN level 1 AND system role is not in chat history WHEN setSystemRoleInChatHistory is called THEN it adds the system role to the chat history', () => { + const chatHistory = setSystemRoleInChatHistory( + LEVEL_NAMES.LEVEL_1, + defencesSystemRoleActive, + chatHistoryWithoutSystemRole + ); + + expect(chatHistory).toEqual(chatHistoryWithSystemRole); +}); + +test('GIVEN level 1 AND system role is in chat history WHEN setSystemRoleInChatHistory is called THEN no change to the chat history', () => { + const chatHistory = setSystemRoleInChatHistory( + LEVEL_NAMES.LEVEL_1, + defencesSystemRoleActive, + chatHistoryWithSystemRole + ); + + expect(chatHistory).toEqual(chatHistoryWithSystemRole); +}); + +test('GIVEN Sandbox AND system role defence active AND system role is not in chat history WHEN setSystemRoleInChatHistory is called THEN it adds the system role to the chat history', () => { + mockIsDefenceActive.mockImplementation(() => true); + const chatHistory = setSystemRoleInChatHistory( + LEVEL_NAMES.SANDBOX, + defencesSystemRoleActive, + chatHistoryWithoutSystemRole + ); + + expect(chatHistory).toEqual(chatHistoryWithSystemRole); +}); + +test('GIVEN Sandbox AND system role defence active AND outdated system role in in chat history WHEN setSystemRoleInChatHistory is called THEN it updates the system role in the chat history', () => { + mockIsDefenceActive.mockImplementation(() => true); + + const mockChatHistoryWithOutdatedSystemRole: ChatHistoryMessage[] = [ + { + completion: { role: 'system', content: 'Yer a wizard, Harry.' }, + chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, + }, + ...chatHistoryWithoutSystemRole, + ]; + + const chatHistory = setSystemRoleInChatHistory( + LEVEL_NAMES.SANDBOX, + defencesSystemRoleActive, + mockChatHistoryWithOutdatedSystemRole + ); + + expect(chatHistory).toEqual(chatHistoryWithSystemRole); +}); + +test('GIVEN Sandbox AND system role defence not active AND system role is in chat history WHEN setSystemRoleInChatHistory is called THEN it removes the system role from the chat history', () => { + mockIsDefenceActive.mockImplementation(() => false); + const chatHistory = setSystemRoleInChatHistory( + LEVEL_NAMES.SANDBOX, + defencesSystemRoleActive, + chatHistoryWithSystemRole + ); + + expect(chatHistory).toEqual(chatHistoryWithoutSystemRole); +}); + +test('GIVEN Sandbox AND system role defence not active AND system role is not in chat history WHEN setSystemRoleInChatHistory is called THEN no change to the chat history', () => { + mockIsDefenceActive.mockImplementation(() => false); + const chatHistory = setSystemRoleInChatHistory( + LEVEL_NAMES.SANDBOX, + defencesSystemRoleActive, + chatHistoryWithoutSystemRole + ); + + expect(chatHistory).toEqual(chatHistoryWithoutSystemRole); +});