Skip to content

Commit

Permalink
keras fix layer sharing inputs names
Browse files Browse the repository at this point in the history
  • Loading branch information
Chizkiyahu committed Oct 6, 2023
1 parent d83c4c0 commit 52bcce9
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions source/keras.js
Original file line number Diff line number Diff line change
Expand Up @@ -756,20 +756,34 @@ keras.Node = class {
const innerType = this.inner ? this.inner.type : null;
const innerMetadata = innerType ? metadata.type(innerType) : null;
let inputIndex = 0;
let inputTypes = [];
if (this._type && this._type.inputs) {
for (let i = 0; i < this._type.inputs.length; i++) {
if (i === 0 && layer.inbound_nodes) {
for (let j = 0; j < layer.inbound_nodes.length; j++) {
const inputType = JSON.parse(JSON.stringify(this._type.inputs[i]));
inputType.name = inputType.name + " edge " + j;
inputTypes.push(inputType);
}
} else {
inputTypes.push(this._type.inputs[i]);
}
}
}
while (inputs.length > 0) {
let list = false;
let name = null;
let visible = true;
if (!innerMetadata || inputIndex == 0) {
if (this._type && this._type.inputs && inputIndex < this._type.inputs.length) {
const input = this._type.inputs[inputIndex];
if (this._type && inputTypes && inputIndex < inputTypes.length) {
const input = inputTypes[inputIndex];
name = input.name;
if (type === 'BatchNormalization' && name === 'gamma' && config.scale === false) {
inputIndex++;
continue;
}
visible = input.visible == false ? false : true;
if (this._type.inputs[inputIndex].list) {
if (inputTypes[inputIndex].list) {
list = true;
}
}
Expand Down Expand Up @@ -834,7 +848,7 @@ keras.Node = class {
this._outputs.push(argument);
}

const inputTypes = new Map((this._type.inputs || []).map((input) => [ input.name, input.type ]));
inputTypes = new Map((inputTypes || []).map((input) => [ input.name, input.type ]));
for (const entry of Object.entries(args)) {
const name = entry[0];
const arg = entry[1];
Expand Down

0 comments on commit 52bcce9

Please sign in to comment.