Skip to content

Commit

Permalink
add function to switch between downloaded models (#70)
Browse files Browse the repository at this point in the history
* add dropdown element for select model

Signed-off-by: cbh778899 <[email protected]>

* implement model switch & hide reset everytime

Signed-off-by: cbh778899 <[email protected]>

* update style

Signed-off-by: cbh778899 <[email protected]>

* fix llama_reset_everytime issue so we can normally reset now

Signed-off-by: cbh778899 <[email protected]>

---------

Signed-off-by: cbh778899 <[email protected]>
  • Loading branch information
cbh778899 authored Oct 21, 2024
1 parent cee3213 commit c277541
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 31 deletions.
5 changes: 3 additions & 2 deletions preloader/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ const {
abortCompletion,
setClient,
downloadModel,
updateModelSettings
updateModelSettings,
formator
} = require("./node-llama-cpp-preloader.js")

const {
Expand All @@ -14,7 +15,7 @@ const {

contextBridge.exposeInMainWorld('node-llama-cpp', {
loadModel, chatCompletions, updateModelSettings,
abortCompletion, setClient, downloadModel
abortCompletion, setClient, downloadModel, formator
})

contextBridge.exposeInMainWorld('file-handler', {
Expand Down
63 changes: 43 additions & 20 deletions preloader/node-llama-cpp-preloader.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,18 @@ async function loadModel(model_name = '') {
})
}



/**
* @typedef Message
* @property {"user"|"assistant"|"system"} role Sender
* @property {String} content Message content
*/

/**
* Set the session, basically reset the history and return a static string 'fake-client'
* @param {String} client a fake client, everything using the same client and switch between clients just simply reset the chat history
* @param {Message[]} history the history to load
* @returns {String}
*/
async function setClient(client, history = []) {
Expand All @@ -66,20 +75,20 @@ async function setClient(client, history = []) {
* @param {any} message The latest user message, either string or any can be converted to string. If is array in the format of {content:String}, it will retrieve the last content
* @returns {String} the retrived message
*/
function findMessageText(message) {
if(typeof message === "string") return message;
else if(typeof message === "object") {
if(Array.isArray(message)) {
while(message.length) {
message = message.pop();
if(typeof message === 'object' && message.role && message.role === 'user' && message.content) {
return message.content;
}
}
}
}
return `${message}`
}
// function findMessageText(message) {
// if(typeof message === "string") return message;
// else if(typeof message === "object") {
// if(Array.isArray(message)) {
// while(message.length) {
// message = message.pop();
// if(typeof message === 'object' && message.role && message.role === 'user' && message.content) {
// return message.content;
// }
// }
// }
// }
// return `${message}`
// }

let model_settings = {};
function updateModelSettings(settings) {
Expand All @@ -99,15 +108,12 @@ function updateModelSettings(settings) {
* @returns {Promise<String>} the response text
*/
async function chatCompletions(latest_message, cb=null) {
const {max_tokens, top_p, temperature, llama_reset_everytime} = model_settings;
const {max_tokens, top_p, temperature} = model_settings;
if(!llama_session) {
cb && cb('> **ERROR: MODEL NOT LOADED**', true);
return '';
}
latest_message = findMessageText(latest_message);
if(llama_reset_everytime) {
setClient(null, llama_session.getChatHistory().filter(({type})=>type === 'system'))
}
// latest_message = findMessageText(latest_message);}


stop_signal = new AbortController();
Expand Down Expand Up @@ -210,11 +216,28 @@ function downloadModel(url, cb=null) {
})
}

/**
* format messages, reset history if needed
* @param {Message[]} messages
*/
function formator(messages) {
const user_messages = messages.filter(e=>e.role === 'user');
const system_messages = messages.filter(e=>e.role === 'system');

const {llama_reset_everytime} = model_settings;
if(llama_reset_everytime) {
setClient(null, system_messages)
}

return user_messages.pop().content;
}

module.exports = {
loadModel,
chatCompletions,
abortCompletion,
setClient,
downloadModel,
updateModelSettings
updateModelSettings,
formator
}
41 changes: 33 additions & 8 deletions src/components/settings/LlamaSettings.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import TextComponent from "./components/TextComponent";
import useIDB from "../../utils/idb";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import { DEFAULT_LLAMA_CPP_MODEL_URL } from "../../utils/types";
import DropdownComponent from "./components/DropdownComponent";

export default function LlamaSettings({ trigger, enabled, updateEnabled, openDownloadProtector, updateState }) {

const [model_download_link, setModelDownloadLink] = useState('');
const [reset_everytime, setResetEveryTime] = useState(false);
const [downloaded_models, setDownloadedModels] = useState([])
const idb = useIDB();

async function saveSettings() {
Expand Down Expand Up @@ -39,14 +41,23 @@ export default function LlamaSettings({ trigger, enabled, updateEnabled, openDow
platform: 'Llama'
}
await idb.insert("downloaded-models", stored_model)
setDownloadedModels([...downloaded_models, { title: stored_model['model-name'], value: stored_model.url }])
}
}
)
}
// load model using the model name retrieved
await window['node-llama-cpp'].loadModel(stored_model['model-name'])
await openDownloadProtector(
'Loading model...',
`Loading model ${stored_model['model-name']}`,
async callback => {
callback(100, false);
// load model using the model name retrieved
await window['node-llama-cpp'].loadModel(stored_model['model-name'])
updateState();
callback(100, true)
}
)
}
updateState();
}

useEffect(()=>{
Expand All @@ -55,9 +66,18 @@ export default function LlamaSettings({ trigger, enabled, updateEnabled, openDow
}, [trigger])

useEffect(()=>{
const { llama_model_url, llama_reset_everytime } = getPlatformSettings();
setModelDownloadLink(llama_model_url || DEFAULT_LLAMA_CPP_MODEL_URL);
setResetEveryTime(llama_reset_everytime);
(async function() {
const { llama_model_url, llama_reset_everytime } = getPlatformSettings();
setModelDownloadLink(llama_model_url || DEFAULT_LLAMA_CPP_MODEL_URL);
setResetEveryTime(llama_reset_everytime);

const models = await idb.getAll("downloaded-models", {
where: [{'platform': 'Llama'}],
select: ["model-name", "url"]
});
setDownloadedModels(models.map(e=>{return { title: e['model-name'], value: e.url }}))
})()
// eslint-disable-next-line
}, [])

return (
Expand All @@ -67,9 +87,14 @@ export default function LlamaSettings({ trigger, enabled, updateEnabled, openDow
description={"Llama.cpp Engine is powerful and it allows you to run your own .gguf model using GPU on your local machine."}
value={enabled} cb={updateEnabled}
/>
<DropdownComponent
title={"Select a downloaded model to use"}
description={"If there's no models, you might want to download a new one by enter url below."}
value={downloaded_models} cb={setModelDownloadLink}
/>
<TextComponent
title={"Set the link of model you want to use"}
description={"Only models with extension .gguf can be used. Please make sure it can be run on your own machine."}
title={"Or enter the link of model you want to use"}
description={"Only models with extension .gguf can be used. Please make sure it can be run on your own machine. If the model of entered url is not downloaded, it will download automatically when you save settings."}
placeholder={"Default model is Microsoft Phi-3"}
value={model_download_link} cb={setModelDownloadLink}
/>
Expand Down
25 changes: 25 additions & 0 deletions src/components/settings/components/DropdownComponent.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
export default function DropdownComponent({ cb, value, title, description }) {
return (
<div className="component">
<div className="title">{title}</div>
{ description && <div className="description">{description}</div> }
<select
className="main-part"
onClick={evt=>evt.preventDefault()}
onChange={evt=>cb && cb(evt.target.value)}
>
{ value.map((e, i)=>{
let title, value;
if(typeof e === "object") {
title = e.title;
value = e.value;
} else {
title = e;
value = e;
}
return <option key={`option-${i}`} value={value}>{ title }</option>
}) }
</select>
</div>
)
}
6 changes: 6 additions & 0 deletions src/styles/settings.css
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@
margin-left: var(--elem-margin);
}

.setting-page > .setting-section > .component > select.main-part {
border-radius: 7px;
background-color: white;
padding: 0px 7px;
}

dialog:has(.download-protector) {
border: none;
border-radius: 15px;
Expand Down
3 changes: 2 additions & 1 deletion src/utils/workers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ export function getCompletionFunctions(platform = null) {
completions: window['node-llama-cpp'].chatCompletions,
abort: window['node-llama-cpp'].abortCompletion,
platform: 'Llama',
initClient: window['node-llama-cpp'].setClient
initClient: window['node-llama-cpp'].setClient,
formator: window['node-llama-cpp'].formator
}
case "Wllama":
return {
Expand Down

0 comments on commit c277541

Please sign in to comment.