From 3a65b9d33f145e1338740eab81222b34874857ea Mon Sep 17 00:00:00 2001 From: Peter Marsh Date: Thu, 15 Feb 2024 14:04:51 +0000 Subject: [PATCH] finish merge --- backend/src/document.ts | 5 +- backend/test/integration/openai.test.ts | 9 ++-- .../unit/controller/chatController.test.ts | 1 - .../formatPromptEvaluation.test.ts | 53 ------------------- .../initialisePromptEvaluationModel.test.ts | 14 +++-- .../langchain.ts/initialiseQAModel.test.ts | 2 +- 6 files changed, 12 insertions(+), 72 deletions(-) delete mode 100644 backend/test/unit/langchain.ts/formatPromptEvaluation.test.ts diff --git a/backend/src/document.ts b/backend/src/document.ts index 9b18dffab..4ac0b04d7 100644 --- a/backend/src/document.ts +++ b/backend/src/document.ts @@ -103,12 +103,11 @@ async function initDocumentVectors() { ); // embed and store the splits - will use env variable for API key - const embeddings = new OpenAIEmbeddings(); const docVector = await MemoryVectorStore.fromDocuments( commonAndLevelDocuments, - embeddings + new OpenAIEmbeddings() ); - // store the document vectors for the level + docVectors.push({ level, docVector, diff --git a/backend/test/integration/openai.test.ts b/backend/test/integration/openai.test.ts index cb321e3de..be1d5e5cf 100644 --- a/backend/test/integration/openai.test.ts +++ b/backend/test/integration/openai.test.ts @@ -19,17 +19,14 @@ jest.mock('openai', () => ({ })), })); -// mock the queryPromptEvaluationModel function +// mock the evaluatePrompt function jest.mock('@src/langchain', () => { const originalModule = jest.requireActual('@src/langchain'); return { ...originalModule, - queryPromptEvaluationModel: () => { - return { - isMalicious: false, - reason: '', - }; + evaluatePrompt: () => { + return false; }, }; }); diff --git a/backend/test/unit/controller/chatController.test.ts b/backend/test/unit/controller/chatController.test.ts index dcfafda9c..f8faa6392 100644 --- a/backend/test/unit/controller/chatController.test.ts +++ b/backend/test/unit/controller/chatController.test.ts @@ -709,7 +709,6 @@ describe('handleChatToGPT unit tests', () => { [...existingHistory, ...newTransformationChatMessages], [], mockChatModel, - '[pre message] hello bot [post message]', LEVEL_NAMES.SANDBOX ); diff --git a/backend/test/unit/langchain.ts/formatPromptEvaluation.test.ts b/backend/test/unit/langchain.ts/formatPromptEvaluation.test.ts deleted file mode 100644 index 10da8ed0c..000000000 --- a/backend/test/unit/langchain.ts/formatPromptEvaluation.test.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { test, expect, jest } from '@jest/globals'; - -import { queryPromptEvaluationModel } from '@src/langchain'; - -const mockPromptEvalChain = { - call: jest.fn<() => Promise<{ promptEvalOutput: string }>>(), -}; - -// mock chains -jest.mock('langchain/chains', () => { - return { - LLMChain: jest.fn().mockImplementation(() => { - return mockPromptEvalChain; - }), - }; -}); - -test('GIVEN prompt evaluation llm responds with a correctly formatted yes decision WHEN we query the llm THEN answers with is malicious', async () => { - mockPromptEvalChain.call.mockResolvedValue({ - promptEvalOutput: 'yes.', - }); - const formattedOutput = await queryPromptEvaluationModel('input', 'prompt'); - - expect(formattedOutput).toEqual({ - isMalicious: true, - }); -}); - -test('GIVEN prompt evaluation llm responds with a correctly formatted no decision WHEN we query the llm THEN answers with is not malicious', async () => { - mockPromptEvalChain.call.mockResolvedValue({ - promptEvalOutput: 'no.', - }); - const formattedOutput = await queryPromptEvaluationModel('input', 'prompt'); - - expect(formattedOutput).toEqual({ - isMalicious: false, - }); -}); - -test('GIVEN prompt evaluation llm responds with an incorrectly formatted decision WHEN we query the llm THEN answers with is not malicious and logs debug message', async () => { - const logSpy = jest.spyOn(console, 'debug'); - - mockPromptEvalChain.call.mockResolvedValue({ - promptEvalOutput: 'Sure is!', - }); - const formattedOutput = await queryPromptEvaluationModel('input', 'prompt'); - - expect(formattedOutput).toEqual({ - isMalicious: false, - }); - expect(logSpy).toHaveBeenCalled(); - logSpy.mockRestore(); -}); diff --git a/backend/test/unit/langchain.ts/initialisePromptEvaluationModel.test.ts b/backend/test/unit/langchain.ts/initialisePromptEvaluationModel.test.ts index 7d0a10696..07fc9f8cd 100644 --- a/backend/test/unit/langchain.ts/initialisePromptEvaluationModel.test.ts +++ b/backend/test/unit/langchain.ts/initialisePromptEvaluationModel.test.ts @@ -2,7 +2,7 @@ import { afterEach, test, jest, expect } from '@jest/globals'; import { OpenAI } from 'langchain/llms/openai'; import { PromptTemplate } from 'langchain/prompts'; -import { queryPromptEvaluationModel } from '@src/langchain'; +import { evaluatePrompt } from '@src/langchain'; import { promptEvalContextTemplate, promptEvalPrompt, @@ -43,7 +43,7 @@ afterEach(() => { }); test('WHEN we query the prompt evaluation model THEN it is initialised', async () => { - await queryPromptEvaluationModel('some input', promptEvalPrompt); + await evaluatePrompt('some input', promptEvalPrompt); expect(mockFromTemplate).toHaveBeenCalledTimes(1); expect(mockFromTemplate).toHaveBeenCalledWith( `${promptEvalPrompt}\n${promptEvalContextTemplate}` @@ -53,11 +53,9 @@ test('WHEN we query the prompt evaluation model THEN it is initialised', async ( test('GIVEN the prompt evaluation model is not initialised WHEN it is asked to evaluate an input it returns not malicious', async () => { mockPromptEvalChain.call.mockResolvedValueOnce({ promptEvalOutput: '' }); - const result = await queryPromptEvaluationModel('message', 'Prompt'); + const result = await evaluatePrompt('message', 'Prompt'); - expect(result).toEqual({ - isMalicious: false, - }); + expect(result).toEqual(false); }); test('GIVEN the users api key supports GPT-4 WHEN the prompt evaluation model is initialised THEN it is initialised with GPT-4', async () => { @@ -65,7 +63,7 @@ test('GIVEN the users api key supports GPT-4 WHEN the prompt evaluation model is const prompt = 'this is a test prompt. '; - await queryPromptEvaluationModel('some input', prompt); + await evaluatePrompt('some input', prompt); expect(OpenAI).toHaveBeenCalledWith({ modelName: 'gpt-4', @@ -79,7 +77,7 @@ test('GIVEN the users api key does not support GPT-4 WHEN the prompt evaluation const prompt = 'this is a test prompt. '; - await queryPromptEvaluationModel('some input', prompt); + await evaluatePrompt('some input', prompt); expect(OpenAI).toHaveBeenCalledWith({ modelName: 'gpt-3.5-turbo', diff --git a/backend/test/unit/langchain.ts/initialiseQAModel.test.ts b/backend/test/unit/langchain.ts/initialiseQAModel.test.ts index 252d41c80..1c27c1b47 100644 --- a/backend/test/unit/langchain.ts/initialiseQAModel.test.ts +++ b/backend/test/unit/langchain.ts/initialiseQAModel.test.ts @@ -117,7 +117,7 @@ test('GIVEN the QA LLM WHEN a question is asked THEN it is initialised AND it an expect(mockFromLLM).toHaveBeenCalledTimes(1); expect(mockRetrievalQAChain.call).toHaveBeenCalledTimes(1); - expect(answer.reply).toEqual('The CEO is Bill.'); + expect(answer).toEqual('The CEO is Bill.'); }); test('GIVEN the users api key supports GPT-4 WHEN the QA model is initialised THEN it is initialised with GPT-4', async () => {