Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ArgMax operator #209

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions ops/opset13/argmax.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package opset13

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

const (
MinArgMaxInputs = 1
MaxArgMaxInputs = 1
)

// ArgMax represents the ONNX argmax operator.
type ArgMax struct {
axis int
keepDims bool
selectLastIndex bool
}

// newArgMax creates a new argmax operator.
func newArgMax() ops.Operator {
return &ArgMax{
keepDims: true,
selectLastIndex: false,
}
}

type ArgMaxAttribute string

const (
axis = "axis"
keepDims = "keepdims"
selectLastIndex = "select_last_index"
)

// Init initializes the argmax operator.
func (a *ArgMax) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()
for _, attr := range attributes {
switch attr.GetName() {
case axis:
a.axis = int(attr.GetI())
case keepDims:
a.keepDims = ops.Int64ToBool(attr.GetI())
case selectLastIndex:
a.selectLastIndex = ops.Int64ToBool(attr.GetI())

// We have no way yet to perform argmax and keeping the
// last index as max in case of duplicates, so if this
// attribute is true, we raise an unsupported error.
if a.selectLastIndex {
return ops.ErrUnsupportedAttribute(attr.GetName(), a)
}
default:
return ops.ErrInvalidAttribute(attr.GetName(), a)
}
}

return nil
}

// Apply applies the argmax operator.
func (a *ArgMax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
axis := ops.ConvertNegativeAxis(a.axis, len(inputs[0].Shape()))

reduced, err := tensor.Argmax(inputs[0], axis)
if err != nil {
return nil, err
}

// Keep the reduced dimension, i.e. if the reduced axis was '1', and
// the original shape was (2, 4, 5), the reduced shape would be (2, 5).
// If keepDims is true, that shape should be (2, 1, 5).
if a.keepDims {
newShape := inputs[0].Shape()
newShape[axis] = 1

if err := reduced.Reshape(newShape...); err != nil {
return nil, err
}
}

// The tensor.Argmax function returns data of type int, but according to
// the ONNX standard this operator should return int64.
backing, ok := reduced.Data().([]int)
if !ok {
return nil, ops.ErrTypeAssert("int", reduced.Dtype())
}

backing2 := make([]int64, len(backing))
for i := range backing {
backing2[i] = int64(backing[i])
}

reduced = tensor.New(tensor.WithShape(reduced.Shape()...), tensor.WithBacking(backing2))

return []tensor.Tensor{reduced}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (a *ArgMax) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(a, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (a *ArgMax) GetMinInputs() int {
return MinArgMaxInputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (a *ArgMax) GetMaxInputs() int {
return MaxArgMaxInputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (a *ArgMax) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{
{tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (a *ArgMax) String() string {
return "argmax operator"
}
134 changes: 134 additions & 0 deletions ops/opset13/argmax_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestArgMaxInit(t *testing.T) {
a := &ArgMax{}

err := a.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "axis", I: 2},
{Name: "keepdims", I: 0},
{Name: "select_last_index", I: 0},
},
},
)
assert.Nil(t, err)

assert.Equal(t, 2, a.axis)
assert.Equal(t, false, a.keepDims)
assert.Equal(t, false, a.selectLastIndex)
}

func TestArgMax(t *testing.T) {
tests := []struct {
argmax *ArgMax
backing []float32
shape []int
expectedShape tensor.Shape
expectedData []int64
}{
{
&ArgMax{axis: 0, keepDims: true},
[]float32{0, 1, 2, 3},
[]int{2, 2},
[]int{1, 2},
[]int64{1, 1},
},
{
&ArgMax{axis: -1, keepDims: true},
[]float32{0, 1, 2, 3},
[]int{2, 2},
[]int{2, 1},
[]int64{1, 1},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.argmax.Apply(inputs)
assert.Nil(t, err)

assert.Equal(t, test.expectedShape, res[0].Shape())
assert.Equal(t, test.expectedData, res[0].Data())
}
}

func TestInputValidationArgMax(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
ops.ErrInvalidInputCount(2, &ArgMax{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &ArgMax{}),
},
}

for _, test := range tests {
argmax := &ArgMax{}
validated, err := argmax.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
2 changes: 1 addition & 1 deletion ops/opset13/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func newExpand() ops.Operator {
}

// Init initializes the expand operator.
func (f *Expand) Init(n *onnx.NodeProto) error {
func (f *Expand) Init(*onnx.NodeProto) error {
return nil
}

Expand Down
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var operators13 = map[string]func() ops.Operator{
"Acosh": newAcosh,
"Add": newAdd,
"And": newAnd,
"ArgMax": newArgMax,
"Asin": newAsin,
"Asinh": newAsinh,
"Atan": newAtan,
Expand Down
20 changes: 13 additions & 7 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,11 @@ func TestGetOperator(t *testing.T) {
nil,
},
{
"Atan",
newAtan(),
nil,
},
{
"Atanh",
newAtanh(),
"ArgMax",
newArgMax(),
nil,
},

{
"Asin",
newAsin(),
Expand All @@ -58,6 +54,16 @@ func TestGetOperator(t *testing.T) {
newAsinh(),
nil,
},
{
"Atan",
newAtan(),
nil,
},
{
"Atanh",
newAtanh(),
nil,
},
{
"Cast",
newCast(),
Expand Down
7 changes: 1 addition & 6 deletions ops/opset13/reduce_max.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ func (r *ReduceMax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {

axes := make([]int, len(r.axes))
for i, axis := range r.axes {
// Convert negative dimensions.
if axis < 0 {
axis = len(input.Shape()) + axis
}

axes[i] = axis
axes[i] = ops.ConvertNegativeAxis(axis, len(input.Shape()))
}

out, err := input.Max(axes...)
Expand Down
8 changes: 1 addition & 7 deletions ops/opset13/reduce_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,7 @@ func (r *ReduceMin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {

axes := make([]int, len(r.axes))
for i, axis := range r.axes {
// Convert negative dimensions to positive dimensions as Go does not support
// negative dimension indexing like Python does.
if axis < 0 {
axis = len(input.Shape()) + axis
}

axes[i] = axis
axes[i] = ops.ConvertNegativeAxis(axis, len(input.Shape()))
}

out, err := input.Min(axes...)
Expand Down
17 changes: 17 additions & 0 deletions ops/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,20 @@ func PairwiseAssign(t1, t2 tensor.Tensor) (err error) {

return nil
}

// Converts a negative axis to the corresponding axis such that it can be used as index.
// For example, if axis is -1, this represents the last dimension. Go does not support
// negative indexing (as opposed to Python, on which ONNX is heavily dependent), so we
// have to convert the negative axis to the positive axis it represents, which is dependent
// on the rank (number of dimensions) of the tensor.
// Example 1: if rank is 3, and axis is -1, the corresponding positive axis is 2.
// Example 2: if rank is 4, and axis is -1, the corresponding positive axis is 3.
// Example 3: if rank is 4, and axis is -3, the corresponding positive axis is 1.
// Example 4: if rank is 3, and axis is 2, the function does nothing.
func ConvertNegativeAxis(axis, rank int) int {
if axis < 0 {
axis = rank + axis
}

return axis
}
17 changes: 17 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ var ignoredTests = []string{
"test_prelu_broadcast_expanded", // Unsupported operator CastLike
"test_prelu_example_expanded", // Unsupported operator CastLike
"test_constant_pad_negative_axes", // Unsupported operator Pad

"test_argmax_keepdims_random_select_last_index", // Unsupported attribute
"test_argmax_keepdims_example_select_last_index", // Unsupported attribute
"test_argmax_no_keepdims_example_select_last_index", // Unsupported attribute
"test_argmax_no_keepdims_random_select_last_index", // Unsupported attribute
"test_argmax_default_axis_example_select_last_index", // Unsupported attribute
"test_argmax_default_axis_random_select_last_index", // Unsupported attribute
"test_argmax_negative_axis_keepdims_example_select_last_index", // Unsupported attribute
"test_argmax_negative_axis_keepdims_random_select_last_index", // Unsupported attribute
}

type ONNXTestCase struct {
Expand Down Expand Up @@ -354,6 +363,14 @@ var expectedTests = []string{
"test_and_bcast4v2d",
"test_and_bcast4v3d",
"test_and_bcast4v4d",
"test_argmax_default_axis_example",
"test_argmax_default_axis_random",
"test_argmax_keepdims_example",
"test_argmax_keepdims_random",
"test_argmax_negative_axis_keepdims_example",
"test_argmax_negative_axis_keepdims_random",
"test_argmax_no_keepdims_example",
"test_argmax_no_keepdims_random",
"test_asin",
"test_asin_example",
"test_asinh",
Expand Down
Loading