Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add function to switch between downloaded models #70

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions preloader/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ const {
abortCompletion,
setClient,
downloadModel,
updateModelSettings
updateModelSettings,
formator
} = require("./node-llama-cpp-preloader.js")

const {
Expand All @@ -14,7 +15,7 @@ const {

contextBridge.exposeInMainWorld('node-llama-cpp', {
loadModel, chatCompletions, updateModelSettings,
abortCompletion, setClient, downloadModel
abortCompletion, setClient, downloadModel, formator
})

contextBridge.exposeInMainWorld('file-handler', {
Expand Down
63 changes: 43 additions & 20 deletions preloader/node-llama-cpp-preloader.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,18 @@ async function loadModel(model_name = '') {
})
}



/**
* @typedef Message
* @property {"user"|"assistant"|"system"} role Sender
* @property {String} content Message content
*/

/**
* Set the session, basically reset the history and return a static string 'fake-client'
* @param {String} client a fake client, everything using the same client and switch between clients just simply reset the chat history
* @param {Message[]} history the history to load
* @returns {String}
*/
async function setClient(client, history = []) {
Expand All @@ -66,20 +75,20 @@ async function setClient(client, history = []) {
* @param {any} message The latest user message, either string or any can be converted to string. If is array in the format of {content:String}, it will retrieve the last content
* @returns {String} the retrived message
*/
function findMessageText(message) {
if(typeof message === "string") return message;
else if(typeof message === "object") {
if(Array.isArray(message)) {
while(message.length) {
message = message.pop();
if(typeof message === 'object' && message.role && message.role === 'user' && message.content) {
return message.content;
}
}
}
}
return `${message}`
}
// function findMessageText(message) {
// if(typeof message === "string") return message;
// else if(typeof message === "object") {
// if(Array.isArray(message)) {
// while(message.length) {
// message = message.pop();
// if(typeof message === 'object' && message.role && message.role === 'user' && message.content) {
// return message.content;
// }
// }
// }
// }
// return `${message}`
// }

let model_settings = {};
function updateModelSettings(settings) {
Expand All @@ -99,15 +108,12 @@ function updateModelSettings(settings) {
* @returns {Promise<String>} the response text
*/
async function chatCompletions(latest_message, cb=null) {
const {max_tokens, top_p, temperature, llama_reset_everytime} = model_settings;
const {max_tokens, top_p, temperature} = model_settings;
if(!llama_session) {
cb && cb('> **ERROR: MODEL NOT LOADED**', true);
return '';
}
latest_message = findMessageText(latest_message);
if(llama_reset_everytime) {
setClient(null, llama_session.getChatHistory().filter(({type})=>type === 'system'))
}
// latest_message = findMessageText(latest_message);}


stop_signal = new AbortController();
Expand Down Expand Up @@ -210,11 +216,28 @@ function downloadModel(url, cb=null) {
})
}

/**
* format messages, reset history if needed
* @param {Message[]} messages
*/
function formator(messages) {
const user_messages = messages.filter(e=>e.role === 'user');
const system_messages = messages.filter(e=>e.role === 'system');

const {llama_reset_everytime} = model_settings;
if(llama_reset_everytime) {
setClient(null, system_messages)
}

return user_messages.pop().content;
}

module.exports = {
loadModel,
chatCompletions,
abortCompletion,
setClient,
downloadModel,
updateModelSettings
updateModelSettings,
formator
}
41 changes: 33 additions & 8 deletions src/components/settings/LlamaSettings.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import TextComponent from "./components/TextComponent";
import useIDB from "../../utils/idb";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import { DEFAULT_LLAMA_CPP_MODEL_URL } from "../../utils/types";
import DropdownComponent from "./components/DropdownComponent";

export default function LlamaSettings({ trigger, enabled, updateEnabled, openDownloadProtector, updateState }) {

const [model_download_link, setModelDownloadLink] = useState('');
const [reset_everytime, setResetEveryTime] = useState(false);
const [downloaded_models, setDownloadedModels] = useState([])
const idb = useIDB();

async function saveSettings() {
Expand Down Expand Up @@ -39,14 +41,23 @@ export default function LlamaSettings({ trigger, enabled, updateEnabled, openDow
platform: 'Llama'
}
await idb.insert("downloaded-models", stored_model)
setDownloadedModels([...downloaded_models, { title: stored_model['model-name'], value: stored_model.url }])
}
}
)
}
// load model using the model name retrieved
await window['node-llama-cpp'].loadModel(stored_model['model-name'])
await openDownloadProtector(
'Loading model...',
`Loading model ${stored_model['model-name']}`,
async callback => {
callback(100, false);
// load model using the model name retrieved
await window['node-llama-cpp'].loadModel(stored_model['model-name'])
updateState();
callback(100, true)
}
)
}
updateState();
}

useEffect(()=>{
Expand All @@ -55,9 +66,18 @@ export default function LlamaSettings({ trigger, enabled, updateEnabled, openDow
}, [trigger])

useEffect(()=>{
const { llama_model_url, llama_reset_everytime } = getPlatformSettings();
setModelDownloadLink(llama_model_url || DEFAULT_LLAMA_CPP_MODEL_URL);
setResetEveryTime(llama_reset_everytime);
(async function() {
const { llama_model_url, llama_reset_everytime } = getPlatformSettings();
setModelDownloadLink(llama_model_url || DEFAULT_LLAMA_CPP_MODEL_URL);
setResetEveryTime(llama_reset_everytime);

const models = await idb.getAll("downloaded-models", {
where: [{'platform': 'Llama'}],
select: ["model-name", "url"]
});
setDownloadedModels(models.map(e=>{return { title: e['model-name'], value: e.url }}))
})()
// eslint-disable-next-line
}, [])

return (
Expand All @@ -67,9 +87,14 @@ export default function LlamaSettings({ trigger, enabled, updateEnabled, openDow
description={"Llama.cpp Engine is powerful and it allows you to run your own .gguf model using GPU on your local machine."}
value={enabled} cb={updateEnabled}
/>
<DropdownComponent
title={"Select a downloaded model to use"}
description={"If there's no models, you might want to download a new one by enter url below."}
value={downloaded_models} cb={setModelDownloadLink}
/>
<TextComponent
title={"Set the link of model you want to use"}
description={"Only models with extension .gguf can be used. Please make sure it can be run on your own machine."}
title={"Or enter the link of model you want to use"}
description={"Only models with extension .gguf can be used. Please make sure it can be run on your own machine. If the model of entered url is not downloaded, it will download automatically when you save settings."}
placeholder={"Default model is Microsoft Phi-3"}
value={model_download_link} cb={setModelDownloadLink}
/>
Expand Down
25 changes: 25 additions & 0 deletions src/components/settings/components/DropdownComponent.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
export default function DropdownComponent({ cb, value, title, description }) {
return (
<div className="component">
<div className="title">{title}</div>
{ description && <div className="description">{description}</div> }
<select
className="main-part"
onClick={evt=>evt.preventDefault()}
onChange={evt=>cb && cb(evt.target.value)}
>
{ value.map((e, i)=>{
let title, value;
if(typeof e === "object") {
title = e.title;
value = e.value;
} else {
title = e;
value = e;
}
return <option key={`option-${i}`} value={value}>{ title }</option>
}) }
</select>
</div>
)
}
6 changes: 6 additions & 0 deletions src/styles/settings.css
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@
margin-left: var(--elem-margin);
}

.setting-page > .setting-section > .component > select.main-part {
border-radius: 7px;
background-color: white;
padding: 0px 7px;
}

dialog:has(.download-protector) {
border: none;
border-radius: 15px;
Expand Down
3 changes: 2 additions & 1 deletion src/utils/workers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ export function getCompletionFunctions(platform = null) {
completions: window['node-llama-cpp'].chatCompletions,
abort: window['node-llama-cpp'].abortCompletion,
platform: 'Llama',
initClient: window['node-llama-cpp'].setClient
initClient: window['node-llama-cpp'].setClient,
formator: window['node-llama-cpp'].formator
}
case "Wllama":
return {
Expand Down