Skip to content

Commit

Permalink
Add CallWrapper interface for type-safe calls
Browse files Browse the repository at this point in the history
Add a new exported interface `CallWrapper` which allow users to use
`InOrder` and `After` with generated type-safe mock types.
  • Loading branch information
EstebanOlmedo committed Sep 11, 2023
1 parent 9b18c60 commit afa3d42
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
13 changes: 9 additions & 4 deletions gomock/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ type Call struct {
actions []func([]any) []any
}

// CallWrapper is an interface for retrieving a *Call.
type CallWrapper interface {
GetCall() *Call
}

// newCall creates a *Call. It requires the method type in order to support
// unexported methods.
func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, args ...any) *Call {
Expand Down Expand Up @@ -79,8 +84,8 @@ func newCall(t TestHelper, receiver any, method string, methodType reflect.Type,
args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
}

// GetCall returns the current `*Call` instance, this is needed to fulfill the
// interface that `InOrder` and `After` receive as parameter.
// GetCall returns the current `*Call` instance, this is needed to implement
// the CallWrapper interface.
func (c *Call) GetCall() *Call {
return c
}
Expand Down Expand Up @@ -294,7 +299,7 @@ func (c *Call) isPreReq(other *Call) bool {
}

// After declares that the call may only match after preReq has been exhausted.
func (c *Call) After(prq interface{GetCall() *Call}) *Call {
func (c *Call) After(prq CallWrapper) *Call {
preReq := prq.GetCall()
c.t.Helper()

Expand Down Expand Up @@ -442,7 +447,7 @@ func (c *Call) call() []func([]any) []any {
}

// InOrder declares that the given calls should occur in order.
func InOrder(calls ...interface{GetCall() *Call}) {
func InOrder(calls ...CallWrapper) {
for i := 1; i < len(calls); i++ {
calls[i].GetCall().After(calls[i-1])
}
Expand Down
8 changes: 4 additions & 4 deletions mockgen/internal/tests/typed_after_in_order/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions mockgen/mockgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,15 +700,15 @@ func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model
g.out()
g.p("}")

g.p("// Call rewrite *gomock.Call.GetCall")
g.p("// GetCall is needed to implement gomock.CallWrapper")
g.p("func (%s *%sCall%s) GetCall() *gomock.Call {", idRecv, recvStructName, shortTp)
g.in()
g.p("return %s.Call", idRecv)
g.out()
g.p("}")

g.p("// After rewrite *gomock.Call.After")
g.p("func (%s *%sCall%s) After(prq interface{ GetCall() *gomock.Call }) *gomock.Call {", idRecv, recvStructName, shortTp)
g.p("func (%s *%sCall%s) After(prq gomock.CallWrapper) *gomock.Call {", idRecv, recvStructName, shortTp)
g.in()
g.p("return %s.Call.After(prq)", idRecv)
g.out()
Expand Down

0 comments on commit afa3d42

Please sign in to comment.