Skip to content

Commit

Permalink
596 fix problems with integrationlangchaintest (#804)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarsh-scottlogic authored Feb 6, 2024
1 parent f5856e2 commit 2647f6e
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 436 deletions.
60 changes: 48 additions & 12 deletions backend/src/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,14 @@ import { CSVLoader } from 'langchain/document_loaders/fs/csv';
import { DirectoryLoader } from 'langchain/document_loaders/fs/directory';
import { PDFLoader } from 'langchain/document_loaders/fs/pdf';
import { TextLoader } from 'langchain/document_loaders/fs/text';
import { OpenAIEmbeddings } from 'langchain/embeddings/openai';
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { MemoryVectorStore } from 'langchain/vectorstores/memory';
import * as fs from 'node:fs';

import { DocumentMeta } from './models/document';
import { DocumentMeta, DocumentsVector } from './models/document';
import { LEVEL_NAMES } from './models/level';

async function getCommonDocuments() {
const commonDocsFilePath = getFilepath('common');
return await getDocuments(commonDocsFilePath);
}

async function getLevelDocuments(level: LEVEL_NAMES) {
const levelDocsFilePath = getFilepath(level);
return await getDocuments(levelDocsFilePath);
}

// load the documents from filesystem
async function getDocuments(filePath: string) {
console.debug(`Loading documents from: ${filePath}`);
Expand Down Expand Up @@ -84,4 +76,48 @@ function getDocumentMetas(folder: string) {
return documentMetas;
}

export { getCommonDocuments, getLevelDocuments, getSandboxDocumentMetas };
// store vectorised documents for each level as array
const documentVectors = (() => {
let docs: DocumentsVector[] = [];
return {
get: () => docs,
set: (newDocs: DocumentsVector[]) => {
docs = newDocs;
},
};
})();
const getDocumentVectors = documentVectors.get;

// create and store the document vectors for each level
async function initDocumentVectors() {
const docVectors: DocumentsVector[] = [];
const commonDocuments = await getDocuments(getFilepath('common'));

const levelValues = Object.values(LEVEL_NAMES)
.filter((value) => !isNaN(Number(value)))
.map((value) => Number(value));

for (const level of levelValues) {
const commonAndLevelDocuments = commonDocuments.concat(
await getDocuments(getFilepath(level))
);

// embed and store the splits - will use env variable for API key
const embeddings = new OpenAIEmbeddings();
const docVector = await MemoryVectorStore.fromDocuments(
commonAndLevelDocuments,
embeddings
);
// store the document vectors for the level
docVectors.push({
level,
docVector,
});
}
documentVectors.set(docVectors);
console.debug(
`Initialised document vectors for each level. count=${docVectors.length}`
);
}

export { getSandboxDocumentMetas, initDocumentVectors, getDocumentVectors };
59 changes: 6 additions & 53 deletions backend/src/langchain.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import { RetrievalQAChain, LLMChain } from 'langchain/chains';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { OpenAIEmbeddings } from 'langchain/embeddings/openai';
import { OpenAI } from 'langchain/llms/openai';
import { PromptTemplate } from 'langchain/prompts';
import { MemoryVectorStore } from 'langchain/vectorstores/memory';

import { getCommonDocuments, getLevelDocuments } from './document';
import { getDocumentVectors } from './document';
import { CHAT_MODELS, ChatAnswer } from './models/chat';
import { DocumentsVector } from './models/document';
import { PromptEvaluationChainReply, QaChainReply } from './models/langchain';
import { LEVEL_NAMES } from './models/level';
import { getOpenAIKey, getValidOpenAIModelsList } from './openai';
Expand All @@ -18,18 +15,6 @@ import {
qAPrompt,
} from './promptTemplates';

// store vectorised documents for each level as array
const vectorisedDocuments = (() => {
const docs: DocumentsVector[] = [];
return {
get: () => docs,
set: (newDocs: DocumentsVector[]) => {
while (docs.length > 0) docs.pop();
docs.push(...newDocs);
},
};
})();

// choose between the provided preprompt and the default preprompt and prepend it to the main prompt and return the PromptTemplate
function makePromptTemplate(
configPrompt: string,
Expand All @@ -46,36 +31,6 @@ function makePromptTemplate(
return PromptTemplate.fromTemplate(fullPrompt);
}

// create and store the document vectors for each level
async function initDocumentVectors() {
const docVectors: DocumentsVector[] = [];
const commonDocuments = await getCommonDocuments();

const levelValues = Object.values(LEVEL_NAMES)
.filter((value) => !isNaN(Number(value)))
.map((value) => Number(value));

for (const level of levelValues) {
const allDocuments = commonDocuments.concat(await getLevelDocuments(level));

// embed and store the splits - will use env variable for API key
const embeddings = new OpenAIEmbeddings();
const docVector = await MemoryVectorStore.fromDocuments(
allDocuments,
embeddings
);
// store the document vectors for the level
docVectors.push({
level,
docVector,
});
}
vectorisedDocuments.set(docVectors);
console.debug(
`Initialised document vectors for each level. count=${docVectors.length}`
);
}

function getChatModel() {
return getValidOpenAIModelsList().includes(CHAT_MODELS.GPT_4)
? CHAT_MODELS.GPT_4
Expand All @@ -84,7 +39,7 @@ function getChatModel() {

function initQAModel(level: LEVEL_NAMES, Prompt: string) {
const openAIApiKey = getOpenAIKey();
const documentVectors = vectorisedDocuments.get()[level].docVector;
const documentVectors = getDocumentVectors()[level].docVector;
// use gpt-4 if avaliable to apiKey
const modelName = getChatModel();

Expand Down Expand Up @@ -196,17 +151,15 @@ async function queryPromptEvaluationModel(

function formatEvaluationOutput(response: string) {
// remove all non-alphanumeric characters
try {
const cleanResponse = response.replace(/\W/g, '').toLowerCase();
const cleanResponse = response.replace(/\W/g, '').toLowerCase();
if (cleanResponse === 'yes' || cleanResponse === 'no') {
return { isMalicious: cleanResponse === 'yes' };
} catch (error) {
// in case the model does not respond in the format we have asked
console.error(error);
} else {
console.debug(
`Did not get a valid response from the prompt evaluation model. Original response: ${response}`
);
return { isMalicious: false };
}
}

export { queryDocuments, queryPromptEvaluationModel, initDocumentVectors };
export { queryDocuments, queryPromptEvaluationModel };
2 changes: 1 addition & 1 deletion backend/src/server.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { env, exit } from 'node:process';

import app from './app';
import { initDocumentVectors } from './langchain';
import { initDocumentVectors } from './document';
import { getValidModelsFromOpenAI } from './openai';
// by default runs on port 3001
const port = env.PORT ?? String(3001);
Expand Down
Loading

0 comments on commit 2647f6e

Please sign in to comment.