diff --git a/handler.go b/handler.go index 783823a..aec375f 100644 --- a/handler.go +++ b/handler.go @@ -104,7 +104,11 @@ func (h *reflectFunc) fnArgs(msg *Message) ([]reflect.Value, error) { inType := h.ft.In(inStart + i) if inType.Kind() == reflect.Interface { - if !v.Type().Implements(inType) { + // If the input value is an untyped nil, simply create a new + // typed nil so that the subsequent .Call() works + if !v.IsValid() { + v = reflect.Zero(inType) + } else if !v.Type().Implements(inType) { hasWrongType = true break } diff --git a/memqueue/memqueue_test.go b/memqueue/memqueue_test.go index 5b3876e..01bc8c0 100644 --- a/memqueue/memqueue_test.go +++ b/memqueue/memqueue_test.go @@ -62,6 +62,51 @@ var _ = Describe("message with args", func() { }) }) +type testInterface interface { + TestFunction() int +} + +type testInterfaceImpl struct { + x int +} + +func (t *testInterfaceImpl) TestFunction() int { + return t.x +} + +var _ = Describe("message with interface args", func() { + ctx := context.Background() + ch := make(chan bool, 10) + + BeforeEach(func() { + q := memqueue.NewQueue(&taskq.QueueOptions{ + Name: "test", + Storage: taskq.NewLocalStorage(), + }) + expected := 7 + task := taskq.RegisterTask(&taskq.TaskOptions{ + Name: "test", + Handler: func(notNilArg testInterface, nilArg testInterface) { + Expect(notNilArg).ToNot(BeNil()) + Expect(nilArg).To(BeNil()) + Expect(notNilArg.TestFunction()).To(Equal(expected)) + ch <- true + }, + }) + notNilInput := testInterface(&testInterfaceImpl{x: expected}) + err := q.Add(task.WithArgs(ctx, notNilInput, nil)) + Expect(err).NotTo(HaveOccurred()) + + err = q.Close() + Expect(err).NotTo(HaveOccurred()) + }) + + It("handler is called with args", func() { + Expect(ch).To(Receive()) + Expect(ch).NotTo(Receive()) + }) +}) + var _ = Describe("context.Context", func() { ctx := context.Background() ch := make(chan bool, 10)