Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BatchNormalization operator #186

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
230 changes: 230 additions & 0 deletions ops/opset13/batch_normalization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package opset13

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

const (
MinBatchNormalizationInputs = 5
MaxBatchNormalizationInputs = 5
BatchNormalizationDefaultEpsilon = 1e-5
BatchNormalizationDefaultMomentum = 0.9
)

// BatchNormalization represents the ONNX batchNormalization operator.
type BatchNormalization struct {
epsilon float32
momentum float32
testMode bool
}

// newBatchNormalization creates a new batchNormalization operator.
func newBatchNormalization() ops.Operator {
return &BatchNormalization{
epsilon: BatchNormalizationDefaultEpsilon,
momentum: BatchNormalizationDefaultMomentum,
}
}

// Init initializes the batchNormalization operator.
func (b *BatchNormalization) Init(n *onnx.NodeProto) error {
hasMomentum := false

for _, attr := range n.GetAttribute() {
switch attr.GetName() {
case "epsilon":
b.epsilon = attr.GetF()
case "momentum":
hasMomentum = true
b.momentum = attr.GetF()
default:
return ops.ErrInvalidAttribute(attr.GetName(), b)
}
}

if !hasMomentum {
b.testMode = true
}

// We only support test mode, as this is by far the most common for inference models.
if !b.testMode {
return ops.ErrUnsupportedAttribute("momentum", b)
}

return nil
}

// Apply applies the batchNormalization operator.
func (b *BatchNormalization) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
X := inputs[0]
scale := inputs[1]
B := inputs[2]
mean := inputs[3]
variance := inputs[4]

out, err := b.testModeCalculation(X, scale, B, mean, variance)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (b *BatchNormalization) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(b, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (b *BatchNormalization) GetMinInputs() int {
return MinBatchNormalizationInputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (b *BatchNormalization) GetMaxInputs() int {
return MaxBatchNormalizationInputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (b *BatchNormalization) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (b *BatchNormalization) String() string {
return "batchNormalization operator"
}

func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tensor.Tensor) (newScale, newBias, newMean, newVariance tensor.Tensor, err error) {
nNonSpatialDims := 2

nSpatialDims := len(X.Shape()) - nNonSpatialDims
if nSpatialDims <= 0 {
return scale, bias, mean, variance, nil
}

// The new shape for the `scale`, `bias`, `mean` and `variance` tensors should
// be (C, 1, 1, ...), such that they can be broadcasted to match the shape of `X`.
newShape := make([]int, 1+nSpatialDims)

// Here we set the channel dimension. The channel dimension is the same
// for all `X`, `scale`, `bias`, `mean` and `variance` tensors.
newShape[0] = scale.Shape()[0]

// Set all the remaining dimensions to 1 to allow for broadcasting.
for i := 1; i < len(newShape); i++ {
newShape[i] = 1
}

// Now we create new tensors for all the input tensors (except `X`) and reshape
// them.
newScale, ok := scale.Clone().(tensor.Tensor)
if !ok {
return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", scale.Clone())
}

newBias, ok = bias.Clone().(tensor.Tensor)
if !ok {
return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", bias.Clone())
}

newMean, ok = mean.Clone().(tensor.Tensor)
if !ok {
return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", mean.Clone())
}

newVariance, ok = variance.Clone().(tensor.Tensor)
if !ok {
return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", variance.Clone())
}

err = newScale.Reshape(newShape...)
if err != nil {
return nil, nil, nil, nil, err
}

err = newBias.Reshape(newShape...)
if err != nil {
return nil, nil, nil, nil, err
}

err = newMean.Reshape(newShape...)
if err != nil {
return nil, nil, nil, nil, err
}

err = newVariance.Reshape(newShape...)
if err != nil {
return nil, nil, nil, nil, err
}

return
}

func (b *BatchNormalization) testModeCalculation(X, scale, bias, mean, variance tensor.Tensor) (tensor.Tensor, error) {
newScale, newBias, newMean, newVariance, err := b.reshapeTensors(X, scale, bias, mean, variance)
if err != nil {
return nil, err
}

numerator, err := ops.ApplyBinaryOperation(
X,
newMean,
ops.Sub,
ops.UnidirectionalBroadcasting,
)
if err != nil {
return nil, err
}

numerator, err = ops.ApplyBinaryOperation(
numerator[0],
newScale,
ops.Mul,
ops.UnidirectionalBroadcasting,
)
if err != nil {
return nil, err
}

denominator, err := tensor.Add(newVariance, b.epsilon)
if err != nil {
return nil, err
}

denominator, err = tensor.Sqrt(denominator)
if err != nil {
return nil, err
}

outputs, err := ops.ApplyBinaryOperation(
numerator[0],
denominator,
ops.Div,
ops.UnidirectionalBroadcasting,
)
if err != nil {
return nil, err
}

outputs, err = ops.ApplyBinaryOperation(
outputs[0],
newBias,
ops.Add,
ops.UnidirectionalBroadcasting,
)
if err != nil {
return nil, err
}

return outputs[0], nil
}
146 changes: 146 additions & 0 deletions ops/opset13/batch_normalization_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestBatchNormalizationInit(t *testing.T) {
b := &BatchNormalization{}

err := b.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "epsilon", F: 0.001},
},
},
)
assert.Nil(t, err)

assert.Equal(t, float32(0.001), b.epsilon)
assert.True(t, b.testMode)
}

func TestBatchNormalizationInitTrainingMode(t *testing.T) {
b := &BatchNormalization{}

err := b.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "epsilon", F: 0.001},
{Name: "momentum", F: 0.99},
},
},
)
assert.Equal(t, ops.ErrUnsupportedAttribute("momentum", b), err)

assert.Equal(t, float32(0.001), b.epsilon)
assert.Equal(t, float32(0.99), b.momentum)
assert.False(t, b.testMode)
}

func TestBatchNormalization(t *testing.T) {
tests := []struct {
batchNormalization *BatchNormalization
backings [][]float32
shapes [][]int
expected []float32
}{
{
&BatchNormalization{
epsilon: 1e5,
momentum: 0.9,
testMode: true,
},
[][]float32{
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{0.2, 0.3, 0.4},
{0.1, -0.1, 0.2},
{4, 8, 12},
{1, 2, 3},
},
[][]int{
{2, 3, 2, 2},
{3},
{3},
{3},
{3},
},
[]float32{0.097470194, 0.098102644, 0.098735094, 0.09936755, -0.103794694, -0.10284603, -0.10189735, -0.10094868, 0.19494043, 0.19620533, 0.19747022, 0.19873512, 0.10505962, 0.10569207, 0.10632452, 0.10695698, -0.09241061, -0.091461934, -0.09051326, -0.08956459, 0.21011914, 0.21138403, 0.21264893, 0.21391381},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...),
ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...),
ops.TensorWithBackingFixture(test.backings[2], test.shapes[2]...),
ops.TensorWithBackingFixture(test.backings[3], test.shapes[3]...),
ops.TensorWithBackingFixture(test.backings[4], test.shapes[4]...),
}

res, err := test.batchNormalization.Apply(inputs)
assert.Nil(t, err)

assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationBatchNormalization(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputCount(1, &BatchNormalization{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]int{3, 4}, 2),
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
ops.ErrInvalidInputType(1, "int", &BatchNormalization{}),
},
}

for _, test := range tests {
batchNormalization := &BatchNormalization{}
validated, err := batchNormalization.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
Loading
Loading