Skip to content

Commit

Permalink
Merge pull request #655 from inexorabletash/bugfix-lstm-gru-variables
Browse files Browse the repository at this point in the history
gru/lstm: Fix missing/incorrect variables and options use
  • Loading branch information
huningxin authored Apr 24, 2024
2 parents 2d06961 + 1661bf9 commit 0df0296
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions index.bs
Original file line number Diff line number Diff line change
Expand Up @@ -3081,6 +3081,8 @@ partial interface MLGraphBuilder {
1. If its [=MLOperand/rank=] is not 3, then [=exception/throw=] a {{TypeError}}.
1. If |options|.{{MLGruOptions/activations}} [=map/exists=] and its [=list/size=] is not 2, then [=exception/throw=] a {{TypeError}}.
1. If |steps| is not equal to |input|'s [=MLOperand/shape=][0], then [=exception/throw=] a {{TypeError}}.
1. Let |batchSize| be |input|'s [=MLOperand/shape=][1].
1. Let |numDirections| be 2 if |options|.{{MLGruOptions/direction}} is {{MLRecurrentNetworkDirection/"both"}}, or 1 otherwise.
1. *Calculate the output shape:*
1. Let |desc0| be a new {{MLOperandDescriptor}}.
1. Set |desc0|.{{MLOperandDescriptor/dimensions}} to the [=/list=] « |numDirections|, |batchSize|, |hiddenSize| ».
Expand All @@ -3103,7 +3105,7 @@ partial interface MLGraphBuilder {
1. If |options|.{{MLGruOptions/bias}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLGruOptions/recurrentBias}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLGruOptions/initialHiddenState}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. Add |options|.{{MLGruOptions/activations}} to |operator|'s [=operator/activation functions=].
1. If |options|.{{MLGruOptions/activations}} [=map/exists=], then add its [=list/items=] to |operator|'s [=operator/activation functions=].
1. Set |operator|'s [=operator/output=] to |output|.
1. Return |output|.
</details>
Expand Down Expand Up @@ -3258,7 +3260,7 @@ partial interface MLGraphBuilder {
1. Set |operator|'s [=operator/inputs=] to |input|, |weight|, |recurrentWeight|, and |hiddenState|.
1. If |options|.{{MLGruCellOptions/bias}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLGruCellOptions/recurrentBias}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. Add |options|.{{MLGruCellOptions/activations}} to |operator|'s [=operator/activation functions=].
1. If |options|.{{MLGruCellOptions/activations}} [=map/exists=], then add its [=list/items=] to |operator|'s [=operator/activation functions=].
1. Set |operator|'s [=operator/output=] to |output|.
1. Return |output|.
</details>
Expand Down Expand Up @@ -3961,7 +3963,7 @@ partial interface MLGraphBuilder {
</summary>
1. If [=MLGraphBuilder/validating operand=] with [=this=] and any of |input|, |weight|, |recurrentWeight|, |options|.{{MLLstmOptions/bias}} (if it [=map/exists=]), |options|.{{MLLstmOptions/recurrentBias}} (if it [=map/exists=]), |options|.{{MLLstmOptions/peepholeWeight}} (if it [=map/exists=]), |options|.{{MLLstmOptions/initialHiddenState}} (if it [=map/exists=]), and |options|.{{MLLstmOptions/initialCellState}} (if it [=map/exists=]) returns false, then [=exception/throw=] a {{TypeError}}.
1. If |options|.{{MLLstmOptions/activations}} [=map/exists=], and [=MLGraphBuilder/validating activation=] with [=this=] and any [=list/item=] in it returns false, then [=exception/throw=] a {{TypeError}}.
1. Let |numDirections| be 1 if |options|.{{MLLstmOptions/direction}} is {{MLRecurrentNetworkDirection/"forward"}}, or otherwise let it be 2.
1. Let |numDirections| be 2 if |options|.{{MLLstmOptions/direction}} is {{MLRecurrentNetworkDirection/"both"}}, or 1 otherwise.
1. If the [=MLOperand/rank=] of any of |input|, |weight| or |recurrentWeight| is not 3, then [=exception/throw=] a {{TypeError}}.
1. If |input|'s [=MLOperand/shape=][0] is not equal to |steps|, then [=exception/throw=] a {{TypeError}}.
1. Let |batchSize| be |input|'s [=MLOperand/shape=][1].
Expand Down Expand Up @@ -4014,6 +4016,7 @@ partial interface MLGraphBuilder {
1. If |options|.{{MLLstmOptions/peepholeWeight}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLLstmOptions/initialHiddenState}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLLstmOptions/initialCellState}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLLstmOptions/activations}} [=map/exists=], then add its [=list/items=] to |operator|'s [=operator/activation functions=].
1. Set |operator|'s [=operator/output=] to |output|.
1. Return |output|.
</details>
Expand Down Expand Up @@ -4194,6 +4197,7 @@ partial interface MLGraphBuilder {
1. If |options|.{{MLLstmCellOptions/bias}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLLstmCellOptions/recurrentBias}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLLstmCellOptions/peepholeWeight}} [=map/exists=], then add it to |operator|'s [=operator/inputs=].
1. If |options|.{{MLLstmCellOptions/activations}} [=map/exists=], then add its [=list/items=] to |operator|'s [=operator/activation functions=].
1. Set |operator|'s [=operator/output=] to |output|.
1. Return |output|.
</details>
Expand Down

0 comments on commit 0df0296

Please sign in to comment.