From e613122c01de6d24de30301ec24d7d7ee57a3368 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Wed, 29 Nov 2023 13:53:19 +0100 Subject: [PATCH 01/14] WIP on RNN --- ops/opset13/rnn.go | 157 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 ops/opset13/rnn.go diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go new file mode 100644 index 0000000..a0235e7 --- /dev/null +++ b/ops/opset13/rnn.go @@ -0,0 +1,157 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinRNNInputs = 3 + MaxRNNInputs = 6 +) + +type RNNDirection string + +const ( + Forward RNNDirection = "forward" + Reverse RNNDirection = "reverse" + Bidirectional RNNDirection = "bidirectional" +) + +// RNN represents the ONNX rnn operator. +type RNN struct { + activationAlpha []float32 + activationBeta []float32 + activations []string + clip float32 + direction RNNDirection + hiddenSize int +} + +// newRNN creates a new rnn operator. +func newRNN() ops.Operator { + return &RNN{} +} + +// Init initializes the rnn operator. +func (r *RNN) Init(attributes []*onnx.AttributeProto) error { + for _, attr := range attributes { + switch attr.GetName() { + case "activation_alpha": + r.activationAlpha = attr.GetFloats() + case "activation_beta": + r.activationBeta = attr.GetFloats() + case "activations": + activations := []string{} + for _, activation := range attr.GetStrings() { + activations = append(activations, string(activation)) + } + + r.activations = activations + case "clip": + r.clip = attr.GetF() + case "direction": + r.direction = RNNDirection(attr.GetS()) + case "hidden_size": + r.hiddenSize = int(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), r) + } + } + + return nil +} + +// Apply applies the rnn operator. +func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + X := inputs[0] + W := inputs[1] + R := inputs[2] + B := inputs[3] + + if inputs[4] != nil { + return nil, ops.ErrUnsupportedInput("sequence lens", r) + } + + initialH := inputs[5] + seqLength := X.Shape()[0] + batchSize := X.Shape()[1] + + Wz, Wr, Wh, err := r.getForwardWeights(W) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (r *RNN) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(a, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (r *RNN) GetMinInputs() int { + return MinRNNInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (r *RNN) GetMaxInputs() int { + return MaxRNNInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (r *RNN) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (r *RNN) String() string { + return "rnn operator" +} + +// getForwardWeights returns the weights for the gate. +func (r *RNN) getForwardWeights(W tensor.Tensor) (Wz, Wr, Wh tensor.Tensor, err error) { + n, err := r.extractWeights(W) + if err != nil { + return nil, nil, nil, err + } + + return n[0], n[1], n[2], nil +} + +// extractWeights extracts 1 or 2 weight tensors from node W. +// W contains all 2 weight tensors concatenated on top of each other in the following order: +// +// forward weights: [Wi, Wbi] +// recurrent weights: [Ri, Rbi] +// +// W will have a shape of (num_directions, 2 * hidden_size, ...) and we extract the +// by slicing over the '2 * hidden_size' dimension. +func (r *RNN) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { + nWeightMatrices := 1 + if r.direction == Bidirectional { + nWeightMatrices = 2 + } + + dirSlice := ops.NewSlicer(0) + weights := make([]tensor.Tensor, nWeightMatrices) + + for i := 0; i < nWeightMatrices; i++ { + slice := ops.NewSlicer(i*r.hiddenSize, (i+1)*r.hiddenSize) + + w, err := W.Slice(dirSlice, slice, nil) + if err != nil { + return nil, err + } + + weights[i] = w + } + + return weights, nil +} From 699204066ba7531ce3e0e89732e15aeb0e9fed8e Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Wed, 29 Nov 2023 15:29:43 +0100 Subject: [PATCH 02/14] WIP on RNN --- ops/opset13/rnn.go | 66 +++++++++++++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go index a0235e7..cbd8828 100644 --- a/ops/opset13/rnn.go +++ b/ops/opset13/rnn.go @@ -31,7 +31,10 @@ type RNN struct { // newRNN creates a new rnn operator. func newRNN() ops.Operator { - return &RNN{} + return &RNN{ + activations: []string{"tanh"}, + direction: Forward, + } } // Init initializes the rnn operator. @@ -53,6 +56,9 @@ func (r *RNN) Init(attributes []*onnx.AttributeProto) error { r.clip = attr.GetF() case "direction": r.direction = RNNDirection(attr.GetS()) + if r.direction != Forward { + return ops.ErrUnsupportedAttribute(attr.GetName(), r) + } case "hidden_size": r.hiddenSize = int(attr.GetI()) default: @@ -65,22 +71,30 @@ func (r *RNN) Init(attributes []*onnx.AttributeProto) error { // Apply applies the rnn operator. func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - X := inputs[0] - W := inputs[1] - R := inputs[2] - B := inputs[3] - if inputs[4] != nil { return nil, ops.ErrUnsupportedInput("sequence lens", r) } + X := inputs[0] + + Wi, err := r.getWeights(inputs[1]) + if err != nil { + return nil, err + } + + Ri, err := r.getWeights(inputs[2]) + if err != nil { + return nil, err + } + initialH := inputs[5] + seqLength := X.Shape()[0] batchSize := X.Shape()[1] - Wz, Wr, Wh, err := r.getForwardWeights(W) - if err != nil { - return nil, err + B := inputs[3] + if B != nil { + // TODO: bias stuff } return []tensor.Tensor{out}, nil @@ -115,24 +129,40 @@ func (r *RNN) String() string { return "rnn operator" } -// getForwardWeights returns the weights for the gate. -func (r *RNN) getForwardWeights(W tensor.Tensor) (Wz, Wr, Wh tensor.Tensor, err error) { - n, err := r.extractWeights(W) +// getWeights returns the weights from a concatenated weight tensor. The result is +// a single weight matrix. W has shape (num_directions, hidden_size, ...). +// We do not support bidirectional layers, so we can simply index the first element +// of W to get the weights for either the input or the recurrence +func (r *RNN) getWeights(X tensor.Tensor) (tensor.Tensor, error) { + weights, err := X.Slice(ops.NewSlicer(0), nil, nil) if err != nil { - return nil, nil, nil, err + return nil, err + } + + return weights, nil +} + +// getRecurrentWeights returns the weights for the recurrence. This consists of a single +// weight matrix. W has shape (num_directions, hidden_size, ...). +// We do not support bidirectional layers, so we can simply index the first element +// of W to get the weights for the input gate. +func (r *RNN) getInputWeights(W tensor.Tensor) (tensor.Tensor, error) { + Wi, err := W.Slice(ops.NewSlicer(0), nil, nil) + if err != nil { + return nil, err } - return n[0], n[1], n[2], nil + return Wi, nil } -// extractWeights extracts 1 or 2 weight tensors from node W. -// W contains all 2 weight tensors concatenated on top of each other in the following order: +// extractWeights extracts 1-2 weight tensors from tensor W. +// W contains all 1-2 weight tensors concatenated on top of each other in the following order: // // forward weights: [Wi, Wbi] // recurrent weights: [Ri, Rbi] // -// W will have a shape of (num_directions, 2 * hidden_size, ...) and we extract the -// by slicing over the '2 * hidden_size' dimension. +// W will have a shape of (num_directions, (1 or 2) * hidden_size, ...) and we extract the +// by slicing over the '(1 or 2) * hidden_size' dimension. func (r *RNN) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { nWeightMatrices := 1 if r.direction == Bidirectional { From b938a25a070e4c0cc5141a0facd3f5a200639813 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Thu, 30 Nov 2023 15:09:02 +0100 Subject: [PATCH 03/14] Working RNN version --- ops/activation.go | 19 ++++- ops/opset13/opset13.go | 1 + ops/opset13/opset13_test.go | 25 +++--- ops/opset13/relu.go | 19 +---- ops/opset13/rnn.go | 166 +++++++++++++++++++++++++++--------- ops_test.go | 2 +- 6 files changed, 165 insertions(+), 67 deletions(-) diff --git a/ops/activation.go b/ops/activation.go index 89aca38..ebca570 100644 --- a/ops/activation.go +++ b/ops/activation.go @@ -1,6 +1,8 @@ package ops -import "gorgonia.org/tensor" +import ( + "gorgonia.org/tensor" +) // Activation is an activation function. type Activation func(n tensor.Tensor) (tensor.Tensor, error) @@ -34,3 +36,18 @@ func Sigmoid(X tensor.Tensor) (tensor.Tensor, error) { return tensor.Div(typedOne, numeratorX) } + +// ReLU performs the ReLU operation on a tensor. +func ReLU(X tensor.Tensor) (tensor.Tensor, error) { + typedZero, err := GetValueAsTensorType(0.0, X.Dtype()) + if err != nil { + return nil, err + } + + comparison, err := tensor.Gt(X, typedZero, tensor.AsSameType()) + if err != nil { + return nil, err + } + + return tensor.Mul(X, comparison) +} diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 6d00326..92d7853 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -30,6 +30,7 @@ var operators13 = map[string]func() ops.Operator{ "PRelu": newPRelu, "Relu": newRelu, "Reshape": newReshape, + "RNN": newRNN, "Scaler": newScaler, "Shape": newShape, "Sigmoid": newSigmoid, diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index 48c33ed..109ac55 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -33,16 +33,16 @@ func TestGetOperator(t *testing.T) { newAdd(), nil, }, - { - "Atan", - newAtan(), - nil, - }, - { - "Atanh", - newAtanh(), - nil, - }, + { + "Atan", + newAtan(), + nil, + }, + { + "Atanh", + newAtanh(), + nil, + }, { "Asin", newAsin(), @@ -133,6 +133,11 @@ func TestGetOperator(t *testing.T) { newReshape(), nil, }, + { + "RNN", + newRNN(), + nil, + }, { "Scaler", newScaler(), diff --git a/ops/opset13/relu.go b/ops/opset13/relu.go index 8a169c2..8c8ed8d 100644 --- a/ops/opset13/relu.go +++ b/ops/opset13/relu.go @@ -21,24 +21,9 @@ func (r *Relu) Init(_ []*onnx.AttributeProto) error { // Apply applies the relu operator. func (r *Relu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - t := inputs[0] + out, err := ops.ReLU(inputs[0]) - typedZero, err := ops.GetValueAsTensorType(0.0, t.Dtype()) - if err != nil { - return nil, err - } - - comparison, err := tensor.Gt(t, typedZero, tensor.AsSameType()) - if err != nil { - return nil, err - } - - out, err := tensor.Mul(t, comparison) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil + return []tensor.Tensor{out}, err } // ValidateInputs validates the inputs that will be given to Apply for this operator. diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go index cbd8828..affbf38 100644 --- a/ops/opset13/rnn.go +++ b/ops/opset13/rnn.go @@ -19,6 +19,12 @@ const ( Bidirectional RNNDirection = "bidirectional" ) +var RNNActivations = map[string]ops.Activation{ + "Tanh": ops.Tanh, + "Sigmoid": ops.Tanh, + "ReLU": ops.ReLU, +} + // RNN represents the ONNX rnn operator. type RNN struct { activationAlpha []float32 @@ -87,22 +93,79 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - initialH := inputs[5] + B := inputs[3] + if B == nil { + B = r.getDefaultB() + } + + Wbi, Rbi, err := r.getBiases(B) + if err != nil { + return nil, err + } seqLength := X.Shape()[0] batchSize := X.Shape()[1] - B := inputs[3] - if B != nil { - // TODO: bias stuff + prevH := inputs[5] + if prevH == nil { + prevH = r.getInitialH(batchSize) + } + + // Extract the shape of the hidden dimensions without the bidirectional dimension, as + // we do not support bidirectional RNN yet. + if err = prevH.Reshape(prevH.Shape().Clone()[1:]...); err != nil { + return nil, err } - return []tensor.Tensor{out}, nil + outputs := []tensor.Tensor{} + + for t := 0; t < seqLength; t++ { + Xt, err := X.Slice(ops.NewSlicer(t, t+1), nil, nil) + if err != nil { + return nil, err + } + + prevH, err = r.layerCalculation(Xt, prevH, Wi, Ri, Wbi, Rbi, ops.Tanh) + if err != nil { + return nil, err + } + + outputs = append(outputs, prevH) + } + + var Y tensor.Tensor + if len(outputs) > 1 { + Y, err = tensor.Concat(0, outputs[0], outputs[1:]...) + if err != nil { + return nil, err + } + } else { + Y = outputs[0] + } + + // Reshape the output so it adds the num_directions as specified by onnx. + err = Y.Reshape(seqLength, 1, batchSize, r.hiddenSize) + if err != nil { + return nil, err + } + + Yh, ok := prevH.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", prevH.Clone()) + } + + // Reshape the output so it adds the num_directions as specified by onnx. + err = Yh.Reshape(1, batchSize, r.hiddenSize) + if err != nil { + return nil, err + } + + return []tensor.Tensor{Y, Yh}, nil } // ValidateInputs validates the inputs that will be given to Apply for this operator. func (r *RNN) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) + return ops.ValidateInputs(r, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. @@ -119,8 +182,12 @@ func (r *RNN) GetMaxInputs() int { // for the corresponding input tensor. func (r *RNN) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Int32}, + {tensor.Float32, tensor.Float64}, } } @@ -132,7 +199,7 @@ func (r *RNN) String() string { // getWeights returns the weights from a concatenated weight tensor. The result is // a single weight matrix. W has shape (num_directions, hidden_size, ...). // We do not support bidirectional layers, so we can simply index the first element -// of W to get the weights for either the input or the recurrence +// of W to get the weights for either the input or the recurrence. func (r *RNN) getWeights(X tensor.Tensor) (tensor.Tensor, error) { weights, err := X.Slice(ops.NewSlicer(0), nil, nil) if err != nil { @@ -142,46 +209,69 @@ func (r *RNN) getWeights(X tensor.Tensor) (tensor.Tensor, error) { return weights, nil } -// getRecurrentWeights returns the weights for the recurrence. This consists of a single -// weight matrix. W has shape (num_directions, hidden_size, ...). -// We do not support bidirectional layers, so we can simply index the first element -// of W to get the weights for the input gate. -func (r *RNN) getInputWeights(W tensor.Tensor) (tensor.Tensor, error) { - Wi, err := W.Slice(ops.NewSlicer(0), nil, nil) +func (r *RNN) getBiases(B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { + Wbi, err := B.Slice(ops.NewSlicer(0), ops.NewSlicer(0, r.hiddenSize)) if err != nil { - return nil, err + return nil, nil, err } - return Wi, nil + nBiasMatrices := 2 + + Rbi, err := B.Slice(ops.NewSlicer(0), ops.NewSlicer(r.hiddenSize, nBiasMatrices*r.hiddenSize)) + if err != nil { + return nil, nil, err + } + + return Wbi, Rbi, nil } -// extractWeights extracts 1-2 weight tensors from tensor W. -// W contains all 1-2 weight tensors concatenated on top of each other in the following order: +// getDefaultB returns the default bias tensor if no bias tensor is provided. +// The bias tensor for RNN consists of two concatenated bias tensors, one for +// the input calculation and one for the hidden calculation. It has shape: // -// forward weights: [Wi, Wbi] -// recurrent weights: [Ri, Rbi] +// (num_directions, 2*hiddenSize). // -// W will have a shape of (num_directions, (1 or 2) * hidden_size, ...) and we extract the -// by slicing over the '(1 or 2) * hidden_size' dimension. -func (r *RNN) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { - nWeightMatrices := 1 - if r.direction == Bidirectional { - nWeightMatrices = 2 - } +// By default all values are 0. +func (r *RNN) getDefaultB() tensor.Tensor { + nBiasMatrices := 2 + + return tensor.New( + tensor.WithShape(1, nBiasMatrices*r.hiddenSize), + tensor.WithBacking(ops.Zeros(nBiasMatrices*r.hiddenSize)), + ) +} - dirSlice := ops.NewSlicer(0) - weights := make([]tensor.Tensor, nWeightMatrices) +// getInitialH can be used to construct an initial hidden tensor when it is not +// specified by the inputs of the operator. In this case it is assumed to be 0. +// It has shape (num_directions, batch_size, hidden_size). +func (r *RNN) getInitialH(batchSize int) tensor.Tensor { + hiddenFloats := ops.Zeros(batchSize * r.hiddenSize) - for i := 0; i < nWeightMatrices; i++ { - slice := ops.NewSlicer(i*r.hiddenSize, (i+1)*r.hiddenSize) + return tensor.New( + tensor.WithShape(1, batchSize, r.hiddenSize), + tensor.WithBacking(hiddenFloats), + ) +} - w, err := W.Slice(dirSlice, slice, nil) - if err != nil { - return nil, err - } +func (r *RNN) layerCalculation( + Xt, H, Wi, Ri, Wbi, Rbi tensor.Tensor, activation ops.Activation, +) (tensor.Tensor, error) { + gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} - weights[i] = w + inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, Wi, Wbi}) + if err != nil { + return nil, err } - return weights, nil + hiddenCalc, err := gemm.Apply([]tensor.Tensor{H, Ri, Rbi}) + if err != nil { + return nil, err + } + + result, err := tensor.Add(inputCalc[0], hiddenCalc[0]) + if err != nil { + return nil, err + } + + return activation(result) } diff --git a/ops_test.go b/ops_test.go index b7986e7..2400c5e 100644 --- a/ops_test.go +++ b/ops_test.go @@ -129,7 +129,6 @@ func TestOps(t *testing.T) { assert.Nil(t, err) for _, test := range tests { - fmt.Println(test.inputs) t.Run(test.name, func(t *testing.T) { outputs, err := test.model.Run(test.inputs) assert.Nil(t, err) @@ -367,6 +366,7 @@ var expectedTests = []string{ "test_reshape_reordered_last_dims", "test_reshape_zero_and_negative_dim", "test_reshape_zero_dim", + "test_rnn_seq_length", "test_shape", "test_sin", "test_sin_example", From 70ac7714451d3b0c628a0aa0d0a9cde1719503fb Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Fri, 1 Dec 2023 12:25:07 +0100 Subject: [PATCH 04/14] Added tests for RNN --- ops/fixtures.go | 14 ++ ops/opset13/rnn.go | 77 ++++++---- ops/opset13/rnn_test.go | 332 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 392 insertions(+), 31 deletions(-) create mode 100644 ops/opset13/rnn_test.go diff --git a/ops/fixtures.go b/ops/fixtures.go index ba0d215..d2f993d 100644 --- a/ops/fixtures.go +++ b/ops/fixtures.go @@ -1,6 +1,8 @@ package ops import ( + "math/rand" + "gorgonia.org/tensor" ) @@ -16,6 +18,18 @@ func Float32TensorFixture(shp ...int) tensor.Tensor { ) } +func RandomFloat32TensorFixture(shp ...int) tensor.Tensor { + rands := make([]float32, NElements(shp...)) + for i := 0; i < NElements(shp...); i++ { + rands[i] = rand.Float32() + } + + return tensor.New( + tensor.WithShape(shp...), + tensor.WithBacking(rands), + ) +} + // TensorWithBackingFixture returns a gorgonia node with a tensor using the given backing. func TensorWithBackingFixture(b interface{}, shp ...int) tensor.Tensor { return tensor.New(tensor.WithShape(shp...), tensor.WithBacking(b)) diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go index affbf38..687d5a7 100644 --- a/ops/opset13/rnn.go +++ b/ops/opset13/rnn.go @@ -19,10 +19,11 @@ const ( Bidirectional RNNDirection = "bidirectional" ) +// These activations are supported in the RNN calculation. var RNNActivations = map[string]ops.Activation{ - "Tanh": ops.Tanh, - "Sigmoid": ops.Tanh, - "ReLU": ops.ReLU, + "tanh": ops.Tanh, + "sigmoid": ops.Sigmoid, + "relu": ops.ReLU, } // RNN represents the ONNX rnn operator. @@ -30,7 +31,6 @@ type RNN struct { activationAlpha []float32 activationBeta []float32 activations []string - clip float32 direction RNNDirection hiddenSize int } @@ -59,7 +59,7 @@ func (r *RNN) Init(attributes []*onnx.AttributeProto) error { r.activations = activations case "clip": - r.clip = attr.GetF() + return ops.ErrUnsupportedAttribute(attr.GetName(), r) case "direction": r.direction = RNNDirection(attr.GetS()) if r.direction != Forward { @@ -106,57 +106,59 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { seqLength := X.Shape()[0] batchSize := X.Shape()[1] - prevH := inputs[5] - if prevH == nil { - prevH = r.getInitialH(batchSize) + Ht := inputs[5] + if Ht == nil { + Ht = r.getInitialH(batchSize) } - // Extract the shape of the hidden dimensions without the bidirectional dimension, as - // we do not support bidirectional RNN yet. - if err = prevH.Reshape(prevH.Shape().Clone()[1:]...); err != nil { + // Reshape the hidden tensor without the bidirectional dimension, as + // we do not support bidirectional RNN yet. This is the first dimension. + if err = Ht.Reshape(Ht.Shape().Clone()[1:]...); err != nil { return nil, err } + activation := RNNActivations[r.activations[0]] + if activation == nil { + return nil, ops.ErrUnsupportedAttribute("activations", r) + } + outputs := []tensor.Tensor{} + // Loop over all timesteps of the input, applying the RNN calculation to every + // timesteps while updating the hidden tensor. for t := 0; t < seqLength; t++ { Xt, err := X.Slice(ops.NewSlicer(t, t+1), nil, nil) if err != nil { return nil, err } - prevH, err = r.layerCalculation(Xt, prevH, Wi, Ri, Wbi, Rbi, ops.Tanh) + Ht, err = r.layerCalculation(Xt, Ht, Wi, Ri, Wbi, Rbi, RNNActivations[r.activations[0]]) if err != nil { return nil, err } - outputs = append(outputs, prevH) + outputs = append(outputs, Ht) } - var Y tensor.Tensor + Y := outputs[0] if len(outputs) > 1 { - Y, err = tensor.Concat(0, outputs[0], outputs[1:]...) + Y, err = tensor.Concat(0, Y, outputs[1:]...) if err != nil { return nil, err } - } else { - Y = outputs[0] } - // Reshape the output so it adds the num_directions as specified by onnx. - err = Y.Reshape(seqLength, 1, batchSize, r.hiddenSize) - if err != nil { - return nil, err + Yh, ok := Ht.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", Ht.Clone()) } - Yh, ok := prevH.Clone().(tensor.Tensor) - if !ok { - return nil, ops.ErrTypeAssert("tensor.Tensor", prevH.Clone()) + // Reshape the outputs so it adds the num_directions as specified by onnx. + if err = Y.Reshape(seqLength, 1, batchSize, r.hiddenSize); err != nil { + return nil, err } - // Reshape the output so it adds the num_directions as specified by onnx. - err = Yh.Reshape(1, batchSize, r.hiddenSize) - if err != nil { + if err = Yh.Reshape(1, batchSize, r.hiddenSize); err != nil { return nil, err } @@ -209,6 +211,11 @@ func (r *RNN) getWeights(X tensor.Tensor) (tensor.Tensor, error) { return weights, nil } +// getBiases splits an input bias tensor B into its subparts. The B input for the +// RNN operator consists of two biases, Wbi and Rbi. These biases are concatenated +// in the second dimension, where B has shape (num_directions, 2 * hiddenSize). +// This function slices the B tensor to return 2 bias tensors. We disregard the +// num_directions axis as we do not support the bidirectional direction. func (r *RNN) getBiases(B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { Wbi, err := B.Slice(ops.NewSlicer(0), ops.NewSlicer(0, r.hiddenSize)) if err != nil { @@ -231,7 +238,8 @@ func (r *RNN) getBiases(B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { // // (num_directions, 2*hiddenSize). // -// By default all values are 0. +// By default all values are 0. Note that we do not support the bidirectional +// option so the first dim always has size 1. func (r *RNN) getDefaultB() tensor.Tensor { nBiasMatrices := 2 @@ -244,15 +252,22 @@ func (r *RNN) getDefaultB() tensor.Tensor { // getInitialH can be used to construct an initial hidden tensor when it is not // specified by the inputs of the operator. In this case it is assumed to be 0. // It has shape (num_directions, batch_size, hidden_size). +// As we do not support the birectional option, the num_directions dim size is +// always 1. func (r *RNN) getInitialH(batchSize int) tensor.Tensor { - hiddenFloats := ops.Zeros(batchSize * r.hiddenSize) - return tensor.New( tensor.WithShape(1, batchSize, r.hiddenSize), - tensor.WithBacking(hiddenFloats), + tensor.WithBacking(ops.Zeros(batchSize*r.hiddenSize)), ) } +// layerCalculation performs the actual RNN calculation. By ONNX definition +// this is: +// +// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) +// +// We achieve this by two Gemm operations, adding them together and finally +// putting them through an activation function. func (r *RNN) layerCalculation( Xt, H, Wi, Ri, Wbi, Rbi tensor.Tensor, activation ops.Activation, ) (tensor.Tensor, error) { diff --git a/ops/opset13/rnn_test.go b/ops/opset13/rnn_test.go new file mode 100644 index 0000000..5d56ed0 --- /dev/null +++ b/ops/opset13/rnn_test.go @@ -0,0 +1,332 @@ +package opset13 + +import ( + "math/rand" + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestRNNInit(t *testing.T) { + rnn := &RNN{} + err := rnn.Init(RNNOnnxAttributeProtoFixture()) + + assert.Nil(t, err) + assert.Equal(t, []float32{1.0}, rnn.activationAlpha) + assert.Equal(t, []float32{2.0}, rnn.activationBeta) + assert.Equal(t, []string{"sigmoid"}, rnn.activations) + assert.Equal(t, RNNDirection("forward"), rnn.direction) + assert.Equal(t, 5, rnn.hiddenSize) +} + +func TestRNNInitUnsupportedAttr(t *testing.T) { + rnn := RNN{} + err := rnn.Init([]*onnx.AttributeProto{{Name: "clip"}}) + assert.Equal(t, err, ops.ErrUnsupportedAttribute("clip", &rnn)) +} + +func TestRNNInitUnknownAttr(t *testing.T) { + rnn := RNN{} + err := rnn.Init([]*onnx.AttributeProto{{Name: "unknown"}}) + assert.Equal(t, err, ops.ErrInvalidAttribute("unknown", &rnn)) +} + +func TestRNN(t *testing.T) { + tests := []struct { + rnn *RNN + inputs ops.InputFixture + expected []float32 + err error + }{ + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"tanh"}, + direction: Forward, + hiddenSize: 4, + }, + rnnInput0, + []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid"}, + direction: Forward, + hiddenSize: 4, + }, + rnnInput0, + []float32{0.82048327, 0.922734, 0.89050114, 0.8620579}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"relu"}, + direction: Forward, + hiddenSize: 4, + }, + rnnInput0, + []float32{1.0667435, 2.328037, 1.7986122, 1.545068}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"tanh"}, + direction: Forward, + hiddenSize: 10, + }, + rnnInput1, + []float32{0.99996024, 0.9999855, 0.99998087, 0.9999288, 0.9997511, 0.99918234, 0.99999964, 0.9999981, 0.9997658, 0.9999618, 0.9998762, 0.9999353, 0.9999194, 0.9999428, 0.9997284, 0.9982606, 0.999999, 0.9999897, 0.99964744, 0.9998234, 0.99997497, 0.9999893, 0.9999906, 0.9999812, 0.99983937, 0.99967873, 0.9999998, 0.9999965, 0.9999516, 0.9999541}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"tanh"}, + direction: Forward, + hiddenSize: 4, + }, + rnnInputNoB, + // Same values as first test, but B is initialized automatically. + []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"tanh"}, + direction: Forward, + hiddenSize: 4, + }, + rnnInputNoBNoH, + // Same values as first test, but B and H are initialized automatically. + []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, + nil, + }, + } + + for _, test := range tests { + inputs := test.inputs() + res, err := test.rnn.Apply(inputs) + assert.Equal(t, test.err, err) + + if err == nil { + assert.Equal(t, test.expected, res[1].Data()) + } + } +} + +func TestInputValidationRNN(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + expected []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + nil, + nil, + nil, + }, + nil, + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, + nil, + ops.ErrInvalidOptionalInputCount(1, &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(0, "int", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(1, "int", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(2, "int", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(3, "int", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(4, "float32", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(5, "int", &RNN{}), + }, + } + + for _, test := range tests { + rnn := &RNN{} + validated, err := rnn.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + if test.expected != nil { + assert.Equal(t, test.expected, validated) + } else { + assert.Equal(t, test.inputs, validated) + } + } + } +} + +func rnnInput0() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(2, 1, 3), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 4, 3), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 4, 4), + // Input B: (num_directions, 2 * hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 8)), 1, 8), + // Input sequence_lens: not supported + nil, + // Input initial_h: (num_directions, batch_size, hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 1, 4)), 1, 1, 4), + } +} + +func rnnInput1() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(10, 3, 4), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 10, 4), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 10, 10), + // Input B: (num_directions, 2 * hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 20)), 1, 20), + // Input sequence_lens: not supported + nil, + // Input initial_h: (num_directions, batch_size, hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 3, 10)), 1, 3, 10), + } +} + +func rnnInputNoB() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(2, 1, 3), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 4, 3), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 4, 4), + // Input B: not provided. + nil, + // Input sequence_lens: not supported + nil, + // Input initial_h: (num_directions, batch_size, hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 1, 4)), 1, 1, 4), + } +} + +func rnnInputNoBNoH() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(2, 1, 3), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 4, 3), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 4, 4), + // Input B: not provided. + nil, + // Input sequence_lens: not supported + nil, + // Input initial_h: (num_directions, batch_size, hidden_size) + nil, + } +} + +func RNNOnnxAttributeProtoFixture() []*onnx.AttributeProto { + return []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{1.0}}, + {Name: "activation_beta", Floats: []float32{2.0}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + } +} From cb74f725023bb663589a018cf40cfdfcecfa98b7 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Wed, 6 Dec 2023 17:45:37 +0100 Subject: [PATCH 05/14] Working version of LSTM operator --- model.go | 2 +- ops/operator.go | 5 +- ops/opset13/abs.go | 2 +- ops/opset13/acos.go | 2 +- ops/opset13/acosh.go | 2 +- ops/opset13/add.go | 2 +- ops/opset13/and.go | 2 +- ops/opset13/asin.go | 2 +- ops/opset13/asinh.go | 2 +- ops/opset13/atan.go | 2 +- ops/opset13/atanh.go | 2 +- ops/opset13/cast.go | 4 +- ops/opset13/cast_test.go | 4 +- ops/opset13/concat.go | 4 +- ops/opset13/concat_test.go | 4 +- ops/opset13/constant.go | 3 +- ops/opset13/constant_of_shape.go | 4 +- ops/opset13/constant_of_shape_test.go | 16 +- ops/opset13/constant_test.go | 6 +- ops/opset13/conv.go | 4 +- ops/opset13/conv_test.go | 28 +- ops/opset13/cos.go | 2 +- ops/opset13/cosh.go | 2 +- ops/opset13/div.go | 2 +- ops/opset13/gather.go | 4 +- ops/opset13/gather_test.go | 14 +- ops/opset13/gemm.go | 4 +- ops/opset13/gemm_test.go | 18 +- ops/opset13/gru.go | 6 +- ops/opset13/gru_test.go | 14 +- ops/opset13/lstm.go | 461 ++++++++++++++++++++++++++ ops/opset13/matmul.go | 2 +- ops/opset13/mul.go | 2 +- ops/opset13/not.go | 2 +- ops/opset13/opset13.go | 1 + ops/opset13/opset13_test.go | 5 + ops/opset13/or.go | 2 +- ops/opset13/prelu.go | 2 +- ops/opset13/relu.go | 2 +- ops/opset13/reshape.go | 2 +- ops/opset13/rnn.go | 5 +- ops/opset13/rnn_test.go | 22 +- ops/opset13/scaler.go | 3 +- ops/opset13/scaler_test.go | 16 +- ops/opset13/shape.go | 2 +- ops/opset13/sigmoid.go | 2 +- ops/opset13/sin.go | 2 +- ops/opset13/sinh.go | 2 +- ops/opset13/slice.go | 2 +- ops/opset13/softmax.go | 4 +- ops/opset13/squeeze.go | 2 +- ops/opset13/sub.go | 2 +- ops/opset13/tan.go | 2 +- ops/opset13/tanh.go | 2 +- ops/opset13/transpose.go | 4 +- ops/opset13/transpose_test.go | 14 +- ops/opset13/unsqueeze.go | 2 +- ops/opset13/xor.go | 2 +- ops/validate_inputs_test.go | 2 +- ops_test.go | 10 +- 60 files changed, 628 insertions(+), 123 deletions(-) create mode 100644 ops/opset13/lstm.go diff --git a/model.go b/model.go index 383eee7..9c9fcd1 100644 --- a/model.go +++ b/model.go @@ -187,7 +187,7 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) { // applyOp applies the operation to the graph. func (m *Model) applyOp(op ops.Operator, n *onnx.NodeProto, tensors Tensors) error { - if err := op.Init(n.GetAttribute()); err != nil { + if err := op.Init(n); err != nil { return err } diff --git a/ops/operator.go b/ops/operator.go index 15926a4..7f26e4d 100644 --- a/ops/operator.go +++ b/ops/operator.go @@ -10,10 +10,11 @@ type Operator interface { // String should return a simple string describing the operator String() string - // Init should initialize the operator based on the given attributes. How these + // Init should initialize the operator based on the given node. + // This node contains attributes, which outputs are expected and more. How these // attributes influence the operator is defined by the ONNX standard, and can be // found in https://github.com/onnx/onnx/blob/main/docs/Operators.md - Init([]*onnx.AttributeProto) error + Init(*onnx.NodeProto) error // Apply should apply the operator to the list of input tensors. It should return a // list with output tensors, the result of the operator. diff --git a/ops/opset13/abs.go b/ops/opset13/abs.go index 6a0572b..482d80e 100644 --- a/ops/opset13/abs.go +++ b/ops/opset13/abs.go @@ -20,7 +20,7 @@ func newAbs() ops.Operator { } // Init initializes the abs operator. -func (a *Abs) Init([]*onnx.AttributeProto) error { +func (a *Abs) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/acos.go b/ops/opset13/acos.go index 9e766c0..139c1ed 100644 --- a/ops/opset13/acos.go +++ b/ops/opset13/acos.go @@ -17,7 +17,7 @@ func newAcos() ops.Operator { } // Init initializes the acos operator. -func (c *Acos) Init(_ []*onnx.AttributeProto) error { +func (c *Acos) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/acosh.go b/ops/opset13/acosh.go index e4a0e13..0e1404c 100644 --- a/ops/opset13/acosh.go +++ b/ops/opset13/acosh.go @@ -17,7 +17,7 @@ func newAcosh() ops.Operator { } // Init initializes the acosh operator. -func (c *Acosh) Init(_ []*onnx.AttributeProto) error { +func (c *Acosh) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/add.go b/ops/opset13/add.go index baddd5f..cf5b566 100644 --- a/ops/opset13/add.go +++ b/ops/opset13/add.go @@ -20,7 +20,7 @@ func newAdd() ops.Operator { } // Init initializes the add operator. -func (a *Add) Init(_ []*onnx.AttributeProto) error { +func (a *Add) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/and.go b/ops/opset13/and.go index 7afef33..68b2a22 100644 --- a/ops/opset13/and.go +++ b/ops/opset13/and.go @@ -20,7 +20,7 @@ func newAnd() ops.Operator { } // Init initializes the and operator. -func (a *And) Init(_ []*onnx.AttributeProto) error { +func (a *And) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/asin.go b/ops/opset13/asin.go index 9a1cb2e..0dae65f 100644 --- a/ops/opset13/asin.go +++ b/ops/opset13/asin.go @@ -17,7 +17,7 @@ func newAsin() ops.Operator { } // Init initializes the asin operator. -func (s *Asin) Init(_ []*onnx.AttributeProto) error { +func (s *Asin) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/asinh.go b/ops/opset13/asinh.go index 1cd9eec..8490711 100644 --- a/ops/opset13/asinh.go +++ b/ops/opset13/asinh.go @@ -17,7 +17,7 @@ func newAsinh() ops.Operator { } // Init initializes the asinh operator. -func (a *Asinh) Init(_ []*onnx.AttributeProto) error { +func (a *Asinh) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/atan.go b/ops/opset13/atan.go index 2abdaaa..d373d65 100644 --- a/ops/opset13/atan.go +++ b/ops/opset13/atan.go @@ -17,7 +17,7 @@ func newAtan() ops.Operator { } // Init initializes the atan operator. -func (a *Atan) Init(_ []*onnx.AttributeProto) error { +func (a *Atan) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/atanh.go b/ops/opset13/atanh.go index 6d10b95..f60b6d1 100644 --- a/ops/opset13/atanh.go +++ b/ops/opset13/atanh.go @@ -17,7 +17,7 @@ func newAtanh() ops.Operator { } // Init initializes the atanh operator. -func (a *Atanh) Init(_ []*onnx.AttributeProto) error { +func (a *Atanh) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/cast.go b/ops/opset13/cast.go index ac79c6e..8a8a552 100644 --- a/ops/opset13/cast.go +++ b/ops/opset13/cast.go @@ -22,7 +22,9 @@ func newCast() ops.Operator { } // Init initializes the cast operator. -func (c *Cast) Init(attributes []*onnx.AttributeProto) error { +func (c *Cast) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { return ops.ErrInvalidAttributeCount(1, len(attributes), c) } diff --git a/ops/opset13/cast_test.go b/ops/opset13/cast_test.go index 8c32b14..74d4648 100644 --- a/ops/opset13/cast_test.go +++ b/ops/opset13/cast_test.go @@ -12,7 +12,7 @@ import ( func TestCastInit(t *testing.T) { c := &Cast{} - err := c.Init([]*onnx.AttributeProto{{Name: "to", I: 1}}) + err := c.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "to", I: 1}}}) assert.Nil(t, err) assert.Equal(t, int32(1), c.to) } @@ -63,7 +63,7 @@ func TestCast(t *testing.T) { } for _, test := range tests { - _ = test.cast.Init([]*onnx.AttributeProto{{Name: "to", I: test.to}}) + _ = test.cast.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "to", I: test.to}}}) inputs := []tensor.Tensor{ops.TensorWithBackingFixture(test.backing, test.shape...)} res, err := test.cast.Apply(inputs) diff --git a/ops/opset13/concat.go b/ops/opset13/concat.go index 154839d..a7a24a0 100644 --- a/ops/opset13/concat.go +++ b/ops/opset13/concat.go @@ -23,7 +23,9 @@ func newConcat() ops.Operator { } // Init initializes the concat operator. -func (c *Concat) Init(attributes []*onnx.AttributeProto) error { +func (c *Concat) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { return ops.ErrInvalidAttributeCount(1, len(attributes), c) } diff --git a/ops/opset13/concat_test.go b/ops/opset13/concat_test.go index 3f83843..4256220 100644 --- a/ops/opset13/concat_test.go +++ b/ops/opset13/concat_test.go @@ -11,7 +11,7 @@ import ( func TestConcatInit(t *testing.T) { concat := &Concat{} - err := concat.Init([]*onnx.AttributeProto{{Name: "axis", I: 3}}) + err := concat.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 3}}}) assert.Nil(t, err) assert.Equal(t, 3, concat.axis) @@ -19,7 +19,7 @@ func TestConcatInit(t *testing.T) { func TestConcatInitFail(t *testing.T) { concat := &Concat{} - err := concat.Init([]*onnx.AttributeProto{}) + err := concat.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{}}) expected := ops.ErrInvalidAttributeCount(1, 0, concat) assert.Equal(t, expected, err) diff --git a/ops/opset13/constant.go b/ops/opset13/constant.go index 758dd21..d0c1261 100644 --- a/ops/opset13/constant.go +++ b/ops/opset13/constant.go @@ -18,7 +18,8 @@ func newConstant() ops.Operator { // Init initializes the constant operator. It supports all constant types except // `sparse_value`, `value_string`, and `value_strings`. -func (c *Constant) Init(attributes []*onnx.AttributeProto) error { +func (c *Constant) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() if len(attributes) != 1 { return ops.ErrInvalidAttributeCount(1, len(attributes), c) } diff --git a/ops/opset13/constant_of_shape.go b/ops/opset13/constant_of_shape.go index a33d864..9511108 100644 --- a/ops/opset13/constant_of_shape.go +++ b/ops/opset13/constant_of_shape.go @@ -24,7 +24,9 @@ func newConstantOfShape() ops.Operator { } // Init initializes the constant of shape operator. -func (c *ConstantOfShape) Init(attributes []*onnx.AttributeProto) error { +func (c *ConstantOfShape) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) > 1 { return ops.ErrInvalidAttributeCount(1, len(attributes), c) } diff --git a/ops/opset13/constant_of_shape_test.go b/ops/opset13/constant_of_shape_test.go index f41030d..68066e7 100644 --- a/ops/opset13/constant_of_shape_test.go +++ b/ops/opset13/constant_of_shape_test.go @@ -87,11 +87,11 @@ func TestConstantOfShape(t *testing.T) { tp := TensorProtoFromNumber(test.input) assert.NotNil(t, tp) - attr := []*onnx.AttributeProto{{Name: "value", T: tp}} + node := &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value", T: tp}}} // Create operator op := ConstantOfShape{} - err := op.Init(attr) + err := op.Init(node) assert.NoError(t, err) assert.Equal(t, test.input, op.value.Data()) @@ -110,7 +110,7 @@ func TestConstantOfShapeEmptyInit(t *testing.T) { op := &ConstantOfShape{} // No init value given - err := op.Init([]*onnx.AttributeProto{}) + err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{}}) assert.NoError(t, err) assert.Equal(t, float32(0.0), op.value.Data()) @@ -130,10 +130,10 @@ func TestIncorrectInput(t *testing.T) { Dims: []int64{3}, Int32Data: []int32{1, 2, 3}, } - attr := []*onnx.AttributeProto{{Name: "value", T: tp}} + node := &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value", T: tp}}} op := &ConstantOfShape{} - err := op.Init(attr) + err := op.Init(node) assert.NotNil(t, err) assert.Equal( t, @@ -144,7 +144,7 @@ func TestIncorrectInput(t *testing.T) { func TestNegativeShapeNotAllowed(t *testing.T) { op := &ConstantOfShape{} - _ = op.Init([]*onnx.AttributeProto{}) + _ = op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{}}) shape := []int64{1, -1} @@ -160,7 +160,7 @@ func TestNegativeShapeNotAllowed(t *testing.T) { func TestEmptyTensorNotAllowed(t *testing.T) { op := &ConstantOfShape{} - _ = op.Init([]*onnx.AttributeProto{}) + _ = op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{}}) shape := []int64{0} @@ -176,7 +176,7 @@ func TestEmptyTensorNotAllowed(t *testing.T) { func TestScalarShapeInput(t *testing.T) { op := &ConstantOfShape{} - _ = op.Init([]*onnx.AttributeProto{}) + _ = op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{}}) shape := []int64{6} input := tensor.New(tensor.WithBacking(shape)) diff --git a/ops/opset13/constant_test.go b/ops/opset13/constant_test.go index c546627..ffebccf 100644 --- a/ops/opset13/constant_test.go +++ b/ops/opset13/constant_test.go @@ -60,7 +60,7 @@ func TestConstantInit(t *testing.T) { for _, test := range tests { constant := &Constant{} - err := constant.Init(test.initAttr) + err := constant.Init(&onnx.NodeProto{Attribute: test.initAttr}) assert.Equal(t, test.err, err) @@ -104,7 +104,7 @@ func TestConstant(t *testing.T) { } for _, test := range tests { - _ = test.constant.Init(test.initAttr) + _ = test.constant.Init(&onnx.NodeProto{Attribute: test.initAttr}) res, err := test.constant.Apply([]tensor.Tensor{}) assert.Nil(t, err) @@ -114,7 +114,7 @@ func TestConstant(t *testing.T) { func TestConstantSingleIntShapeTensor(t *testing.T) { constant := &Constant{} - err := constant.Init([]*onnx.AttributeProto{{Name: "value_ints", Ints: []int64{2}}}) + err := constant.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value_ints", Ints: []int64{2}}}}) assert.Nil(t, err) assert.False(t, constant.value.IsScalar()) diff --git a/ops/opset13/conv.go b/ops/opset13/conv.go index 01a82d1..801a5e9 100644 --- a/ops/opset13/conv.go +++ b/ops/opset13/conv.go @@ -46,10 +46,10 @@ func newConv() ops.Operator { } // Init initializes the conv operator. -func (c *Conv) Init(attributes []*onnx.AttributeProto) error { +func (c *Conv) Init(n *onnx.NodeProto) error { var err error - for _, attr := range attributes { + for _, attr := range n.GetAttribute() { switch attr.GetName() { case "auto_pad": c.autoPad = AutoPadSetting(attr.GetS()) diff --git a/ops/opset13/conv_test.go b/ops/opset13/conv_test.go index 14d975e..8da4b87 100644 --- a/ops/opset13/conv_test.go +++ b/ops/opset13/conv_test.go @@ -11,7 +11,7 @@ import ( func TestConvInit(t *testing.T) { c := &Conv{} - err := c.Init(Conv2DOnnxAttributeProtoFixture()) + err := c.Init(Conv2DOnnxNodeProtoFixture()) assert.Nil(t, err) @@ -26,7 +26,7 @@ func TestConvInit(t *testing.T) { func TestConvInitUnsupported(t *testing.T) { c := &Conv{} - err := c.Init(ConvUnsupportedOnnxAttributeProtoFixture()) + err := c.Init(ConvUnsupportedOnnxNodeProtoFixture()) assert.Equal( t, @@ -671,18 +671,22 @@ func TestAddBias(t *testing.T) { } } -func Conv2DOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "auto_pad", S: []byte("VALID")}, - {Name: "dilations", Ints: []int64{1, 1}}, - {Name: "kernel_shape", Ints: []int64{2, 2}}, - {Name: "pads", Ints: []int64{1, 2}}, - {Name: "strides", Ints: []int64{1, 1}}, +func Conv2DOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("VALID")}, + {Name: "dilations", Ints: []int64{1, 1}}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{1, 2}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, } } -func ConvUnsupportedOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "group", I: 2}, +func ConvUnsupportedOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "group", I: 2}, + }, } } diff --git a/ops/opset13/cos.go b/ops/opset13/cos.go index b71110e..ad01f82 100644 --- a/ops/opset13/cos.go +++ b/ops/opset13/cos.go @@ -17,7 +17,7 @@ func newCos() ops.Operator { } // Init initializes the cos operator. -func (c *Cos) Init(_ []*onnx.AttributeProto) error { +func (c *Cos) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/cosh.go b/ops/opset13/cosh.go index afc974a..cddb129 100644 --- a/ops/opset13/cosh.go +++ b/ops/opset13/cosh.go @@ -17,7 +17,7 @@ func newCosh() ops.Operator { } // Init initializes the cosh operator. -func (c *Cosh) Init(_ []*onnx.AttributeProto) error { +func (c *Cosh) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/div.go b/ops/opset13/div.go index fe92c7d..e918e7f 100644 --- a/ops/opset13/div.go +++ b/ops/opset13/div.go @@ -20,7 +20,7 @@ func newDiv() ops.Operator { } // Init initializes the div operator. -func (d *Div) Init(_ []*onnx.AttributeProto) error { +func (d *Div) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/gather.go b/ops/opset13/gather.go index 5668779..e6e7f3f 100644 --- a/ops/opset13/gather.go +++ b/ops/opset13/gather.go @@ -24,7 +24,9 @@ func newGather() ops.Operator { } // Init initializes the gather operator. -func (g *Gather) Init(attributes []*onnx.AttributeProto) error { +func (g *Gather) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) == 1 { attr := attributes[0] diff --git a/ops/opset13/gather_test.go b/ops/opset13/gather_test.go index fb82347..dbbc8c3 100644 --- a/ops/opset13/gather_test.go +++ b/ops/opset13/gather_test.go @@ -9,8 +9,10 @@ import ( "gorgonia.org/tensor" ) -func makeAxisProto(n int) []*onnx.AttributeProto { - return []*onnx.AttributeProto{{Name: "axis", I: int64(n)}} +func makeAxisProto(n int) *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{{Name: "axis", I: int64(n)}}, + } } func TestGatherInit(t *testing.T) { @@ -23,20 +25,20 @@ func TestGatherInit(t *testing.T) { func TestGatherInitDefault(t *testing.T) { op := Gather{} - err := op.Init([]*onnx.AttributeProto{}) + err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{}}) assert.NoError(t, err) assert.Equal(t, op.axis, 0) } func TestGatherInitTooManyAttrs(t *testing.T) { op := Gather{} - err := op.Init([]*onnx.AttributeProto{{Name: "axis"}, {Name: "default"}}) + err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis"}, {Name: "default"}}}) assert.EqualError(t, err, "gather operator attribute error: invalid count 2 expected 1") } func TestGatherInitInvalidAttrName(t *testing.T) { op := Gather{} - err := op.Init([]*onnx.AttributeProto{{Name: "axes"}}) // should be axis + err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axes"}}}) // should be axis assert.EqualError(t, err, "gather operator attribute error: invalid attribute axes") } @@ -201,7 +203,7 @@ func TestGather(t *testing.T) { func TestCombinedWithOtherOp(t *testing.T) { concat := &Concat{} - err := concat.Init([]*onnx.AttributeProto{{Name: "axis", I: 0}}) + err := concat.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 0}}}) assert.NoError(t, err) data0 := tensor.New(tensor.WithBacking([]int64{1}), tensor.WithShape(1)) diff --git a/ops/opset13/gemm.go b/ops/opset13/gemm.go index 9268043..2db2a44 100644 --- a/ops/opset13/gemm.go +++ b/ops/opset13/gemm.go @@ -30,8 +30,8 @@ func newGemm() ops.Operator { } // Init initializes the Gemm operator based on the ModelProto attributes. -func (g *Gemm) Init(attributes []*onnx.AttributeProto) error { - for _, attr := range attributes { +func (g *Gemm) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { switch attr.GetName() { case "alpha": g.alpha = attr.GetF() diff --git a/ops/opset13/gemm_test.go b/ops/opset13/gemm_test.go index e0067cd..37255d4 100644 --- a/ops/opset13/gemm_test.go +++ b/ops/opset13/gemm_test.go @@ -11,7 +11,7 @@ import ( func TestGemmInit(t *testing.T) { gemm := Gemm{} - err := gemm.Init(GemmOnnxAttributeProtoFixture()) + err := gemm.Init(GemmOnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, float32(10.0), gemm.alpha) @@ -22,7 +22,7 @@ func TestGemmInit(t *testing.T) { func TestGemmInitFail(t *testing.T) { gemm := &Gemm{} - err := gemm.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}}) + err := gemm.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknownAttribute"}}}) expected := ops.ErrInvalidAttribute("unknownAttribute", gemm) assert.Equal(t, expected, err) @@ -172,11 +172,13 @@ func TestInputValidationGemm(t *testing.T) { } } -func GemmOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "alpha", F: 10.0}, - {Name: "beta", F: 0.98}, - {Name: "transA", I: 1}, - {Name: "transB", I: 1}, +func GemmOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 10.0}, + {Name: "beta", F: 0.98}, + {Name: "transA", I: 1}, + {Name: "transB", I: 1}, + }, } } diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index 80c8802..8003993 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -30,9 +30,11 @@ func newGRU() ops.Operator { // Init initializes the gru operator. Currently, our GRU operator does not support all // attributes as specified by the ONNX operator. The basic functionality is working and // the other attributes can be added later on. -func (g *GRU) Init(attributes []*onnx.AttributeProto) error { +func (g *GRU) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() for _, attr := range attributes { switch attr.GetName() { + // nolint as these attributes are operator specific case "hidden_size": g.hiddenSize = int(attr.GetI()) case "linear_before_reset": @@ -321,7 +323,7 @@ func (g *GRU) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { dirSlice := ops.NewSlicer(0) weights := make([]tensor.Tensor, nWeightMatrices) - for i := 0; i < 3; i++ { + for i := 0; i < nWeightMatrices; i++ { slice := ops.NewSlicer(i*g.hiddenSize, (i+1)*g.hiddenSize) w, err := W.Slice(dirSlice, slice, nil) diff --git a/ops/opset13/gru_test.go b/ops/opset13/gru_test.go index 39d419f..05e8fcc 100644 --- a/ops/opset13/gru_test.go +++ b/ops/opset13/gru_test.go @@ -11,7 +11,7 @@ import ( func TestGruInit(t *testing.T) { gru := &GRU{} - err := gru.Init(GRUOnnxAttributeProtoFixture()) + err := gru.Init(GRUOnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, true, gru.linearBeforeReset) @@ -51,7 +51,7 @@ func TestGruInitUnkownAttr(t *testing.T) { } for _, test := range tests { - err := gru.Init(test.attr) + err := gru.Init(&onnx.NodeProto{Attribute: test.attr}) assert.Equal(t, test.err, err) } } @@ -268,9 +268,11 @@ func gruInputNoBNoH() []tensor.Tensor { return inputs } -func GRUOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "linear_before_reset", I: 1}, - {Name: "hidden_size", I: 5}, +func GRUOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "linear_before_reset", I: 1}, + {Name: "hidden_size", I: 5}, + }, } } diff --git a/ops/opset13/lstm.go b/ops/opset13/lstm.go new file mode 100644 index 0000000..ecabc7a --- /dev/null +++ b/ops/opset13/lstm.go @@ -0,0 +1,461 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinLSTMInputs = 3 + MaxLSTMInputs = 8 +) + +// These activations are supported in the LSTM calculation. +var LSTMActivations = map[string]ops.Activation{ + "tanh": ops.Tanh, + "sigmoid": ops.Sigmoid, + "relu": ops.ReLU, +} + +// LSTM represents the ONNX lstm operator. +type LSTM struct { + activationAlpha []float32 + activationBeta []float32 + activations []string + direction RNNDirection + hiddenSize int + inputForget bool + + outputs []string +} + +// newLSTM creates a new lstm operator. +func newLSTM() ops.Operator { + return &LSTM{ + activations: []string{"sigmoid", "tanh", "tanh"}, + direction: Forward, + inputForget: false, + outputs: []string{"Y", "Y_h", "Y_c"}, + } +} + +// Init initializes the lstm operator. +func (l *LSTM) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "activation_alpha": + l.activationAlpha = attr.GetFloats() + case "activation_beta": + l.activationBeta = attr.GetFloats() + case "activations": + activations := []string{} + for _, activation := range attr.GetStrings() { + activations = append(activations, string(activation)) + } + + l.activations = activations + case "clip": + return ops.ErrUnsupportedAttribute(attr.GetName(), l) + case "direction": + l.direction = RNNDirection(attr.GetS()) + if l.direction != Forward { + return ops.ErrUnsupportedAttribute(attr.GetName(), l) + } + // nolint as these attributes are operator specific + case "hidden_size": + l.hiddenSize = int(attr.GetI()) + case "input_forget": + l.inputForget = attr.GetI() == 1 + default: + return ops.ErrInvalidAttribute(attr.GetName(), l) + } + } + + l.outputs = n.GetOutput() + + return nil +} + +// Apply applies the lstm operator. +func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + if inputs[4] != nil { + return nil, ops.ErrUnsupportedInput("sequence lens", l) + } + + X := inputs[0] + + Wi, Wo, Wf, Wc, err := l.getWeights(inputs[1]) + if err != nil { + return nil, err + } + + Ri, Ro, Rf, Rc, err := l.getWeights(inputs[2]) + if err != nil { + return nil, err + } + + B := inputs[3] + if B == nil { + nBiasMatrices := 8 + B = l.getZeroTensor(1, nBiasMatrices*l.hiddenSize) + } + + Wbi, Wbo, Wbf, Wbc, Rbi, Rbo, Rbf, Rbc, err := l.getBiases(B) + if err != nil { + return nil, err + } + + seqLength := X.Shape()[0] + batchSize := X.Shape()[1] + + Ht := inputs[5] + if Ht == nil { + Ht = l.getZeroTensor(1, batchSize, l.hiddenSize) + } + + Ct := inputs[6] + if Ct == nil { + Ct = l.getZeroTensor(1, batchSize, l.hiddenSize) + } + + var Pi, Po, Pf tensor.Tensor + + P := inputs[7] + if P != nil { + Pi, Po, Pf, err = l.getPeepholes(P) + if err != nil { + return nil, err + } + } + + // Reshape the hidden and cell tensor without the bidirectional dimension, as + // we do not support bidirectional yet. This is the dimension at + // index 0. + if err = Ht.Reshape(Ht.Shape().Clone()[1:]...); err != nil { + return nil, err + } + + if err = Ct.Reshape(Ct.Shape().Clone()[1:]...); err != nil { + return nil, err + } + + fActivation := LSTMActivations[l.activations[0]] + if fActivation == nil { + return nil, ops.ErrUnsupportedAttribute("activations", l) + } + + gActivation := LSTMActivations[l.activations[1]] + if gActivation == nil { + return nil, ops.ErrUnsupportedAttribute("activations", l) + } + + hActivation := LSTMActivations[l.activations[2]] + if hActivation == nil { + return nil, ops.ErrUnsupportedAttribute("activations", l) + } + + outputs := []tensor.Tensor{} + + // Loop over all timesteps of the input, applying the LSTM calculation to every + // timesteps while updating the hidden tensor. + for t := 0; t < seqLength; t++ { + Xt, err := X.Slice(ops.NewSlicer(t, t+1), nil, nil) + if err != nil { + return nil, err + } + + it, err := l.gateCalculation(Xt, Wi, Wbi, Ht, Ri, Rbi, Pi, Ct, fActivation) + if err != nil { + return nil, err + } + + ft, err := l.gateCalculation(Xt, Wf, Wbf, Ht, Rf, Rbf, Pf, Ct, fActivation) + if err != nil { + return nil, err + } + + ct, err := l.gateCalculation(Xt, Wc, Wbc, Ht, Rc, Rbc, nil, nil, gActivation) + if err != nil { + return nil, err + } + + Ct, err = l.cellCalculation(ft, it, ct, Ct) + if err != nil { + return nil, err + } + + ot, err := l.gateCalculation(Xt, Wo, Wbo, Ht, Ro, Rbo, Po, Ct, fActivation) + if err != nil { + return nil, err + } + + Ht, err = l.hiddenCalculation(ot, Ct, hActivation) + if err != nil { + return nil, err + } + + outputs = append(outputs, Ht) + } + + Y := outputs[0] + if len(outputs) > 1 { + Y, err = tensor.Concat(0, Y, outputs[1:]...) + if err != nil { + return nil, err + } + } + + Yh, ok := Ht.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", Ht.Clone()) + } + + Yc, ok := Ct.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", Ct.Clone()) + } + + // Reshape the outputs so it adds the num_directions as specified by onnx. + // The output shape as specified by ONNX is: + // (sequence_length, num_directions, batch_size, hidden_size) + // 'num_directions' is only '2' if the LSTMDirection is 'bidirectional'. + // We do not support this, so for this implementation it should always be '1'. + // Here, we reshape our output to include this 'num_directions' dimension. + if err = Y.Reshape(seqLength, 1, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + if err = Yh.Reshape(1, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + if err = Yc.Reshape(1, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + outputMap := map[string]tensor.Tensor{ + "Y": Y, "Y_h": Yh, "Y_c": Yc, + } + + result := []tensor.Tensor{} + for _, outputName := range l.outputs { + result = append(result, outputMap[outputName]) + } + + return result, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (l *LSTM) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(l, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (l *LSTM) GetMinInputs() int { + return MinLSTMInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (l *LSTM) GetMaxInputs() int { + return MaxLSTMInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (l *LSTM) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Int32}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (l *LSTM) String() string { + return "lstm operator" +} + +// gateCalculation performs a standard gate calculation for an LSTM gate defined as: +// +// o = f(Xt*(W^T) + Wb + H*(R^T) + Rb + P (.) C) +// +// Where: +// - 'f()' is an activation function +// - 'Xt' is the input tensor +// - 'W' is the input weight +// - 'Wb' is the input bias +// - 'H' is the hidden tensor +// - 'R' is the hidden weight tensor +// - 'Rb' is the hidden bias +// - 'P' are peephole weights (optional, can be nil) +// - 'C' is the cell state +// - '(.)' is element-wise multiplication +// +// 'o' is the result tensor that is returned. +// This calculation can be used for the forget gate, input gate, cell gate +// and output gate calculations. +func (l *LSTM) gateCalculation( + Xt, W, Wb, H, R, Rb, P, C tensor.Tensor, activation ops.Activation, +) (tensor.Tensor, error) { + gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + + inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) + if err != nil { + return nil, err + } + + hiddenCalc, err := gemm.Apply([]tensor.Tensor{H, R, Rb}) + if err != nil { + return nil, err + } + + output, err := tensor.Add(inputCalc[0], hiddenCalc[0]) + if err != nil { + return nil, err + } + + if P != nil { + C, broadcastedP, err := ops.UnidirectionalBroadcast(C, P) + if err != nil { + return nil, err + } + + peepholeActivation, err := tensor.Mul(broadcastedP, C) + if err != nil { + return nil, err + } + + output, err = tensor.Add(output, peepholeActivation) + if err != nil { + return nil, err + } + } + + return activation(output) +} + +// cellCalculation performs the calculation of the LSTM cell update defined by: +// +// Ct = ft (.) Ct-1 + it (.) ct +// +// Where 'ft' is the forget gate activation at time t, (.) denotes element-wise +// multiplication, 'Ct-1' denotes the cell state at time t-1, 'it' denotes the input +// gate activation at time t and 'ct' denotes the cell state activation at time t (which) +// is not the same as Ct or Ct-1). +func (l *LSTM) cellCalculation(ft, it, ct, Ct tensor.Tensor) (tensor.Tensor, error) { + cellForget, err := tensor.Mul(ft, Ct) + if err != nil { + return nil, err + } + + cellInput, err := tensor.Mul(it, ct) + if err != nil { + return nil, err + } + + return tensor.Add(cellForget, cellInput) +} + +// hiddenCalculation performs the calculation of the new LSTM hidden state defined by: +// +// Ht = ot (.) h(Ct) +// +// Where Ht is the new hidden state at time t, 'ot' is the output at time t, (.) denotes +// element-wise multiplication, 'h()' denotes an activation function and 'Ct' denotes the +// cell state at time t. +func (l *LSTM) hiddenCalculation(ot, Ct tensor.Tensor, activation ops.Activation) (tensor.Tensor, error) { + cellActivated, err := activation(Ct) + if err != nil { + return nil, err + } + + return tensor.Mul(ot, cellActivated) +} + +// getWeights splits tensor W into 4 weight matrices. +func (l *LSTM) getWeights(W tensor.Tensor) (Wi, Wo, Wf, Wh tensor.Tensor, err error) { + nWeightMatrices := 4 + nWeightDimensions := 3 + + weights, err := l.extractMatrices(W, nWeightMatrices, nWeightDimensions) + if err != nil { + return nil, nil, nil, nil, err + } + + return weights[0], weights[1], weights[2], weights[3], nil +} + +// getBiases splits tensor B into 8 bias matrices. +func (l *LSTM) getBiases(B tensor.Tensor) (Wbi, Wbo, Wbf, Wbc, Rbi, Rbo, Rbf, Rbc tensor.Tensor, err error) { + nBiasMatrices := 8 + nBiasDimensions := 2 + + b, err := l.extractMatrices(B, nBiasMatrices, nBiasDimensions) + if err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err + } + + return b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], nil +} + +// getPeepholes splits tensor P into 3 bias matrices. +func (l *LSTM) getPeepholes(P tensor.Tensor) (Pi, Po, Pf tensor.Tensor, err error) { + nPeepholeMatrices := 3 + nPeepholeDimensions := 2 + + p, err := l.extractMatrices(P, nPeepholeMatrices, nPeepholeDimensions) + if err != nil { + return nil, nil, nil, err + } + + return p[0], p[1], p[2], nil +} + +// extractMatrices extracts 4 tensors from tensor M. +// M contains all matrices concatenated on top of each other in the following order: +// +// forward weights: [Wi, Wo, Wf, Wc] +// recurrent weights: [Ri, Ro, Rf, Rc] +// biases: [Wbi, Wbo, Wbf, Wbc, Rbi, Rbo, Rbf, Rbc] +// +// M is assumed to have a shape of (num_directions, nMatrices * hidden_size, ...) and we extract the +// by slicing over the 'nMatrices * hidden_size' dimension. +func (l *LSTM) extractMatrices(M tensor.Tensor, nMatrices, nDimensions int) ([]tensor.Tensor, error) { + dirSlice := ops.NewSlicer(0) + matrices := make([]tensor.Tensor, nMatrices) + + for i := 0; i < nMatrices; i++ { + hiddenSlice := ops.NewSlicer(i*l.hiddenSize, (i+1)*l.hiddenSize) + + allSlices := make([]tensor.Slice, nDimensions) + allSlices[0] = dirSlice + allSlices[1] = hiddenSlice + + for i := 2; i < nDimensions; i++ { + allSlices[i] = nil + } + + m, err := M.Slice(allSlices...) + if err != nil { + return nil, err + } + + matrices[i] = m + } + + return matrices, nil +} + +// getZeroTensor returns a tensor filled with zeros with the given shape. +func (l *LSTM) getZeroTensor(shape ...int) tensor.Tensor { + return tensor.New( + tensor.WithShape(shape...), + tensor.WithBacking(ops.Zeros(ops.NElements(shape...))), + ) +} diff --git a/ops/opset13/matmul.go b/ops/opset13/matmul.go index 2b5969a..1212233 100644 --- a/ops/opset13/matmul.go +++ b/ops/opset13/matmul.go @@ -20,7 +20,7 @@ func newMatMul() ops.Operator { } // Init initializes the matmul operator. -func (m *MatMul) Init(_ []*onnx.AttributeProto) error { +func (m *MatMul) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/mul.go b/ops/opset13/mul.go index f06c672..3d4db10 100644 --- a/ops/opset13/mul.go +++ b/ops/opset13/mul.go @@ -20,7 +20,7 @@ func newMul() ops.Operator { } // Init initializes the mul operator. -func (m *Mul) Init(_ []*onnx.AttributeProto) error { +func (m *Mul) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/not.go b/ops/opset13/not.go index 8d2ff6e..ba69c56 100644 --- a/ops/opset13/not.go +++ b/ops/opset13/not.go @@ -15,7 +15,7 @@ func newNot() ops.Operator { } // Init initializes the not operator. -func (n *Not) Init(_ []*onnx.AttributeProto) error { +func (n *Not) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 554ced8..b44addd 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -25,6 +25,7 @@ var operators13 = map[string]func() ops.Operator{ "Gather": newGather, "Gemm": newGemm, "GRU": newGRU, + "LSTM": newLSTM, "MatMul": newMatMul, "Mul": newMul, "Not": newNot, diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index 95cff24..493a440 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -113,6 +113,11 @@ func TestGetOperator(t *testing.T) { newGRU(), nil, }, + { + "LSTM", + newLSTM(), + nil, + }, { "MatMul", newMatMul(), diff --git a/ops/opset13/or.go b/ops/opset13/or.go index 797965f..f660891 100644 --- a/ops/opset13/or.go +++ b/ops/opset13/or.go @@ -20,7 +20,7 @@ func newOr() ops.Operator { } // Init initializes the or operator. -func (o *Or) Init(_ []*onnx.AttributeProto) error { +func (o *Or) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/prelu.go b/ops/opset13/prelu.go index 95fa94b..bfdc5d2 100644 --- a/ops/opset13/prelu.go +++ b/ops/opset13/prelu.go @@ -20,7 +20,7 @@ func newPRelu() ops.Operator { } // Init initializes the prelu operator. -func (op *PRelu) Init(_ []*onnx.AttributeProto) error { +func (op *PRelu) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/relu.go b/ops/opset13/relu.go index 8c8ed8d..702e1dd 100644 --- a/ops/opset13/relu.go +++ b/ops/opset13/relu.go @@ -15,7 +15,7 @@ func newRelu() ops.Operator { } // Init initializes the relu operator. -func (r *Relu) Init(_ []*onnx.AttributeProto) error { +func (r *Relu) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/reshape.go b/ops/opset13/reshape.go index 140d24d..a2a8f59 100644 --- a/ops/opset13/reshape.go +++ b/ops/opset13/reshape.go @@ -20,7 +20,7 @@ func newReshape() ops.Operator { } // Init initializes the reshape operator. -func (r *Reshape) Init(_ []*onnx.AttributeProto) error { +func (r *Reshape) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go index d1dfe12..0bbbbf9 100644 --- a/ops/opset13/rnn.go +++ b/ops/opset13/rnn.go @@ -47,8 +47,8 @@ func newRNN() ops.Operator { } // Init initializes the rnn operator. -func (r *RNN) Init(attributes []*onnx.AttributeProto) error { - for _, attr := range attributes { +func (r *RNN) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { switch attr.GetName() { case "activation_alpha": r.activationAlpha = attr.GetFloats() @@ -68,6 +68,7 @@ func (r *RNN) Init(attributes []*onnx.AttributeProto) error { if r.direction != Forward { return ops.ErrUnsupportedAttribute(attr.GetName(), r) } + // nolint as these attributes are operator specific case "hidden_size": r.hiddenSize = int(attr.GetI()) default: diff --git a/ops/opset13/rnn_test.go b/ops/opset13/rnn_test.go index 5d56ed0..a333240 100644 --- a/ops/opset13/rnn_test.go +++ b/ops/opset13/rnn_test.go @@ -12,7 +12,7 @@ import ( func TestRNNInit(t *testing.T) { rnn := &RNN{} - err := rnn.Init(RNNOnnxAttributeProtoFixture()) + err := rnn.Init(RNNOnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, []float32{1.0}, rnn.activationAlpha) @@ -24,13 +24,13 @@ func TestRNNInit(t *testing.T) { func TestRNNInitUnsupportedAttr(t *testing.T) { rnn := RNN{} - err := rnn.Init([]*onnx.AttributeProto{{Name: "clip"}}) + err := rnn.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "clip"}}}) assert.Equal(t, err, ops.ErrUnsupportedAttribute("clip", &rnn)) } func TestRNNInitUnknownAttr(t *testing.T) { rnn := RNN{} - err := rnn.Init([]*onnx.AttributeProto{{Name: "unknown"}}) + err := rnn.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknown"}}}) assert.Equal(t, err, ops.ErrInvalidAttribute("unknown", &rnn)) } @@ -321,12 +321,14 @@ func rnnInputNoBNoH() []tensor.Tensor { } } -func RNNOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "activation_alpha", Floats: []float32{1.0}}, - {Name: "activation_beta", Floats: []float32{2.0}}, - {Name: "activations", Strings: [][]byte{[]byte("sigmoid")}}, - {Name: "direction", S: []byte("forward")}, - {Name: "hidden_size", I: 5}, +func RNNOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{1.0}}, + {Name: "activation_beta", Floats: []float32{2.0}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + }, } } diff --git a/ops/opset13/scaler.go b/ops/opset13/scaler.go index b53ec35..c5eb53b 100644 --- a/ops/opset13/scaler.go +++ b/ops/opset13/scaler.go @@ -24,7 +24,8 @@ func newScaler() ops.Operator { } // Init initializes the scaler operator. -func (s *Scaler) Init(attributes []*onnx.AttributeProto) error { +func (s *Scaler) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() if len(attributes) != ScalerExpectedAttributes { return ops.ErrInvalidAttributeCount(ScalerExpectedAttributes, len(attributes), s) } diff --git a/ops/opset13/scaler_test.go b/ops/opset13/scaler_test.go index b49d5aa..fd2bc1d 100644 --- a/ops/opset13/scaler_test.go +++ b/ops/opset13/scaler_test.go @@ -11,7 +11,7 @@ import ( func TestScalerInit(t *testing.T) { scaler := &Scaler{} - err := scaler.Init(ScalerOnnxAttributeProtoFixture()) + err := scaler.Init(ScalerOnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, []float32{1.5, 2.5, 3.5}, scaler.offset.Data()) @@ -20,7 +20,7 @@ func TestScalerInit(t *testing.T) { func TestScalerInitFailWrongAttribute(t *testing.T) { scaler := &Scaler{} - err := scaler.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}, {Name: "Another"}}) + err := scaler.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknownAttribute"}, {Name: "Another"}}}) expected := ops.ErrInvalidAttribute("unknownAttribute", scaler) assert.Equal(t, expected, err) @@ -28,7 +28,7 @@ func TestScalerInitFailWrongAttribute(t *testing.T) { func TestScalerInitFailAttrCount(t *testing.T) { scaler := &Scaler{} - err := scaler.Init([]*onnx.AttributeProto{}) + err := scaler.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{}}) expected := ops.ErrInvalidAttributeCount(2, 0, scaler) assert.Equal(t, expected, err) @@ -128,9 +128,11 @@ func TestInputValidationScaler(t *testing.T) { } } -func ScalerOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "offset", Floats: []float32{1.5, 2.5, 3.5}}, - {Name: "scale", Floats: []float32{0.5, 1.0, 2.0}}, +func ScalerOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "offset", Floats: []float32{1.5, 2.5, 3.5}}, + {Name: "scale", Floats: []float32{0.5, 1.0, 2.0}}, + }, } } diff --git a/ops/opset13/shape.go b/ops/opset13/shape.go index 82a434c..bb99709 100644 --- a/ops/opset13/shape.go +++ b/ops/opset13/shape.go @@ -20,7 +20,7 @@ func newShape() ops.Operator { } // Init initializes the shape operator. -func (s *Shape) Init(_ []*onnx.AttributeProto) error { +func (s *Shape) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/sigmoid.go b/ops/opset13/sigmoid.go index 6171e99..b8bc077 100644 --- a/ops/opset13/sigmoid.go +++ b/ops/opset13/sigmoid.go @@ -15,7 +15,7 @@ func newSigmoid() ops.Operator { } // Init initializes the sigmoid operator. -func (s *Sigmoid) Init(_ []*onnx.AttributeProto) error { +func (s *Sigmoid) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/sin.go b/ops/opset13/sin.go index 50b371d..ff61a71 100644 --- a/ops/opset13/sin.go +++ b/ops/opset13/sin.go @@ -17,7 +17,7 @@ func newSin() ops.Operator { } // Init initializes the sin operator. -func (s *Sin) Init(_ []*onnx.AttributeProto) error { +func (s *Sin) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/sinh.go b/ops/opset13/sinh.go index f5e7b24..19d81e7 100644 --- a/ops/opset13/sinh.go +++ b/ops/opset13/sinh.go @@ -17,7 +17,7 @@ func newSinh() ops.Operator { } // Init initializes the sinh operator. -func (s *Sinh) Init(_ []*onnx.AttributeProto) error { +func (s *Sinh) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/slice.go b/ops/opset13/slice.go index 5aeeb97..d7589f5 100644 --- a/ops/opset13/slice.go +++ b/ops/opset13/slice.go @@ -20,7 +20,7 @@ func newSlice() ops.Operator { } // Init initializes the slice operator. -func (s *Slice) Init(_ []*onnx.AttributeProto) error { +func (s *Slice) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/softmax.go b/ops/opset13/softmax.go index 4348381..8a2c0c0 100644 --- a/ops/opset13/softmax.go +++ b/ops/opset13/softmax.go @@ -20,8 +20,10 @@ func newSoftmax() ops.Operator { } // Init initializes the softmax operator. -func (s *Softmax) Init(attributes []*onnx.AttributeProto) error { +func (s *Softmax) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() nAttributes := len(attributes) + if nAttributes > 1 { return ops.ErrInvalidAttributeCount(1, nAttributes, s) } diff --git a/ops/opset13/squeeze.go b/ops/opset13/squeeze.go index 47b5e1f..d4c9055 100644 --- a/ops/opset13/squeeze.go +++ b/ops/opset13/squeeze.go @@ -20,7 +20,7 @@ func newSqueeze() ops.Operator { } // Init initializes the squeeze operator. -func (s *Squeeze) Init(_ []*onnx.AttributeProto) error { +func (s *Squeeze) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/sub.go b/ops/opset13/sub.go index c8b7bef..9c59508 100644 --- a/ops/opset13/sub.go +++ b/ops/opset13/sub.go @@ -20,7 +20,7 @@ func newSub() ops.Operator { } // Init initializes the sub operator. -func (s *Sub) Init(_ []*onnx.AttributeProto) error { +func (s *Sub) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/tan.go b/ops/opset13/tan.go index 4a03d09..a7b4a3b 100644 --- a/ops/opset13/tan.go +++ b/ops/opset13/tan.go @@ -17,7 +17,7 @@ func newTan() ops.Operator { } // Init initializes the tan operator. -func (t *Tan) Init(_ []*onnx.AttributeProto) error { +func (t *Tan) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/tanh.go b/ops/opset13/tanh.go index 4939941..b435fb9 100644 --- a/ops/opset13/tanh.go +++ b/ops/opset13/tanh.go @@ -15,7 +15,7 @@ func newTanh() ops.Operator { } // Init initializes the sigmoid operator. -func (t *Tanh) Init(_ []*onnx.AttributeProto) error { +func (t *Tanh) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/transpose.go b/ops/opset13/transpose.go index 08090f4..c89aa67 100644 --- a/ops/opset13/transpose.go +++ b/ops/opset13/transpose.go @@ -22,7 +22,9 @@ func newTranspose() ops.Operator { } // Init initializes the transpose operator. -func (t *Transpose) Init(attributes []*onnx.AttributeProto) error { +func (t *Transpose) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { return ops.ErrInvalidAttributeCount(1, len(attributes), t) } diff --git a/ops/opset13/transpose_test.go b/ops/opset13/transpose_test.go index 005b2b6..c766024 100644 --- a/ops/opset13/transpose_test.go +++ b/ops/opset13/transpose_test.go @@ -11,7 +11,7 @@ import ( func TestTransposeInit(t *testing.T) { trans := &Transpose{} - err := trans.Init(TransposeOnnxAttributeProtoFixture()) + err := trans.Init(TransposeOnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, []int{1, 0}, trans.perm) @@ -19,7 +19,7 @@ func TestTransposeInit(t *testing.T) { func TestTransposeInitFailWrongAttribute(t *testing.T) { trans := &Transpose{} - err := trans.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}}) + err := trans.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknownAttribute"}}}) expected := ops.ErrInvalidAttribute("unknownAttribute", trans) assert.Equal(t, expected, err) @@ -27,7 +27,7 @@ func TestTransposeInitFailWrongAttribute(t *testing.T) { func TestTransposeInitFailAttrCount(t *testing.T) { trans := &Transpose{} - err := trans.Init([]*onnx.AttributeProto{}) + err := trans.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{}}) expected := ops.ErrInvalidAttributeCount(1, 0, trans) assert.Equal(t, expected, err) @@ -104,8 +104,10 @@ func TestInputValidationTranspose(t *testing.T) { } } -func TransposeOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "perm", Ints: []int64{1, 0}}, +func TransposeOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "perm", Ints: []int64{1, 0}}, + }, } } diff --git a/ops/opset13/unsqueeze.go b/ops/opset13/unsqueeze.go index 6547541..b7d4530 100644 --- a/ops/opset13/unsqueeze.go +++ b/ops/opset13/unsqueeze.go @@ -22,7 +22,7 @@ func newUnsqueeze() ops.Operator { } // Init initializes the unsqueeze operator. -func (u *Unsqueeze) Init(_ []*onnx.AttributeProto) error { +func (u *Unsqueeze) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/opset13/xor.go b/ops/opset13/xor.go index 01f6b8c..f668a69 100644 --- a/ops/opset13/xor.go +++ b/ops/opset13/xor.go @@ -20,7 +20,7 @@ func newXor() ops.Operator { } // Init initializes the xor operator. -func (x *Xor) Init(_ []*onnx.AttributeProto) error { +func (x *Xor) Init(*onnx.NodeProto) error { return nil } diff --git a/ops/validate_inputs_test.go b/ops/validate_inputs_test.go index ea91fbc..1edde37 100644 --- a/ops/validate_inputs_test.go +++ b/ops/validate_inputs_test.go @@ -155,7 +155,7 @@ type MockOp struct { inputTypeConstraints [][]tensor.Dtype } -func (m *MockOp) Init(_ []*onnx.AttributeProto) error { +func (m *MockOp) Init(*onnx.NodeProto) error { return nil } diff --git a/ops_test.go b/ops_test.go index 001d518..2410ad3 100644 --- a/ops_test.go +++ b/ops_test.go @@ -26,10 +26,8 @@ import ( var ignoredTests = []string{ "test_add_uint8", // Opset14 "test_div_uint8", // Opset14 - "test_gru_defaults", // Opset14 "test_gru_batchwise", // Opset14 - "test_gru_seq_length", // Opset14 - "test_gru_with_initial_bias", // Opset14 + "test_lstm_batchwise", // Opset14 "test_mul_uint8", // Opset14 "test_sub_uint8", // Opset14 "test_shape_clip_end", // Opset15 @@ -48,6 +46,7 @@ var ignoredTests = []string{ "test_gemm_alpha", // For gemm in opset 11. "test_gemm_default_no_bias", // For gemm in opset 11. "test_gemm_default_scalar_bias", // For gemm in opset 11. + "test_lstm_with_peepholes", // Sequence lens attribute is not supported yet. "test_relu_expanded_ver18", // CastLike operator not implemented yet. "test_softmax_default_axis_expanded_ver18", // ReduceMax operator not implemented yet. "test_softmax_axis_1_expanded_ver18", // ReduceMax operator not implemented yet. @@ -350,6 +349,11 @@ var expectedTests = []string{ "test_gemm_default_zero_bias", "test_gemm_beta", "test_gemm_transposeB", + "test_gru_defaults", + "test_gru_seq_length", + "test_gru_with_initial_bias", + "test_lstm_defaults", + "test_lstm_with_initial_bias", "test_matmul_4d", "test_matmul_3d", "test_matmul_2d", From dd95f7857c59f69dc7abff3d6134fe273941a2cf Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Thu, 7 Dec 2023 12:17:28 +0100 Subject: [PATCH 06/14] Reusable attrs and tests for LSTM --- ops/activation.go | 8 + ops/opset13/gru.go | 53 ++++-- ops/opset13/gru_test.go | 66 ++++--- ops/opset13/lstm.go | 36 ++-- ops/opset13/lstm_test.go | 386 +++++++++++++++++++++++++++++++++++++++ ops/opset13/rnn.go | 43 ++--- ops/opset13/rnn_test.go | 14 +- ops/recurrent_utils.go | 22 +++ 8 files changed, 534 insertions(+), 94 deletions(-) create mode 100644 ops/opset13/lstm_test.go create mode 100644 ops/recurrent_utils.go diff --git a/ops/activation.go b/ops/activation.go index ebca570..e90192d 100644 --- a/ops/activation.go +++ b/ops/activation.go @@ -4,6 +4,14 @@ import ( "gorgonia.org/tensor" ) +// Activations maps strings to the activation function. This is +// used by operators like LSTM, GRU and RNN. +var Activations = map[string]Activation{ + "tanh": Tanh, + "sigmoid": Sigmoid, + "relu": ReLU, +} + // Activation is an activation function. type Activation func(n tensor.Tensor) (tensor.Tensor, error) diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index 8003993..7ecf35b 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -14,17 +14,21 @@ const ( // GRU represents the ONNX gru operator. It only supports a simple forward gru // operation with default activations. type GRU struct { - // Number of neurons in the hidden state. - hiddenSize int - - // When computing the output of the hidden gate, apply the linear - // transformation before multiplying by the output of the reset gate. + activationAlpha []float32 + activationBeta []float32 + activations []string + direction ops.SequenceProcessDirection + hiddenSize int linearBeforeReset bool } // newGRU creates a new gru operator. func newGRU() ops.Operator { - return &GRU{} + return &GRU{ + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + linearBeforeReset: false, + } } // Init initializes the gru operator. Currently, our GRU operator does not support all @@ -34,8 +38,25 @@ func (g *GRU) Init(n *onnx.NodeProto) error { attributes := n.GetAttribute() for _, attr := range attributes { switch attr.GetName() { - // nolint as these attributes are operator specific - case "hidden_size": + case ops.ActivationAlphaAttr: + g.activationAlpha = attr.GetFloats() + case ops.ActivationBetaAttr: + g.activationBeta = attr.GetFloats() + case ops.ActivationsAttr: + activations := []string{} + for _, activation := range attr.GetStrings() { + activations = append(activations, string(activation)) + } + + g.activations = activations + case ops.ClipAttr: + return ops.ErrUnsupportedAttribute(attr.GetName(), g) + case ops.DirectionAttr: + g.direction = ops.SequenceProcessDirection(attr.GetS()) + if g.direction != ops.Forward { + return ops.ErrUnsupportedAttribute(attr.GetName(), g) + } + case ops.HiddenSizeAttr: g.hiddenSize = int(attr.GetI()) case "linear_before_reset": g.linearBeforeReset = ops.Int64ToBool(attr.GetI()) @@ -101,6 +122,16 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } + fActivation := ops.Activations[g.activations[0]] + if fActivation == nil { + return nil, ops.ErrUnsupportedAttribute("activations", g) + } + + gActivation := ops.Activations[g.activations[1]] + if gActivation == nil { + return nil, ops.ErrUnsupportedAttribute("activations", g) + } + outputs := []tensor.Tensor{} for i := 0; i < seqLength; i++ { @@ -109,17 +140,17 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - zt, err := g.gateCalculation(Xt, prevH, Wz, Rz, Wbz, Rbz, ops.Sigmoid) + zt, err := g.gateCalculation(Xt, prevH, Wz, Rz, Wbz, Rbz, fActivation) if err != nil { return nil, err } - rt, err := g.gateCalculation(Xt, prevH, Wr, Rr, Wbr, Rbr, ops.Sigmoid) + rt, err := g.gateCalculation(Xt, prevH, Wr, Rr, Wbr, Rbr, fActivation) if err != nil { return nil, err } - ht, err := g.htCalculation(Xt, prevH, rt, Wh, Rh, Wbh, Rbh, ops.Tanh) + ht, err := g.htCalculation(Xt, prevH, rt, Wh, Rh, Wbh, Rbh, gActivation) if err != nil { return nil, err } diff --git a/ops/opset13/gru_test.go b/ops/opset13/gru_test.go index 05e8fcc..44140f9 100644 --- a/ops/opset13/gru_test.go +++ b/ops/opset13/gru_test.go @@ -14,8 +14,12 @@ func TestGruInit(t *testing.T) { err := gru.Init(GRUOnnxNodeProtoFixture()) assert.Nil(t, err) - assert.Equal(t, true, gru.linearBeforeReset) + assert.Equal(t, []float32{1.0}, gru.activationAlpha) + assert.Equal(t, []float32{2.0}, gru.activationBeta) + assert.Equal(t, []string{"sigmoid", "tanh"}, gru.activations) + assert.Equal(t, gru.direction, ops.Forward) assert.Equal(t, 5, gru.hiddenSize) + assert.Equal(t, true, gru.linearBeforeReset) } func TestGruInitUnkownAttr(t *testing.T) { @@ -24,25 +28,9 @@ func TestGruInitUnkownAttr(t *testing.T) { attr []*onnx.AttributeProto err error }{ - { - []*onnx.AttributeProto{{Name: "activation_alpha"}}, - ops.ErrInvalidAttribute("activation_alpha", &gru), - }, - { - []*onnx.AttributeProto{{Name: "activation_beta"}}, - ops.ErrInvalidAttribute("activation_beta", &gru), - }, - { - []*onnx.AttributeProto{{Name: "direction"}}, - ops.ErrInvalidAttribute("direction", &gru), - }, { []*onnx.AttributeProto{{Name: "clip"}}, - ops.ErrInvalidAttribute("clip", &gru), - }, - { - []*onnx.AttributeProto{{Name: "activation"}}, - ops.ErrInvalidAttribute("activation", &gru), + ops.ErrUnsupportedAttribute("clip", &gru), }, { []*onnx.AttributeProto{{Name: "unknown"}}, @@ -64,25 +52,53 @@ func TestGru(t *testing.T) { err error }{ { - &GRU{4, true}, + &GRU{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: true, + }, gruInput0, []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, nil, }, { - &GRU{4, false}, + &GRU{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, gruInput0, []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, nil, }, { - &GRU{4, false}, + &GRU{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, gruInput1, []float32{0.44905475, 0.4406946, 0.43368173, 0.42782417}, nil, }, { - &GRU{4, false}, + &GRU{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, gruInputNoBNoH, []float32{0.24553154, 0.24553154, 0.24553154, 0.24553154}, nil, @@ -271,8 +287,12 @@ func gruInputNoBNoH() []tensor.Tensor { func GRUOnnxNodeProtoFixture() *onnx.NodeProto { return &onnx.NodeProto{ Attribute: []*onnx.AttributeProto{ - {Name: "linear_before_reset", I: 1}, + {Name: "activation_alpha", Floats: []float32{1.0}}, + {Name: "activation_beta", Floats: []float32{2.0}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, {Name: "hidden_size", I: 5}, + {Name: "linear_before_reset", I: 1}, }, } } diff --git a/ops/opset13/lstm.go b/ops/opset13/lstm.go index ecabc7a..7911a03 100644 --- a/ops/opset13/lstm.go +++ b/ops/opset13/lstm.go @@ -11,19 +11,12 @@ const ( MaxLSTMInputs = 8 ) -// These activations are supported in the LSTM calculation. -var LSTMActivations = map[string]ops.Activation{ - "tanh": ops.Tanh, - "sigmoid": ops.Sigmoid, - "relu": ops.ReLU, -} - // LSTM represents the ONNX lstm operator. type LSTM struct { activationAlpha []float32 activationBeta []float32 activations []string - direction RNNDirection + direction ops.SequenceProcessDirection hiddenSize int inputForget bool @@ -34,7 +27,7 @@ type LSTM struct { func newLSTM() ops.Operator { return &LSTM{ activations: []string{"sigmoid", "tanh", "tanh"}, - direction: Forward, + direction: ops.Forward, inputForget: false, outputs: []string{"Y", "Y_h", "Y_c"}, } @@ -44,26 +37,25 @@ func newLSTM() ops.Operator { func (l *LSTM) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "activation_alpha": + case ops.ActivationAlphaAttr: l.activationAlpha = attr.GetFloats() - case "activation_beta": + case ops.ActivationBetaAttr: l.activationBeta = attr.GetFloats() - case "activations": + case ops.ActivationsAttr: activations := []string{} for _, activation := range attr.GetStrings() { activations = append(activations, string(activation)) } l.activations = activations - case "clip": + case ops.ClipAttr: return ops.ErrUnsupportedAttribute(attr.GetName(), l) - case "direction": - l.direction = RNNDirection(attr.GetS()) - if l.direction != Forward { + case ops.DirectionAttr: + l.direction = ops.SequenceProcessDirection(attr.GetS()) + if l.direction != ops.Forward { return ops.ErrUnsupportedAttribute(attr.GetName(), l) } - // nolint as these attributes are operator specific - case "hidden_size": + case ops.HiddenSizeAttr: l.hiddenSize = int(attr.GetI()) case "input_forget": l.inputForget = attr.GetI() == 1 @@ -80,7 +72,7 @@ func (l *LSTM) Init(n *onnx.NodeProto) error { // Apply applies the lstm operator. func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if inputs[4] != nil { - return nil, ops.ErrUnsupportedInput("sequence lens", l) + return nil, ops.ErrUnsupportedInput("sequence_lens", l) } X := inputs[0] @@ -140,17 +132,17 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - fActivation := LSTMActivations[l.activations[0]] + fActivation := ops.Activations[l.activations[0]] if fActivation == nil { return nil, ops.ErrUnsupportedAttribute("activations", l) } - gActivation := LSTMActivations[l.activations[1]] + gActivation := ops.Activations[l.activations[1]] if gActivation == nil { return nil, ops.ErrUnsupportedAttribute("activations", l) } - hActivation := LSTMActivations[l.activations[2]] + hActivation := ops.Activations[l.activations[2]] if hActivation == nil { return nil, ops.ErrUnsupportedAttribute("activations", l) } diff --git a/ops/opset13/lstm_test.go b/ops/opset13/lstm_test.go new file mode 100644 index 0000000..83bfc86 --- /dev/null +++ b/ops/opset13/lstm_test.go @@ -0,0 +1,386 @@ +package opset13 + +import ( + "math/rand" + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestLSTMInit(t *testing.T) { + lstm := &LSTM{} + err := lstm.Init(LSTMOnnxNodeProtoFixture()) + + assert.Nil(t, err) + assert.Equal(t, []float32{1.0}, lstm.activationAlpha) + assert.Equal(t, []float32{2.0}, lstm.activationBeta) + assert.Equal(t, []string{"sigmoid", "tanh", "relu"}, lstm.activations) + assert.Equal(t, ops.Forward, lstm.direction) + assert.Equal(t, 5, lstm.hiddenSize) + assert.Equal(t, false, lstm.inputForget) + assert.Equal(t, []string{"Y", "Y_h"}, lstm.outputs) +} + +func TestLSTMInitUnkownAttr(t *testing.T) { + lstm := LSTM{} + tests := []struct { + attr []*onnx.AttributeProto + err error + }{ + { + []*onnx.AttributeProto{{Name: "clip"}}, + ops.ErrUnsupportedAttribute("clip", &lstm), + }, + { + []*onnx.AttributeProto{{Name: "unknown"}}, + ops.ErrInvalidAttribute("unknown", &lstm), + }, + } + + for _, test := range tests { + err := lstm.Init(&onnx.NodeProto{Attribute: test.attr}) + assert.Equal(t, test.err, err) + } +} + +func TestLSTM(t *testing.T) { + tests := []struct { + lstm *LSTM + inputs ops.InputFixture + expected []float32 + err error + }{ + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInput0, + []float32{0.9159305, 0.9356764, 0.87070554, 0.84180677}, + nil, + }, + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "relu"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInput0, + []float32{1.7530097, 1.7829735, 1.6231446, 1.5197954}, + nil, + }, + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "relu"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInput1, + []float32{10.598255, 10.547241, 10.214846, 10.267471}, + nil, + }, + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "relu"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInputNoBNoH, + []float32{8.276371, 8.291079, 8.161418, 7.7900877}, + nil, + }, + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInputPeepholes, + []float32{0.99891853, 0.99994266, 0.9995524, 0.99171203}, + nil, + }, + } + + for _, test := range tests { + inputs := test.inputs() + res, err := test.lstm.Apply(inputs) + assert.Equal(t, test.err, err) + + if err == nil { + assert.Equal(t, test.expected, res[1].Data()) + } + } +} + +func TestInputValidationLSTM(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + expected []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + nil, + nil, + nil, + nil, + nil, + }, + nil, + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, + nil, + ops.ErrInvalidOptionalInputCount(1, &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(1, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(0, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(2, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(3, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(4, "float32", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(5, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(6, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(7, "int", &LSTM{}), + }, + } + + for _, test := range tests { + lstm := &LSTM{} + validated, err := lstm.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + if test.expected != nil { + assert.Equal(t, test.expected, validated) + } else { + assert.Equal(t, test.inputs, validated) + } + } + } +} + +func lstmInput0() []tensor.Tensor { + rand.Seed(10) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(2, 1, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 16, 4), + // Input B: (num_directions, 8 * hidden_size). + ops.RandomFloat32TensorFixture(1, 32), + // Input sequence_lens: not supported. + nil, + // Input initial_h: (num_directions, batch_size, hidden_size). + ops.TensorWithBackingFixture(ops.Zeros(4), 1, 1, 4), + // Input initial_c: (num_directions, batch_size, hidden_size). + ops.TensorWithBackingFixture(ops.Zeros(4), 1, 1, 4), + // Input P: peephole weights. + nil, + } +} + +func lstmInput1() []tensor.Tensor { + rand.Seed(11) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(10, 1, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 16, 4), + // Input B: (num_directions, 8 * hidden_size). + ops.RandomFloat32TensorFixture(1, 32), + // Input sequence_lens: not supported. + nil, + // Input initial_h: (num_directions, batch_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 1, 4), + // Input initial_c: (num_directions, batch_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 1, 4), + // Input P: peephole weights. + nil, + } +} + +func lstmInputNoBNoH() []tensor.Tensor { + rand.Seed(12) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(10, 1, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 16, 4), + // Input B. + nil, + // Input sequence_lens: not supported. + nil, + // Input initial_h. + nil, + // Input initial_c. + nil, + // Input P: peephole weights. + nil, + } +} + +func lstmInputPeepholes() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(10, 1, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 16, 4), + // Input B. + nil, + // Input sequence_lens: not supported. + nil, + // Input initial_h. + nil, + // Input initial_c. + nil, + // Input P: (num_directions, 3 * hidden_size). + ops.RandomFloat32TensorFixture(1, 12), + } +} + +func LSTMOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{1.0}}, + {Name: "activation_beta", Floats: []float32{2.0}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("relu")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + {Name: "input_forget", I: 0}, + }, + Output: []string{"Y", "Y_h"}, + } +} diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go index 0bbbbf9..5dc9566 100644 --- a/ops/opset13/rnn.go +++ b/ops/opset13/rnn.go @@ -11,30 +11,12 @@ const ( MaxRNNInputs = 6 ) -// RNNDirection is the direction of the RNN. RNNs process sequences. We can process -// those forward (from first to last), in reverse (from last to first) or -// bidirectional (which is both forward and reverse added together). -type RNNDirection string - -const ( - Forward RNNDirection = "forward" - Reverse RNNDirection = "reverse" - Bidirectional RNNDirection = "bidirectional" -) - -// These activations are supported in the RNN calculation. -var RNNActivations = map[string]ops.Activation{ - "tanh": ops.Tanh, - "sigmoid": ops.Sigmoid, - "relu": ops.ReLU, -} - // RNN represents the ONNX rnn operator. type RNN struct { activationAlpha []float32 activationBeta []float32 activations []string - direction RNNDirection + direction ops.SequenceProcessDirection hiddenSize int } @@ -42,7 +24,7 @@ type RNN struct { func newRNN() ops.Operator { return &RNN{ activations: []string{"tanh"}, - direction: Forward, + direction: ops.Forward, } } @@ -50,26 +32,25 @@ func newRNN() ops.Operator { func (r *RNN) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "activation_alpha": + case ops.ActivationAlphaAttr: r.activationAlpha = attr.GetFloats() - case "activation_beta": + case ops.ActivationBetaAttr: r.activationBeta = attr.GetFloats() - case "activations": + case ops.ActivationsAttr: activations := []string{} for _, activation := range attr.GetStrings() { activations = append(activations, string(activation)) } r.activations = activations - case "clip": + case ops.ClipAttr: return ops.ErrUnsupportedAttribute(attr.GetName(), r) - case "direction": - r.direction = RNNDirection(attr.GetS()) - if r.direction != Forward { + case ops.DirectionAttr: + r.direction = ops.SequenceProcessDirection(attr.GetS()) + if r.direction != ops.Forward { return ops.ErrUnsupportedAttribute(attr.GetName(), r) } - // nolint as these attributes are operator specific - case "hidden_size": + case ops.HiddenSizeAttr: r.hiddenSize = int(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), r) @@ -122,7 +103,7 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - activation := RNNActivations[r.activations[0]] + activation := ops.Activations[r.activations[0]] if activation == nil { return nil, ops.ErrUnsupportedAttribute("activations", r) } @@ -137,7 +118,7 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - Ht, err = r.layerCalculation(Xt, Ht, Wi, Ri, Wbi, Rbi, RNNActivations[r.activations[0]]) + Ht, err = r.layerCalculation(Xt, Ht, Wi, Ri, Wbi, Rbi, activation) if err != nil { return nil, err } diff --git a/ops/opset13/rnn_test.go b/ops/opset13/rnn_test.go index a333240..a987ddd 100644 --- a/ops/opset13/rnn_test.go +++ b/ops/opset13/rnn_test.go @@ -18,7 +18,7 @@ func TestRNNInit(t *testing.T) { assert.Equal(t, []float32{1.0}, rnn.activationAlpha) assert.Equal(t, []float32{2.0}, rnn.activationBeta) assert.Equal(t, []string{"sigmoid"}, rnn.activations) - assert.Equal(t, RNNDirection("forward"), rnn.direction) + assert.Equal(t, ops.SequenceProcessDirection("forward"), rnn.direction) assert.Equal(t, 5, rnn.hiddenSize) } @@ -46,7 +46,7 @@ func TestRNN(t *testing.T) { activationAlpha: []float32{}, activationBeta: []float32{}, activations: []string{"tanh"}, - direction: Forward, + direction: ops.Forward, hiddenSize: 4, }, rnnInput0, @@ -58,7 +58,7 @@ func TestRNN(t *testing.T) { activationAlpha: []float32{}, activationBeta: []float32{}, activations: []string{"sigmoid"}, - direction: Forward, + direction: ops.Forward, hiddenSize: 4, }, rnnInput0, @@ -70,7 +70,7 @@ func TestRNN(t *testing.T) { activationAlpha: []float32{}, activationBeta: []float32{}, activations: []string{"relu"}, - direction: Forward, + direction: ops.Forward, hiddenSize: 4, }, rnnInput0, @@ -82,7 +82,7 @@ func TestRNN(t *testing.T) { activationAlpha: []float32{}, activationBeta: []float32{}, activations: []string{"tanh"}, - direction: Forward, + direction: ops.Forward, hiddenSize: 10, }, rnnInput1, @@ -94,7 +94,7 @@ func TestRNN(t *testing.T) { activationAlpha: []float32{}, activationBeta: []float32{}, activations: []string{"tanh"}, - direction: Forward, + direction: ops.Forward, hiddenSize: 4, }, rnnInputNoB, @@ -107,7 +107,7 @@ func TestRNN(t *testing.T) { activationAlpha: []float32{}, activationBeta: []float32{}, activations: []string{"tanh"}, - direction: Forward, + direction: ops.Forward, hiddenSize: 4, }, rnnInputNoBNoH, diff --git a/ops/recurrent_utils.go b/ops/recurrent_utils.go new file mode 100644 index 0000000..1d7e4db --- /dev/null +++ b/ops/recurrent_utils.go @@ -0,0 +1,22 @@ +package ops + +// SequenceProcessDirection is the direction in which a sequential input is processed. +// We can process sequential inputs forward (from first to last), in reverse (from +// last to first) or bidirectional (which is both forward and reverse added together). +type SequenceProcessDirection string + +const ( + Forward SequenceProcessDirection = "forward" + Reverse SequenceProcessDirection = "reverse" + Bidirectional SequenceProcessDirection = "bidirectional" +) + +// These constants define attributes that are applicable to GRU, LSTM and RNN operators. +const ( + ActivationAlphaAttr = "activation_alpha" + ActivationBetaAttr = "activation_beta" + ActivationsAttr = "activations" + ClipAttr = "clip" + DirectionAttr = "direction" + HiddenSizeAttr = "hidden_size" +) From 07612a7363eb2582e1d7c825ad522a456291abbe Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Thu, 7 Dec 2023 13:24:42 +0100 Subject: [PATCH 07/14] Refactored recurrent operators to share code --- ops/opset13/gru.go | 140 ++++++----------------------------------- ops/opset13/lstm.go | 60 +++--------------- ops/opset13/rnn.go | 101 ++++++++++------------------- ops/recurrent_utils.go | 51 +++++++++++++++ 4 files changed, 112 insertions(+), 240 deletions(-) diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index 7ecf35b..52a4652 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -70,31 +70,28 @@ func (g *GRU) Init(n *onnx.NodeProto) error { // Apply applies the gru operator. func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - X := inputs[0] - W := inputs[1] - R := inputs[2] - B := inputs[3] - if inputs[4] != nil { return nil, ops.ErrUnsupportedInput("sequence lens", g) } - initialH := inputs[5] + X := inputs[0] seqLength := X.Shape()[0] batchSize := X.Shape()[1] - Wz, Wr, Wh, err := g.getForwardWeights(W) + Wz, Wr, Wh, err := g.getWeights(inputs[1]) if err != nil { return nil, err } - Rz, Rr, Rh, err := g.getRecurrentWeights(R) + Rz, Rr, Rh, err := g.getWeights(inputs[2]) if err != nil { return nil, err } + B := inputs[3] if B == nil { - B = g.initialB() + nBiasMatrices := 6 + B = ops.ZeroTensor(1, nBiasMatrices*g.hiddenSize) } Wbz, Wbr, Wbh, Rbz, Rbr, Rbh, err := g.getBiases(B) @@ -102,15 +99,9 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - var prevH tensor.Tensor - if initialH == nil { - prevH = g.initialH(batchSize) - } else { - var ok bool - prevH, ok = initialH.Clone().(tensor.Tensor) - if !ok { - return nil, ops.ErrTypeAssert("tensor.Tensor", initialH.Clone()) - } + prevH := inputs[5] + if prevH == nil { + prevH = ops.ZeroTensor(1, batchSize, g.hiddenSize) } // Extract the shape of the hidden dimensions without the bidirectional dimension, as @@ -292,7 +283,7 @@ func (g *GRU) htCalculation( } func (g *GRU) hiddenCalculation(zt, ht, prevH tensor.Tensor) (tensor.Tensor, error) { - temp1, err := tensor.Sub(onesTensor(zt), zt) + temp1, err := tensor.Sub(ops.OnesTensor(zt), zt) if err != nil { return nil, err } @@ -310,119 +301,28 @@ func (g *GRU) hiddenCalculation(zt, ht, prevH tensor.Tensor) (tensor.Tensor, err return tensor.Add(temp2, temp3) } -// getForwardWeights returns the weights for the gate. -func (g *GRU) getForwardWeights(W tensor.Tensor) (Wz, Wr, Wh tensor.Tensor, err error) { - n, err := g.extractWeights(W) - if err != nil { - return nil, nil, nil, err - } - - return n[0], n[1], n[2], nil -} +// getWeights splits tensor W into 3 weight matrices. +func (g *GRU) getWeights(W tensor.Tensor) (Wz, Wr, Wh tensor.Tensor, err error) { + nWeightMatrices := 3 + nWeightDimensions := 3 -// getRecurrentWeights returns recurrent weights. -func (g *GRU) getRecurrentWeights(R tensor.Tensor) (Rz, Rr, Rh tensor.Tensor, err error) { - recurrentWeights, err := g.extractWeights(R) + weights, err := ops.ExtractMatrices(W, nWeightMatrices, nWeightDimensions, g.hiddenSize) if err != nil { return nil, nil, nil, err } - return recurrentWeights[0], recurrentWeights[1], recurrentWeights[2], nil + return weights[0], weights[1], weights[2], nil } // getBiases returns the biases from the Bias node as specified by the ONNX standard. func (g *GRU) getBiases(B tensor.Tensor) (Wbz, Wbr, Wbh, Rbz, Rbr, Rbh tensor.Tensor, err error) { - biases, err := g.extractBiases(B) + nBiasMatrices := 6 + nBiasDimensions := 2 + + biases, err := ops.ExtractMatrices(B, nBiasMatrices, nBiasDimensions, g.hiddenSize) if err != nil { return nil, nil, nil, nil, nil, nil, err } return biases[0], biases[1], biases[2], biases[3], biases[4], biases[5], nil } - -// extractWeights extracts 3 weight tensors from node W. -// W contains all 3 weight tensors concatenated on top of each other in the following order: -// -// forward weights: [Wz, Wr, Wh] -// recurrent weights: [Rz, Rr, Rh] -// -// W will have a shape of (num_directions, 3 * hidden_size, ...) and we extract the -// by slicing over the '3 * hidden_size' dimension. -func (g *GRU) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { - const nWeightMatrices = 3 - - dirSlice := ops.NewSlicer(0) - weights := make([]tensor.Tensor, nWeightMatrices) - - for i := 0; i < nWeightMatrices; i++ { - slice := ops.NewSlicer(i*g.hiddenSize, (i+1)*g.hiddenSize) - - w, err := W.Slice(dirSlice, slice, nil) - if err != nil { - return nil, err - } - - weights[i] = w - } - - return weights, nil -} - -// extractBiases extracts the 6 bias tensors from tensor B. -// B contains all 6 bias tensors concatenated on top of each other in the following order: -// -// [Wbz, Wbr, Wbh, Rbz, Rbr, Rbh] -// -// B has a shape of (num_directions, 6 * hidden_size) and every individual bias tensor should have -// shape (hidden_size). We extract the biases by slicing over the '6 * hidden_size' dimension. -func (g *GRU) extractBiases(B tensor.Tensor) ([]tensor.Tensor, error) { - const nWeightMatrices = 6 - - dirSlice := ops.NewSlicer(0) - biases := make([]tensor.Tensor, nWeightMatrices) - - for i := 0; i < nWeightMatrices; i++ { - slice := ops.NewSlicer(i*g.hiddenSize, (i+1)*g.hiddenSize) - - w, err := B.Slice(dirSlice, slice) - if err != nil { - return nil, err - } - - biases[i] = w - } - - return biases, nil -} - -// initialB returns the initialB tensor. If the biases are not specified by the inputs -// of the gru operator this tensor can be used as the biases tensor. By default biases -// are all 0. -func (g *GRU) initialB() tensor.Tensor { - const nWeightMatrices = 6 - - return tensor.New( - tensor.WithShape(1, nWeightMatrices*g.hiddenSize), - tensor.WithBacking(ops.Zeros(nWeightMatrices*g.hiddenSize)), - ) -} - -// initialH can be used for initialH when it is not specified by the inputs of the operator. -// if it is not specified by the inputs assumed to be 0. It has shape -// (num_directions, batch_size, hidden_size). -func (g *GRU) initialH(batchSize int) tensor.Tensor { - hiddenFloats := ops.Zeros(batchSize * g.hiddenSize) - - return tensor.New( - tensor.WithShape(1, batchSize, g.hiddenSize), - tensor.WithBacking(hiddenFloats), - ) -} - -// onesTensor returns a new tensor with the same shape as the given tensor intialized with all ones. -func onesTensor(t tensor.Tensor) tensor.Tensor { - return tensor.New( - tensor.WithShape(t.Shape()...), - tensor.WithBacking(ops.Ones(ops.NElements(t.Shape()...))), - ) -} diff --git a/ops/opset13/lstm.go b/ops/opset13/lstm.go index 7911a03..3b88d55 100644 --- a/ops/opset13/lstm.go +++ b/ops/opset13/lstm.go @@ -76,6 +76,8 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } X := inputs[0] + seqLength := X.Shape()[0] + batchSize := X.Shape()[1] Wi, Wo, Wf, Wc, err := l.getWeights(inputs[1]) if err != nil { @@ -90,7 +92,7 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { B := inputs[3] if B == nil { nBiasMatrices := 8 - B = l.getZeroTensor(1, nBiasMatrices*l.hiddenSize) + B = ops.ZeroTensor(1, nBiasMatrices*l.hiddenSize) } Wbi, Wbo, Wbf, Wbc, Rbi, Rbo, Rbf, Rbc, err := l.getBiases(B) @@ -98,17 +100,14 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - seqLength := X.Shape()[0] - batchSize := X.Shape()[1] - Ht := inputs[5] if Ht == nil { - Ht = l.getZeroTensor(1, batchSize, l.hiddenSize) + Ht = ops.ZeroTensor(1, batchSize, l.hiddenSize) } Ct := inputs[6] if Ct == nil { - Ct = l.getZeroTensor(1, batchSize, l.hiddenSize) + Ct = ops.ZeroTensor(1, batchSize, l.hiddenSize) } var Pi, Po, Pf tensor.Tensor @@ -375,7 +374,7 @@ func (l *LSTM) getWeights(W tensor.Tensor) (Wi, Wo, Wf, Wh tensor.Tensor, err er nWeightMatrices := 4 nWeightDimensions := 3 - weights, err := l.extractMatrices(W, nWeightMatrices, nWeightDimensions) + weights, err := ops.ExtractMatrices(W, nWeightMatrices, nWeightDimensions, l.hiddenSize) if err != nil { return nil, nil, nil, nil, err } @@ -388,7 +387,7 @@ func (l *LSTM) getBiases(B tensor.Tensor) (Wbi, Wbo, Wbf, Wbc, Rbi, Rbo, Rbf, Rb nBiasMatrices := 8 nBiasDimensions := 2 - b, err := l.extractMatrices(B, nBiasMatrices, nBiasDimensions) + b, err := ops.ExtractMatrices(B, nBiasMatrices, nBiasDimensions, l.hiddenSize) if err != nil { return nil, nil, nil, nil, nil, nil, nil, nil, err } @@ -401,53 +400,10 @@ func (l *LSTM) getPeepholes(P tensor.Tensor) (Pi, Po, Pf tensor.Tensor, err erro nPeepholeMatrices := 3 nPeepholeDimensions := 2 - p, err := l.extractMatrices(P, nPeepholeMatrices, nPeepholeDimensions) + p, err := ops.ExtractMatrices(P, nPeepholeMatrices, nPeepholeDimensions, l.hiddenSize) if err != nil { return nil, nil, nil, err } return p[0], p[1], p[2], nil } - -// extractMatrices extracts 4 tensors from tensor M. -// M contains all matrices concatenated on top of each other in the following order: -// -// forward weights: [Wi, Wo, Wf, Wc] -// recurrent weights: [Ri, Ro, Rf, Rc] -// biases: [Wbi, Wbo, Wbf, Wbc, Rbi, Rbo, Rbf, Rbc] -// -// M is assumed to have a shape of (num_directions, nMatrices * hidden_size, ...) and we extract the -// by slicing over the 'nMatrices * hidden_size' dimension. -func (l *LSTM) extractMatrices(M tensor.Tensor, nMatrices, nDimensions int) ([]tensor.Tensor, error) { - dirSlice := ops.NewSlicer(0) - matrices := make([]tensor.Tensor, nMatrices) - - for i := 0; i < nMatrices; i++ { - hiddenSlice := ops.NewSlicer(i*l.hiddenSize, (i+1)*l.hiddenSize) - - allSlices := make([]tensor.Slice, nDimensions) - allSlices[0] = dirSlice - allSlices[1] = hiddenSlice - - for i := 2; i < nDimensions; i++ { - allSlices[i] = nil - } - - m, err := M.Slice(allSlices...) - if err != nil { - return nil, err - } - - matrices[i] = m - } - - return matrices, nil -} - -// getZeroTensor returns a tensor filled with zeros with the given shape. -func (l *LSTM) getZeroTensor(shape ...int) tensor.Tensor { - return tensor.New( - tensor.WithShape(shape...), - tensor.WithBacking(ops.Zeros(ops.NElements(shape...))), - ) -} diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go index 5dc9566..a3db3eb 100644 --- a/ops/opset13/rnn.go +++ b/ops/opset13/rnn.go @@ -67,6 +67,8 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } X := inputs[0] + seqLength := X.Shape()[0] + batchSize := X.Shape()[1] Wi, err := r.getWeights(inputs[1]) if err != nil { @@ -80,7 +82,8 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { B := inputs[3] if B == nil { - B = r.getDefaultB() + nBiasMatrices := 2 + B = ops.ZeroTensor(1, nBiasMatrices*r.hiddenSize) } Wbi, Rbi, err := r.getBiases(B) @@ -88,12 +91,9 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - seqLength := X.Shape()[0] - batchSize := X.Shape()[1] - Ht := inputs[5] if Ht == nil { - Ht = r.getInitialH(batchSize) + Ht = ops.ZeroTensor(1, batchSize, r.hiddenSize) } // Reshape the hidden tensor without the bidirectional dimension, as @@ -189,69 +189,6 @@ func (r *RNN) String() string { return "rnn operator" } -// getWeights returns the weights from a concatenated weight tensor. The result is -// a single weight matrix. W has shape (num_directions, hidden_size, ...). -// We do not support bidirectional layers, so we can simply index the first element -// of W to get the weights for either the input or the recurrence. -func (r *RNN) getWeights(X tensor.Tensor) (tensor.Tensor, error) { - weights, err := X.Slice(ops.NewSlicer(0), nil, nil) - if err != nil { - return nil, err - } - - return weights, nil -} - -// getBiases splits an input bias tensor B into its subparts. The B input for the -// RNN operator consists of two biases, Wbi and Rbi. These biases are concatenated -// in the second dimension, where B has shape (num_directions, 2 * hiddenSize). -// This function slices the B tensor to return 2 bias tensors. We disregard the -// num_directions axis as we do not support the bidirectional direction. -func (r *RNN) getBiases(B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { - Wbi, err := B.Slice(ops.NewSlicer(0), ops.NewSlicer(0, r.hiddenSize)) - if err != nil { - return nil, nil, err - } - - nBiasMatrices := 2 - - Rbi, err := B.Slice(ops.NewSlicer(0), ops.NewSlicer(r.hiddenSize, nBiasMatrices*r.hiddenSize)) - if err != nil { - return nil, nil, err - } - - return Wbi, Rbi, nil -} - -// getDefaultB returns the default bias tensor if no bias tensor is provided. -// The bias tensor for RNN consists of two concatenated bias tensors, one for -// the input calculation and one for the hidden calculation. It has shape: -// -// (num_directions, 2*hiddenSize). -// -// By default all values are 0. Note that we do not support the bidirectional -// option so the first dim always has size 1. -func (r *RNN) getDefaultB() tensor.Tensor { - nBiasMatrices := 2 - - return tensor.New( - tensor.WithShape(1, nBiasMatrices*r.hiddenSize), - tensor.WithBacking(ops.Zeros(nBiasMatrices*r.hiddenSize)), - ) -} - -// getInitialH can be used to construct an initial hidden tensor when it is not -// specified by the inputs of the operator. In this case it is assumed to be 0. -// It has shape (num_directions, batch_size, hidden_size). -// As we do not support the birectional option, the num_directions dim size is -// always 1. -func (r *RNN) getInitialH(batchSize int) tensor.Tensor { - return tensor.New( - tensor.WithShape(1, batchSize, r.hiddenSize), - tensor.WithBacking(ops.Zeros(batchSize*r.hiddenSize)), - ) -} - // layerCalculation performs the actual RNN calculation. By ONNX definition // this is: // @@ -281,3 +218,31 @@ func (r *RNN) layerCalculation( return activation(result) } + +// getWeights returns the weights from a concatenated weight tensor. The result is +// a single weight matrix. W has shape (num_directions, hidden_size, ...). +// This function extracts 1 weight matrix from tensor W. +func (r *RNN) getWeights(W tensor.Tensor) (tensor.Tensor, error) { + nWeightMatrices := 1 + nWeightDimensions := 3 + + weights, err := ops.ExtractMatrices(W, nWeightMatrices, nWeightDimensions, r.hiddenSize) + if err != nil { + return nil, err + } + + return weights[0], nil +} + +// getBiases splits tensor B into 2 bias matrices. +func (r *RNN) getBiases(B tensor.Tensor) (Wbi, Rbi tensor.Tensor, err error) { + nBiasMatrices := 2 + nBiasDimensions := 2 + + b, err := ops.ExtractMatrices(B, nBiasMatrices, nBiasDimensions, r.hiddenSize) + if err != nil { + return nil, nil, err + } + + return b[0], b[1], nil +} diff --git a/ops/recurrent_utils.go b/ops/recurrent_utils.go index 1d7e4db..385b2a1 100644 --- a/ops/recurrent_utils.go +++ b/ops/recurrent_utils.go @@ -1,5 +1,9 @@ package ops +import ( + "gorgonia.org/tensor" +) + // SequenceProcessDirection is the direction in which a sequential input is processed. // We can process sequential inputs forward (from first to last), in reverse (from // last to first) or bidirectional (which is both forward and reverse added together). @@ -20,3 +24,50 @@ const ( DirectionAttr = "direction" HiddenSizeAttr = "hidden_size" ) + +// extractMatrices extracts a given number of matrices from tensor M. +// M contains concatenated matrices along a certain dimension. +// M is assumed to have a shape of (num_directions, nMatrices * hidden_size, ...) and we extract the +// by slicing over the 'nMatrices * hidden_size' dimension. +// This method is specific for recurrent operators RNN, GRU and LSTM. +func ExtractMatrices(M tensor.Tensor, nMatrices, nDimensions, hiddenSize int) ([]tensor.Tensor, error) { + dirSlice := NewSlicer(0) + matrices := make([]tensor.Tensor, nMatrices) + + for i := 0; i < nMatrices; i++ { + hiddenSlice := NewSlicer(i*hiddenSize, (i+1)*hiddenSize) + + allSlices := make([]tensor.Slice, nDimensions) + allSlices[0] = dirSlice + allSlices[1] = hiddenSlice + + for i := 2; i < nDimensions; i++ { + allSlices[i] = nil + } + + m, err := M.Slice(allSlices...) + if err != nil { + return nil, err + } + + matrices[i] = m + } + + return matrices, nil +} + +// ZeroTensor returns a tensor filled with zeros with the given shape. +func ZeroTensor(shape ...int) tensor.Tensor { + return tensor.New( + tensor.WithShape(shape...), + tensor.WithBacking(Zeros(NElements(shape...))), + ) +} + +// OnesTensor returns a new tensor with the same shape as the given tensor intialized with all ones. +func OnesTensor(t tensor.Tensor) tensor.Tensor { + return tensor.New( + tensor.WithShape(t.Shape()...), + tensor.WithBacking(Ones(NElements(t.Shape()...))), + ) +} From 18cd49eba827731fc3c5a7ead4c1e136fd7ec5b1 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 10 Dec 2023 19:58:53 +0100 Subject: [PATCH 08/14] WIP on batch norm --- ops/opset13/batch_normalization.go | 89 ++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 ops/opset13/batch_normalization.go diff --git a/ops/opset13/batch_normalization.go b/ops/opset13/batch_normalization.go new file mode 100644 index 0000000..d7970ee --- /dev/null +++ b/ops/opset13/batch_normalization.go @@ -0,0 +1,89 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinBatchNormalizationInputs = 5 + MaxBatchNormalizationInputs = 5 +) + +// BatchNormalization represents the ONNX batchNormalization operator. +type BatchNormalization struct { + epsilon float32 + momentum float32 + + outputs []string +} + +// newBatchNormalization creates a new batchNormalization operator. +func newBatchNormalization() ops.Operator { + return &BatchNormalization{ + epsilon: 1e-5, + momentum: 0.9, + } +} + +// Init initializes the batchNormalization operator. +func (b *BatchNormalization) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "epsilon": + b.epsilon = attr.GetF() + case "momentum": + b.momentum = attr.GetF() + default: + return ops.ErrInvalidAttribute(attr.GetName(), b) + } + } + + b.outputs = n.GetOutput() + + return nil +} + +// Apply applies the batchNormalization operator. +func (b *BatchNormalization) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + X := inputs[0] + scale := inputs[1] + B := inputs[2] + mean := inputs[3] + variance := inputs[4] + + return []tensor.Tensor{}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (b *BatchNormalization) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(b, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (b *BatchNormalization) GetMinInputs() int { + return MinBatchNormalizationInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (b *BatchNormalization) GetMaxInputs() int { + return MaxBatchNormalizationInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (b *BatchNormalization) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (b *BatchNormalization) String() string { + return "batchNormalization operator" +} From e79e9c8d6d9b2bf0b992aab1604cc74f2018f2fc Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Tue, 9 Jan 2024 14:51:19 +0100 Subject: [PATCH 09/14] Added BatchNormalization implementation --- ops/opset13/batch_normalization.go | 22 ++-- ops/opset13/batch_normalization_test.go | 146 ++++++++++++++++++++++++ ops/opset13/opset13.go | 101 ++++++++-------- ops/opset13/opset13_test.go | 5 + ops_test.go | 40 ++++--- 5 files changed, 240 insertions(+), 74 deletions(-) create mode 100644 ops/opset13/batch_normalization_test.go diff --git a/ops/opset13/batch_normalization.go b/ops/opset13/batch_normalization.go index 8c05a21..8861cf9 100644 --- a/ops/opset13/batch_normalization.go +++ b/ops/opset13/batch_normalization.go @@ -1,6 +1,8 @@ package opset13 import ( + "fmt" + "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -16,8 +18,6 @@ type BatchNormalization struct { epsilon float32 momentum float32 testMode bool - - outputs []string } // newBatchNormalization creates a new batchNormalization operator. @@ -47,8 +47,6 @@ func (b *BatchNormalization) Init(n *onnx.NodeProto) error { b.testMode = true } - b.outputs = n.GetOutput() - return nil } @@ -60,7 +58,17 @@ func (b *BatchNormalization) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, err mean := inputs[3] variance := inputs[4] - return []tensor.Tensor{}, nil + // We only support test mode, as this is by far the most common for inference models. + if !b.testMode { + return nil, ops.ErrUnsupportedAttribute("momentum", b) + } + + out, err := b.testModeCalculation(X, scale, B, mean, variance) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil } // ValidateInputs validates the inputs that will be given to Apply for this operator. @@ -160,6 +168,7 @@ func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tenso } func (b *BatchNormalization) testModeCalculation(X, scale, bias, mean, variance tensor.Tensor) (tensor.Tensor, error) { + fmt.Println("joe") newScale, newBias, newMean, newVariance, err := b.reshapeTensors(X, scale, bias, mean, variance) if err != nil { return nil, err @@ -217,6 +226,3 @@ func (b *BatchNormalization) testModeCalculation(X, scale, bias, mean, variance return outputs[0], nil } - -func (b *BatchNormalization) trainModeCalculation(X, scale, bias, mean, variance tensor.Tensor) (y, saved_mean, saved_var, output_mean, output_var tensor.Tensor, err error) { -} diff --git a/ops/opset13/batch_normalization_test.go b/ops/opset13/batch_normalization_test.go new file mode 100644 index 0000000..d00b69c --- /dev/null +++ b/ops/opset13/batch_normalization_test.go @@ -0,0 +1,146 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestBatchNormalizationInit(t *testing.T) { + b := &BatchNormalization{} + + err := b.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "epsilon", F: 0.001}, + }, + }, + ) + assert.Nil(t, err) + + assert.Equal(t, float32(0.001), b.epsilon) + assert.True(t, b.testMode) +} + +func TestBatchNormalizationInitTrainingMode(t *testing.T) { + b := &BatchNormalization{} + + err := b.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "epsilon", F: 0.001}, + {Name: "momentum", F: 0.99}, + }, + }, + ) + assert.Nil(t, err) + + assert.Equal(t, float32(0.001), b.epsilon) + assert.Equal(t, float32(0.99), b.momentum) + assert.False(t, b.testMode) +} + +func TestBatchNormalization(t *testing.T) { + tests := []struct { + batchNormalization *BatchNormalization + backings [][]float32 + shapes [][]int + expected []float32 + }{ + { + &BatchNormalization{ + epsilon: 1e5, + momentum: 0.9, + testMode: true, + }, + [][]float32{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {0.2, 0.3, 0.4}, + {0.1, -0.1, 0.2}, + {4, 8, 12}, + {1, 2, 3}, + }, + [][]int{ + {2, 3, 2, 2}, + {3}, + {3}, + {3}, + {3}, + }, + []float32{0.097470194, 0.098102644, 0.098735094, 0.09936755, -0.103794694, -0.10284603, -0.10189735, -0.10094868, 0.19494043, 0.19620533, 0.19747022, 0.19873512, 0.10505962, 0.10569207, 0.10632452, 0.10695698, -0.09241061, -0.091461934, -0.09051326, -0.08956459, 0.21011914, 0.21138403, 0.21264893, 0.21391381}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + ops.TensorWithBackingFixture(test.backings[2], test.shapes[2]...), + ops.TensorWithBackingFixture(test.backings[3], test.shapes[3]...), + ops.TensorWithBackingFixture(test.backings[4], test.shapes[4]...), + } + + res, err := test.batchNormalization.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationBatchNormalization(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, &BatchNormalization{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + ops.ErrInvalidInputType(1, "int", &BatchNormalization{}), + }, + } + + for _, test := range tests { + batchNormalization := &BatchNormalization{} + validated, err := batchNormalization.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 65b7a85..eba2ae2 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -5,56 +5,57 @@ import ( ) var operators13 = map[string]func() ops.Operator{ - "Abs": newAbs, - "Acos": newAcos, - "Acosh": newAcosh, - "Add": newAdd, - "And": newAnd, - "Asin": newAsin, - "Asinh": newAsinh, - "Atan": newAtan, - "Atanh": newAtanh, - "Cast": newCast, - "Concat": newConcat, - "Constant": newConstant, - "ConstantOfShape": newConstantOfShape, - "Conv": newConv, - "Cos": newCos, - "Cosh": newCosh, - "Div": newDiv, - "Equal": newEqual, - "Flatten": newFlatten, - "Gather": newGather, - "Gemm": newGemm, - "Greater": newGreater, - "GreaterOrEqual": newGreaterOrEqual, - "GRU": newGRU, - "Less": newLess, - "LessOrEqual": newLessOrEqual, - "LinearRegressor": newLinearRegressor, - "LSTM": newLSTM, - "MatMul": newMatMul, - "Mul": newMul, - "Not": newNot, - "Or": newOr, - "PRelu": newPRelu, - "Relu": newRelu, - "Reshape": newReshape, - "RNN": newRNN, - "Scaler": newScaler, - "Shape": newShape, - "Sigmoid": newSigmoid, - "Sin": newSin, - "Sinh": newSinh, - "Slice": newSlice, - "Softmax": newSoftmax, - "Squeeze": newSqueeze, - "Sub": newSub, - "Tan": newTan, - "Tanh": newTanh, - "Transpose": newTranspose, - "Unsqueeze": newUnsqueeze, - "Xor": newXor, + "Abs": newAbs, + "Acos": newAcos, + "Acosh": newAcosh, + "Add": newAdd, + "And": newAnd, + "Asin": newAsin, + "Asinh": newAsinh, + "Atan": newAtan, + "Atanh": newAtanh, + "BatchNormalization": newBatchNormalization, + "Cast": newCast, + "Concat": newConcat, + "Constant": newConstant, + "ConstantOfShape": newConstantOfShape, + "Conv": newConv, + "Cos": newCos, + "Cosh": newCosh, + "Div": newDiv, + "Equal": newEqual, + "Flatten": newFlatten, + "Gather": newGather, + "Gemm": newGemm, + "Greater": newGreater, + "GreaterOrEqual": newGreaterOrEqual, + "GRU": newGRU, + "Less": newLess, + "LessOrEqual": newLessOrEqual, + "LinearRegressor": newLinearRegressor, + "LSTM": newLSTM, + "MatMul": newMatMul, + "Mul": newMul, + "Not": newNot, + "Or": newOr, + "PRelu": newPRelu, + "Relu": newRelu, + "Reshape": newReshape, + "RNN": newRNN, + "Scaler": newScaler, + "Shape": newShape, + "Sigmoid": newSigmoid, + "Sin": newSin, + "Sinh": newSinh, + "Slice": newSlice, + "Softmax": newSoftmax, + "Squeeze": newSqueeze, + "Sub": newSub, + "Tan": newTan, + "Tanh": newTanh, + "Transpose": newTranspose, + "Unsqueeze": newUnsqueeze, + "Xor": newXor, } // GetOperator maps strings as found in the ModelProto to Operators from opset 13. diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index a91ec3d..e590941 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -58,6 +58,11 @@ func TestGetOperator(t *testing.T) { newAsinh(), nil, }, + { + "BatchNormalization", + newBatchNormalization(), + nil, + }, { "Cast", newCast(), diff --git a/ops_test.go b/ops_test.go index 0126321..a405939 100644 --- a/ops_test.go +++ b/ops_test.go @@ -24,22 +24,24 @@ import ( // Another reason is that some tests require an opset version higher than we have currently // implemented, or lower, which we also haven't implemented yet. var ignoredTests = []string{ - "test_add_uint8", // Opset14 - "test_div_uint8", // Opset14 - "test_gru_batchwise", // Opset14 - "test_lstm_batchwise", // Opset14 - "test_mul_uint8", // Opset14 - "test_sub_uint8", // Opset14 - "test_shape_clip_end", // Opset15 - "test_shape_clip_start", // Opset15 - "test_shape_end_1", // Opset15 - "test_shape_end_negative_1", // Opset15 - "test_shape_example", // Opset15 - "test_shape_start_1", // Opset15 - "test_shape_start_1_end_2", // Opset15 - "test_shape_start_1_end_negative_1", // Opset15 - "test_shape_start_negative_1", // Opset15 - "test_reshape_allowzero_reordered", // Opset14 + "test_add_uint8", // Opset14 + "test_batchnorm_epsilon_training_mode", // Opset14 + "test_batchnorm_example_training_mode", // Opset14 + "test_div_uint8", // Opset14 + "test_gru_batchwise", // Opset14 + "test_lstm_batchwise", // Opset14 + "test_mul_uint8", // Opset14 + "test_sub_uint8", // Opset14 + "test_shape_clip_end", // Opset15 + "test_shape_clip_start", // Opset15 + "test_shape_end_1", // Opset15 + "test_shape_end_negative_1", // Opset15 + "test_shape_example", // Opset15 + "test_shape_start_1", // Opset15 + "test_shape_start_1_end_2", // Opset15 + "test_shape_start_1_end_negative_1", // Opset15 + "test_shape_start_negative_1", // Opset15 + "test_reshape_allowzero_reordered", // Opset14 "test_constant_pad", // Pad is not implemented yet. "test_constant_pad_axes", // Pad is not implemented yet. @@ -193,6 +195,10 @@ func shouldRunTest(folder, opFilter string) bool { } } + if opFilter == "test_batchnormalization" { + opFilter = "test_batchnorm" + } + if strings.Contains(folder, opFilter) { remaining := strings.ReplaceAll(folder, opFilter, "") if len(remaining) == 0 || remaining[:1] == "_" { @@ -312,6 +318,8 @@ var expectedTests = []string{ "test_atan_example", "test_atanh", "test_atanh_example", + "test_batchnorm_epsilon", + "test_batchnorm_example", "test_cast_DOUBLE_to_FLOAT", "test_cast_FLOAT_to_DOUBLE", "test_concat_1d_axis_0", From 5465c2c6dca45b0c701e30c101e941a74a12cf42 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Tue, 9 Jan 2024 14:55:30 +0100 Subject: [PATCH 10/14] Comment and remove print --- ops/opset13/batch_normalization.go | 3 --- ops_test.go | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ops/opset13/batch_normalization.go b/ops/opset13/batch_normalization.go index 8861cf9..c548fb5 100644 --- a/ops/opset13/batch_normalization.go +++ b/ops/opset13/batch_normalization.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -168,7 +166,6 @@ func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tenso } func (b *BatchNormalization) testModeCalculation(X, scale, bias, mean, variance tensor.Tensor) (tensor.Tensor, error) { - fmt.Println("joe") newScale, newBias, newMean, newVariance, err := b.reshapeTensors(X, scale, bias, mean, variance) if err != nil { return nil, err diff --git a/ops_test.go b/ops_test.go index a405939..66f4bdf 100644 --- a/ops_test.go +++ b/ops_test.go @@ -195,6 +195,8 @@ func shouldRunTest(folder, opFilter string) bool { } } + // For some reason ONNX decided to not let these testcases match the operator name. + // Here we manually replace the filter with the name ONNX uses for this test case. if opFilter == "test_batchnormalization" { opFilter = "test_batchnorm" } From 18ca63b466b06231df26e049084d11d72d16355d Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Tue, 9 Jan 2024 14:58:18 +0100 Subject: [PATCH 11/14] Fix lint --- ops/opset13/batch_normalization.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ops/opset13/batch_normalization.go b/ops/opset13/batch_normalization.go index c548fb5..6caadbd 100644 --- a/ops/opset13/batch_normalization.go +++ b/ops/opset13/batch_normalization.go @@ -7,8 +7,10 @@ import ( ) const ( - MinBatchNormalizationInputs = 5 - MaxBatchNormalizationInputs = 5 + MinBatchNormalizationInputs = 5 + MaxBatchNormalizationInputs = 5 + BatchNormalizationDefaultEpsilon = 1e-5 + BatchNormalizationDefaultMomentum = 0.9 ) // BatchNormalization represents the ONNX batchNormalization operator. @@ -21,14 +23,15 @@ type BatchNormalization struct { // newBatchNormalization creates a new batchNormalization operator. func newBatchNormalization() ops.Operator { return &BatchNormalization{ - epsilon: 1e-5, - momentum: 0.9, + epsilon: BatchNormalizationDefaultEpsilon, + momentum: BatchNormalizationDefaultMomentum, } } // Init initializes the batchNormalization operator. func (b *BatchNormalization) Init(n *onnx.NodeProto) error { hasMomentum := false + for _, attr := range n.GetAttribute() { switch attr.GetName() { case "epsilon": @@ -102,7 +105,9 @@ func (b *BatchNormalization) String() string { } func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tensor.Tensor) (newScale, newBias, newMean, newVariance tensor.Tensor, err error) { - nSpatialDims := len(X.Shape()) - 2 + nNonSpatialDims := 2 + + nSpatialDims := len(X.Shape()) - nNonSpatialDims if nSpatialDims <= 0 { return scale, bias, mean, variance, nil } From faf43dff421dd439f0219833cb92e13ff73a1ec9 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Tue, 9 Jan 2024 15:04:43 +0100 Subject: [PATCH 12/14] Ignore newly added tests --- ops_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ops_test.go b/ops_test.go index 66f4bdf..5a16fc2 100644 --- a/ops_test.go +++ b/ops_test.go @@ -72,6 +72,16 @@ var ignoredTests = []string{ "test_equal_string", // Unsupported datatype String. "test_equal_string_broadcast", // Unsupported datatype String. + "test_cast_INT4_to_INT8", // Unsupported datatype INT4. + "test_cast_INT4_to_FLOAT", // Unsupported datatype INT4. + "test_cast_FLOAT_to_INT4", // Unsupported datatype INT4. + "test_cast_FLOAT_to_UINT4", // Unsupported datatype UINT4. + "test_cast_INT4_to_FLOAT16", // Unsupported datatype INT4/FLOAT16. + "test_cast_FLOAT16_to_UINT4", // Unsupported datatype FLOAT16. + "test_cast_FLOAT16_to_INT4", // Unsupported datatype FLOAT16. + "test_cast_UINT4_to_UINT8", // Unsupported datatype UINT4. + "test_cast_UINT4_to_FLOAT", // Unsupported datatype UINT4. + "test_cast_UINT4_to_FLOAT16", // Unsupported datatype UINT4. "test_cast_FLOAT_to_STRING", // Unsupported datatype STRING. "test_cast_STRING_to_FLOAT", // Unsupported datatype STRING. "test_cast_DOUBLE_to_FLOAT16", // Unsupported datatype FLOAT16. @@ -175,6 +185,7 @@ func getTestCasesForOp(opName string) ([]*ONNXTestCase, error) { for _, testFolder := range testFolders { if shouldRunTest(testFolder, opFilter) { + fmt.Println(testFolder) testcase, err := getTestCase(fmt.Sprintf("./test_data/%v", testFolder)) if err != nil { return nil, err From b3d02687e2f73a615096b3eefc8a972ea2d53357 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Tue, 9 Jan 2024 15:09:43 +0100 Subject: [PATCH 13/14] Remove print statement --- ops_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/ops_test.go b/ops_test.go index 5a16fc2..64d9fb7 100644 --- a/ops_test.go +++ b/ops_test.go @@ -185,7 +185,6 @@ func getTestCasesForOp(opName string) ([]*ONNXTestCase, error) { for _, testFolder := range testFolders { if shouldRunTest(testFolder, opFilter) { - fmt.Println(testFolder) testcase, err := getTestCase(fmt.Sprintf("./test_data/%v", testFolder)) if err != nil { return nil, err From fe10c126f5f40436434d3974468b68af5f5ea807 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Mon, 21 Oct 2024 12:13:46 +0200 Subject: [PATCH 14/14] Error earlier --- ops/opset13/batch_normalization.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ops/opset13/batch_normalization.go b/ops/opset13/batch_normalization.go index 6caadbd..50c9e19 100644 --- a/ops/opset13/batch_normalization.go +++ b/ops/opset13/batch_normalization.go @@ -48,6 +48,11 @@ func (b *BatchNormalization) Init(n *onnx.NodeProto) error { b.testMode = true } + // We only support test mode, as this is by far the most common for inference models. + if !b.testMode { + return ops.ErrUnsupportedAttribute("momentum", b) + } + return nil } @@ -59,11 +64,6 @@ func (b *BatchNormalization) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, err mean := inputs[3] variance := inputs[4] - // We only support test mode, as this is by far the most common for inference models. - if !b.testMode { - return nil, ops.ErrUnsupportedAttribute("momentum", b) - } - out, err := b.testModeCalculation(X, scale, B, mean, variance) if err != nil { return nil, err