diff --git a/command_check.go b/command_check.go index 8d3672e4..6052fc72 100644 --- a/command_check.go +++ b/command_check.go @@ -21,6 +21,12 @@ import ( "github.com/looplab/eventhorizon/uuid" ) +// IsZeroer is used to check if a type is zero-valued, and in that case +// is not allowed to be used in a command. See CheckCommand +type IsZeroer interface { + IsZero() bool +} + // CommandFieldError is returned by Dispatch when a field is incorrect. type CommandFieldError struct { Field string @@ -47,7 +53,15 @@ func CheckCommand(cmd Command) error { continue // Optional field. } - if isZero(rv.Field(i)) { + var zero bool + switch foo := rv.Field(i).Interface().(type) { + case IsZeroer: + zero = foo.IsZero() + default: + zero = isZero(rv.Field(i)) + } + + if zero { return CommandFieldError{field.Name} } } diff --git a/command_check_test.go b/command_check_test.go index d3a27708..d18cf6b8 100644 --- a/command_check_test.go +++ b/command_check_test.go @@ -111,6 +111,26 @@ func TestCheckCommand(t *testing.T) { if err == nil || err.Error() != "missing field: StringArray" { t.Error("there should be a missing field error:", err) } + + // IsZero, fail on zeroable int + err = CheckCommand(&TestCommandZeroableInt{ + TestID: uuid.New(), + TestZeroableInt: 0, + TestInt: 0, + }) + if err == nil || err.Error() != "missing field: TestZeroableInt"{ + t.Error("there should be a missing field error:", err) + } + + // IsZero, do not fail on plain int + err = CheckCommand(&TestCommandZeroableInt{ + TestID: uuid.New(), + TestZeroableInt: 1, + TestInt: 0, + }) + if err != nil { + t.Error("there should not be an error:", err) + } } type TestCommandFields struct { @@ -288,3 +308,24 @@ var _ = Command(TestCommandArray{}) func (t TestCommandArray) AggregateID() uuid.UUID { return t.TestID } func (t TestCommandArray) AggregateType() AggregateType { return AggregateType("Test") } func (t TestCommandArray) CommandType() CommandType { return CommandType("TestCommandArray") } + +type ZeroableInt int + +type TestCommandZeroableInt struct { + TestID uuid.UUID + TestZeroableInt ZeroableInt + TestInt int +} + +var _ = Command(TestCommandZeroableInt{}) + + +func (t TestCommandZeroableInt) AggregateID() uuid.UUID { return t.TestID } +func (t TestCommandZeroableInt) AggregateType() AggregateType { return TestAggregateType } +func (t TestCommandZeroableInt) CommandType() CommandType { + return CommandType("TestCommandZeroableInt") +} + +func (z ZeroableInt) IsZero () bool { + return z == 0 +}