diff --git a/source/keras.js b/source/keras.js index b5d445dd6b1..506b581632e 100644 --- a/source/keras.js +++ b/source/keras.js @@ -756,20 +756,34 @@ keras.Node = class { const innerType = this.inner ? this.inner.type : null; const innerMetadata = innerType ? metadata.type(innerType) : null; let inputIndex = 0; + const typeInputs = []; + if (this._type && this._type.inputs) { + for (let i = 0; i < this._type.inputs.length; i++) { + if (i === 0 && layer.inbound_nodes) { + for (let j = 0; j < layer.inbound_nodes.length; j++) { + let inputType = JSON.parse(JSON.stringify(this._type.inputs[i])); + inputType.name = inputType.name + " edge " + j; + typeInputs.push(inputType); + } + } else { + typeInputs.push(this._type.inputs[i]); + } + } + } while (inputs.length > 0) { let list = false; let name = null; let visible = true; if (!innerMetadata || inputIndex == 0) { - if (this._type && this._type.inputs && inputIndex < this._type.inputs.length) { - const input = this._type.inputs[inputIndex]; + if (this._type && typeInputs && inputIndex < typeInputs.length) { + const input = typeInputs[inputIndex]; name = input.name; if (type === 'BatchNormalization' && name === 'gamma' && config.scale === false) { inputIndex++; continue; } visible = input.visible == false ? false : true; - if (this._type.inputs[inputIndex].list) { + if (typeInputs[inputIndex].list) { list = true; } } @@ -834,7 +848,7 @@ keras.Node = class { this._outputs.push(argument); } - const inputTypes = new Map((this._type.inputs || []).map((input) => [ input.name, input.type ])); + const inputTypes = new Map((typeInputs || []).map((input) => [ input.name, input.type ])); for (const entry of Object.entries(args)) { const name = entry[0]; const arg = entry[1];