Skip to content

Commit

Permalink
keras fix layer sharing inputs and outputs names
Browse files Browse the repository at this point in the history
fix layer shring inputs and outputs
  • Loading branch information
Chizkiyahu committed Oct 7, 2023
1 parent 480261e commit 88eb579
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions source/keras.js
Original file line number Diff line number Diff line change
Expand Up @@ -788,21 +788,54 @@ keras.Node = class {

const innerType = this.inner ? this.inner.type : null;
const innerMetadata = innerType ? metadata.type(innerType) : null;

// handle layer sharing
let inputIndex = 0;
let inputTypes = [];
const outputTypes = [];
if (this._type && this._type.inputs) {
for (let i = 0; i < this._type.inputs.length; i++) {
if (layer.inbound_nodes && layer.inbound_nodes.length > 1 && layer.inbound_nodes[0].length > i) {
for (let j = 0; j < layer.inbound_nodes.length; j++) {
const inputType = JSON.parse(JSON.stringify(this._type.inputs[i]));
inputType.name = j + ": "+ inputType.name;
inputTypes.push(inputType);
}
} else {
inputTypes.push(this._type.inputs[i]);
}
}
}
if (this._type && this._type.outputs) {
for (let i = 0; i < this._type.outputs.length; i++) {
if (layer.inbound_nodes && layer.inbound_nodes.length > 1 && layer.inbound_nodes[0].length > i) {
for (let j = 0; j < layer.inbound_nodes.length; j++) {
const outputType = JSON.parse(JSON.stringify(this._type.outputs[i]));
outputType.name = j + ": "+ outputType.name;
outputTypes.push(outputType);
}
} else {
outputTypes.push(this._type.inputs[i]);
}
}
}
const inbound_nodes_size = layer.inbound_nodes ? layer.inbound_nodes.length : 1;
outputs = Array(inbound_nodes_size).fill(outputs).flat();

while (inputs.length > 0) {
let list = false;
let name = null;
let visible = true;
if (!innerMetadata || inputIndex == 0) {
if (this._type && this._type.inputs && inputIndex < this._type.inputs.length) {
const input = this._type.inputs[inputIndex];
if (inputTypes && inputIndex < inputTypes.length) {
const input = inputTypes[inputIndex];
name = input.name;
if (type === 'BatchNormalization' && name === 'gamma' && config.scale === false) {
inputIndex++;
continue;
}
visible = input.visible == false ? false : true;
if (this._type.inputs[inputIndex].list) {
if (inputTypes[inputIndex].list) {
list = true;
}
}
Expand Down Expand Up @@ -832,7 +865,8 @@ keras.Node = class {
break;
}
}
const input = !list ? [ inputs.shift() ] : inputs.splice(0, inputs.length);
const size = !layer.inbound_nodes ? 1 : layer.inbound_nodes.length;
const input = !list ? [ inputs.shift() ] : inputs.splice(0, size);
const inputArguments = input.map((input) => {
if (input.name) {
return values.map(input.name, null, initializers[input.name]);
Expand Down Expand Up @@ -861,12 +895,12 @@ keras.Node = class {

for (let i = 0; i < outputs.length; i++) {
const output = outputs[i];
const outputName = (this._type && this._type.outputs && i < this._type.outputs.length && this._type.outputs[i] && this._type.outputs[i].name) ? this._type.outputs[i].name : i.toString();
const outputName = (outputTypes[i] && outputTypes[i].name) ? outputTypes[i].name : i.toString();
const argument = new keras.Argument(outputName, true, output.length === 0 ? [] : [ values.map(output) ]);
this._outputs.push(argument);
}

const inputTypes = new Map((this._type.inputs || []).map((input) => [ input.name, input.type ]));
inputTypes = new Map((inputTypes || []).map((input) => [ input.name, input.type ]));
for (const entry of Object.entries(args)) {
const name = entry[0];
const arg = entry[1];
Expand Down

0 comments on commit 88eb579

Please sign in to comment.