-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
run defence detection and chat completion concurrently #375
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,9 @@ import { | |
} from "./defence"; | ||
import { | ||
CHAT_MESSAGE_TYPE, | ||
ChatHistoryMessage, | ||
ChatHttpResponse, | ||
ChatModel, | ||
ChatModelConfiguration, | ||
MODEL_CONFIG, | ||
defaultChatModel, | ||
|
@@ -154,149 +156,239 @@ router.post("/email/clear", (req: EmailClearRequest, res) => { | |
} | ||
}); | ||
|
||
function handleChatError( | ||
res: express.Response, | ||
chatResponse: ChatHttpResponse, | ||
blocked: boolean, | ||
errorMsg: string, | ||
statusCode = 500 | ||
) { | ||
console.error(errorMsg); | ||
chatResponse.reply = errorMsg; | ||
chatResponse.defenceInfo.isBlocked = blocked; | ||
if (blocked) { | ||
chatResponse.defenceInfo.blockedReason = errorMsg; | ||
} | ||
res.status(statusCode).send(chatResponse); | ||
} | ||
// Chat to ChatGPT | ||
router.post("/openai/chat", async (req: OpenAiChatRequest, res) => { | ||
// set reply params | ||
const chatResponse: ChatHttpResponse = { | ||
reply: "", | ||
defenceInfo: { | ||
blockedReason: "", | ||
isBlocked: false, | ||
alertedDefences: [], | ||
triggeredDefences: [], | ||
}, | ||
numLevelsCompleted: req.session.numLevelsCompleted, | ||
transformedMessage: "", | ||
wonLevel: false, | ||
}; | ||
|
||
const message = req.body.message; | ||
const currentLevel = req.body.currentLevel; | ||
|
||
// must have initialised openai | ||
if (!req.session.openAiApiKey) { | ||
res.statusCode = 401; | ||
chatResponse.defenceInfo.isBlocked = true; | ||
chatResponse.defenceInfo.blockedReason = | ||
"Please enter a valid OpenAI API key to chat to me!"; | ||
console.error(chatResponse.reply); | ||
} else if (message === undefined || currentLevel === undefined) { | ||
res.statusCode = 400; | ||
chatResponse.defenceInfo.isBlocked = true; | ||
chatResponse.defenceInfo.blockedReason = | ||
"Please send a message and current level to chat to me!"; | ||
} else { | ||
router.post( | ||
"/openai/chat", | ||
async (req: OpenAiChatRequest, res: express.Response) => { | ||
// set reply params | ||
const chatResponse: ChatHttpResponse = { | ||
reply: "", | ||
defenceInfo: { | ||
blockedReason: "", | ||
isBlocked: false, | ||
alertedDefences: [], | ||
triggeredDefences: [], | ||
}, | ||
numLevelsCompleted: 0, | ||
transformedMessage: "", | ||
wonLevel: false, | ||
}; | ||
const message = req.body.message; | ||
const currentLevel = req.body.currentLevel; | ||
|
||
// must have initialised openai | ||
if (message === undefined || currentLevel === undefined) { | ||
handleChatError( | ||
res, | ||
chatResponse, | ||
true, | ||
"Please send a message and current level to chat to me!", | ||
400 | ||
); | ||
return; | ||
} | ||
// set the transformed message to begin with | ||
chatResponse.transformedMessage = message; | ||
if (!req.session.openAiApiKey) { | ||
handleChatError( | ||
res, | ||
chatResponse, | ||
true, | ||
"Please enter a valid OpenAI API key to chat to me!", | ||
401 | ||
); | ||
return; | ||
} | ||
let numLevelsCompleted = req.session.numLevelsCompleted; | ||
|
||
if (message) { | ||
chatResponse.transformedMessage = message; | ||
// see if this message triggers any defences (only for level 3 and sandbox) | ||
if ( | ||
currentLevel === LEVEL_NAMES.LEVEL_3 || | ||
currentLevel === LEVEL_NAMES.SANDBOX | ||
) { | ||
chatResponse.defenceInfo = await detectTriggeredDefences( | ||
message, | ||
req.session.levelState[currentLevel].defences, | ||
req.session.openAiApiKey | ||
); | ||
// if message is blocked, add to chat history (not as completion) | ||
if (chatResponse.defenceInfo.isBlocked) { | ||
req.session.levelState[currentLevel].chatHistory.push({ | ||
completion: null, | ||
chatMessageType: CHAT_MESSAGE_TYPE.USER, | ||
infoMessage: message, | ||
}); | ||
// use default model for levels, allow user to select in sandbox | ||
const chatModel = | ||
currentLevel === LEVEL_NAMES.SANDBOX | ||
? req.session.chatModel | ||
: defaultChatModel; | ||
|
||
// record the history before chat completion called | ||
const chatHistoryBefore = [ | ||
...req.session.levelState[currentLevel].chatHistory, | ||
]; | ||
try { | ||
if (message) { | ||
// skip defence detection / blocking for levels 1 and 2- sets chatResponse obj | ||
if (currentLevel < LEVEL_NAMES.LEVEL_3) { | ||
await handleLowLevelChat(req, chatResponse, currentLevel, chatModel); | ||
} else { | ||
// apply the defence detection for level 3 and sandbox - sets chatResponse obj | ||
await handleHigherLevelChat( | ||
req, | ||
message, | ||
chatHistoryBefore, | ||
chatResponse, | ||
currentLevel, | ||
chatModel | ||
); | ||
} | ||
} | ||
// if blocked, send the response | ||
if (!chatResponse.defenceInfo.isBlocked) { | ||
// transform the message according to active defences | ||
chatResponse.transformedMessage = transformMessage( | ||
message, | ||
req.session.levelState[currentLevel].defences | ||
); | ||
// if message has been transformed then add the original to chat history and send transformed to chatGPT | ||
const messageIsTransformed = | ||
chatResponse.transformedMessage !== message; | ||
if (messageIsTransformed) { | ||
// if the reply was blocked then add it to the chat history | ||
if (chatResponse.defenceInfo.isBlocked) { | ||
req.session.levelState[currentLevel].chatHistory.push({ | ||
completion: null, | ||
chatMessageType: CHAT_MESSAGE_TYPE.USER, | ||
infoMessage: message, | ||
chatMessageType: CHAT_MESSAGE_TYPE.BOT_BLOCKED, | ||
infoMessage: chatResponse.defenceInfo.blockedReason, | ||
}); | ||
} | ||
// use default model for levels | ||
const chatModel = | ||
currentLevel === LEVEL_NAMES.SANDBOX | ||
? req.session.chatModel | ||
: defaultChatModel; | ||
|
||
// get the chatGPT reply | ||
try { | ||
const openAiReply = await chatGptSendMessage( | ||
req.session.levelState[currentLevel].chatHistory, | ||
req.session.levelState[currentLevel].defences, | ||
chatModel, | ||
chatResponse.transformedMessage, | ||
messageIsTransformed, | ||
req.session.openAiApiKey, | ||
req.session.levelState[currentLevel].sentEmails, | ||
currentLevel | ||
); | ||
|
||
if (openAiReply) { | ||
chatResponse.wonLevel = openAiReply.wonLevel; | ||
chatResponse.reply = openAiReply.completion.content ?? ""; | ||
|
||
// combine triggered defences | ||
chatResponse.defenceInfo.triggeredDefences = [ | ||
...chatResponse.defenceInfo.triggeredDefences, | ||
...openAiReply.defenceInfo.triggeredDefences, | ||
]; | ||
// combine blocked | ||
chatResponse.defenceInfo.isBlocked = | ||
openAiReply.defenceInfo.isBlocked; | ||
|
||
// combine blocked reason | ||
chatResponse.defenceInfo.blockedReason = | ||
openAiReply.defenceInfo.blockedReason; | ||
} | ||
} catch (error) { | ||
res.statusCode = 500; | ||
console.log(error); | ||
if (error instanceof Error) { | ||
chatResponse.reply = "Failed to get chatGPT reply"; | ||
// enable next level when user wins current level | ||
if (chatResponse.wonLevel) { | ||
console.debug("Win conditon met for level: ", currentLevel); | ||
numLevelsCompleted = Math.max(numLevelsCompleted, currentLevel + 1); | ||
req.session.numLevelsCompleted = numLevelsCompleted; | ||
chatResponse.numLevelsCompleted = numLevelsCompleted; | ||
} | ||
} | ||
} else { | ||
handleChatError(res, chatResponse, true, "Missing message"); | ||
return; | ||
} | ||
} catch (error) { | ||
handleChatError(res, chatResponse, false, "Failed to get chatGPT reply"); | ||
return; | ||
} | ||
// log and send the reply with defence info | ||
console.log(chatResponse); | ||
res.send(chatResponse); | ||
} | ||
); | ||
|
||
// if the reply was blocked then add it to the chat history | ||
if (chatResponse.defenceInfo.isBlocked) { | ||
req.session.levelState[currentLevel].chatHistory.push({ | ||
completion: null, | ||
chatMessageType: CHAT_MESSAGE_TYPE.BOT_BLOCKED, | ||
infoMessage: chatResponse.defenceInfo.blockedReason, | ||
}); | ||
} | ||
// handle the chat logic for level 1 and 2 with no defences applied | ||
async function handleLowLevelChat( | ||
req: OpenAiChatRequest, | ||
chatResponse: ChatHttpResponse, | ||
currentLevel: LEVEL_NAMES, | ||
chatModel: ChatModel | ||
) { | ||
// get the chatGPT reply | ||
const openAiReply = await chatGptSendMessage( | ||
req.session.levelState[currentLevel].chatHistory, | ||
req.session.levelState[currentLevel].defences, | ||
chatModel, | ||
chatResponse.transformedMessage, | ||
false, | ||
req.session.openAiApiKey ?? "", | ||
req.session.levelState[currentLevel].sentEmails, | ||
currentLevel | ||
); | ||
chatResponse.reply = openAiReply?.completion.content ?? ""; | ||
chatResponse.wonLevel = openAiReply?.wonLevel ?? false; | ||
|
||
if (openAiReply instanceof Error) { | ||
throw openAiReply; | ||
} | ||
} | ||
|
||
// enable next level when user wins current level | ||
if (chatResponse.wonLevel) { | ||
console.log("Win conditon met for level: ", currentLevel); | ||
numLevelsCompleted = Math.max(numLevelsCompleted, currentLevel + 1); | ||
req.session.numLevelsCompleted = numLevelsCompleted; | ||
chatResponse.numLevelsCompleted = numLevelsCompleted; | ||
} | ||
} else { | ||
res.statusCode = 400; | ||
chatResponse.reply = "Missing message"; | ||
console.error(chatResponse.reply); | ||
// handle the chat logic for high levels (with defence detection) | ||
async function handleHigherLevelChat( | ||
req: OpenAiChatRequest, | ||
message: string, | ||
chatHistoryBefore: ChatHistoryMessage[], | ||
chatResponse: ChatHttpResponse, | ||
currentLevel: LEVEL_NAMES, | ||
chatModel: ChatModel | ||
) { | ||
let openAiReply = null; | ||
|
||
// transform the message according to active defences | ||
chatResponse.transformedMessage = transformMessage( | ||
message, | ||
req.session.levelState[currentLevel].defences | ||
); | ||
// if message has been transformed then add the original to chat history and send transformed to chatGPT | ||
const messageIsTransformed = chatResponse.transformedMessage !== message; | ||
if (messageIsTransformed) { | ||
req.session.levelState[currentLevel].chatHistory.push({ | ||
completion: null, | ||
chatMessageType: CHAT_MESSAGE_TYPE.USER, | ||
infoMessage: message, | ||
}); | ||
} | ||
// detect defences on input message | ||
const triggeredDefencesPromise = detectTriggeredDefences( | ||
message, | ||
req.session.levelState[currentLevel].defences, | ||
req.session.openAiApiKey ?? "" | ||
).then((defenceInfo) => { | ||
chatResponse.defenceInfo = defenceInfo; | ||
}); | ||
|
||
// get the chatGPT reply | ||
try { | ||
const openAiReplyPromise = chatGptSendMessage( | ||
req.session.levelState[currentLevel].chatHistory, | ||
req.session.levelState[currentLevel].defences, | ||
chatModel, | ||
chatResponse.transformedMessage, | ||
messageIsTransformed, | ||
req.session.openAiApiKey ?? "", | ||
req.session.levelState[currentLevel].sentEmails, | ||
currentLevel | ||
); | ||
|
||
// run defence detection and chatGPT concurrently | ||
const [, openAiReplyResolved] = await Promise.all([ | ||
triggeredDefencesPromise, | ||
openAiReplyPromise, | ||
]); | ||
openAiReply = openAiReplyResolved; | ||
|
||
// if input message is blocked, restore the original chat history and add user message (not as completion) | ||
if (chatResponse.defenceInfo.isBlocked) { | ||
// set to null to stop message being returned to user | ||
openAiReply = null; | ||
|
||
// restore the original chat history | ||
req.session.levelState[currentLevel].chatHistory = chatHistoryBefore; | ||
|
||
req.session.levelState[currentLevel].chatHistory.push({ | ||
completion: null, | ||
chatMessageType: CHAT_MESSAGE_TYPE.USER, | ||
infoMessage: message, | ||
}); | ||
} | ||
|
||
if (openAiReply) { | ||
chatResponse.wonLevel = openAiReply.wonLevel; | ||
chatResponse.reply = openAiReply.completion.content ?? ""; | ||
|
||
// combine triggered defences | ||
chatResponse.defenceInfo.triggeredDefences = [ | ||
...chatResponse.defenceInfo.triggeredDefences, | ||
...openAiReply.defenceInfo.triggeredDefences, | ||
]; | ||
// combine blocked | ||
chatResponse.defenceInfo.isBlocked = openAiReply.defenceInfo.isBlocked; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, I'm a little concerned about this mutable chatResponse object being passed around, and modified as a side-effect, as it feels like a difficult-to-detect bug waiting to emerge, particularly in functions with high complexity such as these. I am having difficulty seeing all the places it can be modified, and in what order. I'll add a future issue to tackle this complexity, and introduce immutability. |
||
|
||
// combine blocked reason | ||
chatResponse.defenceInfo.blockedReason = | ||
openAiReply.defenceInfo.blockedReason; | ||
} | ||
} catch (error) { | ||
if (error instanceof Error) { | ||
throw error; | ||
} | ||
} | ||
// log and send the reply with defence info | ||
console.log(chatResponse); | ||
res.send(chatResponse); | ||
}); | ||
} | ||
|
||
// get the chat history | ||
router.get("/openai/history", (req, res) => { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Destructuring is neater: