From 3c22c20ca661b76eb8fbfed65221839b53f02c77 Mon Sep 17 00:00:00 2001 From: Chizkiyahu Date: Sat, 7 Oct 2023 21:21:36 +0300 Subject: [PATCH] fix layer shring inputs and outputs --- source/keras.js | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/source/keras.js b/source/keras.js index b864520d4e5..f03afda5056 100644 --- a/source/keras.js +++ b/source/keras.js @@ -755,14 +755,17 @@ keras.Node = class { const innerType = this.inner ? this.inner.type : null; const innerMetadata = innerType ? metadata.type(innerType) : null; + + // handle layer sharing let inputIndex = 0; let inputTypes = []; + const outputTypes = []; if (this._type && this._type.inputs) { for (let i = 0; i < this._type.inputs.length; i++) { - if (i === 0 && layer.inbound_nodes) { + if (layer.inbound_nodes && layer.inbound_nodes.length > 1 && layer.inbound_nodes[0].length > i) { for (let j = 0; j < layer.inbound_nodes.length; j++) { const inputType = JSON.parse(JSON.stringify(this._type.inputs[i])); - inputType.name = inputType.name + " edge " + j; + inputType.name = j + ": "+ inputType.name; inputTypes.push(inputType); } } else { @@ -770,12 +773,28 @@ keras.Node = class { } } } + if (this._type && this._type.outputs) { + for (let i = 0; i < this._type.outputs.length; i++) { + if (layer.inbound_nodes && layer.inbound_nodes.length > 1 && layer.inbound_nodes[0].length > i) { + for (let j = 0; j < layer.inbound_nodes.length; j++) { + const outputType = JSON.parse(JSON.stringify(this._type.outputs[i])); + outputType.name = j + ": "+ outputType.name; + outputTypes.push(outputType); + } + } else { + outputTypes.push(this._type.inputs[i]); + } + } + } + const inbound_nodes_size = layer.inbound_nodes ? layer.inbound_nodes.length : 1; + outputs = Array(inbound_nodes_size).fill(outputs).flat(); + while (inputs.length > 0) { let list = false; let name = null; let visible = true; if (!innerMetadata || inputIndex == 0) { - if (this._type && inputTypes && inputIndex < inputTypes.length) { + if (inputTypes && inputIndex < inputTypes.length) { const input = inputTypes[inputIndex]; name = input.name; if (type === 'BatchNormalization' && name === 'gamma' && config.scale === false) { @@ -813,7 +832,8 @@ keras.Node = class { break; } } - const input = !list ? [ inputs.shift() ] : inputs.splice(0, inputs.length); + const size = !layer.inbound_nodes ? 1 : layer.inbound_nodes.length; + const input = !list ? [ inputs.shift() ] : inputs.splice(0, size); const inputArguments = input.map((input) => { if (input.name) { return value(input.name, null, initializers[input.name]); @@ -842,7 +862,7 @@ keras.Node = class { for (let i = 0; i < outputs.length; i++) { const output = outputs[i]; - const outputName = (this._type && this._type.outputs && i < this._type.outputs.length && this._type.outputs[i] && this._type.outputs[i].name) ? this._type.outputs[i].name : i.toString(); + const outputName = (outputTypes[i] && outputTypes[i].name) ? outputTypes[i].name : i.toString(); const args = output.length === 0 ? [] : [ value(output) ]; const argument = new keras.Argument(outputName, true, args); this._outputs.push(argument);