Skip to content

Commit

Permalink
Add support for updating huggingface gguf files
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas101 committed Dec 4, 2024
1 parent 0416fae commit 749d1df
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
7 changes: 1 addition & 6 deletions src/extension/background/AI/AIModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
} from '#Shared/Prefs'
import { nanoid } from 'nanoid'
import { EventEmitter } from 'events'
import AIModelId, { AIModelIdProvider } from '#Shared/AIModelId'
import AIModelId from '#Shared/AIModelId'

export enum TaskType {
Install = 'install',
Expand Down Expand Up @@ -231,11 +231,6 @@ class AIModelManagerImpl extends EventEmitter {
return this.#queueTask(TaskType.Update, async (): Promise<boolean> => {
try {
console.log(`Updating model ${modelId}`)
// Only AiBrow models can be updated
if (modelId.provider !== AIModelIdProvider.AiBrow) {
console.log(`Model ${modelId} cannot is not updatable`)
return false
}

if (!await AIModelFileSystem.hasModelInstalled(modelId)) {
throw new Error(`Model ${modelId} not installed`)
Expand Down
38 changes: 27 additions & 11 deletions src/native/main/APIHandler/ModelDownloadAPIHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import config from '#Shared/Config'
import { importLlama } from '#R/Llama'
import { getNonEmptyString } from '#Shared/API/Untrusted/UntrustedParser'
import AIModelId, { AIModelIdProvider } from '#Shared/AIModelId'
import AIModelAssetId from '#Shared/AIModelAssetId'

class ModelDownloadAPIHandler {
/* **************************************************************************/
Expand All @@ -37,9 +38,26 @@ class ModelDownloadAPIHandler {

#handleDownloadAsset = async (channel: IPCInflightChannel) => {
const asset = channel.payload.asset as AIModelAsset

if (await fs.exists(AIModelFileSystem.getAssetPath(asset.id))) {
return
const assetPath = AIModelFileSystem.getAssetPath(asset.id)

// Normally we don't update assets that are on disk, rather we change the endpoint
// and name of them with a new version. There are some special cases though
if (await fs.exists(assetPath)) {
if (
AIModelAssetId.isProvider(asset.id, AIModelIdProvider.HuggingFace) &&
asset.id.endsWith('.gguf')
) {
// Huggingface guff assets can be updated but need to be handled differently
const { readGgufFileInfo } = await importLlama()
const remoteModelInfo = await readGgufFileInfo(asset.url)
const localModelInfo = await readGgufFileInfo(assetPath)
if (remoteModelInfo.version === localModelInfo.version) {
return
}
} else {
// Asset doesn't update, so if it's on disk it's up to date
return
}
}

let loadedSize = 0
Expand All @@ -62,7 +80,6 @@ class ModelDownloadAPIHandler {
}
})
const modelPath = await downloader.download()
const assetPath = AIModelFileSystem.getAssetPath(asset.id)
await fs.ensureDir(path.dirname(assetPath))
await fs.move(modelPath, assetPath, { overwrite: true })
})
Expand All @@ -78,7 +95,6 @@ class ModelDownloadAPIHandler {
reportProgress()
})
await finished(reader.pipe(writer))
const assetPath = AIModelFileSystem.getAssetPath(asset.id)
await fs.ensureDir(path.dirname(assetPath))
await fs.move(file.path, assetPath, { overwrite: true })
})
Expand All @@ -99,8 +115,10 @@ class ModelDownloadAPIHandler {
const repo = getNonEmptyString(channel.payload.repo, undefined)
const model = getNonEmptyString(channel.payload.model, undefined)
if (!owner || !repo || !model) { throw new Error('Huggingface params invalid') }
if (!model.endsWith('.gguf')) { throw new Error('Huggingface model must be in .gguf format') }
const ggufUrl = `https://huggingface.co/${owner}/${repo}/resolve/main/${model}`

const modelId = new AIModelId({ provider: AIModelIdProvider.HuggingFace, owner, repo, model })
const ggufAssetId = AIModelAssetId.generate(modelId.provider, modelId.owner, modelId.repo, modelId.model)
const ggufUrl = AIModelAssetId.getRemoteUrl(ggufAssetId)

// Fetch the gguf metadata
const {
Expand Down Expand Up @@ -133,8 +151,6 @@ class ModelDownloadAPIHandler {
const modelTokenizer = modelInfo.metadata.tokenizer.ggml

// Generate the manifest
const modelId = new AIModelId({ provider: AIModelIdProvider.HuggingFace, owner, repo, model })
const modelAssetId = [modelId.provider, modelId.owner, modelId.repo, modelId.model].join('/')
const manifest: AIModelManifest = {
id: modelId.toString(),
name: `HuggingFace: ${path.basename(modelId.model, path.extname(modelId.model))}`,
Expand All @@ -148,9 +164,9 @@ class ModelDownloadAPIHandler {
: modelMetadata.general.license
? `https://www.google.com/search?q=${encodeURIComponent(modelMetadata.general.license)}`
: '',
model: modelAssetId,
model: ggufAssetId,
assets: [{
id: modelAssetId,
id: ggufAssetId,
url: ggufUrl,
size: fileSize
}],
Expand Down
54 changes: 54 additions & 0 deletions src/shared/AIModelAssetId.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { AIModelIdProvider } from './AIModelId'

const kAssetIdDivider = '/'

class AIModelAssetId {
/**
* Generates an asset id for a given provider
* @param provider: the model provider
* @param owner: the model owner
* @param repo: the model repo
* @param model: the model file
* @returns the local id of the model
*/
static generate (provider: AIModelIdProvider, owner: string, repo: string, model: string) {
switch (provider) {
case AIModelIdProvider.HuggingFace: return [provider, owner, repo, model].join(kAssetIdDivider)
default: throw new Error(`Unable to generate asset id for provider ${provider}`)
}
}

/**
* Checks if the given asset id is for the given provider
* @param assetId: the id of the asset
* @param provider: the provider to check for
* @return true if the asset is for the given provider
*/
static isProvider (assetId: string, provider: AIModelIdProvider) {
switch (provider) {
case AIModelIdProvider.HuggingFace: return assetId.startsWith(`${provider}${kAssetIdDivider}`)
default: throw new Error(`Unable to check asset id for provider ${provider}`)
}
}

/**
* Generates a remote url from a given asset id
* @param assetId: the id of the asset
* @returns the remote url of the asset
*/
static getRemoteUrl (assetId: string) {
const parts = assetId.split(kAssetIdDivider)
switch (parts[0]) {
case AIModelIdProvider.HuggingFace: {
const owner = parts[1]
const repo = parts[2]
const model = parts[3]
if (!model.endsWith('.gguf')) { throw new Error('Huggingface model must be in .gguf format') }
return `https://huggingface.co/${owner}/${repo}/resolve/main/${model}`
}
default: throw new Error(`Unable to get remote url for asset ${assetId}`)
}
}
}

export default AIModelAssetId

0 comments on commit 749d1df

Please sign in to comment.