Skip to content

Commit

Permalink
fix layer shring inputs and outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Chizkiyahu committed Oct 7, 2023
1 parent 52bcce9 commit 3c22c20
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions source/keras.js
Original file line number Diff line number Diff line change
Expand Up @@ -755,27 +755,46 @@ keras.Node = class {

const innerType = this.inner ? this.inner.type : null;
const innerMetadata = innerType ? metadata.type(innerType) : null;

// handle layer sharing
let inputIndex = 0;
let inputTypes = [];
const outputTypes = [];
if (this._type && this._type.inputs) {
for (let i = 0; i < this._type.inputs.length; i++) {
if (i === 0 && layer.inbound_nodes) {
if (layer.inbound_nodes && layer.inbound_nodes.length > 1 && layer.inbound_nodes[0].length > i) {
for (let j = 0; j < layer.inbound_nodes.length; j++) {
const inputType = JSON.parse(JSON.stringify(this._type.inputs[i]));
inputType.name = inputType.name + " edge " + j;
inputType.name = j + ": "+ inputType.name;
inputTypes.push(inputType);
}
} else {
inputTypes.push(this._type.inputs[i]);
}
}
}
if (this._type && this._type.outputs) {
for (let i = 0; i < this._type.outputs.length; i++) {
if (layer.inbound_nodes && layer.inbound_nodes.length > 1 && layer.inbound_nodes[0].length > i) {
for (let j = 0; j < layer.inbound_nodes.length; j++) {
const outputType = JSON.parse(JSON.stringify(this._type.outputs[i]));
outputType.name = j + ": "+ outputType.name;
outputTypes.push(outputType);
}
} else {
outputTypes.push(this._type.inputs[i]);
}
}
}
const inbound_nodes_size = layer.inbound_nodes ? layer.inbound_nodes.length : 1;
outputs = Array(inbound_nodes_size).fill(outputs).flat();

while (inputs.length > 0) {
let list = false;
let name = null;
let visible = true;
if (!innerMetadata || inputIndex == 0) {
if (this._type && inputTypes && inputIndex < inputTypes.length) {
if (inputTypes && inputIndex < inputTypes.length) {
const input = inputTypes[inputIndex];
name = input.name;
if (type === 'BatchNormalization' && name === 'gamma' && config.scale === false) {
Expand Down Expand Up @@ -813,7 +832,8 @@ keras.Node = class {
break;
}
}
const input = !list ? [ inputs.shift() ] : inputs.splice(0, inputs.length);
const size = !layer.inbound_nodes ? 1 : layer.inbound_nodes.length;
const input = !list ? [ inputs.shift() ] : inputs.splice(0, size);
const inputArguments = input.map((input) => {
if (input.name) {
return value(input.name, null, initializers[input.name]);
Expand Down Expand Up @@ -842,7 +862,7 @@ keras.Node = class {

for (let i = 0; i < outputs.length; i++) {
const output = outputs[i];
const outputName = (this._type && this._type.outputs && i < this._type.outputs.length && this._type.outputs[i] && this._type.outputs[i].name) ? this._type.outputs[i].name : i.toString();
const outputName = (outputTypes[i] && outputTypes[i].name) ? outputTypes[i].name : i.toString();
const args = output.length === 0 ? [] : [ value(output) ];
const argument = new keras.Argument(outputName, true, args);
this._outputs.push(argument);
Expand Down

0 comments on commit 3c22c20

Please sign in to comment.