diff --git a/README.md b/README.md index c633545..16e111c 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ the code will attempt to choose an NPU first, then a GPU and finally the CPU if | Function Name | Parameter | Type | Default Value | Supported Hardware | |-----------------------|----------------|------------------------|---------------|--------------------| -| transcribeAudioFile | audioFile | Blob | - | NPU / GPU / CPU | +| transcribeAudioFile | audioFile | Blob | - | GPU / CPU | | | model | string | "Xenova/whisper-tiny"| | | | timestamps | boolean | false | | | | language | string | "en-US" | | @@ -28,10 +28,12 @@ the code will attempt to choose an NPU first, then a GPU and finally the CPU if | | model | string | "Xenova/distilbart-cnn-6-6"| | | ocr | image | Blob | - | GPU / CPU | | | model | string | "Xenova/trocr-small-printed"| | +| image-classification | image | Blob | - | NPU / GPU / CPU | +| | model | string | "Xenova/resnet-50"| | ## Technical Details -The Web AI Toolkit utilizes the [transformers.js project](https://huggingface.co/docs/transformers.js/index) to run AI workloads. All AI processing is performed locally on the device, ensuring data privacy and reducing latency. AI workloads are run using the [WebNN API](https://learn.microsoft.com/en-us/windows/ai/directml/webnn-overview) when available, otherwise falling back to the WebGPU API. Both of these APIs are used to "hardware accelerate" the AI inferences, with WebNN targeting NPUs and GPUs, and WebGPU strictly targeting GPUs. +The Web AI Toolkit utilizes the [transformers.js project](https://huggingface.co/docs/transformers.js/index) to run AI workloads. All AI processing is performed locally on the device, ensuring data privacy and reducing latency. AI workloads are run using the [WebNN API](https://learn.microsoft.com/en-us/windows/ai/directml/webnn-overview) when available, otherwise falling back to the WebGPU API, or even to the CPU with WebAssembly. Choosing the correct hardware to target is handled by the library. ## Usage @@ -77,6 +79,16 @@ const text = await ocr(image); console.log(text); ``` +### Image Classification + +```javascript +import { classifyImage } from 'web-ai-toolkit'; + +const image = ...; // Your image Blob +const text = await classifyImage(image); +console.log(text); +``` + ## Contribution We welcome contributions to the Web AI Toolkit. Please fork the repository and submit a pull request with your changes. For major changes, please open an issue first to discuss what you would like to change. diff --git a/package-lock.json b/package-lock.json index aff5750..cf97e39 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,15 +1,15 @@ { "name": "web-ai-toolkit", - "version": "0.2.0", + "version": "0.2.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "web-ai-toolkit", - "version": "0.2.0", + "version": "0.2.1", "license": "ISC", "dependencies": { - "@huggingface/transformers": "^3.0.0-alpha.16", + "@huggingface/transformers": "^3.0.0-alpha.22", "@xenova/transformers": "^2.17.2" }, "devDependencies": { @@ -600,14 +600,14 @@ } }, "node_modules/@huggingface/transformers": { - "version": "3.0.0-alpha.16", - "resolved": "https://registry.npmjs.org/@huggingface/transformers/-/transformers-3.0.0-alpha.16.tgz", - "integrity": "sha512-EhuGTi5U06EVHDm1Q43LZWnsmHmlYdjhLyIJgLTq/1Mxa9CWhtGgpTEYNAkjdPmCE7TFaHKf2zj4yzauLNYmlA==", + "version": "3.0.0-alpha.22", + "resolved": "https://registry.npmjs.org/@huggingface/transformers/-/transformers-3.0.0-alpha.22.tgz", + "integrity": "sha512-3+Py/aOD6moJsby7pBtJsFwdnevtFxk/sD8WWuVqPMr7kASsvrbpIkyCJdzwzkM5LsdFD1QSD+dss7rDlD5v1g==", "license": "Apache-2.0", "dependencies": { "@huggingface/jinja": "^0.3.0", "onnxruntime-node": "1.19.2", - "onnxruntime-web": "1.20.0-dev.20240908-de7a02beef", + "onnxruntime-web": "1.20.0-dev.20241016-2b8fc5529b", "sharp": "^0.33.5" } }, @@ -649,23 +649,23 @@ } }, "node_modules/@huggingface/transformers/node_modules/onnxruntime-web": { - "version": "1.20.0-dev.20240908-de7a02beef", - "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.20.0-dev.20240908-de7a02beef.tgz", - "integrity": "sha512-4HpjPz06XlM/zD8yfjDqF2Iu7thHTxeZLtLwO4LHTETxdUM2hzwruUuR7d8ZasU3Gxtmi+YfCddfssFaxfObrQ==", + "version": "1.20.0-dev.20241016-2b8fc5529b", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.20.0-dev.20241016-2b8fc5529b.tgz", + "integrity": "sha512-1XovqtgqeEFtupuyzdDQo7Tqj4GRyNHzOoXjapCEo4rfH3JrXok5VtqucWfRXHPsOI5qoNxMQ9VE+drDIp6woQ==", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", "guid-typescript": "^1.0.9", "long": "^5.2.3", - "onnxruntime-common": "1.20.0-dev.20240827-5d54dc1462", + "onnxruntime-common": "1.20.0-dev.20241016-2b8fc5529b", "platform": "^1.3.6", "protobufjs": "^7.2.4" } }, "node_modules/@huggingface/transformers/node_modules/onnxruntime-web/node_modules/onnxruntime-common": { - "version": "1.20.0-dev.20240827-5d54dc1462", - "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.20.0-dev.20240827-5d54dc1462.tgz", - "integrity": "sha512-oR+xPRD64OI+w9nRLXQi9rEXYZ5W9BhqVi688sUTreU9J6pK182JYblmHjvapCg+Tta6MbkAsr3T1NZHM3tB1g==", + "version": "1.20.0-dev.20241016-2b8fc5529b", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.20.0-dev.20241016-2b8fc5529b.tgz", + "integrity": "sha512-KZK8b6zCYGZFjd4ANze0pqBnqnFTS3GIVeclQpa2qseDpXrCQJfkWBixRcrZShNhm3LpFOZ8qJYFC5/qsJK9WQ==", "license": "MIT" }, "node_modules/@huggingface/transformers/node_modules/protobufjs": { diff --git a/package.json b/package.json index 45c5137..8ac9377 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "web-ai-toolkit", - "version": "0.2.1", + "version": "0.3.1", "repository": "https://github.com/jgw96/web-ai-toolkit", "keywords": [ "ai", @@ -38,7 +38,7 @@ "vitest": "^2.1.2" }, "dependencies": { - "@huggingface/transformers": "^3.0.0-alpha.16", + "@huggingface/transformers": "^3.0.0-alpha.22", "@xenova/transformers": "^2.17.2" } } diff --git a/src/index.ts b/src/index.ts index 7d0453f..3d996ed 100644 --- a/src/index.ts +++ b/src/index.ts @@ -42,3 +42,14 @@ export async function ocr(image: Blob, model: string = "Xenova/trocr-small-print return err; } } + +export async function classifyImage(image: Blob, model: string = "Xenova/resnet-50") { + try { + const { runClassifier } = await import("./services/image-classification/image-classification"); + return runClassifier(image, model); + } + catch (err) { + console.error(err); + return err; + } +} \ No newline at end of file diff --git a/src/services/image-classification/image-classification.ts b/src/services/image-classification/image-classification.ts new file mode 100644 index 0000000..f2c152b --- /dev/null +++ b/src/services/image-classification/image-classification.ts @@ -0,0 +1,42 @@ +import { pipeline, env } from '@huggingface/transformers'; +import { webGPUCheck } from '../../utils'; + +let classifier: any = undefined; + +export async function runClassifier(image: Blob | string, model: string = "onnx-community/mobilenetv4s-webnn") { + return new Promise(async (resolve, reject) => { + try { + if (!classifier) { + await loadClassifier(model); + }; + + if (typeof image !== "string") { + image = URL.createObjectURL(image); + } + + const out = await classifier(image); + resolve(out); + } + catch (err) { + reject(err); + } + }); +} + +async function loadClassifier(model: string): Promise { + return new Promise(async (resolve) => { + if (!classifier) { + env.allowLocalModels = false; + env.useBrowserCache = false; + + classifier = await pipeline("image-classification", model || "Xenova/resnet-50", { + device: (navigator as any).ml ? "webnn-npu" : await webGPUCheck() ? "webgpu" : "wasm" + }); + console.log("loaded classifier", classifier) + resolve(); + } + else { + resolve(); + } + }); +} \ No newline at end of file diff --git a/src/services/speech-recognition/recognition.ts b/src/services/speech-recognition/recognition.ts index b00daa1..6a1c2fc 100644 --- a/src/services/speech-recognition/recognition.ts +++ b/src/services/speech-recognition/recognition.ts @@ -10,7 +10,7 @@ export function doLocalWhisper(audioFile: Blob, model: string = "Xenova/whisper- if (!transcriber) { await loadTranscriber(model || 'Xenova/whisper-tiny', false, 'en'); } - + const fileReader = new FileReader(); fileReader.onloadend = async () => { const audioCTX = new AudioContext({ @@ -58,9 +58,11 @@ export async function loadTranscriber(model: string = "Xenova/whisper-tiny", tim // @ts-ignore return_timestamps: timestamps, language, - device: (navigator as any).ml ? "webnn" : await webGPUCheck() ? "webgpu" : "wasm" + // @ts-ignore + device: await webGPUCheck() ? "webgpu" : "wasm" }); + resolve(); } else { @@ -127,8 +129,6 @@ function callback_function(item: any) { // Update tokens of last chunk last.tokens = [...item[0].output_token_ids]; - console.log("callback_function", item, last) - // Merge text chunks // TODO optimise so we don't have to decode all chunks every time // @ts-ignore @@ -138,7 +138,6 @@ function callback_function(item: any) { force_full_sequences: false, }); - console.log("callback_function", data); self.postMessage({ type: 'transcribe-interim', diff --git a/test.html b/test.html index 0655743..eedafaa 100644 --- a/test.html +++ b/test.html @@ -22,6 +22,11 @@ +
+ + +
+