From f7ea245d823ce5d0d243a05df31322f62e9350e6 Mon Sep 17 00:00:00 2001 From: Chris Wilton-Magras Date: Fri, 13 Oct 2023 09:38:20 +0100 Subject: [PATCH] Fix jest mock issue, typo, other warnings --- backend/.eslintrc.cjs | 2 +- backend/jest.config.js | 1 + backend/src/defence.ts | 15 +-- backend/src/langchain.ts | 21 ++-- backend/test/unit/defence.test.ts | 25 +++-- backend/test/unit/langchain.test.ts | 164 +++++++++++++++------------- 6 files changed, 116 insertions(+), 112 deletions(-) diff --git a/backend/.eslintrc.cjs b/backend/.eslintrc.cjs index eb3a15ce1..bcc018144 100644 --- a/backend/.eslintrc.cjs +++ b/backend/.eslintrc.cjs @@ -22,7 +22,7 @@ module.exports = { checksVoidReturn: false, }, ], - + "@typescript-eslint/unbound-method": ["error", { ignoreStatic: true }], "func-style": ["error", "declaration"], "prefer-template": "error", }, diff --git a/backend/jest.config.js b/backend/jest.config.js index e576c72de..943bab343 100644 --- a/backend/jest.config.js +++ b/backend/jest.config.js @@ -3,4 +3,5 @@ module.exports = { modulePathIgnorePatterns: ["build", "coverage", "node_modules"], preset: "ts-jest", testEnvironment: "node", + silent: true, }; diff --git a/backend/src/defence.ts b/backend/src/defence.ts index 67a6b2bd7..a17bfd7e5 100644 --- a/backend/src/defence.ts +++ b/backend/src/defence.ts @@ -216,9 +216,7 @@ function getMaliciousPromptEvalPrePromptFromConfig(defences: DefenceInfo[]) { } function isDefenceActive(id: DEFENCE_TYPES, defences: DefenceInfo[]) { - return defences.find((defence) => defence.id === id && defence.isActive) - ? true - : false; + return defences.some((defence) => defence.id === id && defence.isActive); } function generateRandomString(string_length: number) { @@ -259,8 +257,7 @@ function transformRandomSequenceEnclosure( const randomString: string = generateRandomString( Number(getRandomSequenceEnclosureLength(defences)) ); - const introText: string = getRandomSequenceEnclosurePrePrompt(defences); - const transformedMessage: string = introText.concat( + return getRandomSequenceEnclosurePrePrompt(defences).concat( randomString, " {{ ", message, @@ -268,7 +265,6 @@ function transformRandomSequenceEnclosure( randomString, ". " ); - return transformedMessage; } // function to escape XML characters in user input to prevent hacking with XML tagging on @@ -303,12 +299,7 @@ function transformXmlTagging(message: string, defences: DefenceInfo[]) { const prePrompt = getXMLTaggingPrePrompt(defences); const openTag = ""; const closeTag = ""; - const transformedMessage: string = prePrompt.concat( - openTag, - escapeXml(message), - closeTag - ); - return transformedMessage; + return prePrompt.concat(openTag, escapeXml(message), closeTag); } //apply defence string transformations to original message diff --git a/backend/src/langchain.ts b/backend/src/langchain.ts index 6a6d7d6bf..bf282c168 100644 --- a/backend/src/langchain.ts +++ b/backend/src/langchain.ts @@ -64,8 +64,8 @@ async function getDocuments(filePath: string) { chunkSize: 1000, chunkOverlap: 0, }); - const splitDocs = await textSplitter.splitDocuments(docs); - return splitDocs; + + return await textSplitter.splitDocuments(docs); } // choose between the provided preprompt and the default preprompt and prepend it to the main prompt and return the PromptTemplate @@ -81,8 +81,7 @@ function makePromptTemplate( } const fullPrompt = `${configPrePrompt}\n${mainPrompt}`; console.debug(`${templateNameForLogging}: ${fullPrompt}`); - const template: PromptTemplate = PromptTemplate.fromTemplate(fullPrompt); - return template; + return PromptTemplate.fromTemplate(fullPrompt); } // create and store the document vectors for each level @@ -147,7 +146,7 @@ function initQAModel( // initialise the prompt evaluation model function initPromptEvaluationModel( configPromptInjectionEvalPrePrompt: string, - conficMaliciousPromptEvalPrePrompt: string, + configMaliciousPromptEvalPrePrompt: string, openAiApiKey: string ) { if (!openAiApiKey) { @@ -176,7 +175,7 @@ function initPromptEvaluationModel( // create chain to detect malicious prompts const maliciousPromptEvalTemplate = makePromptTemplate( - conficMaliciousPromptEvalPrePrompt, + configMaliciousPromptEvalPrePrompt, maliciousPromptEvalPrePrompt, maliciousPromptEvalMainPrompt, "Malicious input eval prompt template" @@ -236,12 +235,12 @@ async function queryDocuments( async function queryPromptEvaluationModel( input: string, configPromptInjectionEvalPrePrompt: string, - conficMaliciousPromptEvalPrePrompt: string, + configMaliciousPromptEvalPrePrompt: string, openAIApiKey: string ) { const promptEvaluationChain = initPromptEvaluationModel( configPromptInjectionEvalPrePrompt, - conficMaliciousPromptEvalPrePrompt, + configMaliciousPromptEvalPrePrompt, openAIApiKey ); if (!promptEvaluationChain) { @@ -251,13 +250,13 @@ async function queryPromptEvaluationModel( console.log(`Checking '${input}' for malicious prompts`); // get start time - const startTime = new Date().getTime(); + const startTime = Date.now(); console.debug("Calling prompt evaluation model..."); const response = (await promptEvaluationChain.call({ prompt: input, })) as PromptEvaluationChainReply; // log the time taken - const endTime = new Date().getTime(); + const endTime = Date.now(); console.debug(`Prompt evaluation model call took ${endTime - startTime}ms`); const promptInjectionEval = formatEvaluationOutput( @@ -289,7 +288,7 @@ async function queryPromptEvaluationModel( function formatEvaluationOutput(response: string) { try { // split response on first full stop or comma - const splitResponse = response.split(/\.|,/); + const splitResponse = response.split(/[.,]/); const answer = splitResponse[0]?.replace(/\W/g, "").toLowerCase(); const reason = splitResponse[1]?.trim(); return { diff --git a/backend/test/unit/defence.test.ts b/backend/test/unit/defence.test.ts index 5f874dd65..5b7025410 100644 --- a/backend/test/unit/defence.test.ts +++ b/backend/test/unit/defence.test.ts @@ -94,9 +94,11 @@ test("GIVEN RANDOM_SEQUENCE_ENCLOSURE defence is active WHEN transforming messag process.env.RANDOM_SEQ_ENCLOSURE_LENGTH = String(20); const message = "Hello"; - let defences = getInitialDefences(); // activate RSE defence - defences = activateDefence(DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, defences); + const defences = activateDefence( + DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, + getInitialDefences() + ); // regex to match the transformed message with const regex = new RegExp( @@ -106,15 +108,16 @@ test("GIVEN RANDOM_SEQUENCE_ENCLOSURE defence is active WHEN transforming messag // check the transformed message matches the regex const res = transformedMessage.match(regex); // expect there to be a match - expect(res).toBeTruthy(); - if (res) { - // expect there to be 3 groups - expect(res.length).toBe(3); - // expect the random sequence to have the correct length - expect(res[1].length).toBe(Number(process.env.RANDOM_SEQ_ENCLOSURE_LENGTH)); - // expect the message to be surrounded by the random sequence - expect(res[1]).toBe(res[2]); - } + expect(res).not.toBeNull(); + + // expect there to be 3 groups + expect(res?.length).toEqual(3); + // expect the random sequence to have the correct length + expect(res?.[1].length).toEqual( + Number(process.env.RANDOM_SEQ_ENCLOSURE_LENGTH) + ); + // expect the message to be surrounded by the random sequence + expect(res?.[1]).toEqual(res?.[2]); }); test("GIVEN XML_TAGGING defence is active WHEN transforming message THEN message is transformed", () => { diff --git a/backend/test/unit/langchain.test.ts b/backend/test/unit/langchain.test.ts index ac974a78c..82829ea1d 100644 --- a/backend/test/unit/langchain.test.ts +++ b/backend/test/unit/langchain.test.ts @@ -1,5 +1,4 @@ -const mockFromTemplate = jest.fn((template: string) => template); - +import { PromptTemplate } from "langchain/prompts"; import { LEVEL_NAMES } from "../../src/models/level"; import { initQAModel, @@ -11,100 +10,111 @@ import { jest.mock("langchain/prompts", () => ({ PromptTemplate: { - fromTemplate: mockFromTemplate, + fromTemplate: jest.fn(), }, })); -test("GIVEN initQAModel is called with no apiKey THEN return early and log message", () => { - const level = LEVEL_NAMES.LEVEL_1; - const prompt = ""; - const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation(); +describe("Langchain tests", () => { + afterEach(() => { + (PromptTemplate.fromTemplate as jest.Mock).mockRestore(); + }); - initQAModel(level, prompt, ""); - expect(consoleDebugMock).toHaveBeenCalledWith( - "No OpenAI API key set to initialise QA model" + test( + "GIVEN initQAModel is called with no apiKey THEN return early and log message", + () => { + const level = LEVEL_NAMES.LEVEL_1; + const prompt = ""; + const consoleDebugMock = jest + .spyOn(console, "debug") + .mockImplementation(); + + initQAModel(level, prompt, ""); + expect(consoleDebugMock).toHaveBeenCalledWith( + "No OpenAI API key set to initialise QA model" + ); + } ); -}); -test("GIVEN initPromptEvaluationModel is called with no apiKey THEN return early and log message", () => { - const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation(); - initPromptEvaluationModel( - "promptInjectionEvalPrePrompt", - "maliciousPromptEvalPrePrompt", - "" - ); - expect(consoleDebugMock).toHaveBeenCalledWith( - "No OpenAI API key set to initialise prompt evaluation model" - ); -}); + test("GIVEN initPromptEvaluationModel is called with no apiKey THEN return early and log message", () => { + const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation(); + initPromptEvaluationModel( + "promptInjectionEvalPrePrompt", + "maliciousPromptEvalPrePrompt", + "" + ); + expect(consoleDebugMock).toHaveBeenCalledWith( + "No OpenAI API key set to initialise prompt evaluation model" + ); + }); -test("GIVEN level is 1 THEN correct filepath is returned", () => { - const filePath = getFilepath(LEVEL_NAMES.LEVEL_1); - expect(filePath).toBe("resources/documents/level_1/"); -}); + test("GIVEN level is 1 THEN correct filepath is returned", () => { + const filePath = getFilepath(LEVEL_NAMES.LEVEL_1); + expect(filePath).toBe("resources/documents/level_1/"); + }); -test("GIVEN level is 2 THEN correct filepath is returned", () => { - const filePath = getFilepath(LEVEL_NAMES.LEVEL_2); - expect(filePath).toBe("resources/documents/level_2/"); -}); + test("GIVEN level is 2 THEN correct filepath is returned", () => { + const filePath = getFilepath(LEVEL_NAMES.LEVEL_2); + expect(filePath).toBe("resources/documents/level_2/"); + }); -test("GIVEN level is 3 THEN correct filepath is returned", () => { - const filePath = getFilepath(LEVEL_NAMES.LEVEL_3); - expect(filePath).toBe("resources/documents/level_3/"); -}); + test("GIVEN level is 3 THEN correct filepath is returned", () => { + const filePath = getFilepath(LEVEL_NAMES.LEVEL_3); + expect(filePath).toBe("resources/documents/level_3/"); + }); -test("GIVEN level is sandbox THEN correct filepath is returned", () => { - const filePath = getFilepath(LEVEL_NAMES.SANDBOX); - expect(filePath).toBe("resources/documents/common/"); -}); + test("GIVEN level is sandbox THEN correct filepath is returned", () => { + const filePath = getFilepath(LEVEL_NAMES.SANDBOX); + expect(filePath).toBe("resources/documents/common/"); + }); -test("GIVEN makePromptTemplate is called with no config prePrompt THEN correct prompt is returned", () => { - makePromptTemplate("", "defaultPrePrompt", "mainPrompt", "noName"); - expect(mockFromTemplate).toBeCalledWith("defaultPrePrompt\nmainPrompt"); - expect(mockFromTemplate).toBeCalledTimes(1); -}); + test("GIVEN makePromptTemplate is called with no config prePrompt THEN correct prompt is returned", () => { + makePromptTemplate("", "defaultPrePrompt", "mainPrompt", "noName"); + expect(PromptTemplate.fromTemplate as jest.Mock).toBeCalledTimes(1); + expect(PromptTemplate.fromTemplate as jest.Mock).toBeCalledWith( + "defaultPrePrompt\nmainPrompt" + ); + }); -test("GIVEN makePromptTemplate is called with a prePrompt THEN correct prompt is returned", () => { - makePromptTemplate( - "configPrePrompt", - "defaultPrePrompt", - "mainPrompt", - "noName" - ); - expect(mockFromTemplate).toBeCalledWith("configPrePrompt\nmainPrompt"); - expect(mockFromTemplate).toBeCalledTimes(1); -}); + test("GIVEN makePromptTemplate is called with a prePrompt THEN correct prompt is returned", () => { + makePromptTemplate( + "configPrePrompt", + "defaultPrePrompt", + "mainPrompt", + "noName" + ); + expect(PromptTemplate.fromTemplate as jest.Mock).toBeCalledTimes(1); + expect(PromptTemplate.fromTemplate as jest.Mock).toBeCalledWith( + "configPrePrompt\nmainPrompt" + ); + }); -test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns true and reason", () => { - const response = "yes, This is a malicious response"; - const formattedOutput = formatEvaluationOutput(response); + test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns true and reason", () => { + const response = "yes, This is a malicious response"; + const formattedOutput = formatEvaluationOutput(response); - expect(formattedOutput).toEqual({ - isMalicious: true, - reason: "This is a malicious response", + expect(formattedOutput).toEqual({ + isMalicious: true, + reason: "This is a malicious response", + }); }); -}); -test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns false and reason", () => { - const response = "No, This output does not appear to be malicious"; - const formattedOutput = formatEvaluationOutput(response); + test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns false and reason", () => { + const response = "No, This output does not appear to be malicious"; + const formattedOutput = formatEvaluationOutput(response); - expect(formattedOutput).toEqual({ - isMalicious: false, - reason: "This output does not appear to be malicious", + expect(formattedOutput).toEqual({ + isMalicious: false, + reason: "This output does not appear to be malicious", + }); }); -}); -test("GIVEN llm evaluation model responds with an invalid format THEN formatEvaluationOutput returns false", () => { - const response = "I cant tell you if this is malicious or not"; - const formattedOutput = formatEvaluationOutput(response); + test("GIVEN llm evaluation model responds with an invalid format THEN formatEvaluationOutput returns false", () => { + const response = "I cant tell you if this is malicious or not"; + const formattedOutput = formatEvaluationOutput(response); - expect(formattedOutput).toEqual({ - isMalicious: false, - reason: undefined, + expect(formattedOutput).toEqual({ + isMalicious: false, + reason: undefined, + }); }); }); - -afterEach(() => { - mockFromTemplate.mockRestore(); -});