-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
403 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<!doctype html> | ||
<html> | ||
<head> | ||
<script type="text/javascript">LibAV = {base: "app/bundled/libavjs/dist"};</script> | ||
<script type="text/javascript" src="app/bundled/libavjs/dist/libav-4.6.6.0.1-behave.dbg.js"></script> | ||
<script type="module" src="app/infer/App.js"></script> | ||
<link rel="stylesheet" href="app/infer/App.css"> | ||
</head> | ||
<body> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import { render } from "preact" | ||
|
||
import {Inferrer} from "./Inferrer.js" | ||
|
||
export function App(_props: {}) { | ||
return <div><Inferrer /></div> | ||
} | ||
|
||
render(<App />, document.body) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import {Upload} from "../lib/Upload.js" | ||
import {FileTree, FileTreeBranch, readFileSystemHandle, updateLeaf, convertAll} from "../lib/FileTree.js" | ||
import * as css from "./inferrer.module.css" | ||
import { JSX } from "preact" | ||
import {useState, useEffect} from 'preact/hooks' | ||
import {setBackend, Model, getModel, convert, getOutputFilename} from "./tfjs.js" | ||
|
||
const NR_WORKERS = 4 | ||
|
||
function fileFilter(file: File, extension: string): boolean { | ||
return !file.name.startsWith(".") && file.name.endsWith("." + extension) | ||
} | ||
|
||
|
||
export function Inferrer({}: {}): JSX.Element { | ||
const [files, setFiles] = useState<FileTreeBranch>(new Map()) | ||
const [state, setState] = useState<"uploading" | "converting" | "done">("uploading") | ||
const [model, setModel] = useState<Model | null>(null) | ||
const [tfBackend, setTfBackend] = useState<Parameters<typeof setBackend>[0]>("webgpu") | ||
|
||
function onBackendChange(event: JSX.TargetedEvent<HTMLSelectElement, Event>) { | ||
if (state !== "uploading") { | ||
return | ||
} | ||
setTfBackend((event.target as unknown as {value: Parameters<typeof setBackend>[0]}).value) | ||
} | ||
|
||
useEffect(() => { | ||
if (state !== "uploading") { | ||
return | ||
} | ||
setModel(null) | ||
setBackend(tfBackend) | ||
}, [tfBackend]) | ||
|
||
async function selectModel() { | ||
try { | ||
const modelDir = await window.showDirectoryPicker({id: "model"}) | ||
const newModel = await getModel(modelDir) | ||
setModel(newModel) | ||
} catch (e) { | ||
setModel(null) | ||
} | ||
} | ||
|
||
async function addFiles(fileSystemHandles: FileSystemHandle[]) { | ||
const newFiles = await readFileSystemHandle(fileSystemHandles, file => fileFilter(file, "MTS")) | ||
setFiles(files => new Map([...files, ...newFiles])) | ||
} | ||
function removeFile(path: string[]) { | ||
setFiles(files => updateLeaf(files, path, null)) | ||
} | ||
|
||
async function doConvertAll() { | ||
if (model === null) { | ||
return | ||
} | ||
setState("converting"); | ||
await convertAll( | ||
files, | ||
NR_WORKERS, | ||
(input, outputstream, onProgress) => convert( | ||
model, input, outputstream, onProgress), | ||
getOutputFilename, | ||
setFiles) | ||
setState("done") | ||
} | ||
|
||
return <> | ||
<h1>Video file convertor</h1> | ||
<div className={css.explanation}> | ||
This files converts video files to be used in BEHAVE. At the moment it can only convert MTS files, but it's easy to add additional types upon request. | ||
</div> | ||
<select disabled={state !== "uploading"} value={tfBackend} onChange={onBackendChange}> | ||
<option value="wasm">WASM</option> | ||
<option value="webgl">WebGL</option> | ||
<option value="webgpu">WebGPU</option> | ||
</select> | ||
<button disabled={state !== "uploading"} onClick={selectModel}>Select model</button> | ||
<button disabled={!(state==="uploading" && model !== null && files.size > 0)} | ||
onClick={doConvertAll} | ||
>Start conversion</button> | ||
<div className={css.files}> | ||
{files.size ? <FileTree {...{files, removeFile}} /> : "Add files to convert"} | ||
</div> | ||
{state === "uploading" && <Upload addFiles={addFiles} />} | ||
</> | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
h1 { | ||
|
||
} | ||
.explanation { | ||
padding-left: 1em; | ||
padding-right: 1em; | ||
padding-bottom: 2em; | ||
|
||
} | ||
.files { | ||
border: .2em solid hsl(240, 100%, 50%); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import * as tf from '@tensorflow/tfjs' | ||
import "@tensorflow/tfjs-backend-wasm" | ||
import "@tensorflow/tfjs-backend-webgl" | ||
import "@tensorflow/tfjs-backend-webgpu" | ||
import type {FileTreeLeaf} from "../lib/FileTree.js" | ||
import {getNumberOfFrames, getFrames} from "../lib/video.js" | ||
|
||
|
||
export function getOutputFilename(inputfilename: string) { | ||
const parts = inputfilename.split(".") | ||
const baseparts = parts.length == 1 ? parts : parts.slice(0, -1) | ||
return [...baseparts, "csv"].join(".") | ||
} | ||
|
||
export async function setBackend(backend: "wasm" | "webgl" | "webgpu") { | ||
await tf.setBackend(backend) | ||
await tf.ready() | ||
console.log("TF backend is now", tf.backend()) | ||
} | ||
|
||
interface ModelData { | ||
weightsManifest: {paths: string[]}[] | ||
} | ||
|
||
export type Model = tf.GraphModel<string> | ||
|
||
export async function getModel( | ||
modelDirectory: FileSystemDirectoryHandle | ||
): Promise<Model> { | ||
const modelFile = await modelDirectory.getFileHandle("model.json").then( | ||
fh => fh.getFile()) | ||
const modelData = JSON.parse(await modelFile.text()) as ModelData | ||
const weightFiles = await Promise.all( | ||
modelData.weightsManifest[0].paths.map( | ||
name => modelDirectory.getFileHandle(name).then(fh => fh.getFile()))) | ||
const model = await tf.loadGraphModel( | ||
tf.io.browserFiles([modelFile, ...weightFiles])) | ||
return model | ||
} | ||
|
||
export async function convert( | ||
model: Model, | ||
file: File, | ||
outputstream: FileSystemWritableFileStream, | ||
onProgress: (progress: FileTreeLeaf["progress"]) => void, | ||
) { | ||
const numberOfFrames = await getNumberOfFrames(file) | ||
let framenr = 0; | ||
const textEncoder = new TextEncoder() | ||
for await (const imageData of getFrames(file, 640, 640)) { | ||
const [boxes, scores, classes] = await infer(model, imageData) | ||
for (let i = 0; i < scores.length; i++) { | ||
const box = boxes.slice(i * 4, (i + 1) * 4) | ||
const score = scores.at(i)! | ||
const klass = classes.at(i)! | ||
if (!Number.isInteger(klass)) { | ||
throw new Error(`Class is not an int? ${i} ${boxes}, ${scores} ${classes}`) | ||
} | ||
const line = `${framenr},${klass.toFixed(0)},${[...box].map(c => c.toFixed(4)).join(",")},${score.toFixed(2)}\n` | ||
await outputstream.write(textEncoder.encode(line)) | ||
} | ||
onProgress({"converting": Math.min(framenr / numberOfFrames, 1)}) | ||
framenr++ | ||
} | ||
} | ||
|
||
export function preprocess( | ||
imageData: ImageData, | ||
modelWidth: number, | ||
modelHeight: number | ||
): [tf.Tensor<tf.Rank>, number, number] { | ||
|
||
const img = tf.browser.fromPixels(imageData); | ||
|
||
const [h, w] = img.shape.slice(0, 2); // get source width and height | ||
const maxSize = Math.max(w, h); // get max size | ||
const imgPadded = img.pad([ | ||
[0, maxSize - h], // padding y [bottom only] | ||
[0, maxSize - w], // padding x [right only] | ||
[0, 0], | ||
]) as tf.Tensor3D; | ||
|
||
const xRatio = maxSize / w; // update xRatio | ||
const yRatio = maxSize / h; // update yRatio | ||
|
||
const image = tf.image | ||
.resizeBilinear( | ||
imgPadded, | ||
[modelWidth, modelHeight]) // resize frame | ||
.div(255.0) // normalize | ||
.expandDims(0); // add batch | ||
|
||
return [image, xRatio, yRatio] | ||
}; | ||
|
||
function getBoxesAndScoresAndClassesFromResult( | ||
inferResult: tf.Tensor<tf.Rank> | ||
): [tf.Tensor<tf.Rank>, tf.Tensor<tf.Rank>, tf.Tensor<tf.Rank>] { | ||
let transRes = inferResult.transpose([0, 2, 1]); // transpose result [b, det, n] => [b, n, det] | ||
const w = transRes.slice([0, 0, 2], [-1, -1, 1]); // get width | ||
const h = transRes.slice([0, 0, 3], [-1, -1, 1]); // get height | ||
const x1 = tf.sub( | ||
transRes.slice([0, 0, 0], [-1, -1, 1]), | ||
tf.div(w, 2) | ||
); // x1 | ||
const y1 = tf.sub( | ||
transRes.slice([0, 0, 1], [-1, -1, 1]), | ||
tf.div(h, 2) | ||
); // y1 | ||
const boxes = tf | ||
.concat( | ||
[ | ||
y1, | ||
x1, | ||
tf.add(y1, h), //y2 | ||
tf.add(x1, w), //x2 | ||
], | ||
2 | ||
) | ||
.squeeze(); | ||
const rawScores = transRes.slice([0, 0, 4], [-1, -1, 5]).squeeze([0]); // #6 only squeeze axis 0 to handle only 1 class models | ||
return [boxes, rawScores.max(1), rawScores.argMax(1)]; | ||
} | ||
|
||
export async function infer( | ||
model: Model, | ||
imageData: ImageData, | ||
): Promise<[Float32Array, Float32Array, Float32Array]> { | ||
const [img_tensor, _xRatio, _yRatio] = tf.tidy(() => preprocess(imageData, 640, 640)) | ||
const res = tf.tidy(() => model.execute(img_tensor) as tf.Tensor<tf.Rank>) | ||
const [boxes, scores, classes] = tf.tidy(() => getBoxesAndScoresAndClassesFromResult(res)) | ||
const nms = await tf.image.nonMaxSuppressionAsync( | ||
boxes as tf.Tensor2D, | ||
scores as tf.Tensor1D, | ||
500, | ||
0.45, | ||
0.5 | ||
); // NMS to filter boxes | ||
const boxes_nms = tf.tidy(() => boxes.div(640).gather(nms, 0)) | ||
const scores_nms = tf.tidy(() => scores.gather(nms, 0)) | ||
const classes_nms = tf.tidy(() => classes.gather(nms, 0)) | ||
|
||
const boxes_data = await boxes_nms.data() as Float32Array// indexing boxes by nms index | ||
const scores_data = await scores_nms.data() as Float32Array // indexing scores by nms index | ||
const classes_data = await classes_nms.data() as Float32Array // indexing classes by nms index | ||
tf.dispose([img_tensor, res, boxes, scores, classes, nms, boxes_nms, | ||
scores_nms, classes_nms]); | ||
return [boxes_data, scores_data, classes_data] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.