diff --git a/backend/src/controller/chatController.ts b/backend/src/controller/chatController.ts index 9b1e01cc2..8571064cf 100644 --- a/backend/src/controller/chatController.ts +++ b/backend/src/controller/chatController.ts @@ -10,7 +10,7 @@ import { OpenAiChatRequest } from '@src/models/api/OpenAiChatRequest'; import { OpenAiClearRequest } from '@src/models/api/OpenAiClearRequest'; import { OpenAiGetHistoryRequest } from '@src/models/api/OpenAiGetHistoryRequest'; import { - ChatDefenceReport, + DefenceReport, ChatHttpResponse, ChatModel, LevelHandlerResponse, @@ -20,7 +20,7 @@ import { import { ChatMessage, ChatInfoMessage, - chatInfoMessageType, + chatInfoMessageTypes, } from '@src/models/chatMessage'; import { Defence } from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; @@ -33,9 +33,7 @@ import { import { handleChatError } from './handleError'; -function combineChatDefenceReports( - reports: ChatDefenceReport[] -): ChatDefenceReport { +function combineDefenceReports(reports: DefenceReport[]): DefenceReport { return { blockedReason: reports .filter((report) => report.blockedReason !== null) @@ -100,17 +98,17 @@ async function handleChatWithoutDefenceDetection( chatHistory: ChatMessage[], defences: Defence[] ): Promise { + console.log(`User message: '${message}'`); + const updatedChatHistory = createNewUserMessages(message).reduce( pushMessageToHistory, chatHistory ); - // get the chatGPT reply const openAiReply = await chatGptSendMessage( updatedChatHistory, defences, chatModel, - message, currentLevel ); @@ -146,11 +144,16 @@ async function handleChatWithDefenceDetection( defences ); + console.log( + `User message: '${ + messageTransformation?.transformedMessageCombined ?? message + }'` + ); + const openAiReplyPromise = chatGptSendMessage( chatHistoryWithNewUserMessages, defences, chatModel, - messageTransformation?.transformedMessageCombined ?? message, currentLevel ); @@ -168,7 +171,7 @@ async function handleChatWithDefenceDetection( const defenceReports = outputDefenceReport ? [inputDefenceReport, outputDefenceReport] : [inputDefenceReport]; - const combinedDefenceReport = combineChatDefenceReports(defenceReports); + const combinedDefenceReport = combineDefenceReports(defenceReports); // if blocked, restore original chat history and add user message to chat history without completion const updatedChatHistory = combinedDefenceReport.isBlocked @@ -196,7 +199,6 @@ async function handleChatWithDefenceDetection( } async function handleChatToGPT(req: OpenAiChatRequest, res: Response) { - // set reply params const initChatResponse: ChatHttpResponse = { reply: '', defenceReport: { @@ -232,9 +234,6 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) { ); return; } - const totalSentEmails: EmailInfo[] = [ - ...req.session.levelState[currentLevel].sentEmails, - ]; // use default model for levels, allow user to select in sandbox const chatModel = @@ -283,7 +282,11 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) { } let updatedChatHistory = levelResult.chatHistory; - totalSentEmails.push(...levelResult.chatResponse.sentEmails); + + const totalSentEmails: EmailInfo[] = [ + ...req.session.levelState[currentLevel].sentEmails, + ...levelResult.chatResponse.sentEmails, + ]; const updatedChatResponse: ChatHttpResponse = { ...initChatResponse, @@ -291,7 +294,6 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) { }; if (updatedChatResponse.defenceReport.isBlocked) { - // chatReponse.reply is empty if blocked updatedChatHistory = pushMessageToHistory(updatedChatHistory, { chatMessageType: 'BOT_BLOCKED', infoMessage: @@ -326,7 +328,6 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) { }); } - // update state req.session.levelState[currentLevel].chatHistory = updatedChatHistory; req.session.levelState[currentLevel].sentEmails = totalSentEmails; @@ -376,7 +377,7 @@ function handleAddInfoToChatHistory( if ( infoMessage && chatMessageType && - chatInfoMessageType.includes(chatMessageType) && + chatInfoMessageTypes.includes(chatMessageType) && level !== undefined && level >= LEVEL_NAMES.LEVEL_1 ) { diff --git a/backend/src/defence.ts b/backend/src/defence.ts index a381470f3..f067c5371 100644 --- a/backend/src/defence.ts +++ b/backend/src/defence.ts @@ -1,7 +1,7 @@ import { defaultDefences } from './defaultDefences'; -import { queryPromptEvaluationModel } from './langchain'; +import { evaluatePrompt } from './langchain'; import { - ChatDefenceReport, + DefenceReport, MessageTransformation, SingleDefenceReport, TransformedChatMessage, @@ -20,14 +20,12 @@ import { } from './promptTemplates'; function activateDefence(id: DEFENCE_ID, defences: Defence[]) { - // return the updated list of defences return defences.map((defence) => defence.id === id ? { ...defence, isActive: true } : defence ); } function deactivateDefence(id: DEFENCE_ID, defences: Defence[]) { - // return the updated list of defences return defences.map((defence) => defence.id === id ? { ...defence, isActive: false } : defence ); @@ -38,7 +36,6 @@ function configureDefence( defences: Defence[], config: DefenceConfigItem[] ): Defence[] { - // return the updated list of defences return defences.map((defence) => defence.id === id ? { ...defence, config } : defence ); @@ -95,7 +92,6 @@ function getFilterList(defences: Defence[], type: DEFENCE_ID) { } function getSystemRole( defences: Defence[], - // by default, use sandbox currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX ) { switch (currentLevel) { @@ -183,14 +179,12 @@ function escapeXml(unsafe: string) { }); } -// function to detect any XML tags in user input function containsXMLTags(input: string) { const tagRegex = /<\/?[a-zA-Z][\w-]*(?:\b[^>]*\/\s*|[^>]*>|[?]>)/g; const foundTags: string[] = input.match(tagRegex) ?? []; return foundTags.length > 0; } -// apply XML tagging defence to input message function transformXmlTagging( message: string, defences: Defence[] @@ -213,7 +207,6 @@ function generateRandomString(length: number) { ).join(''); } -// apply random sequence enclosure defence to input message function transformRandomSequenceEnclosure( message: string, defences: Defence[] @@ -250,7 +243,6 @@ function combineTransformedMessage(transformedMessage: TransformedChatMessage) { ); } -//apply defence string transformations to original message function transformMessage( message: string, defences: Defence[] @@ -284,7 +276,6 @@ function transformMessage( }; } -// detects triggered defences in original message and blocks the message if necessary async function detectTriggeredInputDefences( message: string, defences: Defence[] @@ -299,7 +290,6 @@ async function detectTriggeredInputDefences( return combineDefenceReports(singleDefenceReports); } -// detects triggered defences in bot output and blocks the message if necessary function detectTriggeredOutputDefences(message: string, defences: Defence[]) { const singleDefenceReports = [detectFilterBotOutput(message, defences)]; return combineDefenceReports(singleDefenceReports); @@ -307,7 +297,7 @@ function detectTriggeredOutputDefences(message: string, defences: Defence[]) { function combineDefenceReports( defenceReports: SingleDefenceReport[] -): ChatDefenceReport { +): DefenceReport { const isBlocked = defenceReports.some((report) => report.blockedReason); const blockedReason = isBlocked ? defenceReports @@ -451,15 +441,16 @@ async function detectEvaluationLLM( ): Promise { const defence = DEFENCE_ID.PROMPT_EVALUATION_LLM; // to save money and processing time, and to reduce risk of rate limiting, we only run if defence is active + // this means that, contrary to the other defences, the user won't get alerts when the defence is not active, i.e. "your last prompt would have been blocked by the prompt evaluation LLM" if (isDefenceActive(DEFENCE_ID.PROMPT_EVALUATION_LLM, defences)) { const promptEvalLLMPrompt = getPromptEvalPromptFromConfig(defences); - const evaluationResult = await queryPromptEvaluationModel( + const promptIsMalicious = await evaluatePrompt( message, promptEvalLLMPrompt ); - if (evaluationResult.isMalicious) { + if (promptIsMalicious) { console.debug('LLM evaluation defence active and prompt is malicious.'); return { diff --git a/backend/src/document.ts b/backend/src/document.ts index 9b18dffab..4ac0b04d7 100644 --- a/backend/src/document.ts +++ b/backend/src/document.ts @@ -103,12 +103,11 @@ async function initDocumentVectors() { ); // embed and store the splits - will use env variable for API key - const embeddings = new OpenAIEmbeddings(); const docVector = await MemoryVectorStore.fromDocuments( commonAndLevelDocuments, - embeddings + new OpenAIEmbeddings() ); - // store the document vectors for the level + docVectors.push({ level, docVector, diff --git a/backend/src/langchain.ts b/backend/src/langchain.ts index 1803732e5..8341e07dc 100644 --- a/backend/src/langchain.ts +++ b/backend/src/langchain.ts @@ -4,7 +4,7 @@ import { OpenAI } from 'langchain/llms/openai'; import { PromptTemplate } from 'langchain/prompts'; import { getDocumentVectors } from './document'; -import { CHAT_MODELS, ChatAnswer } from './models/chat'; +import { CHAT_MODELS } from './models/chat'; import { PromptEvaluationChainReply, QaChainReply } from './models/langchain'; import { LEVEL_NAMES } from './models/level'; import { getOpenAIKey, getValidOpenAIModelsList } from './openai'; @@ -23,7 +23,6 @@ function makePromptTemplate( templateNameForLogging: string ): PromptTemplate { if (!configPrompt) { - // use the default Prompt configPrompt = defaultPrompt; } const fullPrompt = `${configPrompt}\n${mainPrompt}`; @@ -40,10 +39,8 @@ function getChatModel() { function initQAModel(level: LEVEL_NAMES, Prompt: string) { const openAIApiKey = getOpenAIKey(); const documentVectors = getDocumentVectors()[level].docVector; - // use gpt-4 if avaliable to apiKey const modelName = getChatModel(); - // initialise model const model = new ChatOpenAI({ modelName, streaming: true, @@ -63,7 +60,6 @@ function initQAModel(level: LEVEL_NAMES, Prompt: string) { function initPromptEvaluationModel(configPromptEvaluationPrompt: string) { const openAIApiKey = getOpenAIKey(); - // use gpt-4 if avaliable to apiKey const modelName = getChatModel(); const promptEvalTemplate = makePromptTemplate( @@ -79,87 +75,75 @@ function initPromptEvaluationModel(configPromptEvaluationPrompt: string) { openAIApiKey, }); - const chain = new LLMChain({ + console.debug(`Prompt evaluation model initialised with model: ${modelName}`); + + return new LLMChain({ llm, prompt: promptEvalTemplate, outputKey: 'promptEvalOutput', }); - - console.debug(`Prompt evaluation model initialised with model: ${modelName}`); - return chain; } -// ask the question and return models answer async function queryDocuments( question: string, Prompt: string, currentLevel: LEVEL_NAMES -) { +): Promise { try { const qaChain = initQAModel(currentLevel, Prompt); - // get start time const startTime = Date.now(); console.debug('Calling QA model...'); const response = (await qaChain.call({ query: question, })) as QaChainReply; - // log the time taken - console.debug(`QA model call took ${Date.now() - startTime}ms`); + console.debug(`QA model call took ${Date.now() - startTime}ms`); console.debug(`QA model response: ${response.text}`); - const result: ChatAnswer = { - reply: response.text, - questionAnswered: true, - }; - return result; + + return response.text; } catch (error) { console.error('Error calling QA model: ', error); - return { - reply: 'I cannot answer that question right now.', - questionAnswered: false, - }; + return 'I cannot answer that question right now.'; } } -// ask LLM whether the prompt is malicious -async function queryPromptEvaluationModel( - input: string, - promptEvalPrompt: string -) { +async function evaluatePrompt(input: string, promptEvalPrompt: string) { try { console.debug(`Checking '${input}' for malicious prompts`); const promptEvaluationChain = initPromptEvaluationModel(promptEvalPrompt); - // get start time const startTime = Date.now(); console.debug('Calling prompt evaluation model...'); + const response = (await promptEvaluationChain.call({ prompt: input, })) as PromptEvaluationChainReply; - // log the time taken + console.debug( `Prompt evaluation model call took ${Date.now() - startTime}ms` ); - const promptEvaluation = formatEvaluationOutput(response.promptEvalOutput); + const promptEvaluation = interpretEvaluationOutput( + response.promptEvalOutput + ); console.debug(`Prompt evaluation: ${JSON.stringify(promptEvaluation)}`); return promptEvaluation; } catch (error) { console.error('Error calling prompt evaluation model: ', error); - return { isMalicious: false }; + return false; } } -function formatEvaluationOutput(response: string) { +function interpretEvaluationOutput(response: string) { // remove all non-alphanumeric characters const cleanResponse = response.replace(/\W/g, '').toLowerCase(); if (cleanResponse === 'yes' || cleanResponse === 'no') { - return { isMalicious: cleanResponse === 'yes' }; + return cleanResponse === 'yes'; } else { console.debug( `Did not get a valid response from the prompt evaluation model. Original response: ${response}` ); - return { isMalicious: false }; + return false; } } -export { queryDocuments, queryPromptEvaluationModel }; +export { queryDocuments, evaluatePrompt }; diff --git a/backend/src/models/chat.ts b/backend/src/models/chat.ts index 155b3b38d..9b9e7d321 100644 --- a/backend/src/models/chat.ts +++ b/backend/src/models/chat.ts @@ -36,7 +36,7 @@ interface ChatModelConfiguration { presencePenalty: number; } -interface ChatDefenceReport { +interface DefenceReport { blockedReason: string | null; isBlocked: boolean; alertedDefences: DEFENCE_ID[]; @@ -61,11 +61,6 @@ interface ToolCallResponse { chatHistory: ChatMessage[]; } -interface ChatAnswer { - reply: string; - questionAnswered: boolean; -} - interface ChatMalicious { isMalicious: boolean; reason: string; @@ -98,7 +93,7 @@ interface MessageTransformation { interface ChatHttpResponse { reply: string; - defenceReport: ChatDefenceReport; + defenceReport: DefenceReport; transformedMessage?: TransformedChatMessage; wonLevel: boolean; isError: boolean; @@ -123,8 +118,7 @@ const defaultChatModel: ChatModel = { }; export type { - ChatAnswer, - ChatDefenceReport, + DefenceReport, ChatGptReply, ChatMalicious, ChatModel, diff --git a/backend/src/models/chatMessage.ts b/backend/src/models/chatMessage.ts index 940983c4c..97962bb2e 100644 --- a/backend/src/models/chatMessage.ts +++ b/backend/src/models/chatMessage.ts @@ -67,4 +67,4 @@ export type { CHAT_INFO_MESSAGE_TYPES, }; -export { chatInfoMessageTypes as chatInfoMessageType }; +export { chatInfoMessageTypes }; diff --git a/backend/src/openai.ts b/backend/src/openai.ts index 219cadee5..553dd0883 100644 --- a/backend/src/openai.ts +++ b/backend/src/openai.ts @@ -113,10 +113,9 @@ const getOpenAIKey = (() => { */ async function getValidModelsFromOpenAI() { try { - const openAI = getOpenAI(); - const models: OpenAI.ModelsPage = await openAI.models.list(); + const models: OpenAI.ModelsPage = await getOpenAI().models.list(); - // get the model ids that are supported by our app + // get the model ids that are supported by our app. Non-chat models like Dall-e and whisper are not supported. const validModels = models.data .map((model) => model.id) .filter((id) => Object.values(CHAT_MODELS).includes(id as CHAT_MODELS)) @@ -132,8 +131,7 @@ async function getValidModelsFromOpenAI() { } function getOpenAI() { - const apiKey = getOpenAIKey(); - return new OpenAI({ apiKey }); + return new OpenAI({ apiKey: getOpenAIKey() }); } function isChatGptFunction(functionName: string) { @@ -152,14 +150,13 @@ async function handleAskQuestionFunction( const configQAPrompt = isDefenceActive(DEFENCE_ID.QA_LLM, defences) ? getQAPromptFromConfig(defences) : ''; - return { - reply: ( - await queryDocuments(params.question, configQAPrompt, currentLevel) - ).reply, - }; + return await queryDocuments(params.question, configQAPrompt, currentLevel); } else { - console.error('No arguments provided to askQuestion function'); - return { reply: "Reply with 'I don't know what to ask'" }; + console.error( + 'Incorrect arguments provided to askQuestion function:', + functionCallArgs + ); + return "Reply with 'I don't know what to ask'"; } } @@ -220,12 +217,11 @@ async function chatGptCallFunction( sentEmails.push(...emailFunctionOutput.sentEmails); } } else if (functionName === 'askQuestion') { - const askQuestionFunctionOutput = await handleAskQuestionFunction( + functionReply = await handleAskQuestionFunction( functionCall.arguments, currentLevel, defences ); - functionReply = askQuestionFunctionOutput.reply; } } else { console.error(`Unknown function: ${functionName}`); @@ -245,24 +241,26 @@ async function chatGptCallFunction( async function chatGptChatCompletion( chatHistory: ChatMessage[], chatModel: ChatModel, - openai: OpenAI + openAI: OpenAI ) { const updatedChatHistory = [...chatHistory]; console.debug('Talking to model: ', JSON.stringify(chatModel)); - // get start time const startTime = new Date().getTime(); console.debug('Calling OpenAI chat completion...'); try { - const chat_completion = await openai.chat.completions.create({ + const chat_completion = await openAI.chat.completions.create({ model: chatModel.id, temperature: chatModel.configuration.temperature, top_p: chatModel.configuration.topP, frequency_penalty: chatModel.configuration.frequencyPenalty, presence_penalty: chatModel.configuration.presencePenalty, - messages: getChatCompletionsFromHistory(updatedChatHistory, chatModel.id), + messages: getChatCompletionsInContextWindow( + updatedChatHistory, + chatModel.id + ), tools: chatGptTools, }); console.debug( @@ -293,26 +291,23 @@ async function chatGptChatCompletion( } } -function getChatCompletionsFromHistory( +function getChatCompletionsInContextWindow( chatHistory: ChatMessage[], gptModel: CHAT_MODELS ): ChatCompletionMessageParam[] { - // take only completions to send to model - const completions = chatHistory.reduce( - (result, chatMessage) => { - if ('completion' in chatMessage && chatMessage.completion) { - result.push(chatMessage.completion); - } - return result; - }, - [] - ); + const completions = chatHistory + .map((chatMessage) => + 'completion' in chatMessage ? chatMessage.completion : null + ) + .filter( + (completion) => completion !== null + ) as ChatCompletionMessageParam[]; console.debug( 'Number of tokens in total chat history. prompt_tokens=', countTotalPromptTokens(completions) ); - // limit the number of tokens sent to GPT to fit inside context window + const maxTokens = chatModelMaxTokens[gptModel] * 0.95; // 95% of max tokens to allow for response tokens const reducedCompletions = filterChatHistoryByMaxTokens( completions, @@ -345,7 +340,7 @@ async function performToolCalls( currentLevel ); - // return after getting function reply. may change when we support other tool types. We assume only one function call in toolCalls + // We assume only one function call in toolCalls, and so we return after getting function reply. return { functionCallReply, chatHistory: pushMessageToHistory(chatHistory, { @@ -371,20 +366,17 @@ async function getFinalReplyAfterAllToolCalls( const sentEmails = []; let wonLevel = false; - const openai = getOpenAI(); let gptReply: ChatGptReply | null = null; - + const openAI = getOpenAI(); do { gptReply = await chatGptChatCompletion( updatedChatHistory, chatModel, - openai + openAI ); updatedChatHistory = gptReply.chatHistory; - // check if GPT wanted to call a tool if (gptReply.completion?.tool_calls) { - // push the function call to the chat updatedChatHistory = pushMessageToHistory(updatedChatHistory, { completion: gptReply.completion, chatMessageType: 'FUNCTION_CALL', @@ -418,10 +410,10 @@ async function chatGptSendMessage( chatHistory: ChatMessage[], defences: Defence[], chatModel: ChatModel, - message: string, currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX ) { - console.log(`User message: '${message}'`); + // this method just calls getFinalReplyAfterAllToolCalls then reformats the output. Does it need to exist? + const finalToolCallResponse = await getFinalReplyAfterAllToolCalls( chatHistory, defences, @@ -429,23 +421,21 @@ async function chatGptSendMessage( currentLevel ); - const updatedChatHistory = finalToolCallResponse.chatHistory; - const sentEmails = finalToolCallResponse.sentEmails; - const chatResponse: ChatResponse = { completion: finalToolCallResponse.gptReply.completion, wonLevel: finalToolCallResponse.wonLevel, openAIErrorMessage: finalToolCallResponse.gptReply.openAIErrorMessage, }; - if (!chatResponse.completion?.content || chatResponse.openAIErrorMessage) { - return { chatResponse, chatHistory, sentEmails }; - } + const successfulReply = + chatResponse.completion?.content && !chatResponse.openAIErrorMessage; return { chatResponse, - chatHistory: updatedChatHistory, - sentEmails, + chatHistory: successfulReply + ? finalToolCallResponse.chatHistory + : chatHistory, + sentEmails: finalToolCallResponse.sentEmails, }; } diff --git a/backend/src/utils/token.ts b/backend/src/utils/token.ts index 99b54026c..d431bd6c8 100644 --- a/backend/src/utils/token.ts +++ b/backend/src/utils/token.ts @@ -7,6 +7,7 @@ import { promptTokensEstimate, stringTokens } from 'openai-chat-tokens'; import { CHAT_MODELS } from '@src/models/chat'; import { chatGptTools } from '@src/openai'; +// The size of each model's context window in number of tokens. https://platform.openai.com/docs/models const chatModelMaxTokens = { [CHAT_MODELS.GPT_4_TURBO]: 128000, [CHAT_MODELS.GPT_4]: 8192, diff --git a/backend/test/integration/openai.test.ts b/backend/test/integration/openai.test.ts index 2f4072d35..be1d5e5cf 100644 --- a/backend/test/integration/openai.test.ts +++ b/backend/test/integration/openai.test.ts @@ -19,17 +19,14 @@ jest.mock('openai', () => ({ })), })); -// mock the queryPromptEvaluationModel function +// mock the evaluatePrompt function jest.mock('@src/langchain', () => { const originalModule = jest.requireActual('@src/langchain'); return { ...originalModule, - queryPromptEvaluationModel: () => { - return { - isMalicious: false, - reason: '', - }; + evaluatePrompt: () => { + return false; }, }; }); @@ -49,8 +46,15 @@ function chatResponseAssistant(content: string) { describe('OpenAI Integration Tests', () => { test('GIVEN OpenAI initialised WHEN sending message THEN reply is returned', async () => { - const message = 'Hello'; - const initChatHistory: ChatMessage[] = []; + const chatHistoryWithMessage: ChatMessage[] = [ + { + chatMessageType: 'USER', + completion: { + role: 'user', + content: 'Hi', + }, + }, + ]; const defences: Defence[] = defaultDefences; const chatModel: ChatModel = { id: CHAT_MODELS.GPT_4, @@ -65,10 +69,9 @@ describe('OpenAI Integration Tests', () => { mockCreateChatCompletion.mockResolvedValueOnce(chatResponseAssistant('Hi')); const reply = await chatGptSendMessage( - initChatHistory, + chatHistoryWithMessage, defences, - chatModel, - message + chatModel ); expect(reply).toBeDefined(); diff --git a/backend/test/unit/controller/chatController.test.ts b/backend/test/unit/controller/chatController.test.ts index 6416ea713..57cb24384 100644 --- a/backend/test/unit/controller/chatController.test.ts +++ b/backend/test/unit/controller/chatController.test.ts @@ -13,7 +13,7 @@ import { OpenAiChatRequest } from '@src/models/api/OpenAiChatRequest'; import { OpenAiClearRequest } from '@src/models/api/OpenAiClearRequest'; import { OpenAiGetHistoryRequest } from '@src/models/api/OpenAiGetHistoryRequest'; import { - ChatDefenceReport, + DefenceReport, ChatModel, ChatResponse, MessageTransformation, @@ -232,7 +232,7 @@ describe('handleChatToGPT unit tests', () => { function triggeredDefencesMockReturn( blockedReason: string, triggeredDefence: DEFENCE_ID - ): Promise { + ): Promise { return new Promise((resolve, reject) => { try { resolve({ @@ -240,7 +240,7 @@ describe('handleChatToGPT unit tests', () => { isBlocked: true, alertedDefences: [], triggeredDefences: [triggeredDefence], - } as ChatDefenceReport); + } as DefenceReport); } catch (err) { reject(err); } @@ -516,7 +516,6 @@ describe('handleChatToGPT unit tests', () => { [...existingHistory, newUserChatMessage], [], mockChatModel, - 'What is the answer to life the universe and everything?', LEVEL_NAMES.LEVEL_1 ); @@ -602,7 +601,7 @@ describe('handleChatToGPT unit tests', () => { isBlocked: false, alertedDefences: [], triggeredDefences: [], - } as ChatDefenceReport); + } as DefenceReport); await handleChatToGPT(req, res); @@ -610,7 +609,6 @@ describe('handleChatToGPT unit tests', () => { [...existingHistory, newUserChatMessage], [], mockChatModel, - 'send an email to bob@example.com saying hi', LEVEL_NAMES.SANDBOX ); @@ -703,7 +701,7 @@ describe('handleChatToGPT unit tests', () => { isBlocked: false, alertedDefences: [], triggeredDefences: [], - } as ChatDefenceReport); + } as DefenceReport); await handleChatToGPT(req, res); @@ -711,7 +709,6 @@ describe('handleChatToGPT unit tests', () => { [...existingHistory, ...newTransformationChatMessages], [], mockChatModel, - '[pre message] hello bot [post message]', LEVEL_NAMES.SANDBOX ); diff --git a/backend/test/unit/defence.ts/defence.test.ts b/backend/test/unit/defence.ts/defence.test.ts index 9e01c78e4..b50abfb62 100644 --- a/backend/test/unit/defence.ts/defence.test.ts +++ b/backend/test/unit/defence.ts/defence.test.ts @@ -27,9 +27,7 @@ import { jest.mock('@src/langchain'); beforeEach(() => { - jest - .mocked(langchain.queryPromptEvaluationModel) - .mockResolvedValue({ isMalicious: false }); + jest.mocked(langchain.evaluatePrompt).mockResolvedValue(false); }); const botOutputFilterTriggeredResponse = @@ -291,7 +289,7 @@ test('GIVEN the prompt evaluation LLM prompt has not been configured WHEN detect ); await detectTriggeredInputDefences(message, defences); - expect(langchain.queryPromptEvaluationModel).toHaveBeenCalledWith( + expect(langchain.evaluatePrompt).toHaveBeenCalledWith( message, promptEvalPrompt ); @@ -312,7 +310,7 @@ test('GIVEN the prompt evaluation LLM prompt has been configured WHEN detecting ); await detectTriggeredInputDefences(message, defences); - expect(langchain.queryPromptEvaluationModel).toHaveBeenCalledWith( + expect(langchain.evaluatePrompt).toHaveBeenCalledWith( message, newPromptEvalPrompt ); diff --git a/backend/test/unit/langchain.ts/formatPromptEvaluation.test.ts b/backend/test/unit/langchain.ts/formatPromptEvaluation.test.ts deleted file mode 100644 index 10da8ed0c..000000000 --- a/backend/test/unit/langchain.ts/formatPromptEvaluation.test.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { test, expect, jest } from '@jest/globals'; - -import { queryPromptEvaluationModel } from '@src/langchain'; - -const mockPromptEvalChain = { - call: jest.fn<() => Promise<{ promptEvalOutput: string }>>(), -}; - -// mock chains -jest.mock('langchain/chains', () => { - return { - LLMChain: jest.fn().mockImplementation(() => { - return mockPromptEvalChain; - }), - }; -}); - -test('GIVEN prompt evaluation llm responds with a correctly formatted yes decision WHEN we query the llm THEN answers with is malicious', async () => { - mockPromptEvalChain.call.mockResolvedValue({ - promptEvalOutput: 'yes.', - }); - const formattedOutput = await queryPromptEvaluationModel('input', 'prompt'); - - expect(formattedOutput).toEqual({ - isMalicious: true, - }); -}); - -test('GIVEN prompt evaluation llm responds with a correctly formatted no decision WHEN we query the llm THEN answers with is not malicious', async () => { - mockPromptEvalChain.call.mockResolvedValue({ - promptEvalOutput: 'no.', - }); - const formattedOutput = await queryPromptEvaluationModel('input', 'prompt'); - - expect(formattedOutput).toEqual({ - isMalicious: false, - }); -}); - -test('GIVEN prompt evaluation llm responds with an incorrectly formatted decision WHEN we query the llm THEN answers with is not malicious and logs debug message', async () => { - const logSpy = jest.spyOn(console, 'debug'); - - mockPromptEvalChain.call.mockResolvedValue({ - promptEvalOutput: 'Sure is!', - }); - const formattedOutput = await queryPromptEvaluationModel('input', 'prompt'); - - expect(formattedOutput).toEqual({ - isMalicious: false, - }); - expect(logSpy).toHaveBeenCalled(); - logSpy.mockRestore(); -}); diff --git a/backend/test/unit/langchain.ts/initialisePromptEvaluationModel.test.ts b/backend/test/unit/langchain.ts/initialisePromptEvaluationModel.test.ts index 7d0a10696..07fc9f8cd 100644 --- a/backend/test/unit/langchain.ts/initialisePromptEvaluationModel.test.ts +++ b/backend/test/unit/langchain.ts/initialisePromptEvaluationModel.test.ts @@ -2,7 +2,7 @@ import { afterEach, test, jest, expect } from '@jest/globals'; import { OpenAI } from 'langchain/llms/openai'; import { PromptTemplate } from 'langchain/prompts'; -import { queryPromptEvaluationModel } from '@src/langchain'; +import { evaluatePrompt } from '@src/langchain'; import { promptEvalContextTemplate, promptEvalPrompt, @@ -43,7 +43,7 @@ afterEach(() => { }); test('WHEN we query the prompt evaluation model THEN it is initialised', async () => { - await queryPromptEvaluationModel('some input', promptEvalPrompt); + await evaluatePrompt('some input', promptEvalPrompt); expect(mockFromTemplate).toHaveBeenCalledTimes(1); expect(mockFromTemplate).toHaveBeenCalledWith( `${promptEvalPrompt}\n${promptEvalContextTemplate}` @@ -53,11 +53,9 @@ test('WHEN we query the prompt evaluation model THEN it is initialised', async ( test('GIVEN the prompt evaluation model is not initialised WHEN it is asked to evaluate an input it returns not malicious', async () => { mockPromptEvalChain.call.mockResolvedValueOnce({ promptEvalOutput: '' }); - const result = await queryPromptEvaluationModel('message', 'Prompt'); + const result = await evaluatePrompt('message', 'Prompt'); - expect(result).toEqual({ - isMalicious: false, - }); + expect(result).toEqual(false); }); test('GIVEN the users api key supports GPT-4 WHEN the prompt evaluation model is initialised THEN it is initialised with GPT-4', async () => { @@ -65,7 +63,7 @@ test('GIVEN the users api key supports GPT-4 WHEN the prompt evaluation model is const prompt = 'this is a test prompt. '; - await queryPromptEvaluationModel('some input', prompt); + await evaluatePrompt('some input', prompt); expect(OpenAI).toHaveBeenCalledWith({ modelName: 'gpt-4', @@ -79,7 +77,7 @@ test('GIVEN the users api key does not support GPT-4 WHEN the prompt evaluation const prompt = 'this is a test prompt. '; - await queryPromptEvaluationModel('some input', prompt); + await evaluatePrompt('some input', prompt); expect(OpenAI).toHaveBeenCalledWith({ modelName: 'gpt-3.5-turbo', diff --git a/backend/test/unit/langchain.ts/initialiseQAModel.test.ts b/backend/test/unit/langchain.ts/initialiseQAModel.test.ts index 252d41c80..1c27c1b47 100644 --- a/backend/test/unit/langchain.ts/initialiseQAModel.test.ts +++ b/backend/test/unit/langchain.ts/initialiseQAModel.test.ts @@ -117,7 +117,7 @@ test('GIVEN the QA LLM WHEN a question is asked THEN it is initialised AND it an expect(mockFromLLM).toHaveBeenCalledTimes(1); expect(mockRetrievalQAChain.call).toHaveBeenCalledTimes(1); - expect(answer.reply).toEqual('The CEO is Bill.'); + expect(answer).toEqual('The CEO is Bill.'); }); test('GIVEN the users api key supports GPT-4 WHEN the QA model is initialised THEN it is initialised with GPT-4', async () => { diff --git a/frontend/src/components/ChatBox/ChatBox.tsx b/frontend/src/components/ChatBox/ChatBox.tsx index 74febd351..25fc32f13 100644 --- a/frontend/src/components/ChatBox/ChatBox.tsx +++ b/frontend/src/components/ChatBox/ChatBox.tsx @@ -98,9 +98,7 @@ function ChatBox({ message: response.reply, type: 'ERROR_MSG', }); - } - // add it to the list of messages - else if (response.defenceReport.isBlocked) { + } else if (response.defenceReport.isBlocked) { addChatMessage({ type: 'BOT_BLOCKED', message: response.defenceReport.blockedReason, @@ -111,7 +109,6 @@ function ChatBox({ message: response.reply, }); } - // add altered defences to the chat response.defenceReport.alertedDefences.forEach((triggeredDefence) => { // get user-friendly defence name const defenceName = ALL_DEFENCES.find((defence) => { diff --git a/frontend/src/components/HandbookOverlay/HandbookOverlay.tsx b/frontend/src/components/HandbookOverlay/HandbookOverlay.tsx index ca99d13ef..5c4f7118b 100644 --- a/frontend/src/components/HandbookOverlay/HandbookOverlay.tsx +++ b/frontend/src/components/HandbookOverlay/HandbookOverlay.tsx @@ -5,10 +5,10 @@ import useIsOverflow from '@src/hooks/useIsOverflow'; import { HANDBOOK_PAGES } from '@src/models/handbook'; import { LEVEL_NAMES, LevelSystemRole } from '@src/models/level'; -import HandbookAttacks from './HandbookAttacks'; -import HandbookGlossary from './HandbookGlossary'; import HandbookSpine from './HandbookSpine'; -import HandbookSystemRole from './HandbookSystemRole'; +import HandbookAttacks from './Pages/HandbookAttacks'; +import HandbookGlossary from './Pages/HandbookGlossary'; +import HandbookSystemRole from './Pages/HandbookSystemRole'; import './HandbookOverlay.css'; diff --git a/frontend/src/Attacks.ts b/frontend/src/components/HandbookOverlay/Pages/Attacks.ts similarity index 97% rename from frontend/src/Attacks.ts rename to frontend/src/components/HandbookOverlay/Pages/Attacks.ts index 5d9a77c0d..dde97f945 100644 --- a/frontend/src/Attacks.ts +++ b/frontend/src/components/HandbookOverlay/Pages/Attacks.ts @@ -1,4 +1,4 @@ -import { ATTACK_TYPES, AttackInfo } from './models/attack'; +import { ATTACK_TYPES, AttackInfo } from '@src/models/attack'; const ATTACKS_LEVEL_2: AttackInfo[] = [ { diff --git a/frontend/src/Glossary.ts b/frontend/src/components/HandbookOverlay/Pages/Glossary.ts similarity index 100% rename from frontend/src/Glossary.ts rename to frontend/src/components/HandbookOverlay/Pages/Glossary.ts diff --git a/frontend/src/components/HandbookOverlay/HandbookAttacks.test.tsx b/frontend/src/components/HandbookOverlay/Pages/HandbookAttacks.test.tsx similarity index 96% rename from frontend/src/components/HandbookOverlay/HandbookAttacks.test.tsx rename to frontend/src/components/HandbookOverlay/Pages/HandbookAttacks.test.tsx index 2d8e3482e..ebad2aa20 100644 --- a/frontend/src/components/HandbookOverlay/HandbookAttacks.test.tsx +++ b/frontend/src/components/HandbookOverlay/Pages/HandbookAttacks.test.tsx @@ -1,9 +1,9 @@ import { render, screen } from '@testing-library/react'; import { describe, expect, test } from 'vitest'; -import { ATTACKS_LEVEL_2, ATTACKS_LEVEL_3, ATTACKS_ALL } from '@src/Attacks'; import { LEVEL_NAMES } from '@src/models/level'; +import { ATTACKS_LEVEL_2, ATTACKS_LEVEL_3, ATTACKS_ALL } from './Attacks'; import HandbookAttacks from './HandbookAttacks'; describe('HandbookAttacks component tests', () => { diff --git a/frontend/src/components/HandbookOverlay/HandbookAttacks.tsx b/frontend/src/components/HandbookOverlay/Pages/HandbookAttacks.tsx similarity index 94% rename from frontend/src/components/HandbookOverlay/HandbookAttacks.tsx rename to frontend/src/components/HandbookOverlay/Pages/HandbookAttacks.tsx index 1379f5735..72d142a31 100644 --- a/frontend/src/components/HandbookOverlay/HandbookAttacks.tsx +++ b/frontend/src/components/HandbookOverlay/Pages/HandbookAttacks.tsx @@ -1,7 +1,8 @@ -import { ATTACKS_ALL, ATTACKS_LEVEL_2, ATTACKS_LEVEL_3 } from '@src/Attacks'; import { AttackInfo } from '@src/models/attack'; import { LEVEL_NAMES } from '@src/models/level'; +import { ATTACKS_ALL, ATTACKS_LEVEL_2, ATTACKS_LEVEL_3 } from './Attacks'; + import './HandbookPage.css'; function HandbookAttacks({ currentLevel }: { currentLevel: LEVEL_NAMES }) { diff --git a/frontend/src/components/HandbookOverlay/HandbookGlossary.test.tsx b/frontend/src/components/HandbookOverlay/Pages/HandbookGlossary.test.tsx similarity index 98% rename from frontend/src/components/HandbookOverlay/HandbookGlossary.test.tsx rename to frontend/src/components/HandbookOverlay/Pages/HandbookGlossary.test.tsx index 1a67d03b6..d227b7a40 100644 --- a/frontend/src/components/HandbookOverlay/HandbookGlossary.test.tsx +++ b/frontend/src/components/HandbookOverlay/Pages/HandbookGlossary.test.tsx @@ -1,9 +1,9 @@ import { render, screen } from '@testing-library/react'; import { describe, expect, test } from 'vitest'; -import { GLOSSARY } from '@src/Glossary'; import { LEVEL_NAMES } from '@src/models/level'; +import { GLOSSARY } from './Glossary'; import HandbookGlossary from './HandbookGlossary'; describe('HandbookGlossary component tests', () => { diff --git a/frontend/src/components/HandbookOverlay/HandbookGlossary.tsx b/frontend/src/components/HandbookOverlay/Pages/HandbookGlossary.tsx similarity index 96% rename from frontend/src/components/HandbookOverlay/HandbookGlossary.tsx rename to frontend/src/components/HandbookOverlay/Pages/HandbookGlossary.tsx index 23da7c7cb..300f2f14d 100644 --- a/frontend/src/components/HandbookOverlay/HandbookGlossary.tsx +++ b/frontend/src/components/HandbookOverlay/Pages/HandbookGlossary.tsx @@ -1,6 +1,7 @@ -import { GLOSSARY } from '@src/Glossary'; import { LEVEL_NAMES } from '@src/models/level'; +import { GLOSSARY } from './Glossary'; + import './HandbookPage.css'; function HandbookGlossary({ currentLevel }: { currentLevel: LEVEL_NAMES }) { diff --git a/frontend/src/components/HandbookOverlay/HandbookPage.css b/frontend/src/components/HandbookOverlay/Pages/HandbookPage.css similarity index 100% rename from frontend/src/components/HandbookOverlay/HandbookPage.css rename to frontend/src/components/HandbookOverlay/Pages/HandbookPage.css diff --git a/frontend/src/components/HandbookOverlay/HandbookSystemRole.test.tsx b/frontend/src/components/HandbookOverlay/Pages/HandbookSystemRole.test.tsx similarity index 100% rename from frontend/src/components/HandbookOverlay/HandbookSystemRole.test.tsx rename to frontend/src/components/HandbookOverlay/Pages/HandbookSystemRole.test.tsx diff --git a/frontend/src/components/HandbookOverlay/HandbookSystemRole.tsx b/frontend/src/components/HandbookOverlay/Pages/HandbookSystemRole.tsx similarity index 100% rename from frontend/src/components/HandbookOverlay/HandbookSystemRole.tsx rename to frontend/src/components/HandbookOverlay/Pages/HandbookSystemRole.tsx diff --git a/frontend/src/models/chat.ts b/frontend/src/models/chat.ts index 51e3d36d2..4fa18c496 100644 --- a/frontend/src/models/chat.ts +++ b/frontend/src/models/chat.ts @@ -43,7 +43,7 @@ interface CustomChatModelConfiguration { max: number; } -interface ChatDefenceReport { +interface DefenceReport { blockedReason: string; isBlocked: boolean; alertedDefences: DEFENCE_ID[]; @@ -65,7 +65,7 @@ interface TransformedChatMessage { interface ChatResponse { reply: string; - defenceReport: ChatDefenceReport; + defenceReport: DefenceReport; transformedMessage?: TransformedChatMessage; wonLevel: boolean; isError: boolean;