Skip to content

Commit

Permalink
Integrate chat with completions and history
Browse files Browse the repository at this point in the history
  • Loading branch information
csaroff committed May 3, 2023
1 parent ce8fbde commit 2062153
Show file tree
Hide file tree
Showing 10 changed files with 504 additions and 479 deletions.
1 change: 1 addition & 0 deletions app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"match-sorter": "^6.3.1",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-hotkeys-hook": "^4.4.0",
"react-markdown": "^8.0.6",
"react-responsive": "^9.0.2",
"react-router-dom": "^6.8.1",
Expand Down
104 changes: 54 additions & 50 deletions app/src/app.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ const DEFAULT_CONTEXTS = {
showParametersTable: false
}
},
chat:{
history: DEFAULT_HISTORY_STATE,
editor: DEFAULT_EDITOR_STATE,
modelsState: [],
parameters: DEFAULT_PARAMETERS_STATE
},
},
MODELS: [],
}
Expand Down Expand Up @@ -155,7 +161,7 @@ const APIContextWrapper = ({children}) => {

useEffect(() => {
const sse_request = new SSE("/api/notifications")

sse_request.addEventListener("notification", (event: any) => {
const parsedEvent = JSON.parse(event.data);
notificationSubscribers.current.forEach((callback) => {
Expand All @@ -180,15 +186,15 @@ const APIContextWrapper = ({children}) => {
notificationSubscribers.current = notificationSubscribers.current.filter((cb) => cb !== callback);
},
};

const Provider = {
setAPIKey: async (provider, apiKey) => (await fetch(`/api/provider/${provider}/api-key`, {method: "PUT", headers: {"Content-Type": "application/json"},
setAPIKey: async (provider, apiKey) => (await fetch(`/api/provider/${provider}/api-key`, {method: "PUT", headers: {"Content-Type": "application/json"},
body: JSON.stringify({apiKey: apiKey})}
)).json(),
getAll: async () => (await fetch("/api/providers")).json(),
getAllWithModels: async () => (await fetch("/api/providers-with-key-and-models")).json(),
};

const Inference = {
subscribeTextCompletion: (callback) => {
textCompletionSubscribers.current.push(callback);
Expand All @@ -205,14 +211,14 @@ const APIContextWrapper = ({children}) => {
},
chatCompletion: createChatCompletionRequest,
};

const [apiContext, _] = React.useState({
Model,
Notifications,
Provider,
Inference,
});

function createTextCompletionRequest({prompt, models}) {
const url = "/api/inference/text/stream";
const payload = {
Expand All @@ -221,30 +227,30 @@ const APIContextWrapper = ({children}) => {
};
return createCompletionRequest(url, payload, textCompletionSubscribers);
}

function createChatCompletionRequest(prompt, model) {
const url = "/api/inference/chat/stream";
const payload = {prompt, model};
return createCompletionRequest(url, payload, chatCompletionSubscribers);
}

function createCompletionRequest(url, payload, subscribers) {
pendingCompletionRequest.current = true;
let sse_request = null;

function beforeUnloadHandler() {
if (sse_request) sse_request.close();
}

window.addEventListener("beforeunload", beforeUnloadHandler);
const completionsBuffer = createCompletionsBuffer(payload.models);
let error_occured = false;
let request_complete = false;

sse_request = new SSE(url, {payload: JSON.stringify(payload)});

bindSSEEvents(sse_request, completionsBuffer, {error_occured, request_complete}, beforeUnloadHandler, subscribers);

return () => {
if (sse_request) sse_request.close();
};
Expand All @@ -257,52 +263,52 @@ const APIContextWrapper = ({children}) => {
});
return buffer;
}

function bindSSEEvents(sse_request, completionsBuffer, requestState, beforeUnloadHandler, subscribers) {
sse_request.onopen = async () => {
bulkWrite(completionsBuffer, requestState, subscribers);
};

sse_request.addEventListener("infer", (event) => {
let resp = JSON.parse(event.data);
completionsBuffer[resp.modelTag].push(resp);
});

sse_request.addEventListener("status", (event) => {
subscribers.current.forEach((callback) => callback({
event: "status",
data: JSON.parse(event.data)
}));
});

sse_request.addEventListener("error", (event) => {
requestState.error_occured = true;
try {
const message = JSON.parse(event.data);

subscribers.current.forEach((callback) => callback({
"event": "error",
"data": message.status
"data": message.status
}));
} catch (e) {
subscribers.current.forEach((callback) => callback({
"event": "error",
"data": "Unknown error"
}));
}

close_sse(sse_request, requestState, beforeUnloadHandler, subscribers);
});

sse_request.addEventListener("abort", () => {
requestState.error_occured = true;
close_sse(sse_request, requestState, beforeUnloadHandler, subscribers);
});

sse_request.addEventListener("readystatechange", (event) => {
if (event.readyState === 2) close_sse(sse_request, requestState, beforeUnloadHandler, subscribers);
});

sse_request.stream();
}

Expand All @@ -313,27 +319,27 @@ const APIContextWrapper = ({children}) => {
"meta": {error: requestState.error_occured},
}));
window.removeEventListener("beforeunload", beforeUnloadHandler);
}
}

function bulkWrite(completionsBuffer, requestState, subscribers) {
setTimeout(() => {
let newTokens = false;
let batchUpdate = {};

for (let modelTag in completionsBuffer) {
if (completionsBuffer[modelTag].length > 0) {
newTokens = true;
batchUpdate[modelTag] = completionsBuffer[modelTag].splice(0, completionsBuffer[modelTag].length);
}
}

if (newTokens) {
subscribers.current.forEach((callback) => callback({
event: "completion",
data: batchUpdate,
}));
}

if (!requestState.request_complete) bulkWrite(completionsBuffer, requestState, subscribers);
}, 20);
}
Expand All @@ -347,20 +353,19 @@ const APIContextWrapper = ({children}) => {

const PlaygroundContextWrapper = ({page, children}) => {
const apiContext = React.useContext(APIContext)

const [editorContext, _setEditorContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].editor);
const [parametersContext, _setParametersContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].parameters);
let [modelsStateContext, _setModelsStateContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].modelsState);
const [modelsContext, _setModelsContext] = React.useState(DEFAULT_CONTEXTS.MODELS);
const [historyContext, _setHistoryContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].history);
const [historyContext, setHistoryContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].history);

/* Temporary fix for models that have been purged remotely but are still cached locally */
for(const {name} of modelsStateContext) {
if (!modelsContext[name]) {
modelsStateContext = modelsStateContext.filter(({name: _name}) => _name !== name)
}
}

const editorContextRef = React.useRef(editorContext);
const historyContextRef = React.useRef(historyContext);

Expand Down Expand Up @@ -399,7 +404,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
break;
}
}

apiContext.Notifications.subscribe(notificationCallback)

return () => {
Expand All @@ -410,9 +415,9 @@ const PlaygroundContextWrapper = ({page, children}) => {
const updateModelsData = async () => {
const json_params = await apiContext.Model.getAllEnabled()
const models = {};

const PAGE_MODELS_STATE = SETTINGS.pages[page].modelsState;

for (const [model_key, modelDetails] of Object.entries(json_params)) {
const existingModelEntry = (PAGE_MODELS_STATE.find((model) => model.name === model_key));

Expand Down Expand Up @@ -448,7 +453,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
provider: modelDetails.provider,
}
}

const SERVER_SIDE_MODELS = Object.keys(json_params);
for (const {name} of PAGE_MODELS_STATE) {
if (!SERVER_SIDE_MODELS.includes(name)) {
Expand All @@ -465,7 +470,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
const setEditorContext = (newEditorContext, immediate=false) => {
SETTINGS.pages[page].editor = {...SETTINGS.pages[page].editor, ...newEditorContext};

const _editor = {...SETTINGS.pages[page].editor, internalState: null };
const _editor = {...SETTINGS.pages[page].editor, internalState: null};

_setEditorContext(_editor);
if (immediate) {
Expand All @@ -485,14 +490,14 @@ const PlaygroundContextWrapper = ({page, children}) => {

const setModelsContext = (newModels) => {
SETTINGS.models = newModels;

debouncedSettingsSave()
_setModelsContext(newModels);
}

const setModelsStateContext = (newModelsState) => {
SETTINGS.pages[page].modelsState = newModelsState;

debouncedSettingsSave()
_setModelsStateContext(newModelsState);
}
Expand All @@ -503,7 +508,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
show: (value === undefined || value === null) ? !SETTINGS.pages[page].history.show : value
}

_setHistoryContext(_newHistory);
setHistoryContext(_newHistory);

SETTINGS.pages[page].history = _newHistory;
debouncedSettingsSave()
Expand All @@ -518,7 +523,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
const year = currentDate.getFullYear();
const month = String(currentDate.getMonth() + 1).padStart(2, '0');
const day = String(currentDate.getDate()).padStart(2, '0');

const newEntry = {
timestamp: currentDate.getTime(),
date: `${year}-${month}-${day}`,
Expand All @@ -536,7 +541,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
current: newEntry
}

_setHistoryContext(_newHistory);
setHistoryContext(_newHistory);

//console.warn("Adding to history", _newHistory)
SETTINGS.pages[page].history = _newHistory;
Expand All @@ -549,7 +554,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
entries: SETTINGS.pages[page].history.entries.filter((historyEntry) => historyEntry !== entry)
}

_setHistoryContext(_newHistory);
setHistoryContext(_newHistory);

SETTINGS.pages[page].history = _newHistory;
debouncedSettingsSave()
Expand All @@ -562,7 +567,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
current: null
}

_setHistoryContext(_newHistory);
setHistoryContext(_newHistory);

SETTINGS.pages[page].history = _newHistory;
debouncedSettingsSave()
Expand All @@ -572,7 +577,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
SETTINGS.pages[page].history.current = entry;
_setEditorContext(entry.editor);

_setHistoryContext(SETTINGS.pages[page].history);
setHistoryContext(SETTINGS.pages[page].history);
setParametersContext(entry.parameters);
setModelsStateContext(entry.modelsState);
}
Expand All @@ -583,8 +588,8 @@ const PlaygroundContextWrapper = ({page, children}) => {

return (
<HistoryContext.Provider value = {{
historyContext, selectHistoryItem,
addHistoryEntry, removeHistoryEntry, clearHistory, toggleShowHistory
historyContext, setHistoryContext, selectHistoryItem,
addHistoryEntry, removeHistoryEntry, clearHistory, toggleShowHistory,
}}>
<EditorContext.Provider value = {{editorContext, setEditorContext}}>
<ParametersContext.Provider value = {{parametersContext, setParametersContext}}>
Expand Down Expand Up @@ -630,10 +635,9 @@ function ProviderWithRoutes() {
path="/chat"
element={
<APIContextWrapper>
{/* <PlaygroundContextWrapper key = "chat" page = "chat"> */}
<PlaygroundContextWrapper key = "chat" page = "chat">
<Chat/>
{/* <Toaster /> */}
{/* </PlaygroundContextWrapper> */}
</PlaygroundContextWrapper>
</APIContextWrapper>
}
/>
Expand All @@ -657,4 +661,4 @@ export default function App() {
<ProviderWithRoutes />
</BrowserRouter>
)
}
}
16 changes: 14 additions & 2 deletions app/src/components/chat/avatar.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import { FC } from "react"

const Avatar: FC = () => {
return <div className="bg-red-500 rounded h-8 w-8"></div>
interface Props {
name: string
}

const Avatar: FC<Props> = ({name}) => {
const colors = ["pink", "purple", "red", "yellow", "blue", "gray", "green", "indigo"];
const hashCode = (s:string) => s.split('').reduce((a,b) => (((a << 5) - a) + b.charCodeAt(0))|0, 0);
const color = colors[hashCode(name) % colors.length];
return (
<div
className={`bg-${color}-500 rounded h-8 w-8 flex-shrink-0`}
title={name}
/>
)
}

export default Avatar
Loading

0 comments on commit 2062153

Please sign in to comment.