Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "must" functions for tersely asserting function success #2

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,56 @@ func ErrorIsAll(expected ...error) ErrorCheck {
}
}
}

// Must can be used on a (value, error) pair to either get the value or
// immediately fail the test if the error is non-nil. The T parameter is
// curried, rather than passed as a third argument, so that (value, error)
// function return values can be passed to Must directly, without assigning them
// to intermediate variables.
//
// See also Must0, Must2, and Must3 for working with functions of other coarity.
//
// bytes := expect.Must(io.ReadAll(reader))(t)
func Must[V any](value V, err error) func(T) V {
return func(t T) V {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
return value
}
}

// Must0 is similar to Must but for functions returning just an error, without a
// value.
func Must0(err error) func(T) {
Comment on lines +156 to +158
Copy link
Contributor Author

@jonathansharman jonathansharman Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function wouldn't have to be curried since a unary function call can be passed to another function directly even when there are other parameters. However, I opted to follow the pattern of the other functions for consistency.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely took me a minute to really understand why we're returning an unary function of T. But now I get it. I kinda long for a type-safe arg spread operator. this is a reasonably un-awkward solution.

return func(t T) {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
}
}

// Must2 is similar to Must but for functions returning two values and an error.
func Must2[V1 any, V2 any](value1 V1, value2 V2, err error) func(T) (V1, V2) {
return func(t T) (V1, V2) {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
return value1, value2
}
}

// Must3 is similar to Must but for functions returning three values and an
// error.
func Must3[V1 any, V2 any, V3 any](value1 V1, value2 V2, value3 V3, err error) func(T) (V1, V2, V3) {
Comment on lines +178 to +180
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to draw the line somewhere, and I stopped at three return values (plus the error). Anything over one or maybe two non-error return values is arguably already excessive.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable - though we could go-generate MustN if we had to

return func(t T) (V1, V2, V3) {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
return value1, value2, value3
}
}
141 changes: 141 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,146 @@ func TestErrors(t *testing.T) {
}
}

func TestMust(t *testing.T) {
type testCase struct {
f func() (bool, error)
expectedValue bool
expectedFatalCalls int32
}
run := func(name string, testCase testCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tMock := newTMock()
value := Must(testCase.f())(tMock)
Equal(t, value, testCase.expectedValue)
Equal(t, testCase.expectedFatalCalls, tMock.FatalfCalled)
})
}

run("Error", testCase{
f: func() (bool, error) {
return false, ErrTest
},
expectedValue: false,
expectedFatalCalls: 1,
})
run("Success", testCase{
f: func() (bool, error) {
return true, nil
},
expectedValue: true,
expectedFatalCalls: 0,
})
}

func TestMust0(t *testing.T) {
type testCase struct {
f func() error
expectedFatalCalls int32
}
run := func(name string, testCase testCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tMock := newTMock()
Must0(testCase.f())(tMock)
Equal(t, testCase.expectedFatalCalls, tMock.FatalfCalled)
})
}

run("Error", testCase{
f: func() error {
return ErrTest
},
expectedFatalCalls: 1,
})
run("Success", testCase{
f: func() error {
return nil
},
expectedFatalCalls: 0,
})
}

func TestMust2(t *testing.T) {
type testCase struct {
f func() (bool, bool, error)
expectedValue1 bool
expectedValue2 bool
expectedFatalCalls int32
}
run := func(name string, testCase testCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tMock := newTMock()
value1, value2 := Must2(testCase.f())(tMock)
Equal(t, value1, testCase.expectedValue1)
Equal(t, value2, testCase.expectedValue2)
Equal(t, testCase.expectedFatalCalls, tMock.FatalfCalled)
})
}

run("Error", testCase{
f: func() (bool, bool, error) {
return false, false, ErrTest
},
expectedValue1: false,
expectedValue2: false,
expectedFatalCalls: 1,
})
run("Success", testCase{
f: func() (bool, bool, error) {
return true, true, nil
},
expectedValue1: true,
expectedValue2: true,
expectedFatalCalls: 0,
})
}

func TestMust3(t *testing.T) {
type testCase struct {
f func() (bool, bool, bool, error)
expectedValue1 bool
expectedValue2 bool
expectedValue3 bool
expectedFatalCalls int32
}
run := func(name string, testCase testCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Helper()
tMock := newTMock()
value1, value2, value3 := Must3(testCase.f())(tMock)
Equal(t, value1, testCase.expectedValue1)
Equal(t, value2, testCase.expectedValue2)
Equal(t, value3, testCase.expectedValue3)
Equal(t, testCase.expectedFatalCalls, tMock.FatalfCalled)
})
}

run("Error", testCase{
f: func() (bool, bool, bool, error) {
return false, false, false, ErrTest
},
expectedValue1: false,
expectedValue2: false,
expectedValue3: false,
expectedFatalCalls: 1,
})
run("Success", testCase{
f: func() (bool, bool, bool, error) {
return true, true, true, nil
},
expectedValue1: true,
expectedValue2: true,
expectedValue3: true,
expectedFatalCalls: 0,
})
}

type testErrorA struct{}

func (e testErrorA) Error() string {
Expand Down Expand Up @@ -199,5 +339,6 @@ func newTMock() *TMock {
return &TMock{
HelperStub: func() {},
ErrorfStub: func(format string, args ...any) {},
FatalfStub: func(format string, args ...any) {},
}
}
1 change: 1 addition & 0 deletions testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ package expect
type T interface {
Helper()
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
}
30 changes: 30 additions & 0 deletions testing_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,59 @@ package expect

import (
"sync/atomic"
"testing"
)

// TMock is a mock implementation of the T
// interface.
type TMock struct {
T *testing.T
HelperStub func()
HelperCalled int32
ErrorfStub func(format string, args ...any)
ErrorfCalled int32
FatalfStub func(format string, args ...any)
FatalfCalled int32
}

// Verify that *TMock implements T.
var _ T = &TMock{}

// Helper is a stub for the T.Helper
// method that records the number of times it has been called.
func (m *TMock) Helper() {
atomic.AddInt32(&m.HelperCalled, 1)
if m.HelperStub == nil {
if m.T != nil {
m.T.Error("HelperStub is nil")
}
panic("Helper unimplemented")
}
m.HelperStub()
}

// Errorf is a stub for the T.Errorf
// method that records the number of times it has been called.
func (m *TMock) Errorf(format string, args ...any) {
atomic.AddInt32(&m.ErrorfCalled, 1)
if m.ErrorfStub == nil {
if m.T != nil {
m.T.Error("ErrorfStub is nil")
}
panic("Errorf unimplemented")
}
m.ErrorfStub(format, args...)
}

// Fatalf is a stub for the T.Fatalf
// method that records the number of times it has been called.
func (m *TMock) Fatalf(format string, args ...any) {
atomic.AddInt32(&m.FatalfCalled, 1)
if m.FatalfStub == nil {
if m.T != nil {
m.T.Error("FatalfStub is nil")
}
panic("Fatalf unimplemented")
}
m.FatalfStub(format, args...)
}
Loading