Skip to content

Commit

Permalink
v0.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jgw96 committed Oct 25, 2024
1 parent 5c0415f commit 1ffe92a
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 23 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ the code will attempt to choose an NPU first, then a GPU and finally the CPU if

| Function Name | Parameter | Type | Default Value | Supported Hardware |
|-----------------------|----------------|------------------------|---------------|--------------------|
| transcribeAudioFile | audioFile | Blob | - | NPU / GPU / CPU |
| transcribeAudioFile | audioFile | Blob | - | GPU / CPU |
| | model | string | "Xenova/whisper-tiny"| |
| | timestamps | boolean | false | |
| | language | string | "en-US" | |
Expand All @@ -28,10 +28,12 @@ the code will attempt to choose an NPU first, then a GPU and finally the CPU if
| | model | string | "Xenova/distilbart-cnn-6-6"| |
| ocr | image | Blob | - | GPU / CPU |
| | model | string | "Xenova/trocr-small-printed"| |
| image-classification | image | Blob | - | NPU / GPU / CPU |
| | model | string | "Xenova/resnet-50"| |

## Technical Details

The Web AI Toolkit utilizes the [transformers.js project](https://huggingface.co/docs/transformers.js/index) to run AI workloads. All AI processing is performed locally on the device, ensuring data privacy and reducing latency. AI workloads are run using the [WebNN API](https://learn.microsoft.com/en-us/windows/ai/directml/webnn-overview) when available, otherwise falling back to the WebGPU API. Both of these APIs are used to "hardware accelerate" the AI inferences, with WebNN targeting NPUs and GPUs, and WebGPU strictly targeting GPUs.
The Web AI Toolkit utilizes the [transformers.js project](https://huggingface.co/docs/transformers.js/index) to run AI workloads. All AI processing is performed locally on the device, ensuring data privacy and reducing latency. AI workloads are run using the [WebNN API](https://learn.microsoft.com/en-us/windows/ai/directml/webnn-overview) when available, otherwise falling back to the WebGPU API, or even to the CPU with WebAssembly. Choosing the correct hardware to target is handled by the library.

## Usage

Expand Down Expand Up @@ -77,6 +79,16 @@ const text = await ocr(image);
console.log(text);
```

### Image Classification

```javascript
import { classifyImage } from 'web-ai-toolkit';

const image = ...; // Your image Blob
const text = await classifyImage(image);
console.log(text);
```

## Contribution

We welcome contributions to the Web AI Toolkit. Please fork the repository and submit a pull request with your changes. For major changes, please open an issue first to discuss what you would like to change.
Expand Down
28 changes: 14 additions & 14 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "web-ai-toolkit",
"version": "0.2.1",
"version": "0.3.1",
"repository": "https://github.com/jgw96/web-ai-toolkit",
"keywords": [
"ai",
Expand Down Expand Up @@ -38,7 +38,7 @@
"vitest": "^2.1.2"
},
"dependencies": {
"@huggingface/transformers": "^3.0.0-alpha.16",
"@huggingface/transformers": "^3.0.0-alpha.22",
"@xenova/transformers": "^2.17.2"
}
}
11 changes: 11 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,14 @@ export async function ocr(image: Blob, model: string = "Xenova/trocr-small-print
return err;
}
}

export async function classifyImage(image: Blob, model: string = "Xenova/resnet-50") {
try {
const { runClassifier } = await import("./services/image-classification/image-classification");
return runClassifier(image, model);
}
catch (err) {
console.error(err);
return err;
}
}
42 changes: 42 additions & 0 deletions src/services/image-classification/image-classification.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { pipeline, env } from '@huggingface/transformers';
import { webGPUCheck } from '../../utils';

let classifier: any = undefined;

export async function runClassifier(image: Blob | string, model: string = "onnx-community/mobilenetv4s-webnn") {
return new Promise(async (resolve, reject) => {
try {
if (!classifier) {
await loadClassifier(model);
};

if (typeof image !== "string") {
image = URL.createObjectURL(image);
}

const out = await classifier(image);
resolve(out);
}
catch (err) {
reject(err);
}
});
}

async function loadClassifier(model: string): Promise<void> {
return new Promise(async (resolve) => {
if (!classifier) {
env.allowLocalModels = false;
env.useBrowserCache = false;

classifier = await pipeline("image-classification", model || "Xenova/resnet-50", {
device: (navigator as any).ml ? "webnn-npu" : await webGPUCheck() ? "webgpu" : "wasm"
});
console.log("loaded classifier", classifier)
resolve();
}
else {
resolve();
}
});
}
9 changes: 4 additions & 5 deletions src/services/speech-recognition/recognition.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export function doLocalWhisper(audioFile: Blob, model: string = "Xenova/whisper-
if (!transcriber) {
await loadTranscriber(model || 'Xenova/whisper-tiny', false, 'en');
}

const fileReader = new FileReader();
fileReader.onloadend = async () => {
const audioCTX = new AudioContext({
Expand Down Expand Up @@ -58,9 +58,11 @@ export async function loadTranscriber(model: string = "Xenova/whisper-tiny", tim
// @ts-ignore
return_timestamps: timestamps,
language,
device: (navigator as any).ml ? "webnn" : await webGPUCheck() ? "webgpu" : "wasm"
// @ts-ignore
device: await webGPUCheck() ? "webgpu" : "wasm"
});


resolve();
}
else {
Expand Down Expand Up @@ -127,8 +129,6 @@ function callback_function(item: any) {
// Update tokens of last chunk
last.tokens = [...item[0].output_token_ids];

console.log("callback_function", item, last)

// Merge text chunks
// TODO optimise so we don't have to decode all chunks every time
// @ts-ignore
Expand All @@ -138,7 +138,6 @@ function callback_function(item: any) {
force_full_sequences: false,
});

console.log("callback_function", data);

self.postMessage({
type: 'transcribe-interim',
Expand Down
14 changes: 14 additions & 0 deletions test.html
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
<button id="image-to-text-button">Test Image to Text</button>
</div>

<div id="image-classify-block">
<input type="file" id="image-classify-file" accept="image/*" />
<button id="image-classify-button">Test Image Classification</button>
</div>


<script type="module">
document.querySelector("#summarize_button").addEventListener("click", async () => {
Expand Down Expand Up @@ -56,6 +61,15 @@
console.log(text);
URL.revokeObjectURL(file);
});

document.querySelector("#image-classify-button").addEventListener("click", async () => {
const { classifyImage } = await import("/dist/index.js");

const file = document.querySelector("#image-classify-file").files[0];
const text = await classifyImage(URL.createObjectURL(file));
console.log(text);
URL.revokeObjectURL(file);
});
</script>
</body>

Expand Down

0 comments on commit 1ffe92a

Please sign in to comment.