Skip to content

Commit

Permalink
828 streamline chat model configuration info message network call (#875)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarsh-scottlogic authored Mar 28, 2024
1 parent 3c0cefb commit 57d2027
Show file tree
Hide file tree
Showing 12 changed files with 277 additions and 67 deletions.
53 changes: 45 additions & 8 deletions backend/src/controller/modelController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ import { Response } from 'express';

import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest';
import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest';
import { MODEL_CONFIG } from '@src/models/chat';
import { MODEL_CONFIG_ID, modelConfigIds } from '@src/models/chat';
import { ChatInfoMessage } from '@src/models/chatMessage';
import { LEVEL_NAMES } from '@src/models/level';
import { pushMessageToHistory } from '@src/utils/chat';

import { sendErrorResponse } from './handleError';

function handleSetModel(req: OpenAiSetModelRequest, res: Response) {
const { model } = req.body;
Expand All @@ -19,16 +24,48 @@ function handleSetModel(req: OpenAiSetModelRequest, res: Response) {
}

function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) {
const configId = req.body.configId as MODEL_CONFIG | undefined;
const configId = req.body.configId as MODEL_CONFIG_ID | undefined;
const value = req.body.value;
const maxValue = configId === MODEL_CONFIG.TOP_P ? 1 : 2;

if (configId && value !== undefined && value >= 0 && value <= maxValue) {
req.session.chatModel.configuration[configId] = value;
res.status(200).send();
} else {
res.status(400).send();
if (!configId) {
sendErrorResponse(res, 400, 'Missing configId');
return;
}

if (!modelConfigIds.includes(configId)) {
sendErrorResponse(res, 400, 'Invalid configId');
return;
}

if (!Number.isFinite(value) || value === undefined) {
sendErrorResponse(res, 400, 'Missing or invalid value');
return;
}

const maxValue = configId === 'topP' ? 1 : 2;

if (value < 0 || value > maxValue) {
sendErrorResponse(
res,
400,
`Value should be between 0 and ${maxValue} for ${configId}`
);
return;
}

req.session.chatModel.configuration[configId] = value;

const chatInfoMessage = {
infoMessage: `changed ${configId} to ${value}`,
chatMessageType: 'GENERIC_INFO',
} as ChatInfoMessage;
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory =
pushMessageToHistory(
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory,
chatInfoMessage
);

res.status(200).send({ chatInfoMessage });
}

export { handleSetModel, handleConfigureModel };
4 changes: 3 additions & 1 deletion backend/src/models/api/OpenAiConfigureModelRequest.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import { Request } from 'express';

import { ChatMessage } from '@src/models/chatMessage';

export type OpenAiConfigureModelRequest = Request<
never,
never,
null | { chatInfoMessage: ChatMessage },
{
configId?: string;
value?: number;
Expand Down
4 changes: 2 additions & 2 deletions backend/src/models/api/OpenAiSetModelRequest.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { Request } from 'express';

import { CHAT_MODELS, ChatModelConfiguration } from '@src/models/chat';
import { CHAT_MODELS, ChatModelConfigurations } from '@src/models/chat';

export type OpenAiSetModelRequest = Request<
never,
never,
{
model?: CHAT_MODELS;
configuration?: ChatModelConfiguration;
configuration?: ChatModelConfigurations;
},
never
>;
36 changes: 18 additions & 18 deletions backend/src/models/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,23 @@ enum CHAT_MODELS {
GPT_3_5_TURBO_16K_0613 = 'gpt-3.5-turbo-16k-0613',
}

enum MODEL_CONFIG {
TEMPERATURE = 'temperature',
TOP_P = 'topP',
FREQUENCY_PENALTY = 'frequencyPenalty',
PRESENCE_PENALTY = 'presencePenalty',
}

interface ChatModel {
type ChatModel = {
id: CHAT_MODELS;
configuration: ChatModelConfiguration;
}
configuration: ChatModelConfigurations;
};

interface ChatModelConfiguration {
temperature: number;
topP: number;
frequencyPenalty: number;
presencePenalty: number;
}
const modelConfigIds = [
'temperature',
'topP',
'frequencyPenalty',
'presencePenalty',
] as const;

type MODEL_CONFIG_ID = (typeof modelConfigIds)[number];

type ChatModelConfigurations = {
[key in MODEL_CONFIG_ID]: number;
};

interface DefenceReport {
blockedReason: string | null;
Expand Down Expand Up @@ -121,7 +120,7 @@ export type {
ChatGptReply,
ChatMalicious,
ChatModel,
ChatModelConfiguration,
ChatModelConfigurations,
ChatResponse,
LevelHandlerResponse,
ChatHttpResponse,
Expand All @@ -130,5 +129,6 @@ export type {
ToolCallResponse,
MessageTransformation,
SingleDefenceReport,
MODEL_CONFIG_ID,
};
export { CHAT_MODELS, MODEL_CONFIG, defaultChatModel };
export { CHAT_MODELS, defaultChatModel, modelConfigIds };
156 changes: 156 additions & 0 deletions backend/test/unit/controller/modelController.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import { expect, jest, test, describe } from '@jest/globals';
import { Response } from 'express';

import { handleConfigureModel } from '@src/controller/modelController';
import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest';
import { modelConfigIds } from '@src/models/chat';
import { ChatMessage } from '@src/models/chatMessage';
import { LEVEL_NAMES, LevelState } from '@src/models/level';

function responseMock() {
return {
send: jest.fn(),
status: jest.fn().mockReturnThis(),
} as unknown as Response;
}

describe('handleConfigureModel', () => {
test('WHEN passed sensible parameters THEN configures model AND adds info message to chat history AND responds with info message', () => {
const req = {
body: {
configId: 'topP',
value: 0.5,
},
session: {
chatModel: {
configuration: {
temperature: 0.0,
topP: 0.0,
presencePenalty: 0.0,
frequencyPenalty: 0.0,
},
},
levelState: [{}, {}, {}, { chatHistory: [] } as unknown as LevelState],
},
} as OpenAiConfigureModelRequest;
const res = responseMock();

handleConfigureModel(req, res);

expect(res.status).toHaveBeenCalledWith(200);
expect(req.session.chatModel.configuration.topP).toBe(0.5);

const expectedInfoMessage = {
infoMessage: 'changed topP to 0.5',
chatMessageType: 'GENERIC_INFO',
} as ChatMessage;
expect(
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory.at(-1)
).toEqual(expectedInfoMessage);
expect(res.send).toHaveBeenCalledWith({
chatInfoMessage: expectedInfoMessage,
});
});

test('WHEN missing configId THEN does not configure model', () => {
const req = {
body: {
value: 0.5,
},
} as OpenAiConfigureModelRequest;
const res = responseMock();

handleConfigureModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith('Missing configId');
});

test('WHEN configId is invalid THEN does not configure model', () => {
const req = {
body: {
configId: 'invalid config id',
value: 0.5,
},
} as OpenAiConfigureModelRequest;
const res = responseMock();

handleConfigureModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith('Invalid configId');
});

test('WHEN value is missing THEN does not configure model', () => {
const req = {
body: {
configId: 'topP',
},
} as unknown as OpenAiConfigureModelRequest;
const res = responseMock();

handleConfigureModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith('Missing or invalid value');
});

test('WHEN value is invalid THEN does not configure model', () => {
const req = {
body: {
configId: 'topP',
value: 'invalid value',
},
} as unknown as OpenAiConfigureModelRequest;
const res = responseMock();

handleConfigureModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith('Missing or invalid value');
});

test.each(modelConfigIds)(
'GIVEN configId is %s WHEN value is below range THEN does not configure model',
(configId) => {
const req = {
body: {
configId,
value: -1,
},
} as unknown as OpenAiConfigureModelRequest;
const res = responseMock();

handleConfigureModel(req, res);

const expectedMaxValue = configId === 'topP' ? 1 : 2;

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith(
`Value should be between 0 and ${expectedMaxValue} for ${configId}`
);
}
);

test.each(modelConfigIds)(
'GIVEN configId is %s WHEN value is above range THEN does not configure model',
(configId) => {
const expectedMaxValue = configId === 'topP' ? 1 : 2;

const req = {
body: {
configId,
value: expectedMaxValue + 1,
},
} as unknown as OpenAiConfigureModelRequest;
const res = responseMock();

handleConfigureModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith(
`Value should be between 0 and ${expectedMaxValue} for ${configId}`
);
}
);
});
5 changes: 4 additions & 1 deletion frontend/src/components/ControlPanel/ControlPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import DefenceBox from '@src/components/DefenceBox/DefenceBox';
import DocumentViewButton from '@src/components/DocumentViewer/DocumentViewButton';
import ModelBox from '@src/components/ModelBox/ModelBox';
import DetailElement from '@src/components/ThemedButtons/DetailElement';
import { ChatModel } from '@src/models/chat';
import { ChatMessage, ChatModel } from '@src/models/chat';
import {
DEFENCE_ID,
DefenceConfigItem,
Expand All @@ -25,6 +25,7 @@ function ControlPanel({
setDefenceConfiguration,
openDocumentViewer,
addInfoMessage,
addChatMessage,
}: {
currentLevel: LEVEL_NAMES;
defences: Defence[];
Expand All @@ -42,6 +43,7 @@ function ControlPanel({
) => Promise<boolean>;
openDocumentViewer: () => void;
addInfoMessage: (message: string) => void;
addChatMessage: (chatMessage: ChatMessage) => void;
}) {
const configurableDefences =
currentLevel === LEVEL_NAMES.SANDBOX
Expand Down Expand Up @@ -100,6 +102,7 @@ function ControlPanel({
setChatModelId={setChatModelId}
chatModelOptions={chatModelOptions}
addInfoMessage={addInfoMessage}
addChatMessage={addChatMessage}
/>
)}
</DetailElement>
Expand Down
1 change: 1 addition & 0 deletions frontend/src/components/MainComponent/MainBody.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function MainBody({
setDefenceConfiguration={setDefenceConfiguration}
openDocumentViewer={openDocumentViewer}
addInfoMessage={addInfoMessage}
addChatMessage={addChatMessage}
/>
</div>
<div className="centre-area">
Expand Down
6 changes: 4 additions & 2 deletions frontend/src/components/ModelBox/ModelBox.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ChatModel } from '@src/models/chat';
import { ChatMessage, ChatModel } from '@src/models/chat';

import ModelConfiguration from './ModelConfiguration';
import ModelSelection from './ModelSelection';
Expand All @@ -10,11 +10,13 @@ function ModelBox({
setChatModelId,
chatModelOptions,
addInfoMessage,
addChatMessage,
}: {
chatModel?: ChatModel;
setChatModelId: (modelId: string) => void;
chatModelOptions: string[];
addInfoMessage: (message: string) => void;
addChatMessage: (chatMessage: ChatMessage) => void;
}) {
return (
<div className="model-box">
Expand All @@ -26,7 +28,7 @@ function ModelBox({
/>
<ModelConfiguration
chatModel={chatModel}
addInfoMessage={addInfoMessage}
addChatMessage={addChatMessage}
/>
</div>
);
Expand Down
Loading

0 comments on commit 57d2027

Please sign in to comment.