diff --git a/backend/src/controller/levelController.ts b/backend/src/controller/levelController.ts index e7251537f..4b841ff73 100644 --- a/backend/src/controller/levelController.ts +++ b/backend/src/controller/levelController.ts @@ -1,7 +1,7 @@ import { Response } from 'express'; import { LevelGetRequest } from '@src/models/api/LevelGetRequest'; -import { isValidLevel } from '@src/models/level'; +import { LEVEL_NAMES, isValidLevel } from '@src/models/level'; function handleLoadLevel(req: LevelGetRequest, res: Response) { const { level } = req.query; @@ -20,6 +20,8 @@ function handleLoadLevel(req: LevelGetRequest, res: Response) { emails: req.session.levelState[level].sentEmails, chatHistory: req.session.levelState[level].chatHistory, defences: req.session.levelState[level].defences, + chatModel: + level === LEVEL_NAMES.SANDBOX ? req.session.chatModel : undefined, }); } diff --git a/backend/src/controller/modelController.ts b/backend/src/controller/modelController.ts index 08301113d..fd6bf2b7c 100644 --- a/backend/src/controller/modelController.ts +++ b/backend/src/controller/modelController.ts @@ -1,6 +1,5 @@ import { Response } from 'express'; -import { OpenAIGetModelRequest } from '@src/models/api/OpenAIGetModelRequest'; import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest'; import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest'; import { MODEL_CONFIG } from '@src/models/chat'; @@ -32,8 +31,4 @@ function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) { } } -function handleGetModel(req: OpenAIGetModelRequest, res: Response) { - res.send(req.session.chatModel); -} - -export { handleSetModel, handleConfigureModel, handleGetModel }; +export { handleSetModel, handleConfigureModel }; diff --git a/backend/src/controller/startController.ts b/backend/src/controller/startController.ts index 8ce03d812..64f862ff6 100644 --- a/backend/src/controller/startController.ts +++ b/backend/src/controller/startController.ts @@ -39,6 +39,8 @@ function handleStart(req: StartGetRequest, res: Response) { defences: req.session.levelState[level].defences, availableModels: getValidOpenAIModels(), systemRoles, + chatModel: + level === LEVEL_NAMES.SANDBOX ? req.session.chatModel : undefined, } as StartResponse); } diff --git a/backend/src/models/api/LevelGetRequest.ts b/backend/src/models/api/LevelGetRequest.ts index 70326d143..1d414aa2f 100644 --- a/backend/src/models/api/LevelGetRequest.ts +++ b/backend/src/models/api/LevelGetRequest.ts @@ -1,5 +1,6 @@ import { Request } from 'express'; +import { ChatModel } from '@src/models/chat'; import { ChatMessage } from '@src/models/chatMessage'; import { Defence } from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; @@ -11,6 +12,7 @@ export type LevelGetRequest = Request< emails: EmailInfo[]; chatHistory: ChatMessage[]; defences: Defence[]; + chatModel?: ChatModel; }, never, { diff --git a/backend/src/models/api/OpenAIGetModelRequest.ts b/backend/src/models/api/OpenAIGetModelRequest.ts deleted file mode 100644 index 3005b2472..000000000 --- a/backend/src/models/api/OpenAIGetModelRequest.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { Request } from 'express'; - -import { ChatModel } from '@src/models/chat'; -import { LEVEL_NAMES } from '@src/models/level'; - -export type OpenAIGetModelRequest = Request< - never, - ChatModel | string, - never, - { - level?: LEVEL_NAMES; - } ->; diff --git a/backend/src/models/api/StartGetRequest.ts b/backend/src/models/api/StartGetRequest.ts index 344c12adf..d3ba89143 100644 --- a/backend/src/models/api/StartGetRequest.ts +++ b/backend/src/models/api/StartGetRequest.ts @@ -1,5 +1,6 @@ import { Request } from 'express'; +import { ChatModel } from '@src/models/chat'; import { ChatMessage } from '@src/models/chatMessage'; import { Defence } from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; @@ -14,6 +15,7 @@ export type StartResponse = { level: LEVEL_NAMES; systemRole: string; }[]; + chatModel?: ChatModel; }; export type StartGetRequest = Request< diff --git a/backend/src/sessionRoutes.ts b/backend/src/sessionRoutes.ts index ef01632dc..5821f0a4b 100644 --- a/backend/src/sessionRoutes.ts +++ b/backend/src/sessionRoutes.ts @@ -18,7 +18,6 @@ import { handleClearEmails } from './controller/emailController'; import { handleLoadLevel } from './controller/levelController'; import { handleConfigureModel, - handleGetModel, handleSetModel, } from './controller/modelController'; import { handleResetProgress } from './controller/resetController'; @@ -102,7 +101,6 @@ router.post('/openai/addInfoToHistory', handleAddInfoToChatHistory); router.post('/openai/clear', handleClearChatHistory); // model configurations -router.get('/openai/model', handleGetModel); router.post('/openai/model', handleSetModel); router.post('/openai/model/configure', handleConfigureModel); diff --git a/frontend/src/components/ControlPanel/ControlPanel.tsx b/frontend/src/components/ControlPanel/ControlPanel.tsx index 3f4d8a493..9719bcd25 100644 --- a/frontend/src/components/ControlPanel/ControlPanel.tsx +++ b/frontend/src/components/ControlPanel/ControlPanel.tsx @@ -2,6 +2,7 @@ import { DEFENCES_HIDDEN_LEVEL3_IDS, MODEL_DEFENCES } from '@src/Defences'; import DefenceBox from '@src/components/DefenceBox/DefenceBox'; import DocumentViewButton from '@src/components/DocumentViewer/DocumentViewButton'; import ModelBox from '@src/components/ModelBox/ModelBox'; +import { ChatModel } from '@src/models/chat'; import { DEFENCE_ID, DefenceConfigItem, @@ -15,6 +16,8 @@ import './ControlPanel.css'; function ControlPanel({ currentLevel, defences, + chatModel, + setChatModelId, chatModelOptions, toggleDefence, resetDefenceConfiguration, @@ -24,6 +27,8 @@ function ControlPanel({ }: { currentLevel: LEVEL_NAMES; defences: Defence[]; + chatModel?: ChatModel; + setChatModelId: (modelId: string) => void; chatModelOptions: string[]; toggleDefence: (defence: Defence) => void; resetDefenceConfiguration: ( @@ -94,6 +99,8 @@ function ControlPanel({ {/* only show model box in sandbox mode */} {showConfigurations && ( diff --git a/frontend/src/components/MainComponent/MainBody.tsx b/frontend/src/components/MainComponent/MainBody.tsx index 150c1f142..70e937bc0 100644 --- a/frontend/src/components/MainComponent/MainBody.tsx +++ b/frontend/src/components/MainComponent/MainBody.tsx @@ -1,7 +1,7 @@ import ChatBox from '@src/components/ChatBox/ChatBox'; import ControlPanel from '@src/components/ControlPanel/ControlPanel'; import EmailBox from '@src/components/EmailBox/EmailBox'; -import { ChatMessage } from '@src/models/chat'; +import { ChatMessage, ChatModel } from '@src/models/chat'; import { DEFENCE_ID, DefenceConfigItem, @@ -18,7 +18,9 @@ function MainBody({ defences, emails, messages, + chatModel, chatModels, + setChatModelId, addChatMessage, addInfoMessage, addSentEmails, @@ -34,6 +36,8 @@ function MainBody({ defences: Defence[]; emails: EmailInfo[]; messages: ChatMessage[]; + chatModel?: ChatModel; + setChatModelId: (modelId: string) => void; chatModels: string[]; addChatMessage: (message: ChatMessage) => void; addInfoMessage: (message: string) => void; @@ -59,6 +63,8 @@ function MainBody({ ([]); const [chatModels, setChatModels] = useState([]); const [systemRoles, setSystemRoles] = useState([]); + const [chatModel, setChatModel] = useState(undefined); const isFirstRender = useRef(true); @@ -81,11 +82,23 @@ function MainComponent({ async function loadBackendData() { try { - const { availableModels, defences, emails, chatHistory, systemRoles } = - await startService.start(currentLevel); + const { + availableModels, + defences, + emails, + chatHistory, + systemRoles, + chatModel, + } = await startService.start(currentLevel); setChatModels(availableModels); setSystemRoles(systemRoles); - processBackendLevelData(currentLevel, emails, chatHistory, defences); + processBackendLevelData( + currentLevel, + emails, + chatHistory, + defences, + chatModel + ); } catch (err) { console.warn(err); setMessages([ @@ -152,17 +165,17 @@ function MainComponent({ // for going switching level without clearing progress async function setNewLevel(newLevel: LEVEL_NAMES) { - const { emails, chatHistory, defences } = await levelService.loadLevel( - newLevel - ); - processBackendLevelData(newLevel, emails, chatHistory, defences); + const { emails, chatHistory, defences, chatModel } = + await levelService.loadLevel(newLevel); + processBackendLevelData(newLevel, emails, chatHistory, defences, chatModel); } function processBackendLevelData( level: LEVEL_NAMES, emails: EmailInfo[], chatHistory: ChatMessage[], - defences: Defence[] + defences: Defence[], + chatModel?: ChatModel ) { setEmails(emails); @@ -172,6 +185,9 @@ function MainComponent({ : setMessages(chatHistory); setDefences(defences); + + // we will only update the chatModel if it is defined in the backend response. It will only defined for the sandbox level. + setChatModel(chatModel); setMainBodyKey(MainBodyKey + 1); } @@ -289,6 +305,16 @@ function MainComponent({ setMessages((messages: ChatMessage[]) => [resetMessage, ...messages]); } + function setChatModelId(modelId: string) { + if (!chatModel) { + console.error( + 'You are trying to change the id of the chatModel but it has not been loaded yet' + ); + return; + } + setChatModel({ ...chatModel, id: modelId }); + } + return (
void; chatModelOptions: string[]; addInfoMessage: (message: string) => void; }) { return (
- +
); } diff --git a/frontend/src/components/ModelBox/ModelConfiguration.tsx b/frontend/src/components/ModelBox/ModelConfiguration.tsx index 76d83aa3d..149e29335 100644 --- a/frontend/src/components/ModelBox/ModelConfiguration.tsx +++ b/frontend/src/components/ModelBox/ModelConfiguration.tsx @@ -1,53 +1,60 @@ import { useEffect, useState } from 'react'; -import { CustomChatModelConfiguration, MODEL_CONFIG } from '@src/models/chat'; +import { + ChatModel, + CustomChatModelConfiguration, + MODEL_CONFIG, +} from '@src/models/chat'; import { chatService } from '@src/service'; import ModelConfigurationSlider from './ModelConfigurationSlider'; import './ModelConfiguration.css'; +const DEFAULT_CONFIGS: CustomChatModelConfiguration[] = [ + { + id: MODEL_CONFIG.TEMPERATURE, + name: 'Model Temperature', + info: 'Controls the randomness of the model. Lower means more deterministic, higher means more surprising. Default is 1.', + value: 1, + min: 0, + max: 2, + }, + { + id: MODEL_CONFIG.TOP_P, + name: 'Top P', + info: 'Controls how many different words or phrases the language model considers when it’s trying to predict the next word. Default is 1. ', + value: 1, + min: 0, + max: 1, + }, + { + id: MODEL_CONFIG.PRESENCE_PENALTY, + name: 'Presence Penalty', + info: 'Controls diversity of text generation. Higher presence penalty increases likelihood of using new words. Default is 0.', + value: 0, + min: 0, + max: 2, + }, + { + id: MODEL_CONFIG.FREQUENCY_PENALTY, + name: 'Frequency Penalty', + info: 'Controls diversity of text generation. Higher frequency penalty decreases likelihood of using the same words. Default is 0.', + value: 0, + min: 0, + max: 2, + }, +]; + function ModelConfiguration({ + chatModel, addInfoMessage, }: { + chatModel?: ChatModel; addInfoMessage: (message: string) => void; }) { - const [customChatModelConfigs, setCustomChatModel] = useState< - CustomChatModelConfiguration[] - >([ - { - id: MODEL_CONFIG.TEMPERATURE, - name: 'Model Temperature', - info: 'Controls the randomness of the model. Lower means more deterministic, higher means more surprising. Default is 1.', - value: 1, - min: 0, - max: 2, - }, - { - id: MODEL_CONFIG.TOP_P, - name: 'Top P', - info: 'Controls how many different words or phrases the language model considers when it’s trying to predict the next word. Default is 1. ', - value: 1, - min: 0, - max: 1, - }, - { - id: MODEL_CONFIG.PRESENCE_PENALTY, - name: 'Presence Penalty', - info: 'Controls diversity of text generation. Higher presence penalty increases likelihood of using new words. Default is 0.', - value: 0, - min: 0, - max: 2, - }, - { - id: MODEL_CONFIG.FREQUENCY_PENALTY, - name: 'Frequency Penalty', - info: 'Controls diversity of text generation. Higher frequency penalty decreases likelihood of using the same words. Default is 0.', - value: 0, - min: 0, - max: 2, - }, - ]); + const [customChatModelConfigs, setCustomChatModel] = + useState(DEFAULT_CONFIGS); function setCustomChatModelByID(id: MODEL_CONFIG, value: number) { setCustomChatModel((prev) => { @@ -79,25 +86,18 @@ function ModelConfiguration({ }); } - // get model configs on mount useEffect(() => { - chatService - .getGptModel() - .then((model) => { - // apply the currently set values - const newCustomChatModelConfigs = customChatModelConfigs.map( - (config) => { - const newConfig = { ...config }; - newConfig.value = model.configuration[config.id]; - return newConfig; - } - ); - setCustomChatModel(newCustomChatModelConfigs); - }) - .catch((err) => { - console.log(err); - }); - }, []); + if (!chatModel) { + // chatModel is undefined if this is the first time that the user has switched to the sandbox level + // and the change level request has not yet resolved successfully + return; + } + const newCustomChatModelConfigs = customChatModelConfigs.map((config) => ({ + ...config, + value: chatModel.configuration[config.id], + })); + setCustomChatModel(newCustomChatModelConfigs); + }, [chatModel]); return (
diff --git a/frontend/src/components/ModelBox/ModelSelection.tsx b/frontend/src/components/ModelBox/ModelSelection.tsx index 8740548ef..019a6418c 100644 --- a/frontend/src/components/ModelBox/ModelSelection.tsx +++ b/frontend/src/components/ModelBox/ModelSelection.tsx @@ -2,22 +2,27 @@ import { useEffect, useState } from 'react'; import LoadingButton from '@src/components/ThemedButtons/LoadingButton'; +import { ChatModel } from '@src/models/chat'; import { chatService } from '@src/service'; import './ModelSelection.css'; // return a drop down menu with the models function ModelSelection({ + chatModel, + setChatModelId, chatModelOptions, addInfoMessage, }: { + chatModel?: ChatModel; + setChatModelId: (modelId: string) => void; chatModelOptions: string[]; addInfoMessage: (message: string) => void; }) { // model currently selected in the dropdown - const [selectedModel, setSelectedModel] = useState(null); - // model in use by the app - const [modelInUse, setModelInUse] = useState(null); + const [selectedModel, setSelectedModel] = useState( + undefined + ); const [errorChangingModel, setErrorChangingModel] = useState(false); @@ -32,9 +37,9 @@ function ModelSelection({ const modelUpdated = await chatService.setGptModel(currentSelectedModel); setIsSettingModel(false); if (modelUpdated) { - setModelInUse(currentSelectedModel); setErrorChangingModel(false); addInfoMessage(`changed model to ${currentSelectedModel}`); + setChatModelId(currentSelectedModel); } else { setErrorChangingModel(true); } @@ -43,67 +48,58 @@ function ModelSelection({ // get the model useEffect(() => { - chatService - .getGptModel() - .then((model) => { - setModelInUse(model.id); - // default the dropdown selection to the model in use - setSelectedModel(model.id); - }) - .catch((err) => { - console.log(err); - }); - }, []); + setSelectedModel(chatModel?.id); + }, [chatModel]); // return a drop down menu with the models return (
Select Model -
-
- - void submitSelectedModel()} - isLoading={isSettingModel} - loadingTooltip="Changing model..." - > - Choose - -
-
+ {chatModel ? ( + <> +
+
+ + void submitSelectedModel()} + isLoading={isSettingModel} + loadingTooltip="Changing model..." + > + Choose + +
+
-
- {errorChangingModel ? ( -

- Error: Could not change model. You are still chatting to: - {modelInUse} -

- ) : ( -

- {modelInUse ? ( - <> - You are chatting to model: {modelInUse} - +

+ {errorChangingModel ? ( +

+ Error: Could not change model. You are still chatting to: + {chatModel.id} +

) : ( - 'You are not connected to a model.' +

+ You are chatting to model: {chatModel.id} +

)} -

- )} -
+
+ + ) : ( +

Loading chatModel...

+ )}
); diff --git a/frontend/src/models/combined.ts b/frontend/src/models/combined.ts index c77672b20..d70589842 100644 --- a/frontend/src/models/combined.ts +++ b/frontend/src/models/combined.ts @@ -1,4 +1,4 @@ -import { ChatMessageDTO } from './chat'; +import { ChatMessageDTO, ChatModel } from './chat'; import { DefenceDTO } from './defence'; import { EmailInfo } from './email'; import { LevelSystemRole } from './level'; @@ -9,12 +9,14 @@ type StartReponse = { defences?: DefenceDTO[]; availableModels: string[]; systemRoles: LevelSystemRole[]; + chatModel?: ChatModel; }; type LoadLevelResponse = { emails: EmailInfo[]; chatHistory: ChatMessageDTO[]; defences?: DefenceDTO[]; + chatModel?: ChatModel; }; export type { StartReponse, LoadLevelResponse }; diff --git a/frontend/src/service/chatService.ts b/frontend/src/service/chatService.ts index 2ed55c7b0..97aaaeb08 100644 --- a/frontend/src/service/chatService.ts +++ b/frontend/src/service/chatService.ts @@ -2,7 +2,6 @@ import { CHAT_MESSAGE_TYPE, ChatMessageDTO, ChatMessage, - ChatModel, ChatResponse, MODEL_CONFIG, } from '@src/models/chat'; @@ -87,11 +86,6 @@ async function configureGptModel( return response.status === 200; } -async function getGptModel(): Promise { - const response = await sendRequest(`${PATH}model`, { method: 'GET' }); - return (await response.json()) as ChatModel; -} - async function addInfoMessageToChatHistory( message: string, chatMessageType: CHAT_MESSAGE_TYPE, @@ -113,7 +107,6 @@ export { clearChat, sendMessage, configureGptModel, - getGptModel, setGptModel, addInfoMessageToChatHistory, getChatMessagesFromDTOResponse, diff --git a/frontend/src/service/levelService.ts b/frontend/src/service/levelService.ts index d376d4d98..4b8619759 100644 --- a/frontend/src/service/levelService.ts +++ b/frontend/src/service/levelService.ts @@ -10,13 +10,14 @@ async function loadLevel(level: number) { const response = await sendRequest(`${PATH}?level=${level}`, { method: 'GET', }); - const { defences, emails, chatHistory } = + const { defences, emails, chatHistory, chatModel } = (await response.json()) as LoadLevelResponse; return { emails, chatHistory: getChatMessagesFromDTOResponse(chatHistory), defences: defences ? getDefencesFromDTOs(defences) : [], + chatModel, }; } diff --git a/frontend/src/service/startService.ts b/frontend/src/service/startService.ts index 089764d15..55e5fb718 100644 --- a/frontend/src/service/startService.ts +++ b/frontend/src/service/startService.ts @@ -10,8 +10,14 @@ async function start(level: number) { const response = await sendRequest(`${PATH}?level=${level}`, { method: 'GET', }); - const { availableModels, defences, emails, chatHistory, systemRoles } = - (await response.json()) as StartReponse; + const { + availableModels, + defences, + emails, + chatHistory, + systemRoles, + chatModel, + } = (await response.json()) as StartReponse; return { emails, @@ -19,6 +25,7 @@ async function start(level: number) { defences: defences ? getDefencesFromDTOs(defences) : [], availableModels, systemRoles, + chatModel, }; }