diff --git a/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/app.py b/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/app.py deleted file mode 100755 index 4a502788..00000000 --- a/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/app.py +++ /dev/null @@ -1,10 +0,0 @@ -from flask import Flask, render_template - -app = Flask(__name__) - -@app.route('/') -def index(): - return render_template('index.html') - -if __name__ == "__main__": - app.run(debug=False) \ No newline at end of file diff --git a/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/static/main.js b/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/static/main.js deleted file mode 100644 index 86e7dff8..00000000 --- a/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/static/main.js +++ /dev/null @@ -1,181 +0,0 @@ -// Simple status function to print messages to the page -function setStatus(status, new_line = true) { - document.getElementById('status').innerHTML += status; - if (new_line) { - document.getElementById('status').innerHTML += '
'; - } - console.log(status); -} - -// Define custom layer for computing mel spectrograms -class MelSpecLayerSimple extends tf.layers.Layer { - constructor(config) { - super(config); - - // Initialize parameters - this.sampleRate = config.sampleRate; - this.specShape = config.specShape; - this.frameStep = config.frameStep; - this.frameLength = config.frameLength; - this.fmin = config.fmin; - this.fmax = config.fmax; - this.melFilterbank = tf.tensor2d(config.melFilterbank); - } - - build(inputShape) { - // Initialize trainable weights, for example: - this.magScale = this.addWeight( - 'magnitude_scaling', - [], - 'float32', - tf.initializers.constant({ value: 1.23 }) - ); - - super.build(inputShape); - } - - // Compute the output shape of the layer - computeOutputShape(inputShape) { - return [inputShape[0], this.specShape[0], this.specShape[1], 1]; - } - - // Define the layer's forward pass - call(input) { - return tf.tidy(() => { - - // inputs is a tensor representing the input data - input = input[0].squeeze() - - // Normalize values between -1 and 1 - input = tf.sub(input, tf.min(input, -1, true)); - input = tf.div(input, tf.max(input, -1, true).add(0.000001)); - input = tf.sub(input, 0.5); - input = tf.mul(input, 2.0); - - // Perform STFT - let spec = tf.signal.stft( - input, - this.frameLength, - this.frameStep, - this.frameLength, - tf.signal.hannWindow, - ); - - // Cast from complex to float - spec = tf.cast(spec, 'float32'); - - // Apply mel filter bank - spec = tf.matMul(spec, this.melFilterbank); - - // Convert to power spectrogram - spec = spec.pow(2.0); - - // Apply nonlinearity - spec = spec.pow(tf.div(1.0, tf.add(1.0, tf.exp(this.magScale.read())))); - - // Flip the spectrogram - spec = tf.reverse(spec, -1); - - // Swap axes to fit input shape - spec = tf.transpose(spec) - - // Adding the channel dimension - spec = spec.expandDims(-1); - - // Adding batch dimension - spec = spec.expandDims(0); - - return spec; - }); - } - - // Optionally, include the `className` method to provide a machine-readable name for the layer - static get className() { - return 'MelSpecLayerSimple'; - } -} - -// Register the custom layer with TensorFlow.js -tf.serialization.registerClass(MelSpecLayerSimple); - -// Main function -async function run() { - await tf.ready(); // Make sure TensorFlow.js is fully loaded - - // Load model with custom layer (this can take a while the first time) - setStatus('Loading model...', false); - const model = await tf.loadLayersModel('static/model/model.json', {custom_objects: {'MelSpecLayerSimple': MelSpecLayerSimple}}); - setStatus('Done!') - - // Load labels - setStatus('Loading labels...', false); - const label_data = await fetch('static/model/labels.json'); - const labels = await label_data.json(); - setStatus('Done!') - - // Load the audio file - setStatus('Loading audio file...', false); - const response = await fetch('static/sample.wav'); - const arrayBuffer = await response.arrayBuffer(); - - // Decode the audio data - const audioCtx = new (window.AudioContext || window.webkitAudioContext)(); - const audioBuffer = await new Promise((resolve, reject) => { - audioCtx.decodeAudioData(arrayBuffer, resolve, reject); - }); - - // Read audio data - const audioData = audioBuffer.getChannelData(0); // Get data for the first channel - - // TODO: Resample and split audio data into chunks of 3 seconds - // to match the input shape of the model (1, 144000) - // For now, let's use a sample that is already 3 seconds long and sampled at 48 kHz - - // Create a tensor from the audio data - // Only works for batch size 1 for now - let input = tf.tensor(audioData).reshape([1, 144000]); - setStatus('Done!') - - // Run the prediction - setStatus('Running prediction...', false); - const prediction = model.predict(input); - setStatus('Done!') - - // Print top 3 probabilities and labels - setStatus('Results:'); - const probs = await prediction.data(); - const probs_sorted = probs.slice().sort().reverse(); - for (let i = 0; i < 3; i++) { - const index = probs.indexOf(probs_sorted[i]); - setStatus(labels[index] + ': ' + probs_sorted[i]); - } - - // Load metadata model - setStatus('
Loading metadata model...', false); - const metadata_model = await tf.loadGraphModel('static/model/mdata/model.json'); - setStatus('Done!') - - // Dummy location and week - const lat = 52.5; - const lon = 13.4; - const week = 42; - let mdata_input = tf.tensor([lat, lon, week]).expandDims(0); - - // Run the prediction - setStatus('Running mdata prediction...', false); - const mdata_prediction = metadata_model.predict(mdata_input); - setStatus('Done!') - - // Print top 10 probabilities and labels (labels are the same as for the audio model) - setStatus('Most common species @ (' + lat + '/' + lon + ') in week ' + week + ':'); - const mdata_probs = await mdata_prediction.data(); - const mdata_probs_sorted = mdata_probs.slice().sort().reverse(); - for (let i = 0; i < 10; i++) { - const index = mdata_probs.indexOf(mdata_probs_sorted[i]); - setStatus(labels[index] + ': ' + mdata_probs_sorted[i]); - } - -} - -// Run the function above after the page is fully loaded -window.addEventListener('load', run); diff --git a/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/static/sample.wav b/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/static/sample.wav deleted file mode 100755 index 002cbb01..00000000 Binary files a/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/static/sample.wav and /dev/null differ diff --git a/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/templates/index.html b/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/templates/index.html deleted file mode 100644 index 3c125413..00000000 --- a/BirdNET_GLOBAL_6K_V2.4_Model_TFJS/templates/index.html +++ /dev/null @@ -1,16 +0,0 @@ - - - - BirdNET TFJS Example - - -

BirdNET TFJS Example

- -

- - - - - - - \ No newline at end of file diff --git a/js/worker.js b/js/worker.js index fc8f8f0e..a21fb509 100644 --- a/js/worker.js +++ b/js/worker.js @@ -298,38 +298,32 @@ break; case "get-valid-species": {getValidSpecies(); break; } - case "get-locations": {getLocations({ - db: STATE.db, - file: args.file -}); + case "get-locations": { getLocations({ db: STATE.db, file: args.file }); break; } - case "get-valid-files-list": {await getFiles(args.files); + case "get-valid-files-list": { await getFiles(args.files); break; } - case "insert-manual-record": {const count = await onInsertManualRecord(args); + case "insert-manual-record": { await onInsertManualRecord(args); break; } - case "load-model": {SEEN_LABELS = false; -SEEN_MODEL_READY = false; -UI.postMessage({ - event: "spawning" -}); -if (args.model === "v3") { - BATCH_SIZE = 1; - sampleRate = 48_000; -} else { - BATCH_SIZE = parseInt(args.batchSize); - sampleRate = 24_000; -} -setAudioContext(sampleRate); -memoryDB = undefined; -BACKEND = args.backend; -STATE.update({ - model: args.model -}); -predictWorkers.length && terminateWorkers(); -spawnWorkers(args.model, args.list, BATCH_SIZE, args.threads); + case "load-model": { + SEEN_LABELS = false; + SEEN_MODEL_READY = false; + UI.postMessage({ + event: "spawning" + }); + sampleRate = args.model === "v2.4" ? 48_000 : 24_000; + + BATCH_SIZE = parseInt(args.batchSize); + setAudioContext(sampleRate); + memoryDB = undefined; + BACKEND = args.backend; + STATE.update({ + model: args.model + }); + predictWorkers.length && terminateWorkers(); + spawnWorkers(args.model, args.list, BATCH_SIZE, args.threads); break; } case "post": {await uploadOpus(args); @@ -339,7 +333,7 @@ break; break; } case "save": {console.log("file save requested"); -await saveAudio(args.file, args.start, args.end, args.filename, args.metadata); + await saveAudio(args.file, args.start, args.end, args.filename, args.metadata); break; } case "save2db": {await onSave2DiskDB(args); @@ -354,23 +348,23 @@ break; case "update-file-start": {await onUpdateFileStart(args); break; } - case "update-list": {UI.postMessage({ - event: "show-spinner" -}); -SEEN_LIST_UPDATE = false; -predictWorkers.forEach((worker) => worker.postMessage({ - message: "list", - list: args.list, - lat: STATE.lat, - lon: STATE.lon, - week: -1, - threshold: STATE.speciesThreshold -})); + case "update-list": { + UI.postMessage({ event: "show-spinner" }); + SEEN_LIST_UPDATE = false; + predictWorkers.forEach((worker) => worker.postMessage({ + message: "list", + list: args.list, + lat: STATE.lat, + lon: STATE.lon, + week: -1, + threshold: STATE.speciesThreshold + })); break; } - case "update-state": {TEMP = args.temp || TEMP; -appPath = args.path || appPath; -STATE.update(args); + case "update-state": { + TEMP = args.temp || TEMP; + appPath = args.path || appPath; + STATE.update(args); break; } default: {UI.postMessage("Worker communication lines open"); @@ -1218,7 +1212,7 @@ const getPredictBuffers = async ({ worker = workerInstance; } // Create array with 0's (short segment of silence that will trigger the finalChunk flag - const myArray = new Float32Array(new Array(chunkLength).fill(0)); + const myArray = new Float32Array(Array.from({length: chunkLength}).fill(0)); feedChunksToModel(myArray, chunkStart, file, end, worker); readStream.resume(); } @@ -2518,8 +2512,8 @@ const getChartTotals = ({ const getRate = (species) => { return new Promise(function (resolve, reject) { - const calls = new Array(52).fill(0); - const total = new Array(52).fill(0); + const calls = Array.from({length: 52}).fill(0); + const total = Array.from({length: 52}).fill(0); // Add Location filter const locationFilter = filterLocation(); @@ -2734,7 +2728,7 @@ async function onChartRequest(args) { const count = rows[i].count; // stack years if (!(year in results)) { - results[year] = new Array(dataPoints).fill(0); + results[year] = Array.from({length: dataPoints}).fill(0); } if (aggregation === 'Week') { results[year][parseInt(week) - 1] = count; diff --git a/package.json b/package.json index 583b0d0a..3d4d0d6c 100644 --- a/package.json +++ b/package.json @@ -35,6 +35,8 @@ "!*fixed_roll*${/*}", "!Help/example.mp3", "!venv${/*}", + "!test${/*}", + "!custom_tfjs${/*}", "!*git*${/*}", "!package-lock.json", "!poetry.lock*",