diff --git a/backend/jest.config.js b/backend/jest.config.js index e576c72de..98200bb0b 100644 --- a/backend/jest.config.js +++ b/backend/jest.config.js @@ -3,4 +3,5 @@ module.exports = { modulePathIgnorePatterns: ["build", "coverage", "node_modules"], preset: "ts-jest", testEnvironment: "node", + silent: true }; diff --git a/backend/src/langchain.ts b/backend/src/langchain.ts index 6a6d7d6bf..d847e2c9d 100644 --- a/backend/src/langchain.ts +++ b/backend/src/langchain.ts @@ -64,8 +64,8 @@ async function getDocuments(filePath: string) { chunkSize: 1000, chunkOverlap: 0, }); - const splitDocs = await textSplitter.splitDocuments(docs); - return splitDocs; + + return await textSplitter.splitDocuments(docs); } // choose between the provided preprompt and the default preprompt and prepend it to the main prompt and return the PromptTemplate @@ -81,8 +81,7 @@ function makePromptTemplate( } const fullPrompt = `${configPrePrompt}\n${mainPrompt}`; console.debug(`${templateNameForLogging}: ${fullPrompt}`); - const template: PromptTemplate = PromptTemplate.fromTemplate(fullPrompt); - return template; + return PromptTemplate.fromTemplate(fullPrompt); } // create and store the document vectors for each level @@ -147,7 +146,7 @@ function initQAModel( // initialise the prompt evaluation model function initPromptEvaluationModel( configPromptInjectionEvalPrePrompt: string, - conficMaliciousPromptEvalPrePrompt: string, + configMaliciousPromptEvalPrePrompt: string, openAiApiKey: string ) { if (!openAiApiKey) { @@ -176,7 +175,7 @@ function initPromptEvaluationModel( // create chain to detect malicious prompts const maliciousPromptEvalTemplate = makePromptTemplate( - conficMaliciousPromptEvalPrePrompt, + configMaliciousPromptEvalPrePrompt, maliciousPromptEvalPrePrompt, maliciousPromptEvalMainPrompt, "Malicious input eval prompt template" @@ -289,7 +288,7 @@ async function queryPromptEvaluationModel( function formatEvaluationOutput(response: string) { try { // split response on first full stop or comma - const splitResponse = response.split(/\.|,/); + const splitResponse = response.split(/[.,]/); const answer = splitResponse[0]?.replace(/\W/g, "").toLowerCase(); const reason = splitResponse[1]?.trim(); return { diff --git a/backend/test/unit/defence.test.ts b/backend/test/unit/defence.test.ts index 5f874dd65..7887bb5cb 100644 --- a/backend/test/unit/defence.test.ts +++ b/backend/test/unit/defence.test.ts @@ -94,9 +94,11 @@ test("GIVEN RANDOM_SEQUENCE_ENCLOSURE defence is active WHEN transforming messag process.env.RANDOM_SEQ_ENCLOSURE_LENGTH = String(20); const message = "Hello"; - let defences = getInitialDefences(); // activate RSE defence - defences = activateDefence(DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, defences); + const defences = activateDefence( + DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, + getInitialDefences() + ); // regex to match the transformed message with const regex = new RegExp( @@ -106,15 +108,15 @@ test("GIVEN RANDOM_SEQUENCE_ENCLOSURE defence is active WHEN transforming messag // check the transformed message matches the regex const res = transformedMessage.match(regex); // expect there to be a match - expect(res).toBeTruthy(); - if (res) { - // expect there to be 3 groups - expect(res.length).toBe(3); - // expect the random sequence to have the correct length - expect(res[1].length).toBe(Number(process.env.RANDOM_SEQ_ENCLOSURE_LENGTH)); - // expect the message to be surrounded by the random sequence - expect(res[1]).toBe(res[2]); - } + expect(res).not.toBeNull(); + + // expect there to be 3 groups + expect(res?.length).toEqual(3); + // expect the random sequence to have the correct length + expect(res?.[1].length).toEqual(Number(process.env.RANDOM_SEQ_ENCLOSURE_LENGTH)); + // expect the message to be surrounded by the random sequence + expect(res?.[1]).toEqual(res?.[2]); + }); test("GIVEN XML_TAGGING defence is active WHEN transforming message THEN message is transformed", () => { diff --git a/backend/test/unit/langchain.test.ts b/backend/test/unit/langchain.test.ts index ac974a78c..e0598804e 100644 --- a/backend/test/unit/langchain.test.ts +++ b/backend/test/unit/langchain.test.ts @@ -1,5 +1,4 @@ -const mockFromTemplate = jest.fn((template: string) => template); - +import { PromptTemplate } from "langchain/prompts"; import { LEVEL_NAMES } from "../../src/models/level"; import { initQAModel, @@ -11,100 +10,103 @@ import { jest.mock("langchain/prompts", () => ({ PromptTemplate: { - fromTemplate: mockFromTemplate, + fromTemplate: jest.fn(), }, })); -test("GIVEN initQAModel is called with no apiKey THEN return early and log message", () => { - const level = LEVEL_NAMES.LEVEL_1; - const prompt = ""; - const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation(); - - initQAModel(level, prompt, ""); - expect(consoleDebugMock).toHaveBeenCalledWith( - "No OpenAI API key set to initialise QA model" - ); -}); +describe("Langchain tests", () => { + afterEach(() => { + (PromptTemplate.fromTemplate as jest.Mock).mockRestore(); + }); -test("GIVEN initPromptEvaluationModel is called with no apiKey THEN return early and log message", () => { - const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation(); - initPromptEvaluationModel( - "promptInjectionEvalPrePrompt", - "maliciousPromptEvalPrePrompt", - "" - ); - expect(consoleDebugMock).toHaveBeenCalledWith( - "No OpenAI API key set to initialise prompt evaluation model" - ); -}); + test("GIVEN initQAModel is called with no apiKey TH" + + "EN return early and log message", () => { + const level = LEVEL_NAMES.LEVEL_1; + const prompt = ""; + const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation(); -test("GIVEN level is 1 THEN correct filepath is returned", () => { - const filePath = getFilepath(LEVEL_NAMES.LEVEL_1); - expect(filePath).toBe("resources/documents/level_1/"); -}); + initQAModel(level, prompt, ""); + expect(consoleDebugMock).toHaveBeenCalledWith( + "No OpenAI API key set to initialise QA model" + ); + }); -test("GIVEN level is 2 THEN correct filepath is returned", () => { - const filePath = getFilepath(LEVEL_NAMES.LEVEL_2); - expect(filePath).toBe("resources/documents/level_2/"); -}); + test("GIVEN initPromptEvaluationModel is called with no apiKey THEN return early and log message", () => { + const consoleDebugMock = jest.spyOn(console, "debug").mockImplementation(); + initPromptEvaluationModel( + "promptInjectionEvalPrePrompt", + "maliciousPromptEvalPrePrompt", + "" + ); + expect(consoleDebugMock).toHaveBeenCalledWith( + "No OpenAI API key set to initialise prompt evaluation model" + ); + }); -test("GIVEN level is 3 THEN correct filepath is returned", () => { - const filePath = getFilepath(LEVEL_NAMES.LEVEL_3); - expect(filePath).toBe("resources/documents/level_3/"); -}); + test("GIVEN level is 1 THEN correct filepath is returned", () => { + const filePath = getFilepath(LEVEL_NAMES.LEVEL_1); + expect(filePath).toBe("resources/documents/level_1/"); + }); -test("GIVEN level is sandbox THEN correct filepath is returned", () => { - const filePath = getFilepath(LEVEL_NAMES.SANDBOX); - expect(filePath).toBe("resources/documents/common/"); -}); + test("GIVEN level is 2 THEN correct filepath is returned", () => { + const filePath = getFilepath(LEVEL_NAMES.LEVEL_2); + expect(filePath).toBe("resources/documents/level_2/"); + }); -test("GIVEN makePromptTemplate is called with no config prePrompt THEN correct prompt is returned", () => { - makePromptTemplate("", "defaultPrePrompt", "mainPrompt", "noName"); - expect(mockFromTemplate).toBeCalledWith("defaultPrePrompt\nmainPrompt"); - expect(mockFromTemplate).toBeCalledTimes(1); -}); + test("GIVEN level is 3 THEN correct filepath is returned", () => { + const filePath = getFilepath(LEVEL_NAMES.LEVEL_3); + expect(filePath).toBe("resources/documents/level_3/"); + }); -test("GIVEN makePromptTemplate is called with a prePrompt THEN correct prompt is returned", () => { - makePromptTemplate( - "configPrePrompt", - "defaultPrePrompt", - "mainPrompt", - "noName" - ); - expect(mockFromTemplate).toBeCalledWith("configPrePrompt\nmainPrompt"); - expect(mockFromTemplate).toBeCalledTimes(1); -}); + test("GIVEN level is sandbox THEN correct filepath is returned", () => { + const filePath = getFilepath(LEVEL_NAMES.SANDBOX); + expect(filePath).toBe("resources/documents/common/"); + }); -test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns true and reason", () => { - const response = "yes, This is a malicious response"; - const formattedOutput = formatEvaluationOutput(response); + test("GIVEN makePromptTemplate is called with no config prePrompt THEN correct prompt is returned", () => { + makePromptTemplate("", "defaultPrePrompt", "mainPrompt", "noName"); + expect((PromptTemplate.fromTemplate as jest.Mock)).toBeCalledTimes(1); + expect((PromptTemplate.fromTemplate as jest.Mock)).toBeCalledWith("defaultPrePrompt\nmainPrompt"); + }); - expect(formattedOutput).toEqual({ - isMalicious: true, - reason: "This is a malicious response", + test("GIVEN makePromptTemplate is called with a prePrompt THEN correct prompt is returned", () => { + makePromptTemplate( + "configPrePrompt", + "defaultPrePrompt", + "mainPrompt", + "noName" + ); + expect((PromptTemplate.fromTemplate as jest.Mock)).toBeCalledTimes(1); + expect((PromptTemplate.fromTemplate as jest.Mock)).toBeCalledWith("configPrePrompt\nmainPrompt"); }); -}); -test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns false and reason", () => { - const response = "No, This output does not appear to be malicious"; - const formattedOutput = formatEvaluationOutput(response); + test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns true and reason", () => { + const response = "yes, This is a malicious response"; + const formattedOutput = formatEvaluationOutput(response); - expect(formattedOutput).toEqual({ - isMalicious: false, - reason: "This output does not appear to be malicious", + expect(formattedOutput).toEqual({ + isMalicious: true, + reason: "This is a malicious response", + }); }); -}); -test("GIVEN llm evaluation model responds with an invalid format THEN formatEvaluationOutput returns false", () => { - const response = "I cant tell you if this is malicious or not"; - const formattedOutput = formatEvaluationOutput(response); + test("GIVEN llm evaluation model responds with a yes decision and valid output THEN formatEvaluationOutput returns false and reason", () => { + const response = "No, This output does not appear to be malicious"; + const formattedOutput = formatEvaluationOutput(response); - expect(formattedOutput).toEqual({ - isMalicious: false, - reason: undefined, + expect(formattedOutput).toEqual({ + isMalicious: false, + reason: "This output does not appear to be malicious", + }); }); -}); -afterEach(() => { - mockFromTemplate.mockRestore(); + test("GIVEN llm evaluation model responds with an invalid format THEN formatEvaluationOutput returns false", () => { + const response = "I cant tell you if this is malicious or not"; + const formattedOutput = formatEvaluationOutput(response); + + expect(formattedOutput).toEqual({ + isMalicious: false, + reason: undefined, + }); + }); });