Skip to content

Commit

Permalink
Misc tidying (#837)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarsh-scottlogic authored and chriswilty committed Apr 8, 2024
1 parent 1d1f713 commit 484d808
Show file tree
Hide file tree
Showing 26 changed files with 129 additions and 227 deletions.
35 changes: 18 additions & 17 deletions backend/src/controller/chatController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { OpenAiChatRequest } from '@src/models/api/OpenAiChatRequest';
import { OpenAiClearRequest } from '@src/models/api/OpenAiClearRequest';
import { OpenAiGetHistoryRequest } from '@src/models/api/OpenAiGetHistoryRequest';
import {
ChatDefenceReport,
DefenceReport,
ChatHttpResponse,
ChatModel,
LevelHandlerResponse,
Expand All @@ -20,7 +20,7 @@ import {
import {
ChatMessage,
ChatInfoMessage,
chatInfoMessageType,
chatInfoMessageTypes,
} from '@src/models/chatMessage';
import { Defence } from '@src/models/defence';
import { EmailInfo } from '@src/models/email';
Expand All @@ -33,9 +33,7 @@ import {

import { handleChatError } from './handleError';

function combineChatDefenceReports(
reports: ChatDefenceReport[]
): ChatDefenceReport {
function combineDefenceReports(reports: DefenceReport[]): DefenceReport {
return {
blockedReason: reports
.filter((report) => report.blockedReason !== null)
Expand Down Expand Up @@ -100,17 +98,17 @@ async function handleChatWithoutDefenceDetection(
chatHistory: ChatMessage[],
defences: Defence[]
): Promise<LevelHandlerResponse> {
console.log(`User message: '${message}'`);

const updatedChatHistory = createNewUserMessages(message).reduce(
pushMessageToHistory,
chatHistory
);

// get the chatGPT reply
const openAiReply = await chatGptSendMessage(
updatedChatHistory,
defences,
chatModel,
message,
currentLevel
);

Expand Down Expand Up @@ -146,11 +144,16 @@ async function handleChatWithDefenceDetection(
defences
);

console.log(
`User message: '${
messageTransformation?.transformedMessageCombined ?? message
}'`
);

const openAiReplyPromise = chatGptSendMessage(
chatHistoryWithNewUserMessages,
defences,
chatModel,
messageTransformation?.transformedMessageCombined ?? message,
currentLevel
);

Expand All @@ -168,7 +171,7 @@ async function handleChatWithDefenceDetection(
const defenceReports = outputDefenceReport
? [inputDefenceReport, outputDefenceReport]
: [inputDefenceReport];
const combinedDefenceReport = combineChatDefenceReports(defenceReports);
const combinedDefenceReport = combineDefenceReports(defenceReports);

// if blocked, restore original chat history and add user message to chat history without completion
const updatedChatHistory = combinedDefenceReport.isBlocked
Expand Down Expand Up @@ -196,7 +199,6 @@ async function handleChatWithDefenceDetection(
}

async function handleChatToGPT(req: OpenAiChatRequest, res: Response) {
// set reply params
const initChatResponse: ChatHttpResponse = {
reply: '',
defenceReport: {
Expand Down Expand Up @@ -232,9 +234,6 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) {
);
return;
}
const totalSentEmails: EmailInfo[] = [
...req.session.levelState[currentLevel].sentEmails,
];

// use default model for levels, allow user to select in sandbox
const chatModel =
Expand Down Expand Up @@ -283,15 +282,18 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) {
}

let updatedChatHistory = levelResult.chatHistory;
totalSentEmails.push(...levelResult.chatResponse.sentEmails);

const totalSentEmails: EmailInfo[] = [
...req.session.levelState[currentLevel].sentEmails,
...levelResult.chatResponse.sentEmails,
];

const updatedChatResponse: ChatHttpResponse = {
...initChatResponse,
...levelResult.chatResponse,
};

if (updatedChatResponse.defenceReport.isBlocked) {
// chatReponse.reply is empty if blocked
updatedChatHistory = pushMessageToHistory(updatedChatHistory, {
chatMessageType: 'BOT_BLOCKED',
infoMessage:
Expand Down Expand Up @@ -326,7 +328,6 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) {
});
}

// update state
req.session.levelState[currentLevel].chatHistory = updatedChatHistory;
req.session.levelState[currentLevel].sentEmails = totalSentEmails;

Expand Down Expand Up @@ -376,7 +377,7 @@ function handleAddInfoToChatHistory(
if (
infoMessage &&
chatMessageType &&
chatInfoMessageType.includes(chatMessageType) &&
chatInfoMessageTypes.includes(chatMessageType) &&
level !== undefined &&
level >= LEVEL_NAMES.LEVEL_1
) {
Expand Down
21 changes: 6 additions & 15 deletions backend/src/defence.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { defaultDefences } from './defaultDefences';
import { queryPromptEvaluationModel } from './langchain';
import { evaluatePrompt } from './langchain';
import {
ChatDefenceReport,
DefenceReport,
MessageTransformation,
SingleDefenceReport,
TransformedChatMessage,
Expand All @@ -20,14 +20,12 @@ import {
} from './promptTemplates';

function activateDefence(id: DEFENCE_ID, defences: Defence[]) {
// return the updated list of defences
return defences.map((defence) =>
defence.id === id ? { ...defence, isActive: true } : defence
);
}

function deactivateDefence(id: DEFENCE_ID, defences: Defence[]) {
// return the updated list of defences
return defences.map((defence) =>
defence.id === id ? { ...defence, isActive: false } : defence
);
Expand All @@ -38,7 +36,6 @@ function configureDefence(
defences: Defence[],
config: DefenceConfigItem[]
): Defence[] {
// return the updated list of defences
return defences.map((defence) =>
defence.id === id ? { ...defence, config } : defence
);
Expand Down Expand Up @@ -95,7 +92,6 @@ function getFilterList(defences: Defence[], type: DEFENCE_ID) {
}
function getSystemRole(
defences: Defence[],
// by default, use sandbox
currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX
) {
switch (currentLevel) {
Expand Down Expand Up @@ -183,14 +179,12 @@ function escapeXml(unsafe: string) {
});
}

// function to detect any XML tags in user input
function containsXMLTags(input: string) {
const tagRegex = /<\/?[a-zA-Z][\w-]*(?:\b[^>]*\/\s*|[^>]*>|[?]>)/g;
const foundTags: string[] = input.match(tagRegex) ?? [];
return foundTags.length > 0;
}

// apply XML tagging defence to input message
function transformXmlTagging(
message: string,
defences: Defence[]
Expand All @@ -213,7 +207,6 @@ function generateRandomString(length: number) {
).join('');
}

// apply random sequence enclosure defence to input message
function transformRandomSequenceEnclosure(
message: string,
defences: Defence[]
Expand Down Expand Up @@ -250,7 +243,6 @@ function combineTransformedMessage(transformedMessage: TransformedChatMessage) {
);
}

//apply defence string transformations to original message
function transformMessage(
message: string,
defences: Defence[]
Expand Down Expand Up @@ -284,7 +276,6 @@ function transformMessage(
};
}

// detects triggered defences in original message and blocks the message if necessary
async function detectTriggeredInputDefences(
message: string,
defences: Defence[]
Expand All @@ -299,15 +290,14 @@ async function detectTriggeredInputDefences(
return combineDefenceReports(singleDefenceReports);
}

// detects triggered defences in bot output and blocks the message if necessary
function detectTriggeredOutputDefences(message: string, defences: Defence[]) {
const singleDefenceReports = [detectFilterBotOutput(message, defences)];
return combineDefenceReports(singleDefenceReports);
}

function combineDefenceReports(
defenceReports: SingleDefenceReport[]
): ChatDefenceReport {
): DefenceReport {
const isBlocked = defenceReports.some((report) => report.blockedReason);
const blockedReason = isBlocked
? defenceReports
Expand Down Expand Up @@ -451,15 +441,16 @@ async function detectEvaluationLLM(
): Promise<SingleDefenceReport> {
const defence = DEFENCE_ID.PROMPT_EVALUATION_LLM;
// to save money and processing time, and to reduce risk of rate limiting, we only run if defence is active
// this means that, contrary to the other defences, the user won't get alerts when the defence is not active, i.e. "your last prompt would have been blocked by the prompt evaluation LLM"
if (isDefenceActive(DEFENCE_ID.PROMPT_EVALUATION_LLM, defences)) {
const promptEvalLLMPrompt = getPromptEvalPromptFromConfig(defences);

const evaluationResult = await queryPromptEvaluationModel(
const promptIsMalicious = await evaluatePrompt(
message,
promptEvalLLMPrompt
);

if (evaluationResult.isMalicious) {
if (promptIsMalicious) {
console.debug('LLM evaluation defence active and prompt is malicious.');

return {
Expand Down
5 changes: 2 additions & 3 deletions backend/src/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ async function initDocumentVectors() {
);

// embed and store the splits - will use env variable for API key
const embeddings = new OpenAIEmbeddings();
const docVector = await MemoryVectorStore.fromDocuments(
commonAndLevelDocuments,
embeddings
new OpenAIEmbeddings()
);
// store the document vectors for the level

docVectors.push({
level,
docVector,
Expand Down
Loading

0 comments on commit 484d808

Please sign in to comment.