Skip to content

Commit

Permalink
707 refactor of chatgptchatcompletion (#758)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarsh-scottlogic authored and chriswilty committed Apr 8, 2024
1 parent fe97049 commit c324b37
Show file tree
Hide file tree
Showing 9 changed files with 418 additions and 494 deletions.
14 changes: 10 additions & 4 deletions backend/src/controller/chatController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ import { Defence } from '@src/models/defence';
import { EmailInfo } from '@src/models/email';
import { LEVEL_NAMES } from '@src/models/level';
import { chatGptSendMessage } from '@src/openai';
import { pushMessageToHistory } from '@src/utils/chat';
import {
pushMessageToHistory,
setSystemRoleInChatHistory,
} from '@src/utils/chat';

import { handleChatError } from './handleError';

Expand Down Expand Up @@ -233,9 +236,12 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) {
? req.session.chatModel
: defaultChatModel;

const currentChatHistory = [
...req.session.levelState[currentLevel].chatHistory,
];
const currentChatHistory = setSystemRoleInChatHistory(
currentLevel,
req.session.levelState[currentLevel].defences,
req.session.levelState[currentLevel].chatHistory
);

const defences = [...req.session.levelState[currentLevel].defences];

let levelResult: LevelHandlerResponse;
Expand Down
54 changes: 5 additions & 49 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@ import {
ChatCompletionMessageParam,
ChatCompletionTool,
ChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam,
} from 'openai/resources/chat/completions';

import {
isDefenceActive,
getSystemRole,
getQAPromptFromConfig,
} from './defence';
import { isDefenceActive, getQAPromptFromConfig } from './defence';
import { sendEmail } from './email';
import { queryDocuments } from './langchain';
import {
Expand Down Expand Up @@ -250,48 +245,11 @@ async function chatGptCallFunction(

async function chatGptChatCompletion(
chatHistory: ChatHistoryMessage[],
defences: Defence[],
chatModel: ChatModel,
openai: OpenAI,
// default to sandbox
currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX
): Promise<ChatGptReply> {
openai: OpenAI
) {
const updatedChatHistory = [...chatHistory];

// check if we need to set a system role
// system role is always active on levels
if (
currentLevel !== LEVEL_NAMES.SANDBOX ||
isDefenceActive(DEFENCE_ID.SYSTEM_ROLE, defences)
) {
const completionConfig: ChatCompletionSystemMessageParam = {
role: 'system',
content: getSystemRole(defences, currentLevel),
};

// check to see if there's already a system role
const systemRole = chatHistory.find(
(message) => message.completion?.role === 'system'
);
if (!systemRole) {
// add the system role to the start of the chat history
updatedChatHistory.unshift({
completion: completionConfig,
chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM,
});
} else {
// replace with the latest system role
systemRole.completion = completionConfig;
}
} else {
// remove the system role from the chat history
while (
updatedChatHistory.length > 0 &&
updatedChatHistory[0].completion?.role === 'system'
) {
updatedChatHistory.shift();
}
}
console.debug('Talking to model: ', JSON.stringify(chatModel));

// get start time
Expand Down Expand Up @@ -376,7 +334,6 @@ async function performToolCalls(
): Promise<ToolCallResponse> {
for (const toolCall of toolCalls) {
// only tool type supported by openai is function

// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
if (toolCall.type === 'function') {
const functionCallReply = await chatGptCallFunction(
Expand All @@ -385,6 +342,7 @@ async function performToolCalls(
toolCall.function,
currentLevel
);

// return after getting function reply. may change when we support other tool types. We assume only one function call in toolCalls
return {
functionCallReply,
Expand Down Expand Up @@ -417,10 +375,8 @@ async function getFinalReplyAfterAllToolCalls(
do {
gptReply = await chatGptChatCompletion(
updatedChatHistory,
defences,
chatModel,
openai,
currentLevel
openai
);
updatedChatHistory = gptReply.chatHistory;

Expand Down
51 changes: 49 additions & 2 deletions backend/src/utils/chat.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { ChatHistoryMessage } from '@src/models/chat';
import { ChatCompletionSystemMessageParam } from 'openai/resources/chat/completions';

import { getSystemRole, isDefenceActive } from '@src/defence';
import { CHAT_MESSAGE_TYPE, ChatHistoryMessage } from '@src/models/chat';
import { DEFENCE_ID, Defence } from '@src/models/defence';
import { LEVEL_NAMES } from '@src/models/level';

function pushMessageToHistory(
chatHistory: ChatHistoryMessage[],
Expand All @@ -21,4 +26,46 @@ function pushMessageToHistory(
return updatedChatHistory;
}

export { pushMessageToHistory };
function setSystemRoleInChatHistory(
currentLevel: LEVEL_NAMES,
defences: Defence[],
chatHistory: ChatHistoryMessage[]
) {
const systemRoleNeededInChatHistory =
currentLevel !== LEVEL_NAMES.SANDBOX ||
isDefenceActive(DEFENCE_ID.SYSTEM_ROLE, defences);

if (systemRoleNeededInChatHistory) {
const completionConfig: ChatCompletionSystemMessageParam = {
role: 'system',
content: getSystemRole(defences, currentLevel),
};

const existingSystemRole = chatHistory.find(
(message) => message.completion?.role === 'system'
);
if (!existingSystemRole) {
return [
{
completion: completionConfig,
chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM,
},
...chatHistory,
];
} else {
return chatHistory.map((message) => {
if (message.completion?.role === 'system') {
return { ...existingSystemRole, completion: completionConfig };
} else {
return message;
}
});
}
} else {
return chatHistory.filter(
(message) => message.completion?.role !== 'system'
);
}
}

export { pushMessageToHistory, setSystemRoleInChatHistory };
Loading

0 comments on commit c324b37

Please sign in to comment.