Skip to content

Commit

Permalink
741 fix backend chat history message types (#803)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarsh-scottlogic authored and chriswilty committed Apr 8, 2024
1 parent 6c0a831 commit 6789736
Show file tree
Hide file tree
Showing 27 changed files with 610 additions and 507 deletions.
97 changes: 53 additions & 44 deletions backend/src/controller/chatController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@ import {
detectTriggeredInputDefences,
detectTriggeredOutputDefences,
} from '@src/defence';
import { OpenAiAddHistoryRequest } from '@src/models/api/OpenAiAddHistoryRequest';
import { OpenAiAddInfoToChatHistoryRequest } from '@src/models/api/OpenAiAddInfoToChatHistoryRequest';
import { OpenAiChatRequest } from '@src/models/api/OpenAiChatRequest';
import { OpenAiClearRequest } from '@src/models/api/OpenAiClearRequest';
import { OpenAiGetHistoryRequest } from '@src/models/api/OpenAiGetHistoryRequest';
import {
CHAT_MESSAGE_TYPE,
ChatDefenceReport,
ChatHistoryMessage,
ChatHttpResponse,
ChatModel,
LevelHandlerResponse,
MessageTransformation,
defaultChatModel,
} from '@src/models/chat';
import {
ChatMessage,
ChatInfoMessage,
chatInfoMessageType,
} from '@src/models/chatMessage';
import { Defence } from '@src/models/defence';
import { EmailInfo } from '@src/models/email';
import { LEVEL_NAMES } from '@src/models/level';
Expand Down Expand Up @@ -46,38 +49,45 @@ function combineChatDefenceReports(

function createNewUserMessages(
message: string,
messageTransformation?: MessageTransformation
): ChatHistoryMessage[] {
messageTransformation?: MessageTransformation,
createAs: 'completion' | 'info' = 'completion'
): ChatMessage[] {
if (messageTransformation) {
return [
{
completion: null,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
chatMessageType: 'USER',
infoMessage: message,
},
{
completion: null,
chatMessageType: CHAT_MESSAGE_TYPE.INFO,
chatMessageType: 'GENERIC_INFO',
infoMessage: messageTransformation.transformedMessageInfo,
},
{
completion: {
role: 'user',
content: messageTransformation.transformedMessageCombined,
},
chatMessageType: CHAT_MESSAGE_TYPE.USER_TRANSFORMED,
completion:
createAs === 'completion'
? {
role: 'user',
content: messageTransformation.transformedMessageCombined,
}
: undefined,
chatMessageType: 'USER_TRANSFORMED',
transformedMessage: messageTransformation.transformedMessage,
},
];
} else {
return [
{
completion: {
role: 'user',
content: message,
},
chatMessageType: CHAT_MESSAGE_TYPE.USER,
},
createAs === 'completion'
? {
completion: {
role: 'user',
content: message,
},
chatMessageType: 'USER',
}
: {
chatMessageType: 'USER',
infoMessage: message,
},
];
}
}
Expand All @@ -87,7 +97,7 @@ async function handleChatWithoutDefenceDetection(
chatResponse: ChatHttpResponse,
currentLevel: LEVEL_NAMES,
chatModel: ChatModel,
chatHistory: ChatHistoryMessage[],
chatHistory: ChatMessage[],
defences: Defence[]
): Promise<LevelHandlerResponse> {
const updatedChatHistory = createNewUserMessages(message).reduce(
Expand Down Expand Up @@ -122,7 +132,7 @@ async function handleChatWithDefenceDetection(
chatResponse: ChatHttpResponse,
currentLevel: LEVEL_NAMES,
chatModel: ChatModel,
chatHistory: ChatHistoryMessage[],
chatHistory: ChatMessage[],
defences: Defence[]
): Promise<LevelHandlerResponse> {
const messageTransformation = transformMessage(message, defences);
Expand Down Expand Up @@ -162,11 +172,10 @@ async function handleChatWithDefenceDetection(

// if blocked, restore original chat history and add user message to chat history without completion
const updatedChatHistory = combinedDefenceReport.isBlocked
? pushMessageToHistory(chatHistory, {
completion: null,
chatMessageType: CHAT_MESSAGE_TYPE.USER,
infoMessage: message,
})
? createNewUserMessages(message, messageTransformation, 'info').reduce(
pushMessageToHistory,
chatHistory
)
: openAiReply.chatHistory;

const updatedChatResponse: ChatHttpResponse = {
Expand Down Expand Up @@ -284,9 +293,10 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) {
if (updatedChatResponse.defenceReport.isBlocked) {
// chatReponse.reply is empty if blocked
updatedChatHistory = pushMessageToHistory(updatedChatHistory, {
completion: null,
chatMessageType: CHAT_MESSAGE_TYPE.BOT_BLOCKED,
infoMessage: updatedChatResponse.defenceReport.blockedReason,
chatMessageType: 'BOT_BLOCKED',
infoMessage:
updatedChatResponse.defenceReport.blockedReason ??
'block reason unknown',
});
} else if (updatedChatResponse.openAIErrorMessage) {
const errorMsg = simplifyOpenAIErrorMessage(
Expand All @@ -307,13 +317,12 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) {
handleChatError(res, updatedChatResponse, errorMsg, 500);
return;
} else {
// add bot message to chat history
updatedChatHistory = pushMessageToHistory(updatedChatHistory, {
completion: {
role: 'assistant',
content: updatedChatResponse.reply,
},
chatMessageType: CHAT_MESSAGE_TYPE.BOT,
chatMessageType: 'BOT',
});
}

Expand All @@ -339,13 +348,12 @@ function simplifyOpenAIErrorMessage(openAIErrorMessage: string) {
}

function addErrorToChatHistory(
chatHistory: ChatHistoryMessage[],
chatHistory: ChatMessage[],
errorMessage: string
): ChatHistoryMessage[] {
): ChatMessage[] {
console.error(errorMessage);
return pushMessageToHistory(chatHistory, {
completion: null,
chatMessageType: CHAT_MESSAGE_TYPE.ERROR_MSG,
chatMessageType: 'ERROR_MSG',
infoMessage: errorMessage,
});
}
Expand All @@ -360,23 +368,24 @@ function handleGetChatHistory(req: OpenAiGetHistoryRequest, res: Response) {
}
}

function handleAddToChatHistory(req: OpenAiAddHistoryRequest, res: Response) {
const infoMessage = req.body.message;
const chatMessageType = req.body.chatMessageType;
const level = req.body.level;
function handleAddInfoToChatHistory(
req: OpenAiAddInfoToChatHistoryRequest,
res: Response
) {
const { infoMessage, chatMessageType, level } = req.body;
if (
infoMessage &&
chatMessageType &&
chatInfoMessageType.includes(chatMessageType) &&
level !== undefined &&
level >= LEVEL_NAMES.LEVEL_1
) {
req.session.levelState[level].chatHistory = pushMessageToHistory(
req.session.levelState[level].chatHistory,
{
completion: null,
chatMessageType,
infoMessage,
}
} as ChatInfoMessage
);
res.send();
} else {
Expand All @@ -400,6 +409,6 @@ function handleClearChatHistory(req: OpenAiClearRequest, res: Response) {
export {
handleChatToGPT,
handleGetChatHistory,
handleAddToChatHistory,
handleAddInfoToChatHistory as handleAddInfoToChatHistory,
handleClearChatHistory,
};
16 changes: 0 additions & 16 deletions backend/src/models/api/OpenAiAddHistoryRequest.ts

This file was deleted.

16 changes: 16 additions & 0 deletions backend/src/models/api/OpenAiAddInfoToChatHistoryRequest.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { Request } from 'express';

import { CHAT_INFO_MESSAGE_TYPES } from '@src/models/chatMessage';
import { LEVEL_NAMES } from '@src/models/level';

export type OpenAiAddInfoToChatHistoryRequest = Request<
never,
never,
{
chatMessageType?: CHAT_INFO_MESSAGE_TYPES;
infoMessage?: string;
level?: LEVEL_NAMES;
},
never,
never
>;
4 changes: 2 additions & 2 deletions backend/src/models/api/OpenAiGetHistoryRequest.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Request } from 'express';

import { ChatHistoryMessage } from '@src/models/chat';
import { ChatMessage } from '@src/models/chatMessage';

export type OpenAiGetHistoryRequest = Request<
never,
ChatHistoryMessage[] | string,
ChatMessage[] | string,
never,
{
level?: string;
Expand Down
39 changes: 8 additions & 31 deletions backend/src/models/chat.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import {
ChatCompletionMessage,
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
} from 'openai/resources/chat/completions';

import { ChatMessage } from './chatMessage';
import { DEFENCE_ID } from './defence';
import { EmailInfo } from './email';

Expand All @@ -16,21 +17,6 @@ enum CHAT_MODELS {
GPT_3_5_TURBO_16K_0613 = 'gpt-3.5-turbo-16k-0613',
}

enum CHAT_MESSAGE_TYPE {
BOT,
BOT_BLOCKED,
INFO,
USER,
USER_TRANSFORMED,
LEVEL_INFO,
DEFENCE_ALERTED,
DEFENCE_TRIGGERED,
SYSTEM,
FUNCTION_CALL,
ERROR_MSG,
RESET_LEVEL,
}

enum MODEL_CONFIG {
TEMPERATURE = 'temperature',
TOP_P = 'topP',
Expand Down Expand Up @@ -72,7 +58,7 @@ interface FunctionCallResponse {
interface ToolCallResponse {
functionCallReply?: FunctionCallResponse;
chatResponse?: ChatResponse;
chatHistory: ChatHistoryMessage[];
chatHistory: ChatMessage[];
}

interface ChatAnswer {
Expand All @@ -92,8 +78,8 @@ interface ChatResponse {
}

interface ChatGptReply {
chatHistory: ChatHistoryMessage[];
completion: ChatCompletionMessage | null;
chatHistory: ChatMessage[];
completion: ChatCompletionAssistantMessageParam | null;
openAIErrorMessage: string | null;
}

Expand Down Expand Up @@ -123,17 +109,9 @@ interface ChatHttpResponse {

interface LevelHandlerResponse {
chatResponse: ChatHttpResponse;
chatHistory: ChatHistoryMessage[];
}

interface ChatHistoryMessage {
completion: ChatCompletionMessageParam | null;
chatMessageType: CHAT_MESSAGE_TYPE;
infoMessage?: string | null;
transformedMessage?: TransformedChatMessage;
chatHistory: ChatMessage[];
}

// default settings for chat model
const defaultChatModel: ChatModel = {
id: CHAT_MODELS.GPT_3_5_TURBO,
configuration: {
Expand All @@ -154,11 +132,10 @@ export type {
ChatResponse,
LevelHandlerResponse,
ChatHttpResponse,
ChatHistoryMessage,
SingleDefenceReport,
TransformedChatMessage,
FunctionCallResponse,
ToolCallResponse,
MessageTransformation,
SingleDefenceReport,
};
export { CHAT_MODELS, CHAT_MESSAGE_TYPE, MODEL_CONFIG, defaultChatModel };
export { CHAT_MODELS, MODEL_CONFIG, defaultChatModel };
Loading

0 comments on commit 6789736

Please sign in to comment.