From 385babae741456b413a1eae74ec40ba97635100e Mon Sep 17 00:00:00 2001 From: Antony Ho Date: Tue, 25 Jun 2024 15:47:00 +0200 Subject: [PATCH] Add UnexpectedCallsWereMade function to support checking unexpected command call --- client_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ cluster_test.go | 30 ++++++++++++++++++++++++++++++ expect.go | 4 ++++ mock.go | 48 +++++++++++++++++++++++++++++------------------- 4 files changed, 111 insertions(+), 19 deletions(-) diff --git a/client_test.go b/client_test.go index f1774b8..650dbb7 100644 --- a/client_test.go +++ b/client_test.go @@ -38,6 +38,12 @@ var _ = Describe("Client", func() { pipe = client.TxPipeline() }) + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + }) + It("tx pipeline order", func() { get := pipe.Get(ctx, "key1") hashGet := pipe.HGet(ctx, "hash_key", "hash_field") @@ -88,6 +94,12 @@ var _ = Describe("Client", func() { pipe = client.Pipeline() }) + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + }) + It("pipeline order", func() { clientMock.MatchExpectationsInOrder(true) @@ -136,6 +148,12 @@ var _ = Describe("Client", func() { clientMock.ExpectSet("key2", "2", 1*time.Second).SetVal("OK") }) + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeTrue()) + Expect(unexpectedCalls).ShouldNot(BeNil()) + }) + It("watch error", func() { clientMock.MatchExpectationsInOrder(false) txf := func(tx *redis.Tx) error { @@ -220,6 +238,10 @@ var _ = Describe("Client", func() { getSet := client.GetSet(ctx, "key", "0") Expect(getSet.Err()).NotTo(HaveOccurred()) Expect(getSet.Val()).To(Equal("1")) + + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) }) It("surplus", func() { @@ -234,6 +256,10 @@ var _ = Describe("Client", func() { _ = client.Get(ctx, "key") Expect(clientMock.ExpectationsWereMet()).To(HaveOccurred()) + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + _ = client.GetSet(ctx, "key", "0") }) @@ -247,6 +273,10 @@ var _ = Describe("Client", func() { get := client.HGet(ctx, "key", "field") Expect(get.Err()).To(HaveOccurred()) Expect(get.Val()).To(Equal("")) + + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeTrue()) + Expect(unexpectedCalls).NotTo(BeNil()) }) }) @@ -260,6 +290,12 @@ var _ = Describe("Client", func() { clientMock.ExpectGetSet("key", "0").SetVal("1") }) + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + }) + It("ordinary", func() { get := client.Get(ctx, "key") Expect(get.Err()).NotTo(HaveOccurred()) @@ -277,6 +313,12 @@ var _ = Describe("Client", func() { Describe("work other match", func() { + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + }) + It("regexp match", func() { clientMock.Regexp().ExpectSet("key", `^order_id_[0-9]{10}$`, 1*time.Second).SetVal("OK") clientMock.Regexp().ExpectSet("key2", `^order_id_[0-9]{4}\-[0-9]{2}\-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}.+$`, 1*time.Second).SetVal("OK") @@ -320,6 +362,12 @@ var _ = Describe("Client", func() { Describe("work error", func() { + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clientMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + }) + It("set error", func() { clientMock.ExpectGet("key").SetErr(errors.New("set error")) diff --git a/cluster_test.go b/cluster_test.go index dfc1afb..8449cb3 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -49,6 +49,10 @@ var _ = Describe("Cluster", func() { getSet := client.GetSet(ctx, "key", "0") Expect(getSet.Err()).NotTo(HaveOccurred()) Expect(getSet.Val()).To(Equal("1")) + + hasUnexpectedCall, unexpectedCalls := clusterMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) }) It("surplus", func() { @@ -63,6 +67,10 @@ var _ = Describe("Cluster", func() { _ = client.Get(ctx, "key") Expect(clusterMock.ExpectationsWereMet()).To(HaveOccurred()) + hasUnexpectedCall, unexpectedCalls := clusterMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + _ = client.GetSet(ctx, "key", "0") }) @@ -76,6 +84,10 @@ var _ = Describe("Cluster", func() { get := client.HGet(ctx, "key", "field") Expect(get.Err()).To(HaveOccurred()) Expect(get.Val()).To(Equal("")) + + hasUnexpectedCall, unexpectedCalls := clusterMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeTrue()) + Expect(unexpectedCalls).NotTo(BeNil()) }) }) @@ -89,6 +101,12 @@ var _ = Describe("Cluster", func() { clusterMock.ExpectGetSet("key", "0").SetVal("1") }) + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clusterMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + }) + It("ordinary", func() { get := client.Get(ctx, "key") Expect(get.Err()).NotTo(HaveOccurred()) @@ -106,6 +124,12 @@ var _ = Describe("Cluster", func() { Describe("work other match", func() { + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clusterMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + }) + It("regexp match", func() { clusterMock.Regexp().ExpectSet("key", `^order_id_[0-9]{10}$`, 1*time.Second).SetVal("OK") clusterMock.Regexp().ExpectSet("key2", `^order_id_[0-9]{4}\-[0-9]{2}\-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}.+$`, 1*time.Second).SetVal("OK") @@ -149,6 +173,12 @@ var _ = Describe("Cluster", func() { Describe("work error", func() { + AfterEach(func() { + hasUnexpectedCall, unexpectedCalls := clusterMock.UnexpectedCallsWereMade() + Expect(hasUnexpectedCall).To(BeFalse()) + Expect(unexpectedCalls).To(BeNil()) + }) + It("set error", func() { clusterMock.ExpectGet("key").SetErr(errors.New("set error")) diff --git a/expect.go b/expect.go index 8e177cb..c7c5b97 100644 --- a/expect.go +++ b/expect.go @@ -24,6 +24,10 @@ type baseMock interface { // were met in order. If any of them was not met - an error is returned. ExpectationsWereMet() error + // UnexpectedCallsWereMade returns any unexpected calls which were made. + // If any unexpected call was made, a list of unexpected call redis.Cmder is returned. + UnexpectedCallsWereMade() (bool, []redis.Cmder) + // MatchExpectationsInOrder gives an option whether to match all expectations in the order they were set or not. MatchExpectationsInOrder(b bool) diff --git a/mock.go b/mock.go index 9b862fd..3c6a3b9 100644 --- a/mock.go +++ b/mock.go @@ -23,9 +23,10 @@ type mock struct { parent *mock - factory mockCmdable - client redis.Cmdable - expected []expectation + factory mockCmdable + client redis.Cmdable + expected []expectation + unexpected []redis.Cmder strictOrder bool @@ -180,6 +181,7 @@ func (m *mock) process(cmd redis.Cmder) (err error) { } err = fmt.Errorf(msg, cmd.Args()) cmd.SetErr(err) + m.unexpected = append(m.unexpected, cmd) return err } @@ -362,6 +364,7 @@ func (m *mock) ClearExpect() { return } m.expected = nil + m.unexpected = nil } func (m *mock) Regexp() *mock { @@ -402,6 +405,13 @@ func (m *mock) ExpectationsWereMet() error { return nil } +func (m *mock) UnexpectedCallsWereMade() (bool, []redis.Cmder) { + if m.parent != nil { + return m.parent.UnexpectedCallsWereMade() + } + return len(m.unexpected) > 0, m.unexpected +} + func (m *mock) MatchExpectationsInOrder(b bool) { if m.parent != nil { m.MatchExpectationsInOrder(b) @@ -2862,29 +2872,29 @@ func (m *mock) ExpectTSMRangeWithArgs(fromTimestamp int, toTimestamp int, filter } func (m *mock) ExpectTSMRevRange(fromTimestamp int, toTimestamp int, filterExpr []string) *ExpectedMapStringSliceInterface { - e := &ExpectedMapStringSliceInterface{} - e.cmd = m.factory.TSMRevRange(m.ctx, fromTimestamp, toTimestamp, filterExpr) - m.pushExpect(e) - return e + e := &ExpectedMapStringSliceInterface{} + e.cmd = m.factory.TSMRevRange(m.ctx, fromTimestamp, toTimestamp, filterExpr) + m.pushExpect(e) + return e } func (m *mock) ExpectTSMRevRangeWithArgs(fromTimestamp int, toTimestamp int, filterExpr []string, options *redis.TSMRevRangeOptions) *ExpectedMapStringSliceInterface { - e := &ExpectedMapStringSliceInterface{} - e.cmd = m.factory.TSMRevRangeWithArgs(m.ctx, fromTimestamp, toTimestamp, filterExpr, options) - m.pushExpect(e) - return e + e := &ExpectedMapStringSliceInterface{} + e.cmd = m.factory.TSMRevRangeWithArgs(m.ctx, fromTimestamp, toTimestamp, filterExpr, options) + m.pushExpect(e) + return e } func (m *mock) ExpectTSMGet(filters []string) *ExpectedMapStringSliceInterface { - e := &ExpectedMapStringSliceInterface{} - e.cmd = m.factory.TSMGet(m.ctx, filters) - m.pushExpect(e) - return e + e := &ExpectedMapStringSliceInterface{} + e.cmd = m.factory.TSMGet(m.ctx, filters) + m.pushExpect(e) + return e } func (m *mock) ExpectTSMGetWithArgs(filters []string, options *redis.TSMGetOptions) *ExpectedMapStringSliceInterface { - e := &ExpectedMapStringSliceInterface{} - e.cmd = m.factory.TSMGetWithArgs(m.ctx, filters, options) - m.pushExpect(e) - return e + e := &ExpectedMapStringSliceInterface{} + e.cmd = m.factory.TSMGetWithArgs(m.ctx, filters, options) + m.pushExpect(e) + return e }