diff --git a/src/extension/background/AI/AIModelManager.ts b/src/extension/background/AI/AIModelManager.ts index 82f9ee0..588572f 100644 --- a/src/extension/background/AI/AIModelManager.ts +++ b/src/extension/background/AI/AIModelManager.ts @@ -9,7 +9,7 @@ import { } from '#Shared/Prefs' import { nanoid } from 'nanoid' import { EventEmitter } from 'events' -import AIModelId, { AIModelIdProvider } from '#Shared/AIModelId' +import AIModelId from '#Shared/AIModelId' export enum TaskType { Install = 'install', @@ -231,11 +231,6 @@ class AIModelManagerImpl extends EventEmitter { return this.#queueTask(TaskType.Update, async (): Promise => { try { console.log(`Updating model ${modelId}`) - // Only AiBrow models can be updated - if (modelId.provider !== AIModelIdProvider.AiBrow) { - console.log(`Model ${modelId} cannot is not updatable`) - return false - } if (!await AIModelFileSystem.hasModelInstalled(modelId)) { throw new Error(`Model ${modelId} not installed`) diff --git a/src/native/main/APIHandler/ModelDownloadAPIHandler.ts b/src/native/main/APIHandler/ModelDownloadAPIHandler.ts index 951f177..52b87ea 100644 --- a/src/native/main/APIHandler/ModelDownloadAPIHandler.ts +++ b/src/native/main/APIHandler/ModelDownloadAPIHandler.ts @@ -19,6 +19,7 @@ import config from '#Shared/Config' import { importLlama } from '#R/Llama' import { getNonEmptyString } from '#Shared/API/Untrusted/UntrustedParser' import AIModelId, { AIModelIdProvider } from '#Shared/AIModelId' +import AIModelAssetId from '#Shared/AIModelAssetId' class ModelDownloadAPIHandler { /* **************************************************************************/ @@ -37,9 +38,26 @@ class ModelDownloadAPIHandler { #handleDownloadAsset = async (channel: IPCInflightChannel) => { const asset = channel.payload.asset as AIModelAsset - - if (await fs.exists(AIModelFileSystem.getAssetPath(asset.id))) { - return + const assetPath = AIModelFileSystem.getAssetPath(asset.id) + + // Normally we don't update assets that are on disk, rather we change the endpoint + // and name of them with a new version. There are some special cases though + if (await fs.exists(assetPath)) { + if ( + AIModelAssetId.isProvider(asset.id, AIModelIdProvider.HuggingFace) && + asset.id.endsWith('.gguf') + ) { + // Huggingface guff assets can be updated but need to be handled differently + const { readGgufFileInfo } = await importLlama() + const remoteModelInfo = await readGgufFileInfo(asset.url) + const localModelInfo = await readGgufFileInfo(assetPath) + if (remoteModelInfo.version === localModelInfo.version) { + return + } + } else { + // Asset doesn't update, so if it's on disk it's up to date + return + } } let loadedSize = 0 @@ -62,7 +80,6 @@ class ModelDownloadAPIHandler { } }) const modelPath = await downloader.download() - const assetPath = AIModelFileSystem.getAssetPath(asset.id) await fs.ensureDir(path.dirname(assetPath)) await fs.move(modelPath, assetPath, { overwrite: true }) }) @@ -78,7 +95,6 @@ class ModelDownloadAPIHandler { reportProgress() }) await finished(reader.pipe(writer)) - const assetPath = AIModelFileSystem.getAssetPath(asset.id) await fs.ensureDir(path.dirname(assetPath)) await fs.move(file.path, assetPath, { overwrite: true }) }) @@ -99,8 +115,10 @@ class ModelDownloadAPIHandler { const repo = getNonEmptyString(channel.payload.repo, undefined) const model = getNonEmptyString(channel.payload.model, undefined) if (!owner || !repo || !model) { throw new Error('Huggingface params invalid') } - if (!model.endsWith('.gguf')) { throw new Error('Huggingface model must be in .gguf format') } - const ggufUrl = `https://huggingface.co/${owner}/${repo}/resolve/main/${model}` + + const modelId = new AIModelId({ provider: AIModelIdProvider.HuggingFace, owner, repo, model }) + const ggufAssetId = AIModelAssetId.generate(modelId.provider, modelId.owner, modelId.repo, modelId.model) + const ggufUrl = AIModelAssetId.getRemoteUrl(ggufAssetId) // Fetch the gguf metadata const { @@ -133,8 +151,6 @@ class ModelDownloadAPIHandler { const modelTokenizer = modelInfo.metadata.tokenizer.ggml // Generate the manifest - const modelId = new AIModelId({ provider: AIModelIdProvider.HuggingFace, owner, repo, model }) - const modelAssetId = [modelId.provider, modelId.owner, modelId.repo, modelId.model].join('/') const manifest: AIModelManifest = { id: modelId.toString(), name: `HuggingFace: ${path.basename(modelId.model, path.extname(modelId.model))}`, @@ -148,9 +164,9 @@ class ModelDownloadAPIHandler { : modelMetadata.general.license ? `https://www.google.com/search?q=${encodeURIComponent(modelMetadata.general.license)}` : '', - model: modelAssetId, + model: ggufAssetId, assets: [{ - id: modelAssetId, + id: ggufAssetId, url: ggufUrl, size: fileSize }], diff --git a/src/shared/AIModelAssetId.ts b/src/shared/AIModelAssetId.ts new file mode 100644 index 0000000..72dde6c --- /dev/null +++ b/src/shared/AIModelAssetId.ts @@ -0,0 +1,54 @@ +import { AIModelIdProvider } from './AIModelId' + +const kAssetIdDivider = '/' + +class AIModelAssetId { + /** + * Generates an asset id for a given provider + * @param provider: the model provider + * @param owner: the model owner + * @param repo: the model repo + * @param model: the model file + * @returns the local id of the model + */ + static generate (provider: AIModelIdProvider, owner: string, repo: string, model: string) { + switch (provider) { + case AIModelIdProvider.HuggingFace: return [provider, owner, repo, model].join(kAssetIdDivider) + default: throw new Error(`Unable to generate asset id for provider ${provider}`) + } + } + + /** + * Checks if the given asset id is for the given provider + * @param assetId: the id of the asset + * @param provider: the provider to check for + * @return true if the asset is for the given provider + */ + static isProvider (assetId: string, provider: AIModelIdProvider) { + switch (provider) { + case AIModelIdProvider.HuggingFace: return assetId.startsWith(`${provider}${kAssetIdDivider}`) + default: throw new Error(`Unable to check asset id for provider ${provider}`) + } + } + + /** + * Generates a remote url from a given asset id + * @param assetId: the id of the asset + * @returns the remote url of the asset + */ + static getRemoteUrl (assetId: string) { + const parts = assetId.split(kAssetIdDivider) + switch (parts[0]) { + case AIModelIdProvider.HuggingFace: { + const owner = parts[1] + const repo = parts[2] + const model = parts[3] + if (!model.endsWith('.gguf')) { throw new Error('Huggingface model must be in .gguf format') } + return `https://huggingface.co/${owner}/${repo}/resolve/main/${model}` + } + default: throw new Error(`Unable to get remote url for asset ${assetId}`) + } + } +} + +export default AIModelAssetId