Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(extension): fix hf local model and remote cache paths #64

Merged
merged 6 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ To run extension, follow these steps.

```sh
# vsce currently only supports npm
# see https://github.com/microsoft/vscode-vsce/issues/517
# see https://github.com/microsoft/vscode-vsce/issues/421
$ npm install
```

Expand Down
2 changes: 2 additions & 0 deletions CONTRIBUTING.zh-cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

```sh
# 因 vsce 原因,插件依赖请使用 npm 进行安装
# see https://github.com/microsoft/vscode-vsce/issues/517
# see https://github.com/microsoft/vscode-vsce/issues/421
$ npm install
```

Expand Down
2 changes: 1 addition & 1 deletion package.nls.zh-cn.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"configuration.ollama.model.description": "模型名称",

"configuration.transformers.remoteHost.description": "加载模型的主机 URL,默认为 Hugging Face Hub。",
"configuration.transformers.remoteHost.markdownDescription": "加载模型的主机 URL,默认为 [Hugging Face Hub](https://huggingface.co)",
"configuration.transformers.remoteHost.markdownDescription": "加载模型的主机 URL,默认为 [Hugging Face Hub](https://huggingface.co),国内推荐走 https://hf-mirror.com 代理站点。",
"configuration.transformers.remotePathTemplate.description": "加载模型时填写并合并到 remoteHost 的路径模板。",
"configuration.transformers.remotePathTemplate.markdownDescription": "加载模型时填写并合并到 `#autodev.transformers.remoteHost#` 的路径模板。",
"configuration.transformers.allowLocalModels.markdownDescription": "是否允许加载本地文件,默认为 `true`。如果设置为 `false`,将跳过本地文件检查,尝试从远程主机加载模型。",
Expand Down
108 changes: 76 additions & 32 deletions pre-download-build.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,84 @@ const targetToLanceDb = {

const platforms = ["darwin", "linux", "win32"];
const architectures = ["x64", "arm64"];
let targets = platforms.flatMap((platform) =>
architectures.map((arch) => `${platform}-${arch}`),
);

console.log("[info] Building binaries with pkg...");
for (const target of targets) {
const targetDir = `bin/${target}`;
fs.mkdirSync(targetDir, { recursive: true });
console.log(`[info] Building ${target}...`);
// execSync(
// `npx pkg --no-bytecode --public-packages "*" --public pkgJson/${target} --out-path ${targetDir}`,
// );
function download(base) {
let targets = platforms.flatMap((platform) =>
architectures.map((arch) => `${platform}-${arch}`),
);

// Download and unzip prebuilt sqlite3 binary for the target
const downloadUrl = `https://github.com/TryGhost/node-sqlite3/releases/download/v5.1.7/sqlite3-v5.1.7-napi-v6-${
target === "win32-arm64" ? "win32-ia32" : target
}.tar.gz`;
execSync(`curl -L -o ${targetDir}/build.tar.gz ${downloadUrl}`);
execSync(`cd ${targetDir} && tar -xvzf build.tar.gz`);
fs.copyFileSync(
`${targetDir}/build/Release/node_sqlite3.node`,
`${targetDir}/node_sqlite3.node`,
);
fs.unlinkSync(`${targetDir}/build.tar.gz`);
fs.rmSync(`${targetDir}/build`, {
recursive: true,
force: true,
});
console.log("[info] Building binaries with pkg...");
for (const target of targets) {
const targetDir = `bin/${target}`;
fs.mkdirSync(targetDir, { recursive: true });
console.log(`[info] Building ${target}...`);
// execSync(
// `npx pkg --no-bytecode --public-packages "*" --public pkgJson/${target} --out-path ${targetDir}`,
// );

// Download and unzip prebuilt sqlite3 binary for the target
const downloadUrl = `${base}/releases/download/v5.1.7/sqlite3-v5.1.7-napi-v6-${
target === "win32-arm64" ? "win32-ia32" : target
}.tar.gz`;
execSync(`curl -L -o ${targetDir}/build.tar.gz ${downloadUrl}`);
execSync(`cd ${targetDir} && tar -xvzf build.tar.gz`);
fs.copyFileSync(
`${targetDir}/build/Release/node_sqlite3.node`,
`${targetDir}/node_sqlite3.node`,
);
fs.unlinkSync(`${targetDir}/build.tar.gz`);
fs.rmSync(`${targetDir}/build`, {
recursive: true,
force: true,
});
}

console.log("[info] Downloading prebuilt lancedb...");
for (const target of targets) {
if (targetToLanceDb[target]) {
console.log(`[info] Downloading ${target}...`);
execSync(`npm install -f ${targetToLanceDb[target]}@0.4.20 --no-save`);
}
}
}

console.log("[info] Downloading prebuilt lancedb...");
for (const target of targets) {
if (targetToLanceDb[target]) {
console.log(`[info] Downloading ${target}...`);
execSync(`npm install -f ${targetToLanceDb[target]}@0.4.20 --no-save`);
}
getSqliteRepoUrl().then(repoUrl => {
console.log("[info] Downloading sqlite3 from", repoUrl)
download(repoUrl);
})

function getSqliteRepoUrl() {
const github = 'https://github.com/TryGhost/node-sqlite3'
const gitee = 'https://gitee.com/Ypeng/node-sqlite3' // china proxy repo

// Like any
const promise = new Promise((resolve) => {
let counter = 0;

function fallback() {
counter++;

if (counter === 2) {
resolve(github);
}
}

function ping(url) {
return fetch(url).then(
res => {
if (res.ok) {
resolve(url)
} else {
fallback()
}
},
fallback
)
}

ping(github)
ping(gitee)
})

return promise
}
8 changes: 7 additions & 1 deletion src/base/common/configuration/configurationService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
EventEmitter,
workspace,
type WorkspaceConfiguration,
Uri
} from 'vscode';

import { AUTODEV_CONFIG_PREFIX } from './configuration';
Expand All @@ -27,7 +28,7 @@ export class ConfigurationService {

constructor(
@inject(IExtensionContext)
extensionContext: IExtensionContext,
private readonly extensionContext: IExtensionContext,
) {
this._projectConfig = resolveProjectConfig(extensionContext.extension.packageJSON);
this._config = workspace.getConfiguration(AUTODEV_CONFIG_PREFIX);
Expand Down Expand Up @@ -94,6 +95,11 @@ export class ConfigurationService {
return path.join(base, '.autodev', ...paths);
}

extensionJoinPath(...paths: string[]): Uri {
const base = this.extensionContext.extensionUri;
return Uri.joinPath(base, ...paths);
}

dispose(): void {
Disposable.from(...this._disposables).dispose();
this._disposables.length = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os from 'node:os';
import path from 'node:path';
import { fileURLToPath } from 'node:url';

import { chunkArray } from '@langchain/core/utils/chunk_array';
import {
Expand All @@ -18,8 +20,8 @@ interface HuggingFaceTransformersParams {
remoteHost: string;
remotePathTemplate: string;
allowLocalModels: boolean;
localModelPath: string;
onnxWasmNumThreads: 'auto' | number;
localModelPath: string | null;
onnxWasmNumThreads: 'auto' | number | null;
logLevel: 'verbose' | 'info' | 'warning' | 'error' | 'fatal';
}

Expand All @@ -28,8 +30,8 @@ const defaults: HuggingFaceTransformersParams = {
remoteHost: 'https://huggingface.co',
remotePathTemplate: '{model}/resolve/{revision}/',
allowLocalModels: true,
localModelPath: 'models',
onnxWasmNumThreads: 'auto',
localModelPath: null,
onnxWasmNumThreads: null,
logLevel: 'error',
};

Expand Down Expand Up @@ -156,23 +158,43 @@ export class HuggingFaceTransformersLanguageModelProvider implements ILanguageMo
env.remoteHost = valueGetter('remoteHost');
env.remotePathTemplate = valueGetter('remotePathTemplate');

// NOTE: Enabled or disable the internal embedding models
// The extension has a built-in embedding model, prevent users from overwriting the model paths
// If you must disable local adoption of an external, use a remote cache to overwrite the
const allowLocalModels = valueGetter('allowLocalModels');

if (allowLocalModels) {
env.allowLocalModels = true;
env.localModelPath = configService.joinPath('models');
const internalModelsUri = configService.extensionJoinPath('dist', 'models');
env.localModelPath = fileURLToPath(internalModelsUri.toString());
} else {
env.allowLocalModels = false;
}

env.cacheDir = configService.joinPath(valueGetter('localModelPath').replace('~', os.homedir()));
const localModelPath = valueGetter('localModelPath');

if (localModelPath) {
// ~/.autodev/models
if (localModelPath.startsWith('~')) {
env.cacheDir = path.join(localModelPath.replace('~', ''), os.homedir());
} else {
// Maybe cache into the current workspace
env.cacheDir = configService.joinPath(localModelPath);
}
} else {
// Default cache into home directory
env.cacheDir = path.join(os.homedir(), '.autodev', 'models');
}

const numThreads = valueGetter('onnxWasmNumThreads');

if (numThreads === 'auto') {
env.backends.onnx.wasm.numThreads = os.cpus().length;
} else {
} else if (typeof numThreads === 'number') {
env.backends.onnx.wasm.numThreads = numThreads;
} else {
// see https://github.com/microsoft/onnxruntime/issues/17274#issuecomment-1692587686
env.backends.onnx.wasm.numThreads = 1;
}

env.backends.onnx.logLevel = valueGetter('logLevel');
Expand Down
Loading