diff --git a/README.md b/README.md index 0059fd9..b71a03d 100644 --- a/README.md +++ b/README.md @@ -24,12 +24,18 @@ func main() { Namespace: "abc", PartitionID: 0, PartitionTotal: 1, - TriggerFn: func(context.Context, *api.TriggerRequest) bool { + TriggerFn: func(context.Context, *api.TriggerRequest) *api.TriggerResponse { // Do something with your trigger here. - // Return true if the trigger was successful, false otherwise. - // Note, returning false will cause the job to be retried according to - // the Jobs configurable FailurePolicy. - return true + // Return SUCCESS if the trigger was successful, FAILED if the trigger + // failed and should be subject to the FailurePolicy, or UNDELIVERABLE if + // the job is currently undeliverable and should be moved to the staging + // queue. Use `cron.DeliverablePrefixes` elsewhere to mark jobs with the + // given prefixes as now deliverable. + return &api.TriggerResponse{ + Result: api.TriggerResponseResult_SUCCESS, + // Result: api.TriggerResponseResult_FAILED, + // Result: api.TriggerResponseResult_UNDELIVERABLE, + } }, }) if err != nil { @@ -43,11 +49,14 @@ func main() { meta, _ := anypb.New(wrapperspb.String("world")) tt := time.Now().Add(time.Second).Format(time.RFC3339) - cron.Add(context.TODO(), "my-job", &api.Job{ + err = cron.Add(context.TODO(), "my-job", &api.Job{ DueTime: &tt, Payload: payload, Metadata: meta, }) + if err != nil { + panic(err) + } } ``` @@ -73,12 +82,19 @@ A Job itself is made up of the following fields: Optional. - `FailurePolicy` Controls whether the Job should be retired if the trigger function returns false. `Drop` doesn't retry the job, `Constant `Constant` will - constantly retry the job trigger for a configurable internal, up to a configurable + constantly retry the job trigger for a configurable interval, up to a configurable maximum number of retries (which could be infinite). By default, Jobs have a `Constant` policy, with a 1s interval and 3 maximum retries. A job must have *at least* either a `Schedule` or a `DueTime` set. +### Undeliverable Jobs + +It can be the case that a job trigger hasn't actually _failed_, but instead is simply undeliverable at the current time. +In such cases, the trigger function can return `UNDELIVERABLE` to indicate that the job should be moved to the "staging queue" to be held until it can be delivered. +Staged jobs can be marked as deliverable again by calling `cron.DeliverablePrefixes` with the prefixes of those job names. +Jobs whose name match these prefixes will be re-enqueued for delivery. + ## Leadership The cron scheduler uses a partition key ownership model to ensure that only one partition instance of the scheduler is running at any given time. diff --git a/api/api.go b/api/api.go index 4b6c353..6e1cc9b 100644 --- a/api/api.go +++ b/api/api.go @@ -11,10 +11,10 @@ import ( // TriggerFunction is the type of the function that is called when a job is // triggered. -// Returning true will "tick" the job forward to the next scheduled time. -// Returning false will cause the job to be re-enqueued and triggered -// immediately. -type TriggerFunction func(context.Context, *TriggerRequest) bool +// The returne TriggerResponse will indicate whether the Job was successfully +// triggered, the trigger failed, or the Job need to be put into the staging +// queue. +type TriggerFunction func(context.Context, *TriggerRequest) *TriggerResponse // API is the interface for interacting with the cron instance. type API interface { @@ -33,6 +33,17 @@ type API interface { // List lists all jobs under a given job name prefix. List(ctx context.Context, prefix string) (*ListResponse, error) + + // DeliverablePrefixes registers the given Job name prefixes as being + // deliverable. Any Jobs that reside in the staging queue because they were + // undeliverable at the time of trigger but whose names match these prefixes + // will be immediately re-triggered. + // The returned CancelFunc should be called to unregister the prefixes, + // meaning these prefixes are no longer delivable by the caller. Duplicate + // Prefixes may be called together and will be pooled together, meaning that + // the prefix is still active if there is at least one DeliverablePrefixes + // call that has not been unregistered. + DeliverablePrefixes(ctx context.Context, prefixes ...string) (context.CancelFunc, error) } // Interface is a cron interface. It schedules and manages job which are stored diff --git a/api/trigger.pb.go b/api/trigger.pb.go index ea806fd..a83b56d 100644 --- a/api/trigger.pb.go +++ b/api/trigger.pb.go @@ -25,6 +25,64 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +// TriggerResponseResult is indicates the state result from triggering the job +// by the consumer. +type TriggerResponseResult int32 + +const ( + // SUCCESS indicates that the job was successfully triggered and will be + // ticked forward according to the schedule. + TriggerResponseResult_SUCCESS TriggerResponseResult = 0 + // FAILED indicates that the job failed to trigger and is subject to the + // FailurePolicy. + TriggerResponseResult_FAILED TriggerResponseResult = 1 + // UNDELIVERABLE indicates that the job should be added to the staging queue + // as the Job was undeliverable. Once the Job name prefix is marked as + // deliverable, it will be immediately triggered. + TriggerResponseResult_UNDELIVERABLE TriggerResponseResult = 2 +) + +// Enum value maps for TriggerResponseResult. +var ( + TriggerResponseResult_name = map[int32]string{ + 0: "SUCCESS", + 1: "FAILED", + 2: "UNDELIVERABLE", + } + TriggerResponseResult_value = map[string]int32{ + "SUCCESS": 0, + "FAILED": 1, + "UNDELIVERABLE": 2, + } +) + +func (x TriggerResponseResult) Enum() *TriggerResponseResult { + p := new(TriggerResponseResult) + *p = x + return p +} + +func (x TriggerResponseResult) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (TriggerResponseResult) Descriptor() protoreflect.EnumDescriptor { + return file_proto_api_trigger_proto_enumTypes[0].Descriptor() +} + +func (TriggerResponseResult) Type() protoreflect.EnumType { + return &file_proto_api_trigger_proto_enumTypes[0] +} + +func (x TriggerResponseResult) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use TriggerResponseResult.Descriptor instead. +func (TriggerResponseResult) EnumDescriptor() ([]byte, []int) { + return file_proto_api_trigger_proto_rawDescGZIP(), []int{0} +} + // TriggerRequest is the request sent to the caller when a job is triggered. type TriggerRequest struct { state protoimpl.MessageState @@ -92,6 +150,57 @@ func (x *TriggerRequest) GetPayload() *anypb.Any { return nil } +// TriggerResponse is returned by the caller from a TriggerResponse. Signals +// whether the Job was successfully triggered, the trigger failed, or instead +// needs to be added to the staging queue due to impossible delivery. +type TriggerResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // result is the result given by the consumer when trigging the Job. + Result TriggerResponseResult `protobuf:"varint,1,opt,name=result,proto3,enum=api.TriggerResponseResult" json:"result,omitempty"` +} + +func (x *TriggerResponse) Reset() { + *x = TriggerResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_api_trigger_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TriggerResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TriggerResponse) ProtoMessage() {} + +func (x *TriggerResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_api_trigger_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TriggerResponse.ProtoReflect.Descriptor instead. +func (*TriggerResponse) Descriptor() ([]byte, []int) { + return file_proto_api_trigger_proto_rawDescGZIP(), []int{1} +} + +func (x *TriggerResponse) GetResult() TriggerResponseResult { + if x != nil { + return x.Result + } + return TriggerResponseResult_SUCCESS +} + var File_proto_api_trigger_proto protoreflect.FileDescriptor var file_proto_api_trigger_proto_rawDesc = []byte{ @@ -107,10 +216,18 @@ var file_proto_api_trigger_proto_rawDesc = []byte{ 0x74, 0x61, 0x12, 0x2e, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, - 0x61, 0x64, 0x42, 0x27, 0x5a, 0x25, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x64, 0x69, 0x61, 0x67, 0x72, 0x69, 0x64, 0x69, 0x6f, 0x2f, 0x67, 0x6f, 0x2d, 0x65, 0x74, - 0x63, 0x64, 0x2d, 0x63, 0x72, 0x6f, 0x6e, 0x2f, 0x61, 0x70, 0x69, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x61, 0x64, 0x22, 0x45, 0x0a, 0x0f, 0x54, 0x72, 0x69, 0x67, 0x67, 0x65, 0x72, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x06, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x54, 0x72, 0x69, 0x67, + 0x67, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x52, 0x06, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x2a, 0x43, 0x0a, 0x15, 0x54, 0x72, 0x69, + 0x67, 0x67, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x65, 0x73, 0x75, + 0x6c, 0x74, 0x12, 0x0b, 0x0a, 0x07, 0x53, 0x55, 0x43, 0x43, 0x45, 0x53, 0x53, 0x10, 0x00, 0x12, + 0x0a, 0x0a, 0x06, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x55, + 0x4e, 0x44, 0x45, 0x4c, 0x49, 0x56, 0x45, 0x52, 0x41, 0x42, 0x4c, 0x45, 0x10, 0x02, 0x42, 0x27, + 0x5a, 0x25, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x69, 0x61, + 0x67, 0x72, 0x69, 0x64, 0x69, 0x6f, 0x2f, 0x67, 0x6f, 0x2d, 0x65, 0x74, 0x63, 0x64, 0x2d, 0x63, + 0x72, 0x6f, 0x6e, 0x2f, 0x61, 0x70, 0x69, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -125,19 +242,23 @@ func file_proto_api_trigger_proto_rawDescGZIP() []byte { return file_proto_api_trigger_proto_rawDescData } -var file_proto_api_trigger_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_proto_api_trigger_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_proto_api_trigger_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_proto_api_trigger_proto_goTypes = []interface{}{ - (*TriggerRequest)(nil), // 0: api.TriggerRequest - (*anypb.Any)(nil), // 1: google.protobuf.Any + (TriggerResponseResult)(0), // 0: api.TriggerResponseResult + (*TriggerRequest)(nil), // 1: api.TriggerRequest + (*TriggerResponse)(nil), // 2: api.TriggerResponse + (*anypb.Any)(nil), // 3: google.protobuf.Any } var file_proto_api_trigger_proto_depIdxs = []int32{ - 1, // 0: api.TriggerRequest.metadata:type_name -> google.protobuf.Any - 1, // 1: api.TriggerRequest.payload:type_name -> google.protobuf.Any - 2, // [2:2] is the sub-list for method output_type - 2, // [2:2] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 3, // 0: api.TriggerRequest.metadata:type_name -> google.protobuf.Any + 3, // 1: api.TriggerRequest.payload:type_name -> google.protobuf.Any + 0, // 2: api.TriggerResponse.result:type_name -> api.TriggerResponseResult + 3, // [3:3] is the sub-list for method output_type + 3, // [3:3] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name } func init() { file_proto_api_trigger_proto_init() } @@ -158,19 +279,32 @@ func file_proto_api_trigger_proto_init() { return nil } } + file_proto_api_trigger_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TriggerResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proto_api_trigger_proto_rawDesc, - NumEnums: 0, - NumMessages: 1, + NumEnums: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, GoTypes: file_proto_api_trigger_proto_goTypes, DependencyIndexes: file_proto_api_trigger_proto_depIdxs, + EnumInfos: file_proto_api_trigger_proto_enumTypes, MessageInfos: file_proto_api_trigger_proto_msgTypes, }.Build() File_proto_api_trigger_proto = out.File diff --git a/cron/cron_test.go b/cron/cron_test.go index 7841cd7..75ff376 100644 --- a/cron/cron_test.go +++ b/cron/cron_test.go @@ -7,205 +7,24 @@ package cron import ( "context" - "strconv" - "sync" "sync/atomic" "testing" "time" - "github.com/dapr/kit/ptr" "github.com/go-logr/logr" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - clientv3 "go.etcd.io/etcd/client/v3" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/protobuf/types/known/durationpb" - "google.golang.org/protobuf/types/known/timestamppb" - "google.golang.org/protobuf/types/known/wrapperspb" "github.com/diagridio/go-etcd-cron/api" - "github.com/diagridio/go-etcd-cron/internal/api/stored" - "github.com/diagridio/go-etcd-cron/internal/client" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) -func Test_retry(t *testing.T) { - t.Parallel() - - var ok bool - var lock sync.Mutex - helper := testCronWithOptions(t, testCronOptions{ - total: 1, - triggerFn: func(*api.TriggerRequest) bool { - lock.Lock() - defer lock.Unlock() - return ok - }, - }) - - job := &api.Job{ - DueTime: ptr.Of(time.Now().Format(time.RFC3339)), - } - require.NoError(t, helper.api.Add(helper.ctx, "yoyo", job)) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Greater(c, helper.triggered.Load(), int64(1)) - }, 5*time.Second, 10*time.Millisecond) - lock.Lock() - triggered := helper.triggered.Load() - triggered++ - ok = true - lock.Unlock() - assert.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, triggered, helper.triggered.Load()) - }, time.Second*10, time.Millisecond*10) - <-time.After(3 * time.Second) - assert.Equal(t, triggered, helper.triggered.Load()) -} - -func Test_payload(t *testing.T) { - t.Parallel() - - gotCh := make(chan *api.TriggerRequest, 1) - helper := testCronWithOptions(t, testCronOptions{ - total: 1, - gotCh: gotCh, - }) - - payload, err := anypb.New(wrapperspb.String("hello")) - require.NoError(t, err) - meta, err := anypb.New(wrapperspb.String("world")) - require.NoError(t, err) - job := &api.Job{ - DueTime: ptr.Of(time.Now().Format(time.RFC3339)), - Payload: payload, - Metadata: meta, - } - require.NoError(t, helper.api.Add(helper.ctx, "yoyo", job)) - - select { - case got := <-gotCh: - assert.Equal(t, "yoyo", got.GetName()) - var gotPayload wrapperspb.StringValue - require.NoError(t, got.GetPayload().UnmarshalTo(&gotPayload)) - assert.Equal(t, "hello", gotPayload.GetValue()) - var gotMeta wrapperspb.StringValue - require.NoError(t, got.GetMetadata().UnmarshalTo(&gotMeta)) - assert.Equal(t, "world", gotMeta.GetValue()) - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for trigger") - } -} - -func Test_remove(t *testing.T) { - t.Parallel() - - helper := testCron(t, 1) - - job := &api.Job{ - DueTime: ptr.Of(time.Now().Add(time.Second * 2).Format(time.RFC3339)), - } - require.NoError(t, helper.api.Add(helper.ctx, "def", job)) - require.NoError(t, helper.api.Delete(helper.ctx, "def")) - - <-time.After(3 * time.Second) - - assert.Equal(t, int64(0), helper.triggered.Load()) -} - -func Test_upsert(t *testing.T) { - t.Parallel() - - helper := testCron(t, 1) - - job := &api.Job{ - DueTime: ptr.Of(time.Now().Add(time.Hour).Format(time.RFC3339)), - } - require.NoError(t, helper.api.Add(helper.ctx, "def", job)) - job = &api.Job{ - DueTime: ptr.Of(time.Now().Add(time.Second).Format(time.RFC3339)), - } - require.NoError(t, helper.api.Add(helper.ctx, "def", job)) - - assert.Eventually(t, func() bool { - return helper.triggered.Load() == 1 - }, 5*time.Second, 1*time.Second) - - resp, err := helper.client.Get(context.Background(), "abc/jobs/def") - require.NoError(t, err) - assert.Empty(t, resp.Kvs) -} - -func Test_patition(t *testing.T) { - t.Parallel() - - helper := testCron(t, 100) - - for i := range 100 { - job := &api.Job{ - DueTime: ptr.Of(time.Now().Add(time.Second).Format(time.RFC3339)), - } - require.NoError(t, helper.allCrons[i].Add(helper.ctx, "test-"+strconv.Itoa(i), job)) - } - - assert.Eventually(t, func() bool { - return helper.triggered.Load() == 100 - }, 5*time.Second, 1*time.Second) - - resp, err := helper.client.Get(context.Background(), "abc/jobs", clientv3.WithPrefix()) - require.NoError(t, err) - assert.Empty(t, resp.Kvs) -} - -func Test_oneshot(t *testing.T) { - t.Parallel() - - helper := testCron(t, 1) - - job := &api.Job{ - DueTime: ptr.Of(time.Now().Add(time.Second).Format(time.RFC3339)), - } - - require.NoError(t, helper.api.Add(helper.ctx, "def", job)) - - assert.Eventually(t, func() bool { - return helper.triggered.Load() == 1 - }, 5*time.Second, 1*time.Second) - - resp, err := helper.client.Get(context.Background(), "abc/jobs/def") - require.NoError(t, err) - assert.Empty(t, resp.Kvs) -} - -func Test_repeat(t *testing.T) { - t.Parallel() - - helper := testCron(t, 1) - - job := &api.Job{ - Schedule: ptr.Of("@every 10ms"), - Repeats: ptr.Of(uint32(3)), - } - - require.NoError(t, helper.api.Add(helper.ctx, "def", job)) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, int64(3), helper.triggered.Load()) - }, 5*time.Second, 1*time.Second) - - resp, err := helper.client.Get(context.Background(), "abc/jobs/def") - require.NoError(t, err) - assert.Empty(t, resp.Kvs) -} - func Test_Run(t *testing.T) { t.Parallel() t.Run("Running multiple times should error", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCDBareClient(t) + client := etcd.EmbeddedBareClient(t) var triggered atomic.Int64 cronI, err := New(Options{ Log: logr.Discard(), @@ -213,9 +32,9 @@ func Test_Run(t *testing.T) { Namespace: "abc", PartitionID: 0, PartitionTotal: 1, - TriggerFn: func(context.Context, *api.TriggerRequest) bool { + TriggerFn: func(context.Context, *api.TriggerRequest) *api.TriggerResponse { triggered.Add(1) - return true + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} }, }) require.NoError(t, err) @@ -255,550 +74,3 @@ func Test_Run(t *testing.T) { } }) } - -func Test_zeroDueTime(t *testing.T) { - t.Parallel() - - helper := testCron(t, 1) - - require.NoError(t, helper.api.Add(helper.ctx, "yoyo", &api.Job{ - Schedule: ptr.Of("@every 1h"), - DueTime: ptr.Of("0s"), - })) - assert.Eventually(t, func() bool { - return helper.triggered.Load() == 1 - }, 3*time.Second, time.Millisecond*10) - - require.NoError(t, helper.api.Add(helper.ctx, "yoyo2", &api.Job{ - Schedule: ptr.Of("@every 1h"), - DueTime: ptr.Of("1s"), - })) - assert.Eventually(t, func() bool { - return helper.triggered.Load() == 2 - }, 3*time.Second, time.Millisecond*10) - - require.NoError(t, helper.api.Add(helper.ctx, "yoyo3", &api.Job{ - Schedule: ptr.Of("@every 1h"), - })) - <-time.After(2 * time.Second) - assert.Equal(t, int64(2), helper.triggered.Load()) -} - -func Test_parallel(t *testing.T) { - t.Parallel() - - for _, test := range []struct { - name string - total uint32 - }{ - {"1 queue", 1}, - {"multi queue", 50}, - } { - total := test.total - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - releaseCh := make(chan struct{}) - var waiting atomic.Int32 - var done atomic.Int32 - helper := testCronWithOptions(t, testCronOptions{ - total: total, - triggerFn: func(*api.TriggerRequest) bool { - waiting.Add(1) - <-releaseCh - done.Add(1) - return true - }, - }) - - for i := range 100 { - require.NoError(t, helper.api.Add(helper.ctx, strconv.Itoa(i), &api.Job{ - DueTime: ptr.Of("0s"), - })) - } - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, int32(100), waiting.Load()) - }, 5*time.Second, 10*time.Millisecond) - close(releaseCh) - assert.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, int32(100), done.Load()) - }, 5*time.Second, 10*time.Millisecond) - }) - } -} - -func Test_schedule(t *testing.T) { - t.Parallel() - - t.Run("if no counter, job should not be deleted and no counter created", func(t *testing.T) { - t.Parallel() - - client := tests.EmbeddedETCDBareClient(t) - - now := time.Now().UTC() - jobBytes1, err := proto.Marshal(&stored.Job{ - Begin: &stored.Job_DueTime{DueTime: timestamppb.New(now.Add(time.Hour))}, - PartitionId: 123, - Job: &api.Job{DueTime: ptr.Of(now.Add(time.Hour).Format(time.RFC3339))}, - }) - require.NoError(t, err) - _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes1)) - require.NoError(t, err) - - jobBytes2, err := proto.Marshal(&stored.Job{ - Begin: &stored.Job_DueTime{DueTime: timestamppb.New(now)}, - PartitionId: 123, - Job: &api.Job{DueTime: ptr.Of(now.Format(time.RFC3339))}, - }) - require.NoError(t, err) - _, err = client.Put(context.Background(), "abc/jobs/2", string(jobBytes2)) - require.NoError(t, err) - - resp, err := client.Get(context.Background(), "abc/jobs", clientv3.WithPrefix()) - require.NoError(t, err) - assert.Len(t, resp.Kvs, 2) - - cron := testCronWithOptions(t, testCronOptions{ - total: 1, - client: client, - }) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, int64(1), cron.triggered.Load()) - }, 5*time.Second, 10*time.Millisecond) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - resp, err = client.Get(context.Background(), "abc/jobs", clientv3.WithPrefix()) - require.NoError(t, err) - assert.Len(c, resp.Kvs, 1) - }, 5*time.Second, 10*time.Millisecond) - - cron.closeCron() - - resp, err = client.Get(context.Background(), "abc/jobs/1") - require.NoError(t, err) - require.Len(t, resp.Kvs, 1) - assert.Equal(t, string(jobBytes1), string(resp.Kvs[0].Value)) - - resp, err = client.Get(context.Background(), "abc/counters", clientv3.WithPrefix()) - require.NoError(t, err) - require.Empty(t, resp.Kvs) - - assert.Equal(t, int64(1), cron.triggered.Load()) - }) - - t.Run("if schedule is not done, job and counter should not be deleted", func(t *testing.T) { - t.Parallel() - - client := tests.EmbeddedETCDBareClient(t) - - future := time.Now().UTC().Add(time.Hour) - jobBytes, err := proto.Marshal(&stored.Job{ - Begin: &stored.Job_DueTime{ - DueTime: timestamppb.New(future), - }, - PartitionId: 123, - Job: &api.Job{ - DueTime: ptr.Of(future.Format(time.RFC3339)), - }, - }) - require.NoError(t, err) - counterBytes, err := proto.Marshal(&stored.Counter{ - LastTrigger: nil, - Count: 0, - JobPartitionId: 123, - }) - require.NoError(t, err) - - _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) - require.NoError(t, err) - _, err = client.Put(context.Background(), "abc/counters/1", string(counterBytes)) - require.NoError(t, err) - - now := time.Now().UTC() - jobBytes2, err := proto.Marshal(&stored.Job{ - Begin: &stored.Job_DueTime{DueTime: timestamppb.New(now)}, - Job: &api.Job{DueTime: ptr.Of(now.Format(time.RFC3339))}, - }) - require.NoError(t, err) - _, err = client.Put(context.Background(), "abc/jobs/2", string(jobBytes2)) - require.NoError(t, err) - - cron := testCronWithOptions(t, testCronOptions{ - total: 1, - client: client, - }) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, int64(1), cron.triggered.Load()) - }, 5*time.Second, 10*time.Millisecond) - - resp, err := client.Get(context.Background(), "abc/jobs/1") - require.NoError(t, err) - require.Len(t, resp.Kvs, 1) - assert.Equal(t, string(jobBytes), string(resp.Kvs[0].Value)) - - resp, err = client.Get(context.Background(), "abc/counters/1") - require.NoError(t, err) - require.Len(t, resp.Kvs, 1) - assert.Equal(t, string(counterBytes), string(resp.Kvs[0].Value)) - - resp, err = client.Get(context.Background(), "abc/jobs", clientv3.WithPrefix()) - require.NoError(t, err) - assert.Len(t, resp.Kvs, 1) - }) - - t.Run("if schedule is done, expect job and counter to be deleted", func(t *testing.T) { - t.Parallel() - - client := tests.EmbeddedETCDBareClient(t) - - now := time.Now().UTC() - jobBytes, err := proto.Marshal(&stored.Job{ - Begin: &stored.Job_DueTime{ - DueTime: timestamppb.New(now), - }, - PartitionId: 123, - Job: &api.Job{ - DueTime: ptr.Of(now.Format(time.RFC3339)), - }, - }) - require.NoError(t, err) - counterBytes, err := proto.Marshal(&stored.Counter{ - LastTrigger: timestamppb.New(now), - Count: 1, - JobPartitionId: 123, - }) - require.NoError(t, err) - - _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) - require.NoError(t, err) - _, err = client.Put(context.Background(), "abc/counters/1", string(counterBytes)) - require.NoError(t, err) - - cron := testCronWithOptions(t, testCronOptions{ - total: 1, - client: client, - }) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - resp, err := client.Get(context.Background(), "abc/jobs/1") - require.NoError(t, err) - assert.Empty(c, resp.Kvs) - resp, err = client.Get(context.Background(), "abc/counters/1") - require.NoError(t, err) - assert.Empty(c, resp.Kvs) - }, 5*time.Second, 10*time.Millisecond) - - assert.Equal(t, int64(0), cron.triggered.Load()) - }) -} - -func Test_jobWithSpace(t *testing.T) { - t.Parallel() - - cron := testCronWithOptions(t, testCronOptions{ - total: 1, - client: tests.EmbeddedETCDBareClient(t), - }) - - require.NoError(t, cron.api.Add(context.Background(), "hello world", &api.Job{ - DueTime: ptr.Of(time.Now().Add(2).Format(time.RFC3339)), - })) - resp, err := cron.api.Get(context.Background(), "hello world") - require.NoError(t, err) - assert.NotNil(t, resp) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, int64(1), cron.triggered.Load()) - resp, err = cron.api.Get(context.Background(), "hello world") - assert.NoError(c, err) - assert.Nil(c, resp) - }, time.Second*10, time.Millisecond*10) - - require.NoError(t, cron.api.Add(context.Background(), "another hello world", &api.Job{ - Schedule: ptr.Of("@every 1s"), - })) - resp, err = cron.api.Get(context.Background(), "another hello world") - require.NoError(t, err) - assert.NotNil(t, resp) - listresp, err := cron.api.List(context.Background(), "") - require.NoError(t, err) - assert.Len(t, listresp.GetJobs(), 1) - require.NoError(t, cron.api.Delete(context.Background(), "another hello world")) - resp, err = cron.api.Get(context.Background(), "another hello world") - require.NoError(t, err) - assert.Nil(t, resp) - listresp, err = cron.api.List(context.Background(), "") - require.NoError(t, err) - assert.Empty(t, listresp.GetJobs()) -} - -func Test_FailurePolicy(t *testing.T) { - t.Parallel() - - t.Run("default policy should retry 3 times with a 1sec interval", func(t *testing.T) { - t.Parallel() - - gotCh := make(chan *api.TriggerRequest, 1) - var got atomic.Uint32 - cron := testCronWithOptions(t, testCronOptions{ - total: 1, - client: tests.EmbeddedETCDBareClient(t), - triggerFn: func(*api.TriggerRequest) bool { - assert.GreaterOrEqual(t, uint32(8), got.Add(1)) - return false - }, - gotCh: gotCh, - }) - - require.NoError(t, cron.api.Add(context.Background(), "test", &api.Job{ - DueTime: ptr.Of(time.Now().Format(time.RFC3339)), - Schedule: ptr.Of("@every 1s"), - Repeats: ptr.Of(uint32(2)), - })) - - for range 8 { - resp, err := cron.api.Get(context.Background(), "test") - require.NoError(t, err) - assert.NotNil(t, resp) - select { - case <-gotCh: - case <-time.After(time.Second * 3): - assert.Fail(t, "timeout waiting for trigger") - } - } - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - resp, err := cron.api.Get(context.Background(), "test") - assert.NoError(c, err) - assert.Nil(c, resp) - }, time.Second*5, time.Millisecond*10) - }) - - t.Run("drop policy should not retry triggering", func(t *testing.T) { - t.Parallel() - - gotCh := make(chan *api.TriggerRequest, 1) - var got atomic.Uint32 - cron := testCronWithOptions(t, testCronOptions{ - total: 1, - client: tests.EmbeddedETCDBareClient(t), - triggerFn: func(*api.TriggerRequest) bool { - assert.GreaterOrEqual(t, uint32(2), got.Add(1)) - return false - }, - gotCh: gotCh, - }) - - require.NoError(t, cron.api.Add(context.Background(), "test", &api.Job{ - DueTime: ptr.Of(time.Now().Format(time.RFC3339)), - Schedule: ptr.Of("@every 1s"), - Repeats: ptr.Of(uint32(2)), - FailurePolicy: &api.FailurePolicy{ - Policy: new(api.FailurePolicy_Drop), - }, - })) - - for range 2 { - resp, err := cron.api.Get(context.Background(), "test") - require.NoError(t, err) - assert.NotNil(t, resp) - select { - case <-gotCh: - case <-time.After(time.Second * 3): - assert.Fail(t, "timeout waiting for trigger") - } - } - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - resp, err := cron.api.Get(context.Background(), "test") - assert.NoError(c, err) - assert.Nil(c, resp) - }, time.Second*5, time.Millisecond*10) - }) - - t.Run("constant policy should only retry when it fails ", func(t *testing.T) { - t.Parallel() - - gotCh := make(chan *api.TriggerRequest, 1) - var got atomic.Uint32 - cron := testCronWithOptions(t, testCronOptions{ - total: 1, - client: tests.EmbeddedETCDBareClient(t), - triggerFn: func(*api.TriggerRequest) bool { - assert.GreaterOrEqual(t, uint32(5), got.Add(1)) - return got.Load() == 3 - }, - gotCh: gotCh, - }) - - require.NoError(t, cron.api.Add(context.Background(), "test", &api.Job{ - DueTime: ptr.Of(time.Now().Format(time.RFC3339)), - Schedule: ptr.Of("@every 1s"), - Repeats: ptr.Of(uint32(3)), - FailurePolicy: &api.FailurePolicy{ - Policy: &api.FailurePolicy_Constant{ - Constant: &api.FailurePolicyConstant{ - Interval: durationpb.New(time.Millisecond), MaxRetries: ptr.Of(uint32(1)), - }, - }, - }, - })) - - for range 5 { - resp, err := cron.api.Get(context.Background(), "test") - require.NoError(t, err) - assert.NotNil(t, resp) - select { - case <-gotCh: - case <-time.After(time.Second * 3): - assert.Fail(t, "timeout waiting for trigger") - } - } - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - resp, err := cron.api.Get(context.Background(), "test") - assert.NoError(c, err) - assert.Nil(c, resp) - }, time.Second*5, time.Millisecond*10) - }) - - t.Run("constant policy can retry forever until it succeeds", func(t *testing.T) { - t.Parallel() - - gotCh := make(chan *api.TriggerRequest, 1) - var got atomic.Uint32 - cron := testCronWithOptions(t, testCronOptions{ - total: 1, - client: tests.EmbeddedETCDBareClient(t), - triggerFn: func(*api.TriggerRequest) bool { - assert.GreaterOrEqual(t, uint32(100), got.Add(1)) - return got.Load() == 100 - }, - gotCh: gotCh, - }) - - require.NoError(t, cron.api.Add(context.Background(), "test", &api.Job{ - DueTime: ptr.Of(time.Now().Format(time.RFC3339)), - FailurePolicy: &api.FailurePolicy{ - Policy: &api.FailurePolicy_Constant{ - Constant: &api.FailurePolicyConstant{ - Interval: durationpb.New(time.Millisecond), - }, - }, - }, - })) - - for range 100 { - resp, err := cron.api.Get(context.Background(), "test") - require.NoError(t, err) - assert.NotNil(t, resp) - select { - case <-gotCh: - case <-time.After(time.Second * 3): - assert.Fail(t, "timeout waiting for trigger") - } - } - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - resp, err := cron.api.Get(context.Background(), "test") - assert.NoError(c, err) - assert.Nil(c, resp) - }, time.Second*5, time.Millisecond*10) - }) -} - -type testCronOptions struct { - total uint32 - gotCh chan *api.TriggerRequest - triggerFn func(*api.TriggerRequest) bool - client *clientv3.Client -} - -type helper struct { - ctx context.Context - closeCron func() - client client.Interface - api api.Interface - allCrons []api.Interface - triggered *atomic.Int64 -} - -func testCron(t *testing.T, total uint32) *helper { - t.Helper() - return testCronWithOptions(t, testCronOptions{ - total: total, - }) -} - -func testCronWithOptions(t *testing.T, opts testCronOptions) *helper { - t.Helper() - - require.Positive(t, opts.total) - cl := opts.client - if cl == nil { - cl = tests.EmbeddedETCDBareClient(t) - } - - var triggered atomic.Int64 - var a api.Interface - allCrns := make([]api.Interface, opts.total) - for i := range opts.total { - c, err := New(Options{ - Log: logr.Discard(), - Client: cl, - Namespace: "abc", - PartitionID: i, - PartitionTotal: opts.total, - TriggerFn: func(_ context.Context, req *api.TriggerRequest) bool { - defer func() { triggered.Add(1) }() - if opts.gotCh != nil { - opts.gotCh <- req - } - if opts.triggerFn != nil { - return opts.triggerFn(req) - } - return true - }, - - CounterGarbageCollectionInterval: ptr.Of(time.Millisecond * 300), - }) - require.NoError(t, err) - allCrns[i] = c - if i == 0 { - a = c - } - } - - errCh := make(chan error, opts.total) - ctx, cancel := context.WithCancel(context.Background()) - - closeOnce := sync.OnceFunc(func() { - cancel() - for range opts.total { - select { - case err := <-errCh: - require.NoError(t, err) - case <-time.After(10 * time.Second): - t.Fatal("timeout waiting for cron to stop") - } - } - }) - t.Cleanup(closeOnce) - for i := range opts.total { - go func(i uint32) { - errCh <- allCrns[i].Run(ctx) - }(i) - } - - return &helper{ - ctx: ctx, - client: client.New(client.Options{Client: cl, Log: logr.Discard()}), - api: a, - allCrons: allCrns, - triggered: &triggered, - closeCron: closeOnce, - } -} diff --git a/go.mod b/go.mod index c196f17..9a8c77f 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/diagridio/go-etcd-cron go 1.23.1 require ( - github.com/dapr/kit v0.13.1-0.20240924041040-2d6ff15a9744 + github.com/dapr/kit v0.13.1-0.20241007143932-bc3a4f0fb4e0 github.com/go-logr/logr v1.4.2 github.com/go-logr/zapr v1.3.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 0d113bf..4a31d45 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/dapr/kit v0.13.1-0.20240924041040-2d6ff15a9744 h1:GZxwr7os1PAnVt/q1FVBJBMaudbRJ7fkZthhmPwBDvI= -github.com/dapr/kit v0.13.1-0.20240924041040-2d6ff15a9744/go.mod h1:Hz1W2LmWfA4UX/12MdA+brsf+np6f/1dJt6C6F63cjI= +github.com/dapr/kit v0.13.1-0.20241007143932-bc3a4f0fb4e0 h1:Ny3AwwG1a6ICFz4nbtRbeYut8K5EvXQtNpohpM/S18o= +github.com/dapr/kit v0.13.1-0.20241007143932-bc3a4f0fb4e0/go.mod h1:Hz1W2LmWfA4UX/12MdA+brsf+np6f/1dJt6C6F63cjI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= diff --git a/gomod2nix.toml b/gomod2nix.toml index 2cd33b8..7540717 100644 --- a/gomod2nix.toml +++ b/gomod2nix.toml @@ -17,8 +17,8 @@ schema = 3 version = "v22.5.0" hash = "sha256-E2zXikbmIQImghstLUWuey1YgA0Folu3F+fi5k4hCxA=" [mod."github.com/dapr/kit"] - version = "v0.13.1-0.20240924041040-2d6ff15a9744" - hash = "sha256-6caMARadWFCKapU90Gj32mLrmGKm0TesgE4h+UhLdvs=" + version = "v0.13.1-0.20241007143932-bc3a4f0fb4e0" + hash = "sha256-bilnPOXyWH7HQ7vFS+GNnSsZ2HfLHeqcfwJTnIFOtCo=" [mod."github.com/davecgh/go-spew"] version = "v1.1.2-0.20180830191138-d8f796af33cc" hash = "sha256-fV9oI51xjHdOmEx6+dlq7Ku2Ag+m/bmbzPo6A4Y74qc=" diff --git a/internal/api/api.go b/internal/api/api.go index 1a961c4..94cd93d 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -209,3 +209,28 @@ func (a *api) List(ctx context.Context, prefix string) (*cronapi.ListResponse, e Jobs: jobs, }, nil } + +// RegisterDeliverablePrefixes registers the given Job name prefixes as being +// deliverable. Calling the returned CancelFunc will de-register those +// prefixes as being deliverable. +func (a *api) DeliverablePrefixes(ctx context.Context, prefixes ...string) (context.CancelFunc, error) { + select { + case <-a.readyCh: + case <-a.closeCh: + return nil, errAPIClosed + case <-ctx.Done(): + return nil, context.Cause(ctx) + } + + if len(prefixes) == 0 { + return nil, errors.New("no prefixes provided") + } + + for _, prefix := range prefixes { + if err := a.validator.JobName(prefix); err != nil { + return nil, err + } + } + + return a.queue.DeliverablePrefixes(prefixes...), nil +} diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 94e238d..d185b14 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -3,6 +3,7 @@ Copyright (c) 2024 Diagrid Inc. Licensed under the MIT License. */ +//nolint:dupl package api import ( @@ -22,7 +23,7 @@ import ( "github.com/diagridio/go-etcd-cron/internal/key" "github.com/diagridio/go-etcd-cron/internal/queue" "github.com/diagridio/go-etcd-cron/internal/scheduler" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) var errCancel = errors.New("custom cancel") @@ -254,6 +255,39 @@ func Test_List(t *testing.T) { }) } +func Test_DeliverablePrefixes(t *testing.T) { + t.Parallel() + + t.Run("returns context error if cron not ready in time", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancelCause(context.Background()) + cancel(errCancel) + dcancel, err := newAPINotReady(t).DeliverablePrefixes(ctx, "helloworld") + assert.Equal(t, errCancel, err) + assert.Nil(t, dcancel) + }) + + t.Run("returns closed error if cron is closed", func(t *testing.T) { + t.Parallel() + + api := newAPINotReady(t) + close(api.closeCh) + cancel, err := api.DeliverablePrefixes(context.Background(), "hello world") + assert.Equal(t, errors.New("api is closed"), err) + assert.Nil(t, cancel) + }) + + t.Run("invalid name should error", func(t *testing.T) { + t.Parallel() + + api := newAPI(t) + cancel, err := api.DeliverablePrefixes(context.Background(), "./.") + require.Error(t, err) + assert.Nil(t, cancel) + }) +} + func newAPI(t *testing.T) *api { t.Helper() api := newAPINotReady(t) @@ -264,7 +298,7 @@ func newAPI(t *testing.T) *api { func newAPINotReady(t *testing.T) *api { t.Helper() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) collector, err := garbage.New(garbage.Options{ Log: logr.Discard(), @@ -280,8 +314,10 @@ func newAPINotReady(t *testing.T) *api { Client: client, Key: key, SchedulerBuilder: schedulerBuilder, - TriggerFn: func(context.Context, *cronapi.TriggerRequest) bool { return true }, - Collector: collector, + TriggerFn: func(context.Context, *cronapi.TriggerRequest) *cronapi.TriggerResponse { + return &cronapi.TriggerResponse{Result: cronapi.TriggerResponseResult_SUCCESS} + }, + Collector: collector, }) ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/api/serve_test.go b/internal/api/serve_test.go index ddd4c2e..ce77ff6 100644 --- a/internal/api/serve_test.go +++ b/internal/api/serve_test.go @@ -16,7 +16,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "github.com/diagridio/go-etcd-cron/api" - "github.com/diagridio/go-etcd-cron/tests/cron" + "github.com/diagridio/go-etcd-cron/tests/framework/cron" ) func Test_Add(t *testing.T) { diff --git a/internal/counter/counter.go b/internal/counter/counter.go index 93f09a8..a5dd79b 100644 --- a/internal/counter/counter.go +++ b/internal/counter/counter.go @@ -47,11 +47,22 @@ type Options struct { Collector garbage.Interface } -// Counter is a counter, tracking state of a scheduled job as it is triggered +// Interface is a counter, tracking state of a scheduled job as it is triggered // over time. Returns, if necessary, the time the job should be triggered next. // Counter handles the deletion of the associated job if it has expired and // adding the counter object to the garbage collector. -type Counter struct { +type Interface interface { + ScheduledTime() time.Time + Key() string + JobName() string + TriggerRequest() *api.TriggerRequest + TriggerSuccess(ctx context.Context) (bool, error) + TriggerFailed(ctx context.Context) (bool, error) +} + +// counter is the implementation of the counter interface. +type counter struct { + name string jobKey string counterKey string client client.Interface @@ -64,7 +75,7 @@ type Counter struct { triggerRequest *api.TriggerRequest } -func New(ctx context.Context, opts Options) (*Counter, bool, error) { +func New(ctx context.Context, opts Options) (Interface, bool, error) { counterKey := opts.Key.CounterKey(opts.Name) jobKey := opts.Key.JobKey(opts.Name) @@ -78,7 +89,8 @@ func New(ctx context.Context, opts Options) (*Counter, bool, error) { } if res.Count == 0 { - c := &Counter{ + c := &counter{ + name: opts.Name, jobKey: jobKey, counterKey: counterKey, client: opts.Client, @@ -119,7 +131,7 @@ func New(ctx context.Context, opts Options) (*Counter, bool, error) { } } - c := &Counter{ + c := &counter{ counterKey: counterKey, jobKey: jobKey, client: opts.Client, @@ -144,23 +156,29 @@ func New(ctx context.Context, opts Options) (*Counter, bool, error) { // ScheduledTime is the time at which the job is scheduled to be triggered // next. Implements the kit events queueable item. -func (c *Counter) ScheduledTime() time.Time { +func (c *counter) ScheduledTime() time.Time { return c.next } -// Key returns the name of the job. Implements the kit events queueable item. -func (c *Counter) Key() string { +// Key returns the Etcd key of the job. Implements the kit events queueable +// item. +func (c *counter) Key() string { return c.jobKey } +// JobName returns the consumer name of the job. +func (c *counter) JobName() string { + return c.name +} + // TriggerRequest is the trigger request representation for the job. -func (c *Counter) TriggerRequest() *api.TriggerRequest { +func (c *counter) TriggerRequest() *api.TriggerRequest { return c.triggerRequest } // TriggerSuccess updates the counter state given what the next trigger time // was. Returns true if the job will be triggered again. -func (c *Counter) TriggerSuccess(ctx context.Context) (bool, error) { +func (c *counter) TriggerSuccess(ctx context.Context) (bool, error) { // Update the last trigger time as the next trigger time, and increment the // counter. // Set attempts to 0 as this trigger was successful. @@ -191,7 +209,7 @@ func (c *Counter) TriggerSuccess(ctx context.Context) (bool, error) { // Returns true if the job failure policy indicates that the job should be // tried again. Returns false if the job should not be attempted again and was // deleted. -func (c *Counter) TriggerFailed(ctx context.Context) (bool, error) { +func (c *counter) TriggerFailed(ctx context.Context) (bool, error) { // Increment the attempts counter as this count tick failed. c.count.Attempts++ @@ -217,7 +235,7 @@ func (c *Counter) TriggerFailed(ctx context.Context) (bool, error) { // policyTryAgain returns true if the failure policy indicates this job should // be tried again at this tick. -func (c *Counter) policyTryAgain() bool { +func (c *counter) policyTryAgain() bool { fp := c.job.GetJob().GetFailurePolicy() if fp == nil { c.count.LastTrigger = timestamppb.New(c.next) @@ -253,7 +271,7 @@ func (c *Counter) policyTryAgain() bool { // tickNext updates the next trigger time, and deletes the counter record if // needed. -func (c *Counter) tickNext() (bool, error) { +func (c *counter) tickNext() (bool, error) { if c.updateNext() { return true, nil } @@ -271,7 +289,7 @@ func (c *Counter) tickNext() (bool, error) { // updateNext updates the counter's next trigger time. // Returns false if the job and counter should be deleted because it has // expired. -func (c *Counter) updateNext() bool { +func (c *counter) updateNext() bool { // If job completed repeats, delete the counter. if c.job.GetJob().Repeats != nil && (c.count.GetCount() >= c.job.GetJob().GetRepeats()) { return false diff --git a/internal/counter/counter_test.go b/internal/counter/counter_test.go index 018a99d..5022a3a 100644 --- a/internal/counter/counter_test.go +++ b/internal/counter/counter_test.go @@ -26,7 +26,7 @@ import ( "github.com/diagridio/go-etcd-cron/internal/grave" "github.com/diagridio/go-etcd-cron/internal/key" "github.com/diagridio/go-etcd-cron/internal/scheduler" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) func Test_New(t *testing.T) { @@ -35,7 +35,7 @@ func Test_New(t *testing.T) { t.Run("New pops the job key on the collector", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -107,7 +107,7 @@ func Test_New(t *testing.T) { t.Run("if the counter already exists and partition ID matches, expect counter be kept the same", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -191,7 +191,7 @@ func Test_New(t *testing.T) { t.Run("if the counter already exists but partition ID doesn't match, expect counter to be written with new value", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -282,7 +282,7 @@ func Test_New(t *testing.T) { t.Run("if the counter already exists and partition ID matches but is expired, expect both job and counter to be deleted", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -363,7 +363,7 @@ func Test_New(t *testing.T) { t.Run("if the counter doesn't exist, create new counter and update next but don't write, return true", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -441,7 +441,7 @@ func Test_TriggerSuccess(t *testing.T) { t.Run("if tick next is true, expect job be kept and counter to incremented", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -454,14 +454,14 @@ func Test_TriggerSuccess(t *testing.T) { Schedule: ptr.Of("@every 1s"), }, } - counter := &stored.Counter{LastTrigger: nil, JobPartitionId: 123} + scounter := &stored.Counter{LastTrigger: nil, JobPartitionId: 123} sched, err := scheduler.NewBuilder().Schedule(job) require.NoError(t, err) jobBytes, err := proto.Marshal(job) require.NoError(t, err) - counterBytes, err := proto.Marshal(counter) + counterBytes, err := proto.Marshal(scounter) require.NoError(t, err) _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) @@ -479,12 +479,12 @@ func Test_TriggerSuccess(t *testing.T) { }() yard := grave.New() - c := &Counter{ + c := &counter{ yard: yard, client: client, collector: collector, job: job, - count: counter, + count: scounter, schedule: sched, jobKey: "abc/jobs/1", counterKey: "abc/counters/1", @@ -526,7 +526,7 @@ func Test_TriggerSuccess(t *testing.T) { t.Run("if tick next is false, expect job and counter to be deleted", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -538,7 +538,7 @@ func Test_TriggerSuccess(t *testing.T) { DueTime: ptr.Of(now.Format(time.RFC3339)), }, } - counter := &stored.Counter{ + scounter := &stored.Counter{ LastTrigger: nil, JobPartitionId: 123, Count: 0, @@ -549,7 +549,7 @@ func Test_TriggerSuccess(t *testing.T) { jobBytes, err := proto.Marshal(job) require.NoError(t, err) - counterBytes, err := proto.Marshal(counter) + counterBytes, err := proto.Marshal(scounter) require.NoError(t, err) _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) @@ -567,13 +567,13 @@ func Test_TriggerSuccess(t *testing.T) { }() yard := grave.New() - c := &Counter{ + c := &counter{ yard: yard, client: client, collector: collector, job: job, next: now, - count: counter, + count: scounter, schedule: sched, jobKey: "abc/jobs/1", counterKey: "abc/counters/1", @@ -605,7 +605,7 @@ func Test_TriggerSuccess(t *testing.T) { t.Run("The number of attempts on the counter should always be reset to 0 when Trigger is called", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -618,14 +618,14 @@ func Test_TriggerSuccess(t *testing.T) { Schedule: ptr.Of("@every 1s"), }, } - counter := &stored.Counter{LastTrigger: nil, JobPartitionId: 123, Attempts: 456} + scounter := &stored.Counter{LastTrigger: nil, JobPartitionId: 123, Attempts: 456} sched, err := scheduler.NewBuilder().Schedule(job) require.NoError(t, err) jobBytes, err := proto.Marshal(job) require.NoError(t, err) - counterBytes, err := proto.Marshal(counter) + counterBytes, err := proto.Marshal(scounter) require.NoError(t, err) _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) @@ -643,12 +643,12 @@ func Test_TriggerSuccess(t *testing.T) { }() yard := grave.New() - c := &Counter{ + c := &counter{ yard: yard, client: client, collector: collector, job: job, - count: counter, + count: scounter, schedule: sched, jobKey: "abc/jobs/1", counterKey: "abc/counters/1", @@ -695,7 +695,7 @@ func Test_tickNext(t *testing.T) { t.Run("if the updateNext returns true, expect no delete", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -707,14 +707,14 @@ func Test_tickNext(t *testing.T) { DueTime: ptr.Of(now.Format(time.RFC3339)), }, } - counter := &stored.Counter{LastTrigger: nil, JobPartitionId: 123} + scounter := &stored.Counter{LastTrigger: nil, JobPartitionId: 123} sched, err := scheduler.NewBuilder().Schedule(job) require.NoError(t, err) jobBytes, err := proto.Marshal(job) require.NoError(t, err) - counterBytes, err := proto.Marshal(counter) + counterBytes, err := proto.Marshal(scounter) require.NoError(t, err) _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) @@ -732,12 +732,12 @@ func Test_tickNext(t *testing.T) { }() yard := grave.New() - c := &Counter{ + c := &counter{ yard: yard, client: client, collector: collector, job: job, - count: counter, + count: scounter, schedule: sched, jobKey: "abc/jobs/1", counterKey: "abc/counters/1", @@ -771,7 +771,7 @@ func Test_tickNext(t *testing.T) { t.Run("if the updateNext returns false, expect delete", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) now := time.Now().UTC() @@ -783,7 +783,7 @@ func Test_tickNext(t *testing.T) { DueTime: ptr.Of(now.Format(time.RFC3339)), }, } - counter := &stored.Counter{ + scounter := &stored.Counter{ LastTrigger: timestamppb.New(now), JobPartitionId: 123, Count: 1, @@ -794,7 +794,7 @@ func Test_tickNext(t *testing.T) { jobBytes, err := proto.Marshal(job) require.NoError(t, err) - counterBytes, err := proto.Marshal(counter) + counterBytes, err := proto.Marshal(scounter) require.NoError(t, err) _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) @@ -812,12 +812,12 @@ func Test_tickNext(t *testing.T) { }() yard := grave.New() - c := &Counter{ + c := &counter{ yard: yard, client: client, collector: collector, job: job, - count: counter, + count: scounter, schedule: sched, jobKey: "abc/jobs/1", counterKey: "abc/counters/1", @@ -888,12 +888,12 @@ func Test_updateNext(t *testing.T) { require.NoError(t, err) tests := map[string]struct { - counter *Counter + counter *counter exp bool expNext time.Time }{ "if the number of counts is the same as repeats return false": { - counter: &Counter{ + counter: &counter{ schedule: repeats, job: &stored.Job{Job: &api.Job{ Repeats: ptr.Of(uint32(4)), @@ -905,7 +905,7 @@ func Test_updateNext(t *testing.T) { exp: false, }, "if the number of counts is more than repeats return false (should never happen)": { - counter: &Counter{ + counter: &counter{ schedule: repeats, job: &stored.Job{Job: &api.Job{ Repeats: ptr.Of(uint32(4)), @@ -915,7 +915,7 @@ func Test_updateNext(t *testing.T) { exp: false, }, "if the last trigger time if the same as the expiry, expect false": { - counter: &Counter{ + counter: &counter{ schedule: expires, job: &stored.Job{Job: &api.Job{ Repeats: ptr.Of(uint32(4)), @@ -928,7 +928,7 @@ func Test_updateNext(t *testing.T) { exp: false, }, "if the count is equal to total, return false": { - counter: &Counter{ + counter: &counter{ schedule: expires, job: &stored.Job{Job: &api.Job{ Repeats: ptr.Of(uint32(4)), @@ -941,7 +941,7 @@ func Test_updateNext(t *testing.T) { exp: false, }, "if under the number of counts, but job is past expiry time, return false": { - counter: &Counter{ + counter: &counter{ schedule: expires, job: &stored.Job{ Expiration: timestamppb.New(now.Add(-5 * time.Second)), @@ -955,7 +955,7 @@ func Test_updateNext(t *testing.T) { exp: false, }, "if time is past the trigger time but no triggered yet for one shot, return true and set trigger time": { - counter: &Counter{ + counter: &counter{ schedule: oneshot, job: &stored.Job{Job: new(api.Job)}, count: &stored.Counter{ @@ -967,7 +967,7 @@ func Test_updateNext(t *testing.T) { expNext: now, }, "if oneshot trigger but has already been triggered, expect false": { - counter: &Counter{ + counter: &counter{ schedule: oneshot, job: &stored.Job{Job: new(api.Job)}, count: &stored.Counter{ @@ -1640,7 +1640,7 @@ func Test_TriggerFailed(t *testing.T) { sched, err := scheduler.NewBuilder().Schedule(job) require.NoError(t, err) - counter := &Counter{ + counter := &counter{ jobKey: "abc/jobs/1", counterKey: "abc/counters/1", client: client, @@ -1709,7 +1709,7 @@ func Test_TriggerFailureSuccess(t *testing.T) { sched, err := scheduler.NewBuilder().Schedule(job) require.NoError(t, err) - counter := &Counter{ + counter := &counter{ jobKey: "abc/jobs/1", counterKey: "abc/counters/1", client: client, diff --git a/internal/counter/fake/fake.go b/internal/counter/fake/fake.go new file mode 100644 index 0000000..63cb564 --- /dev/null +++ b/internal/counter/fake/fake.go @@ -0,0 +1,93 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package fake + +import ( + "context" + "time" + + "github.com/diagridio/go-etcd-cron/api" +) + +type Fake struct { + scheduledTimeFn func() time.Time + key string + jobName string + triggerRequestFn func() *api.TriggerRequest + triggerSuccessFn func(context.Context) (bool, error) + triggerFailedFn func(context.Context) (bool, error) +} + +func New() *Fake { + return &Fake{ + scheduledTimeFn: time.Now, + key: "key", + jobName: "job", + triggerRequestFn: func() *api.TriggerRequest { + return &api.TriggerRequest{} + }, + triggerSuccessFn: func(context.Context) (bool, error) { + return true, nil + }, + triggerFailedFn: func(context.Context) (bool, error) { + return false, nil + }, + } +} + +func (f *Fake) WithScheduledTime(fn func() time.Time) *Fake { + f.scheduledTimeFn = fn + return f +} + +func (f *Fake) WithKey(key string) *Fake { + f.key = key + return f +} + +func (f *Fake) WithJobName(jobName string) *Fake { + f.jobName = jobName + return f +} + +func (f *Fake) WithTriggerRequest(fn func() *api.TriggerRequest) *Fake { + f.triggerRequestFn = fn + return f +} + +func (f *Fake) WithTriggerSuccess(fn func(context.Context) (bool, error)) *Fake { + f.triggerSuccessFn = fn + return f +} + +func (f *Fake) WithTriggerFailed(fn func(context.Context) (bool, error)) *Fake { + f.triggerFailedFn = fn + return f +} + +func (f *Fake) ScheduledTime() time.Time { + return f.scheduledTimeFn() +} + +func (f *Fake) Key() string { + return f.key +} + +func (f *Fake) JobName() string { + return f.jobName +} + +func (f *Fake) TriggerRequest() *api.TriggerRequest { + return f.triggerRequestFn() +} + +func (f *Fake) TriggerSuccess(ctx context.Context) (bool, error) { + return f.triggerSuccessFn(ctx) +} + +func (f *Fake) TriggerFailed(ctx context.Context) (bool, error) { + return f.triggerFailedFn(ctx) +} diff --git a/internal/counter/fake/fake_test.go b/internal/counter/fake/fake_test.go new file mode 100644 index 0000000..cc7f743 --- /dev/null +++ b/internal/counter/fake/fake_test.go @@ -0,0 +1,16 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package fake + +import ( + "testing" + + "github.com/diagridio/go-etcd-cron/internal/counter" +) + +func Test_Fake(*testing.T) { + var _ counter.Interface = New() +} diff --git a/internal/garbage/collector_test.go b/internal/garbage/collector_test.go index c3d0abf..3fb7c97 100644 --- a/internal/garbage/collector_test.go +++ b/internal/garbage/collector_test.go @@ -17,7 +17,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" clocktesting "k8s.io/utils/clock/testing" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) func Test_New(t *testing.T) { @@ -113,7 +113,7 @@ func Test_Run(t *testing.T) { t.Run("closing the collector should result in the remaining keys to be deleted", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) coll, err := New(Options{ Client: client, }) @@ -157,7 +157,7 @@ func Test_Run(t *testing.T) { t.Run("reaching max garbage limit (500k) should cause all keys to be deleted", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) coll, err := New(Options{ Client: client, }) @@ -209,7 +209,7 @@ func Test_Run(t *testing.T) { t.Run("if ticks past 180 seconds, then should delete all garbage keys", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) clock := clocktesting.NewFakeClock(time.Now()) coll, err := New(Options{ Client: client, @@ -257,7 +257,7 @@ func Test_Run(t *testing.T) { t.Run("if ticks past custom 60s, then should delete all garbage keys", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) clock := clocktesting.NewFakeClock(time.Now()) coll, err := New(Options{ Client: client, @@ -410,7 +410,7 @@ func Test_collect(t *testing.T) { t.Run("if there are keys to delete, expect them to be deleted", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) coll, err := New(Options{ Client: client, }) @@ -444,7 +444,7 @@ func Test_collect(t *testing.T) { t.Run("should not delete other keys which are not marked for deletion", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) coll, err := New(Options{ Client: client, }) diff --git a/internal/informer/informer_test.go b/internal/informer/informer_test.go index d4fd36a..aa1583a 100644 --- a/internal/informer/informer_test.go +++ b/internal/informer/informer_test.go @@ -24,7 +24,7 @@ import ( "github.com/diagridio/go-etcd-cron/internal/grave" "github.com/diagridio/go-etcd-cron/internal/key" "github.com/diagridio/go-etcd-cron/internal/partitioner" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) func Test_Run(t *testing.T) { @@ -39,7 +39,7 @@ func Test_Run(t *testing.T) { t.Run("No keys in the db should return no events after ready", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) collector, err := garbage.New(garbage.Options{Client: client}) require.NoError(t, err) i := New(Options{ @@ -82,7 +82,7 @@ func Test_Run(t *testing.T) { t.Run("keys in the db should be returned after ready, filtered by partition", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) collector, err := garbage.New(garbage.Options{Client: client}) require.NoError(t, err) @@ -155,7 +155,7 @@ func Test_Run(t *testing.T) { t.Run("keys added to the db after Ready should be synced, filtering by partition", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) collector, err := garbage.New(garbage.Options{Client: client}) require.NoError(t, err) diff --git a/internal/key/key_test.go b/internal/key/key_test.go index 34e947d..e77f19d 100644 --- a/internal/key/key_test.go +++ b/internal/key/key_test.go @@ -42,14 +42,13 @@ func Test_JobKey(t *testing.T) { } for _, test := range tests { - testInLoop := test t.Run(test.namespace+"/"+test.jobName, func(t *testing.T) { t.Parallel() key := New(Options{ - Namespace: testInLoop.namespace, - PartitionID: testInLoop.partitionID, + Namespace: test.namespace, + PartitionID: test.partitionID, }) - assert.Equal(t, testInLoop.expJobKey, key.JobKey(testInLoop.jobName)) + assert.Equal(t, test.expJobKey, key.JobKey(test.jobName)) }) } } @@ -84,14 +83,13 @@ func Test_CounterKey(t *testing.T) { } for _, test := range tests { - testInLoop := test t.Run(test.namespace+"/"+test.jobName, func(t *testing.T) { t.Parallel() key := New(Options{ - Namespace: testInLoop.namespace, - PartitionID: testInLoop.partitionID, + Namespace: test.namespace, + PartitionID: test.partitionID, }) - assert.Equal(t, testInLoop.expCounterKey, key.CounterKey(testInLoop.jobName)) + assert.Equal(t, test.expCounterKey, key.CounterKey(test.jobName)) }) } } @@ -118,14 +116,13 @@ func Test_LeadershipNamespace(t *testing.T) { } for _, test := range tests { - testInLoop := test t.Run(test.namespace, func(t *testing.T) { t.Parallel() key := New(Options{ - Namespace: testInLoop.namespace, + Namespace: test.namespace, PartitionID: 123, }) - assert.Equal(t, testInLoop.expLeadershipNS, key.LeadershipNamespace()) + assert.Equal(t, test.expLeadershipNS, key.LeadershipNamespace()) }) } } @@ -156,14 +153,13 @@ func Test_LeadershipKey(t *testing.T) { } for _, test := range tests { - testInLoop := test t.Run(fmt.Sprintf("%s/%d", test.namespace, test.partitionID), func(t *testing.T) { t.Parallel() key := New(Options{ - Namespace: testInLoop.namespace, - PartitionID: testInLoop.partitionID, + Namespace: test.namespace, + PartitionID: test.partitionID, }) - assert.Equal(t, testInLoop.expLeadershipKey, key.LeadershipKey()) + assert.Equal(t, test.expLeadershipKey, key.LeadershipKey()) }) } } @@ -190,14 +186,13 @@ func Test_JobNamespace(t *testing.T) { } for _, test := range tests { - testInLoop := test t.Run(test.namespace, func(t *testing.T) { t.Parallel() key := New(Options{ - Namespace: testInLoop.namespace, + Namespace: test.namespace, PartitionID: 123, }) - assert.Equal(t, testInLoop.expJobNS, key.JobNamespace()) + assert.Equal(t, test.expJobNS, key.JobNamespace()) }) } } @@ -224,14 +219,13 @@ func Test_JobName(t *testing.T) { } for _, test := range tests { - testInLoop := test t.Run(test.key, func(t *testing.T) { t.Parallel() key := New(Options{ Namespace: "/123", PartitionID: 123, }) - assert.Equal(t, testInLoop.expJobName, key.JobName([]byte(testInLoop.key))) + assert.Equal(t, test.expJobName, key.JobName([]byte(test.key))) }) } } diff --git a/internal/leadership/leadership_test.go b/internal/leadership/leadership_test.go index 8ec6f02..8406a79 100644 --- a/internal/leadership/leadership_test.go +++ b/internal/leadership/leadership_test.go @@ -17,7 +17,7 @@ import ( "github.com/diagridio/go-etcd-cron/internal/client" "github.com/diagridio/go-etcd-cron/internal/key" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) //nolint:gocyclo @@ -27,7 +27,7 @@ func Test_Run(t *testing.T) { t.Run("Leadership should become leader and become ready", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -56,7 +56,7 @@ func Test_Run(t *testing.T) { t.Run("Running leadership multiple times should error", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -86,7 +86,7 @@ func Test_Run(t *testing.T) { t.Run("Closing the leadership should delete the accosted partition leader key", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -134,7 +134,7 @@ func Test_Run(t *testing.T) { t.Run("Closing the leadership should not delete the other partition keys", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -196,7 +196,7 @@ func Test_Run(t *testing.T) { t.Run("An existing key will gate becoming ready until deleted", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -248,7 +248,7 @@ func Test_Run(t *testing.T) { t.Run("Leadership will gate until all partition keys have the same total", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -314,7 +314,7 @@ func Test_Run(t *testing.T) { t.Run("Leadership of different partition IDs should all become leader", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l1 := New(Options{ Client: client, PartitionTotal: 3, @@ -376,7 +376,7 @@ func Test_Run(t *testing.T) { t.Run("Two leaders of the same partition should make one passive unil the other is closed", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l1 := New(Options{ Client: client, PartitionTotal: 1, @@ -452,7 +452,7 @@ func Test_checkLeadershipKeys(t *testing.T) { t.Run("if no leadership keys, return error", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -470,7 +470,7 @@ func Test_checkLeadershipKeys(t *testing.T) { t.Run("if all keys have the same partition total, return true", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -493,7 +493,7 @@ func Test_checkLeadershipKeys(t *testing.T) { t.Run("if some keys have the same partition total, return true", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -518,7 +518,7 @@ func Test_checkLeadershipKeys(t *testing.T) { t.Run("if some keys have the same partition total but this partition doesn't exist, return error", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -541,7 +541,7 @@ func Test_checkLeadershipKeys(t *testing.T) { t.Run("if some keys have the same partition total but some don't, return error", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -572,7 +572,7 @@ func Test_attemptPartitionLeadership(t *testing.T) { t.Run("no previous leader, expect to become leader", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, @@ -599,7 +599,7 @@ func Test_attemptPartitionLeadership(t *testing.T) { t.Run("previous leader, expect not to become leader", func(t *testing.T) { t.Parallel() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) l := New(Options{ Client: client, PartitionTotal: 10, diff --git a/internal/queue/queue.go b/internal/queue/queue.go index ab3bb9b..1c31984 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -9,6 +9,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" "sync/atomic" @@ -66,19 +67,37 @@ type Queue struct { key *key.Key schedBuilder *scheduler.Builder - collector garbage.Interface - cache concurrency.Map[string, struct{}] yard *grave.Yard + collector garbage.Interface + + // counter cache tracks counters which are active. Used to ensure there is no + // race condition whereby a Delete operation on a job which was mid execution + // on the same scheduler instance, would not see that job as not deleted from + // the in-memory. Used to back out of an execution if it is no longer in that + // cache (deleted). + cache concurrency.Map[string, struct{}] // lock prevents an informed schedule from overwriting a job as it is being // triggered, i.e. prevent a PUT and mid-trigger race condition. lock concurrency.MutexMap[string] wg sync.WaitGroup - queue *queue.Processor[string, *counter.Counter] + queue *queue.Processor[string, counter.Interface] + + // staged are the counters that have been staged for later triggering as the + // consumer has signalled that the job is current undeliverable. When the + // consumer signals a prefix has become deliverable, counters in that prefix + // will be enqueued. Indexed by the counters Job Name. + staged map[string]counter.Interface + stagedLock sync.Mutex + + // activeConsumerPrefixes tracks the job name prefixes which are currently + // deliverable. Since consumer may indicate the same prefix is deliverable + // multiple times due to pooling, we track the length, and remove the prefix + // when the length is 0. + deliverablePrefixes map[string]*atomic.Int32 running atomic.Bool readyCh chan struct{} - errCh chan error } func New(opts Options) *Queue { @@ -88,18 +107,19 @@ func New(opts Options) *Queue { } return &Queue{ - log: opts.Log.WithName("queue"), - client: opts.Client, - key: opts.Key, - triggerFn: opts.TriggerFn, - collector: opts.Collector, - schedBuilder: opts.SchedulerBuilder, - cache: concurrency.NewMap[string, struct{}](), - yard: opts.Yard, - clock: cl, - lock: concurrency.NewMutexMap[string](), - readyCh: make(chan struct{}), - errCh: make(chan error, 10), + log: opts.Log.WithName("queue"), + client: opts.Client, + key: opts.Key, + triggerFn: opts.TriggerFn, + collector: opts.Collector, + schedBuilder: opts.SchedulerBuilder, + cache: concurrency.NewMap[string, struct{}](), + yard: opts.Yard, + clock: cl, + deliverablePrefixes: make(map[string]*atomic.Int32), + staged: make(map[string]counter.Interface), + lock: concurrency.NewMutexMap[string](), + readyCh: make(chan struct{}), } } @@ -110,8 +130,8 @@ func (q *Queue) Run(ctx context.Context) error { return errors.New("queue is already running") } - q.queue = queue.NewProcessor[string, *counter.Counter]( - func(counter *counter.Counter) { + q.queue = queue.NewProcessor[string, counter.Interface]( + func(counter counter.Interface) { q.lock.RLock(counter.Key()) _, ok := q.cache.Load(counter.Key()) if !ok || ctx.Err() != nil { @@ -122,52 +142,49 @@ func (q *Queue) Run(ctx context.Context) error { q.wg.Add(1) go func() { defer q.wg.Done() - if q.handleTrigger(ctx, counter) { - q.lock.RUnlock(counter.Key()) - } else { + if !q.handleTrigger(ctx, counter) { q.cache.Delete(counter.Key()) q.lock.DeleteRUnlock(counter.Key()) + return } + + q.lock.RUnlock(counter.Key()) }() }, ).WithClock(q.clock) close(q.readyCh) - var err error - select { - case <-ctx.Done(): - case err = <-q.errCh: - if errors.Is(err, queue.ErrProcessorStopped) { - err = nil - } - } + <-ctx.Done() - return errors.Join(q.queue.Close(), err) + return q.queue.Close() } -func (q *Queue) Delete(ctx context.Context, name string) error { +func (q *Queue) Delete(ctx context.Context, jobName string) error { select { case <-ctx.Done(): return ctx.Err() case <-q.readyCh: } - key := q.key.JobKey(name) + key := q.key.JobKey(jobName) q.lock.Lock(key) defer q.lock.DeleteUnlock(key) + q.stagedLock.Lock() + delete(q.staged, jobName) + q.stagedLock.Unlock() + if _, err := q.client.Delete(ctx, key); err != nil { return err } - if _, ok := q.cache.Load(key); !ok { - return nil + if _, ok := q.cache.LoadAndDelete(key); ok { + q.queue.Dequeue(key) } - q.cache.Delete(key) - return q.queue.Dequeue(key) + return nil } func (q *Queue) DeletePrefixes(ctx context.Context, prefixes ...string) error { @@ -189,8 +206,16 @@ func (q *Queue) DeletePrefixes(ctx context.Context, prefixes ...string) error { } for _, kv := range resp.PrevKvs { - errs = append(errs, q.cacheDelete(string(kv.Key))) + q.cacheDelete(string(kv.Key)) } + + q.stagedLock.Lock() + for jobName := range q.staged { + if strings.HasPrefix(jobName, prefix) { + delete(q.staged, jobName) + } + } + q.stagedLock.Unlock() } return errors.Join(errs...) @@ -208,7 +233,7 @@ func (q *Queue) HandleInformerEvent(ctx context.Context, e *informer.Event) erro return ctx.Err() } - if err := q.scheduleEvent(ctx, e); !errors.Is(err, queue.ErrProcessorStopped) { + if err := q.scheduleEvent(ctx, e); err != nil { return err } @@ -217,63 +242,89 @@ func (q *Queue) HandleInformerEvent(ctx context.Context, e *informer.Event) erro func (q *Queue) scheduleEvent(ctx context.Context, e *informer.Event) error { q.lock.Lock(string(e.Key)) + + jobName := q.key.JobName(e.Key) + + q.stagedLock.Lock() + delete(q.staged, jobName) + q.stagedLock.Unlock() + if e.IsPut { defer q.lock.Unlock(string(e.Key)) - return q.schedule(ctx, q.key.JobName(e.Key), e.Job) + return q.schedule(ctx, jobName, e.Job) } defer q.lock.DeleteUnlock(string(e.Key)) q.cache.Delete(string(e.Key)) - q.collector.Push(q.key.CounterKey(q.key.JobName(e.Key))) - return q.queue.Dequeue(string(e.Key)) + q.collector.Push(q.key.CounterKey(jobName)) + q.queue.Dequeue(string(e.Key)) + + return nil } -func (q *Queue) cacheDelete(jobKey string) error { +func (q *Queue) cacheDelete(jobKey string) { q.lock.Lock(jobKey) defer q.lock.DeleteUnlock(jobKey) - if _, ok := q.cache.Load(jobKey); ok { - return nil + if _, ok := q.cache.Load(jobKey); !ok { + return } q.cache.Delete(jobKey) - return q.queue.Dequeue(jobKey) + q.queue.Dequeue(jobKey) } // handleTrigger handles triggering a schedule job. // Returns true if the job is being re-enqueued, false otherwise. -func (q *Queue) handleTrigger(ctx context.Context, counter *counter.Counter) bool { - if !q.triggerFn(ctx, counter.TriggerRequest()) { - ok, err := counter.TriggerFailed(ctx) +func (q *Queue) handleTrigger(ctx context.Context, counter counter.Interface) bool { + result := q.triggerFn(ctx, counter.TriggerRequest()).GetResult() + if ctx.Err() != nil { + return false + } + + switch result { + // Job was successfully triggered. Re-enqueue if the Job has more triggers + // according to the schedule. + case api.TriggerResponseResult_SUCCESS: + ok, err := counter.TriggerSuccess(ctx) if err != nil { - q.log.Error(err, "failure failing job for next retry trigger", "name", counter.Key()) + q.log.Error(err, "failure marking job for next trigger", "name", counter.Key()) } - return q.enqueueCounter(ctx, counter, ok) - } - - ok, err := counter.TriggerSuccess(ctx) - if err != nil { - q.log.Error(err, "failure marking job for next trigger", "name", counter.Key()) - } + if ok { + q.queue.Enqueue(counter) + } - return q.enqueueCounter(ctx, counter, ok) -} + return ok -// enqueueCounter enqueues the job to the queue at this count tick. -func (q *Queue) enqueueCounter(ctx context.Context, counter *counter.Counter, ok bool) bool { - if ok && ctx.Err() == nil { - if err := q.queue.Enqueue(counter); err != nil { - select { - case <-ctx.Done(): - case q.errCh <- err: - } + // The Job failed to trigger. Re-enqueue if the Job has more trigger + // attempts according to FailurePolicy, or the Job has more triggers + // according to the schedule. + case api.TriggerResponseResult_FAILED: + ok, err := counter.TriggerFailed(ctx) + if err != nil { + q.log.Error(err, "failure failing job for next retry trigger", "name", counter.Key()) } + if ok { + q.queue.Enqueue(counter) + } + return ok + + // The Job was undeliverable so will be moved to the staging queue where it + // will stay until it become deliverable. Due to a race, if the job is in + // fact now deliverable, we need to re-enqueue immediately, else simply + // keep it in staging until the prefix is deliverable. + case api.TriggerResponseResult_UNDELIVERABLE: + if !q.stage(counter) { + q.queue.Enqueue(counter) + } return true - } - return false + default: + q.log.Error(errors.New("unknown trigger response result"), "unknown trigger response result", "name", counter.Key(), "result", result) + return false + } } // schedule schedules a job to it's next scheduled time. @@ -303,5 +354,6 @@ func (q *Queue) schedule(ctx context.Context, name string, job *stored.Job) erro } q.cache.Store(counter.Key(), struct{}{}) - return q.queue.Enqueue(counter) + q.queue.Enqueue(counter) + return nil } diff --git a/internal/queue/queue_test.go b/internal/queue/queue_test.go index a81c819..4ebc405 100644 --- a/internal/queue/queue_test.go +++ b/internal/queue/queue_test.go @@ -26,18 +26,18 @@ import ( "github.com/diagridio/go-etcd-cron/internal/informer" "github.com/diagridio/go-etcd-cron/internal/key" "github.com/diagridio/go-etcd-cron/internal/scheduler" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) func Test_delete_race(t *testing.T) { t.Parallel() triggered := make([]atomic.Int64, 20) - queue := newQueue(t, func(_ context.Context, req *api.TriggerRequest) bool { + queue := newQueue(t, func(_ context.Context, req *api.TriggerRequest) *api.TriggerResponse { i, err := strconv.Atoi(req.GetName()) require.NoError(t, err) triggered[i].Add(1) - return true + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} }) jobKeys := make([]string, 20) @@ -87,7 +87,7 @@ func Test_delete_race(t *testing.T) { func newQueue(t *testing.T, triggerFn api.TriggerFunction) *Queue { t.Helper() - client := tests.EmbeddedETCD(t) + client := etcd.Embedded(t) collector, err := garbage.New(garbage.Options{Client: client}) require.NoError(t, err) diff --git a/internal/queue/staging.go b/internal/queue/staging.go new file mode 100644 index 0000000..5ef160c --- /dev/null +++ b/internal/queue/staging.go @@ -0,0 +1,80 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package queue + +import ( + "context" + "strings" + "sync/atomic" + + "github.com/diagridio/go-etcd-cron/internal/counter" +) + +// DeliverablePrefixes adds the job name prefixes that can currently be +// delivered by the consumer. When the returned `CancelFunc` is called, the +// prefixes registered are released indicating that these prefixes can no +// longer be delivered. Multiple of the same prefix can be added and are +// tracked as a pool, meaning the prefix is still active if at least one +// instance is still registered. +func (q *Queue) DeliverablePrefixes(prefixes ...string) context.CancelFunc { + q.stagedLock.Lock() + defer q.stagedLock.Unlock() + + var toEnqueue []counter.Interface + for _, prefix := range prefixes { + if _, ok := q.deliverablePrefixes[prefix]; !ok { + q.deliverablePrefixes[prefix] = new(atomic.Int32) + + for jobName, stage := range q.staged { + if strings.HasPrefix(jobName, prefix) { + toEnqueue = append(toEnqueue, stage) + delete(q.staged, jobName) + } + } + } + + q.deliverablePrefixes[prefix].Add(1) + } + + for _, counter := range toEnqueue { + q.queue.Enqueue(counter) + } + + return func() { + q.stagedLock.Lock() + defer q.stagedLock.Unlock() + + for _, prefix := range prefixes { + if i, ok := q.deliverablePrefixes[prefix]; ok { + if i.Add(-1) <= 0 { + delete(q.deliverablePrefixes, prefix) + } + } + } + } +} + +// stage adds the counter (job) to the staging queue. Accounting for race +// conditions, returns false if the counter can actually be delivered now based +// on the current deliverable prefixes and should be immediately re-queued at +// the current count. +func (q *Queue) stage(counter counter.Interface) bool { + q.stagedLock.Lock() + defer q.stagedLock.Unlock() + + jobName := counter.JobName() + + // Check if the job is actually now deliverable. + for prefix := range q.deliverablePrefixes { + if strings.HasPrefix(jobName, prefix) { + return false + } + } + + q.staged[jobName] = counter + + return true +} diff --git a/internal/queue/staging_test.go b/internal/queue/staging_test.go new file mode 100644 index 0000000..98f5571 --- /dev/null +++ b/internal/queue/staging_test.go @@ -0,0 +1,199 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package queue + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/dapr/kit/events/queue" + "github.com/stretchr/testify/assert" + + "github.com/diagridio/go-etcd-cron/internal/counter" + "github.com/diagridio/go-etcd-cron/internal/counter/fake" +) + +func Test_DeliverablePrefixes(t *testing.T) { + t.Parallel() + + t.Run("registering empty prefixes should add nothing", func(t *testing.T) { + t.Parallel() + + q := &Queue{deliverablePrefixes: make(map[string]*atomic.Int32)} + assert.Empty(t, q.deliverablePrefixes) + + cancel := q.DeliverablePrefixes() + assert.Empty(t, q.deliverablePrefixes) + cancel() + assert.Empty(t, q.deliverablePrefixes) + }) + + t.Run("registering and cancelling should add then remove the prefix", func(t *testing.T) { + t.Parallel() + + q := &Queue{deliverablePrefixes: make(map[string]*atomic.Int32)} + assert.Empty(t, q.deliverablePrefixes) + + cancel := q.DeliverablePrefixes("abc") + assert.Len(t, q.deliverablePrefixes, 1) + cancel() + assert.Empty(t, q.deliverablePrefixes) + }) + + t.Run("multiple: registering and cancelling should add then remove the prefix", func(t *testing.T) { + t.Parallel() + + q := &Queue{deliverablePrefixes: make(map[string]*atomic.Int32)} + assert.Empty(t, q.deliverablePrefixes) + + cancel1 := q.DeliverablePrefixes("abc") + assert.Len(t, q.deliverablePrefixes, 1) + cancel2 := q.DeliverablePrefixes("abc") + assert.Len(t, q.deliverablePrefixes, 1) + + cancel1() + assert.Len(t, q.deliverablePrefixes, 1) + cancel2() + assert.Empty(t, q.deliverablePrefixes) + }) + + t.Run("multiple with diff prefixes: registering and cancelling should add then remove the prefix", func(t *testing.T) { + t.Parallel() + + q := &Queue{deliverablePrefixes: make(map[string]*atomic.Int32)} + assert.Empty(t, q.deliverablePrefixes) + + cancel1 := q.DeliverablePrefixes("abc") + assert.Len(t, q.deliverablePrefixes, 1) + cancel2 := q.DeliverablePrefixes("abc") + assert.Len(t, q.deliverablePrefixes, 1) + cancel3 := q.DeliverablePrefixes("def") + assert.Len(t, q.deliverablePrefixes, 2) + cancel4 := q.DeliverablePrefixes("def") + assert.Len(t, q.deliverablePrefixes, 2) + + cancel1() + assert.Len(t, q.deliverablePrefixes, 2) + cancel4() + assert.Len(t, q.deliverablePrefixes, 2) + cancel2() + assert.Len(t, q.deliverablePrefixes, 1) + cancel3() + assert.Empty(t, q.deliverablePrefixes) + }) + + t.Run("staged counters should be enqueued if they match an added prefix", func(t *testing.T) { + t.Parallel() + + var lock sync.Mutex + var triggered []string + q := &Queue{ + deliverablePrefixes: make(map[string]*atomic.Int32), + staged: make(map[string]counter.Interface), + queue: queue.NewProcessor[string, counter.Interface]( + func(counter counter.Interface) { + lock.Lock() + defer lock.Unlock() + triggered = append(triggered, counter.JobName()) + }, + ), + } + + counter1 := fake.New().WithJobName("abc123").WithKey("abc123") + counter2 := fake.New().WithJobName("abc234").WithKey("abc234") + counter3 := fake.New().WithJobName("def123").WithKey("def123") + counter4 := fake.New().WithJobName("def234").WithKey("def234") + counter5 := fake.New().WithJobName("xyz123").WithKey("xyz123") + counter6 := fake.New().WithJobName("xyz234").WithKey("xyz234") + q.staged = map[string]counter.Interface{ + "abc123": counter1, "abc234": counter2, + "def123": counter3, "def234": counter4, + "xyz123": counter5, "xyz234": counter6, + } + + cancel := q.DeliverablePrefixes("abc", "xyz") + t.Cleanup(cancel) + assert.Equal(t, map[string]counter.Interface{"def123": counter3, "def234": counter4}, q.staged) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + lock.Lock() + defer lock.Unlock() + assert.ElementsMatch(c, []string{"abc123", "abc234", "xyz123", "xyz234"}, triggered) + }, time.Second*10, time.Millisecond*10) + + cancel = q.DeliverablePrefixes("d") + t.Cleanup(cancel) + assert.Empty(t, q.staged) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + lock.Lock() + defer lock.Unlock() + assert.ElementsMatch(c, []string{"abc123", "abc234", "xyz123", "xyz234", "def123", "def234"}, triggered) + }, time.Second*10, time.Millisecond*10) + }) +} + +func Test_stage(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + jobName string + deliverablePrefixes []string + expStaged bool + }{ + "no deliverable prefixes, should stage": { + jobName: "abc123", + deliverablePrefixes: []string{}, + expStaged: true, + }, + "deliverable prefixes but different, should stage": { + jobName: "abc123", + deliverablePrefixes: []string{"def", "cba"}, + expStaged: true, + }, + "deliverable prefixes and matches, should not stage": { + jobName: "abc123", + deliverablePrefixes: []string{"abc123"}, + expStaged: false, + }, + "multiple deliverable prefixes and matches, should not stage": { + jobName: "abc123", + deliverablePrefixes: []string{"def", "abc123", "cba"}, + expStaged: false, + }, + "multiple deliverable prefixes and matches on prefix, should not stage": { + jobName: "abc123", + deliverablePrefixes: []string{"def", "cba", "abc"}, + expStaged: false, + }, + "multiple deliverable prefixes and not matches on prefix, should stage": { + jobName: "abc123", + deliverablePrefixes: []string{"def", "cba", "abc1234"}, + expStaged: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + q := &Queue{ + deliverablePrefixes: make(map[string]*atomic.Int32), + staged: make(map[string]counter.Interface), + } + + for _, prefix := range test.deliverablePrefixes { + q.deliverablePrefixes[prefix] = new(atomic.Int32) + q.deliverablePrefixes[prefix].Add(1) + } + + got := q.stage(fake.New().WithJobName(test.jobName)) + + assert.Equal(t, test.expStaged, got) + assert.Equal(t, test.expStaged, len(q.staged) == 1) + }) + } +} diff --git a/proto/api/trigger.proto b/proto/api/trigger.proto index c11b0ad..e956fc4 100644 --- a/proto/api/trigger.proto +++ b/proto/api/trigger.proto @@ -22,3 +22,28 @@ message TriggerRequest { // payload is the job payload. google.protobuf.Any payload = 3; } + +// TriggerResponseResult is indicates the state result from triggering the job +// by the consumer. +enum TriggerResponseResult { + // SUCCESS indicates that the job was successfully triggered and will be + // ticked forward according to the schedule. + SUCCESS = 0; + + // FAILED indicates that the job failed to trigger and is subject to the + // FailurePolicy. + FAILED = 1; + + // UNDELIVERABLE indicates that the job should be added to the staging queue + // as the Job was undeliverable. Once the Job name prefix is marked as + // deliverable, it will be immediately triggered. + UNDELIVERABLE = 2; +} + +// TriggerResponse is returned by the caller from a TriggerResponse. Signals +// whether the Job was successfully triggered, the trigger failed, or instead +// needs to be added to the staging queue due to impossible delivery. +message TriggerResponse { + // result is the result given by the consumer when trigging the Job. + TriggerResponseResult result = 1; +} diff --git a/tests/cron/cluster.go b/tests/framework/cron/cluster.go similarity index 84% rename from tests/cron/cluster.go rename to tests/framework/cron/cluster.go index 8e4eb6f..44e985a 100644 --- a/tests/cron/cluster.go +++ b/tests/framework/cron/cluster.go @@ -8,7 +8,7 @@ package cron import ( "testing" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) type Cluster struct { @@ -18,7 +18,7 @@ type Cluster struct { func TripplePartition(t *testing.T) *Cluster { t.Helper() - client := tests.EmbeddedETCDBareClient(t) + client := etcd.EmbeddedBareClient(t) cr1 := newCron(t, client, 3, 0) cr2 := newCron(t, client, 3, 1) cr3 := newCron(t, client, 3, 2) diff --git a/tests/cron/cron.go b/tests/framework/cron/cron.go similarity index 82% rename from tests/cron/cron.go rename to tests/framework/cron/cron.go index aa55554..9581a5d 100644 --- a/tests/cron/cron.go +++ b/tests/framework/cron/cron.go @@ -31,12 +31,15 @@ func newCron(t *testing.T, client *clientv3.Client, total, id uint32) *Cron { var calls atomic.Int64 cron, err := cron.New(cron.Options{ - Log: logr.Discard(), - Client: client, - Namespace: "abc", - PartitionID: id, - PartitionTotal: total, - TriggerFn: func(context.Context, *api.TriggerRequest) bool { calls.Add(1); return true }, + Log: logr.Discard(), + Client: client, + Namespace: "abc", + PartitionID: id, + PartitionTotal: total, + TriggerFn: func(context.Context, *api.TriggerRequest) *api.TriggerResponse { + calls.Add(1) + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} + }, CounterGarbageCollectionInterval: ptr.Of(time.Millisecond * 300), }) require.NoError(t, err) diff --git a/tests/framework/cron/integration/integration.go b/tests/framework/cron/integration/integration.go new file mode 100644 index 0000000..88b5b8d --- /dev/null +++ b/tests/framework/cron/integration/integration.go @@ -0,0 +1,141 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package integration + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/go-logr/logr" + "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/cron" + "github.com/diagridio/go-etcd-cron/internal/client" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" +) + +type Options struct { + PartitionTotal uint32 + GotCh chan *api.TriggerRequest + TriggerFn func(*api.TriggerRequest) *api.TriggerResponse + Client *clientv3.Client +} + +type Integration struct { + ctx context.Context + closeCron func() + client client.Interface + api api.Interface + allCrons []api.Interface + triggered *atomic.Int64 +} + +func NewBase(t *testing.T, partitionTotal uint32) *Integration { + t.Helper() + return New(t, Options{ + PartitionTotal: partitionTotal, + }) +} + +func New(t *testing.T, opts Options) *Integration { + t.Helper() + + require.Positive(t, opts.PartitionTotal) + cl := opts.Client + if cl == nil { + cl = etcd.EmbeddedBareClient(t) + } + + var triggered atomic.Int64 + var a api.Interface + allCrns := make([]api.Interface, opts.PartitionTotal) + for i := range opts.PartitionTotal { + c, err := cron.New(cron.Options{ + Log: logr.Discard(), + Client: cl, + Namespace: "abc", + PartitionID: i, + PartitionTotal: opts.PartitionTotal, + TriggerFn: func(_ context.Context, req *api.TriggerRequest) *api.TriggerResponse { + defer triggered.Add(1) + if opts.GotCh != nil { + opts.GotCh <- req + } + if opts.TriggerFn != nil { + return opts.TriggerFn(req) + } + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} + }, + + CounterGarbageCollectionInterval: ptr.Of(time.Millisecond * 300), + }) + require.NoError(t, err) + allCrns[i] = c + if i == 0 { + a = c + } + } + + errCh := make(chan error, opts.PartitionTotal) + ctx, cancel := context.WithCancel(context.Background()) + + closeOnce := sync.OnceFunc(func() { + cancel() + for range opts.PartitionTotal { + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for cron to stop") + } + } + }) + t.Cleanup(closeOnce) + for i := range opts.PartitionTotal { + go func(i uint32) { + errCh <- allCrns[i].Run(ctx) + }(i) + } + + return &Integration{ + ctx: ctx, + client: client.New(client.Options{Client: cl, Log: logr.Discard()}), + api: a, + allCrons: allCrns, + triggered: &triggered, + closeCron: closeOnce, + } +} + +func (i *Integration) Context() context.Context { + return i.ctx +} + +func (i *Integration) Client() client.Interface { + return i.client +} + +func (i *Integration) API() api.Interface { + return i.api +} + +func (i *Integration) AllCrons() []api.Interface { + return i.allCrons +} + +func (i *Integration) Triggered() int { + return int(i.triggered.Load()) +} + +func (i *Integration) Close() { + i.closeCron() +} diff --git a/tests/cron/single.go b/tests/framework/cron/single.go similarity index 70% rename from tests/cron/single.go rename to tests/framework/cron/single.go index 4240268..d45bdee 100644 --- a/tests/cron/single.go +++ b/tests/framework/cron/single.go @@ -8,12 +8,12 @@ package cron import ( "testing" - "github.com/diagridio/go-etcd-cron/tests" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" ) func SinglePartition(t *testing.T) *Cron { t.Helper() - return newCron(t, tests.EmbeddedETCDBareClient(t), 1, 0) + return newCron(t, etcd.EmbeddedBareClient(t), 1, 0) } func SinglePartitionRun(t *testing.T) *Cron { diff --git a/tests/tests.go b/tests/framework/etcd/etcd.go similarity index 86% rename from tests/tests.go rename to tests/framework/etcd/etcd.go index aa817bf..7d9eb8f 100644 --- a/tests/tests.go +++ b/tests/framework/etcd/etcd.go @@ -3,7 +3,7 @@ Copyright (c) 2024 Diagrid Inc. Licensed under the MIT License. */ -package tests +package etcd import ( "net/url" @@ -19,15 +19,15 @@ import ( "github.com/diagridio/go-etcd-cron/internal/client" ) -func EmbeddedETCD(t *testing.T) client.Interface { +func Embedded(t *testing.T) client.Interface { t.Helper() return client.New(client.Options{ Log: logr.Discard(), - Client: EmbeddedETCDBareClient(t), + Client: EmbeddedBareClient(t), }) } -func EmbeddedETCDBareClient(t *testing.T) *clientv3.Client { +func EmbeddedBareClient(t *testing.T) *clientv3.Client { t.Helper() cfg := embed.NewConfig() diff --git a/tests/fake/fake.go b/tests/framework/fake/fake.go similarity index 81% rename from tests/fake/fake.go rename to tests/framework/fake/fake.go index 239c1d2..8261167 100644 --- a/tests/fake/fake.go +++ b/tests/framework/fake/fake.go @@ -19,6 +19,8 @@ type Fake struct { delFn func(ctx context.Context, name string) error delPFn func(ctx context.Context, prefixes ...string) error listFn func(ctx context.Context, prefix string) (*api.ListResponse, error) + + deliverablePrefixesFn func(ctx context.Context, prefixes ...string) (context.CancelFunc, error) } func New() *Fake { @@ -42,6 +44,9 @@ func New() *Fake { listFn: func(context.Context, string) (*api.ListResponse, error) { return nil, nil }, + deliverablePrefixesFn: func(context.Context, ...string) (context.CancelFunc, error) { + return func() {}, nil + }, } } @@ -75,6 +80,11 @@ func (f *Fake) WithList(fn func(context.Context, string) (*api.ListResponse, err return f } +func (f *Fake) WithDeliverablePrefixes(fn func(context.Context, ...string) (context.CancelFunc, error)) *Fake { + f.deliverablePrefixesFn = fn + return f +} + func (f *Fake) Run(ctx context.Context) error { return f.runFn(ctx) } @@ -98,3 +108,7 @@ func (f *Fake) DeletePrefixes(ctx context.Context, prefixes ...string) error { func (f *Fake) List(ctx context.Context, prefix string) (*api.ListResponse, error) { return f.listFn(ctx, prefix) } + +func (f *Fake) DeliverablePrefixes(ctx context.Context, prefixes ...string) (context.CancelFunc, error) { + return f.deliverablePrefixesFn(ctx, prefixes...) +} diff --git a/tests/fake/fake_test.go b/tests/framework/fake/fake_test.go similarity index 80% rename from tests/fake/fake_test.go rename to tests/framework/fake/fake_test.go index 1321430..48be5c4 100644 --- a/tests/fake/fake_test.go +++ b/tests/framework/fake/fake_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/diagridio/go-etcd-cron/api" - "github.com/diagridio/go-etcd-cron/tests/fake" + "github.com/diagridio/go-etcd-cron/tests/framework/fake" ) func Test_Fake(t *testing.T) { diff --git a/tests/suite/failurepolicy_test.go b/tests/suite/failurepolicy_test.go new file mode 100644 index 0000000..6735e39 --- /dev/null +++ b/tests/suite/failurepolicy_test.go @@ -0,0 +1,203 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" +) + +func Test_FailurePolicy(t *testing.T) { + t.Parallel() + + t.Run("default policy should retry 3 times with a 1sec delay", func(t *testing.T) { + t.Parallel() + + gotCh := make(chan *api.TriggerRequest, 1) + var got atomic.Uint32 + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + Client: etcd.EmbeddedBareClient(t), + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + assert.GreaterOrEqual(t, uint32(8), got.Add(1)) + return &api.TriggerResponse{Result: api.TriggerResponseResult_FAILED} + }, + GotCh: gotCh, + }) + + require.NoError(t, cron.API().Add(context.Background(), "test", &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + Schedule: ptr.Of("@every 1s"), + Repeats: ptr.Of(uint32(2)), + })) + + for range 8 { + resp, err := cron.API().Get(context.Background(), "test") + require.NoError(t, err) + assert.NotNil(t, resp) + select { + case <-gotCh: + case <-time.After(time.Second * 3): + assert.Fail(t, "timeout waiting for trigger") + } + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + resp, err := cron.API().Get(context.Background(), "test") + assert.NoError(c, err) + assert.Nil(c, resp) + }, time.Second*5, time.Millisecond*10) + }) + + t.Run("drop policy should not retry triggering", func(t *testing.T) { + t.Parallel() + + gotCh := make(chan *api.TriggerRequest, 1) + var got atomic.Uint32 + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + Client: etcd.EmbeddedBareClient(t), + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + assert.GreaterOrEqual(t, uint32(2), got.Add(1)) + return &api.TriggerResponse{Result: api.TriggerResponseResult_FAILED} + }, + GotCh: gotCh, + }) + + require.NoError(t, cron.API().Add(context.Background(), "test", &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + Schedule: ptr.Of("@every 1s"), + Repeats: ptr.Of(uint32(2)), + FailurePolicy: &api.FailurePolicy{ + Policy: new(api.FailurePolicy_Drop), + }, + })) + + for range 2 { + resp, err := cron.API().Get(context.Background(), "test") + require.NoError(t, err) + assert.NotNil(t, resp) + select { + case <-gotCh: + case <-time.After(time.Second * 3): + assert.Fail(t, "timeout waiting for trigger") + } + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + resp, err := cron.API().Get(context.Background(), "test") + assert.NoError(c, err) + assert.Nil(c, resp) + }, time.Second*5, time.Millisecond*10) + }) + + t.Run("constant policy should only retry when it fails ", func(t *testing.T) { + t.Parallel() + + gotCh := make(chan *api.TriggerRequest, 1) + var got atomic.Uint32 + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + Client: etcd.EmbeddedBareClient(t), + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + assert.GreaterOrEqual(t, uint32(5), got.Add(1)) + if got.Load() == 3 { + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} + } + return &api.TriggerResponse{Result: api.TriggerResponseResult_FAILED} + }, + GotCh: gotCh, + }) + + require.NoError(t, cron.API().Add(context.Background(), "test", &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + Schedule: ptr.Of("@every 1s"), + Repeats: ptr.Of(uint32(3)), + FailurePolicy: &api.FailurePolicy{ + Policy: &api.FailurePolicy_Constant{ + Constant: &api.FailurePolicyConstant{ + Interval: durationpb.New(time.Millisecond), MaxRetries: ptr.Of(uint32(1)), + }, + }, + }, + })) + + for range 5 { + resp, err := cron.API().Get(context.Background(), "test") + require.NoError(t, err) + assert.NotNil(t, resp) + select { + case <-gotCh: + case <-time.After(time.Second * 3): + assert.Fail(t, "timeout waiting for trigger") + } + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + resp, err := cron.API().Get(context.Background(), "test") + assert.NoError(c, err) + assert.Nil(c, resp) + }, time.Second*5, time.Millisecond*10) + }) + + t.Run("constant policy can retry forever until it succeeds", func(t *testing.T) { + t.Parallel() + + gotCh := make(chan *api.TriggerRequest, 1) + var got atomic.Uint32 + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + Client: etcd.EmbeddedBareClient(t), + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + assert.GreaterOrEqual(t, uint32(100), got.Add(1)) + if got.Load() == 100 { + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} + } + return &api.TriggerResponse{Result: api.TriggerResponseResult_FAILED} + }, + GotCh: gotCh, + }) + + require.NoError(t, cron.API().Add(context.Background(), "test", &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + FailurePolicy: &api.FailurePolicy{ + Policy: &api.FailurePolicy_Constant{ + Constant: &api.FailurePolicyConstant{ + Interval: durationpb.New(time.Millisecond), + }, + }, + }, + })) + + for range 100 { + resp, err := cron.API().Get(context.Background(), "test") + require.NoError(t, err) + assert.NotNil(t, resp) + select { + case <-gotCh: + case <-time.After(time.Second * 3): + assert.Fail(t, "timeout waiting for trigger") + } + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + resp, err := cron.API().Get(context.Background(), "test") + assert.NoError(c, err) + assert.Nil(c, resp) + }, time.Second*5, time.Millisecond*10) + }) +} diff --git a/tests/suite/jobwithspace_test.go b/tests/suite/jobwithspace_test.go new file mode 100644 index 0000000..c0cb7f0 --- /dev/null +++ b/tests/suite/jobwithspace_test.go @@ -0,0 +1,56 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "context" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_jobWithSpace(t *testing.T) { + t.Parallel() + + cron := integration.NewBase(t, 1) + + require.NoError(t, cron.API().Add(context.Background(), "hello world", &api.Job{ + DueTime: ptr.Of(time.Now().Add(2).Format(time.RFC3339)), + })) + resp, err := cron.API().Get(context.Background(), "hello world") + require.NoError(t, err) + assert.NotNil(t, resp) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 1, cron.Triggered()) + resp, err = cron.API().Get(context.Background(), "hello world") + assert.NoError(c, err) + assert.Nil(c, resp) + }, time.Second*10, time.Millisecond*10) + + require.NoError(t, cron.API().Add(context.Background(), "another hello world", &api.Job{ + Schedule: ptr.Of("@every 1s"), + })) + resp, err = cron.API().Get(context.Background(), "another hello world") + require.NoError(t, err) + assert.NotNil(t, resp) + listresp, err := cron.API().List(context.Background(), "") + require.NoError(t, err) + assert.Len(t, listresp.GetJobs(), 1) + require.NoError(t, cron.API().Delete(context.Background(), "another hello world")) + resp, err = cron.API().Get(context.Background(), "another hello world") + require.NoError(t, err) + assert.Nil(t, resp) + listresp, err = cron.API().List(context.Background(), "") + require.NoError(t, err) + assert.Empty(t, listresp.GetJobs()) +} diff --git a/tests/suite/oneshot_test.go b/tests/suite/oneshot_test.go new file mode 100644 index 0000000..c7d5a20 --- /dev/null +++ b/tests/suite/oneshot_test.go @@ -0,0 +1,39 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "context" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_oneshot(t *testing.T) { + t.Parallel() + + cron := integration.NewBase(t, 1) + + job := &api.Job{ + DueTime: ptr.Of(time.Now().Add(time.Second).Format(time.RFC3339)), + } + + require.NoError(t, cron.API().Add(cron.Context(), "def", job)) + + assert.Eventually(t, func() bool { + return cron.Triggered() == 1 + }, 5*time.Second, 1*time.Second) + + resp, err := cron.Client().Get(context.Background(), "abc/jobs/def") + require.NoError(t, err) + assert.Empty(t, resp.Kvs) +} diff --git a/tests/suite/parallel_test.go b/tests/suite/parallel_test.go new file mode 100644 index 0000000..482d3c6 --- /dev/null +++ b/tests/suite/parallel_test.go @@ -0,0 +1,64 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_parallel(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + name string + total uint32 + }{ + {"1 queue", 1}, + {"multi queue", 50}, + } { + total := test.total + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + releaseCh := make(chan struct{}) + var waiting atomic.Int32 + var done atomic.Int32 + cron := integration.New(t, integration.Options{ + PartitionTotal: total, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + waiting.Add(1) + <-releaseCh + done.Add(1) + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} + }, + }) + + for i := range 100 { + require.NoError(t, cron.API().Add(cron.Context(), strconv.Itoa(i), &api.Job{ + DueTime: ptr.Of("0s"), + })) + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, int32(100), waiting.Load()) + }, 5*time.Second, 10*time.Millisecond) + close(releaseCh) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, int32(100), done.Load()) + }, 5*time.Second, 10*time.Millisecond) + }) + } +} diff --git a/tests/suite/partition_test.go b/tests/suite/partition_test.go new file mode 100644 index 0000000..2e5368b --- /dev/null +++ b/tests/suite/partition_test.go @@ -0,0 +1,42 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_patition(t *testing.T) { + t.Parallel() + + cron := integration.NewBase(t, 100) + + for i := range 100 { + job := &api.Job{ + DueTime: ptr.Of(time.Now().Add(time.Second).Format(time.RFC3339)), + } + require.NoError(t, cron.AllCrons()[i].Add(cron.Context(), "test-"+strconv.Itoa(i), job)) + } + + assert.Eventually(t, func() bool { + return cron.Triggered() == 100 + }, 5*time.Second, 1*time.Second) + + resp, err := cron.Client().Get(context.Background(), "abc/jobs", clientv3.WithPrefix()) + require.NoError(t, err) + assert.Empty(t, resp.Kvs) +} diff --git a/tests/suite/payload_test.go b/tests/suite/payload_test.go new file mode 100644 index 0000000..feb1d1a --- /dev/null +++ b/tests/suite/payload_test.go @@ -0,0 +1,54 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_payload(t *testing.T) { + t.Parallel() + + gotCh := make(chan *api.TriggerRequest, 1) + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + GotCh: gotCh, + }) + + payload, err := anypb.New(wrapperspb.String("hello")) + require.NoError(t, err) + meta, err := anypb.New(wrapperspb.String("world")) + require.NoError(t, err) + job := &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + Payload: payload, + Metadata: meta, + } + require.NoError(t, cron.API().Add(cron.Context(), "yoyo", job)) + + select { + case got := <-gotCh: + assert.Equal(t, "yoyo", got.GetName()) + var gotPayload wrapperspb.StringValue + require.NoError(t, got.GetPayload().UnmarshalTo(&gotPayload)) + assert.Equal(t, "hello", gotPayload.GetValue()) + var gotMeta wrapperspb.StringValue + require.NoError(t, got.GetMetadata().UnmarshalTo(&gotMeta)) + assert.Equal(t, "world", gotMeta.GetValue()) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for trigger") + } +} diff --git a/tests/suite/remove_test.go b/tests/suite/remove_test.go new file mode 100644 index 0000000..65b27fd --- /dev/null +++ b/tests/suite/remove_test.go @@ -0,0 +1,34 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_remove(t *testing.T) { + t.Parallel() + + cron := integration.NewBase(t, 1) + + job := &api.Job{ + DueTime: ptr.Of(time.Now().Add(time.Second * 2).Format(time.RFC3339)), + } + require.NoError(t, cron.API().Add(cron.Context(), "def", job)) + require.NoError(t, cron.API().Delete(cron.Context(), "def")) + + <-time.After(3 * time.Second) + + assert.Equal(t, 0, cron.Triggered()) +} diff --git a/tests/suite/repeat_test.go b/tests/suite/repeat_test.go new file mode 100644 index 0000000..0e60230 --- /dev/null +++ b/tests/suite/repeat_test.go @@ -0,0 +1,40 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "context" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_repeat(t *testing.T) { + t.Parallel() + + cron := integration.NewBase(t, 1) + + job := &api.Job{ + Schedule: ptr.Of("@every 10ms"), + Repeats: ptr.Of(uint32(3)), + } + + require.NoError(t, cron.API().Add(cron.Context(), "def", job)) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 3, cron.Triggered()) + }, 5*time.Second, 1*time.Second) + + resp, err := cron.Client().Get(context.Background(), "abc/jobs/def") + require.NoError(t, err) + assert.Empty(t, resp.Kvs) +} diff --git a/tests/suite/retry_test.go b/tests/suite/retry_test.go new file mode 100644 index 0000000..542a0a2 --- /dev/null +++ b/tests/suite/retry_test.go @@ -0,0 +1,53 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "sync" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_retry(t *testing.T) { + t.Parallel() + + ok := api.TriggerResponseResult_FAILED + var lock sync.Mutex + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + lock.Lock() + defer lock.Unlock() + return &api.TriggerResponse{Result: ok} + }, + }) + + job := &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + } + require.NoError(t, cron.API().Add(cron.Context(), "yoyo", job)) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Greater(c, cron.Triggered(), 1) + }, 5*time.Second, 10*time.Millisecond) + lock.Lock() + triggered := cron.Triggered() + triggered++ + ok = api.TriggerResponseResult_SUCCESS + lock.Unlock() + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, triggered, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + <-time.After(3 * time.Second) + assert.Equal(t, triggered, cron.Triggered()) +} diff --git a/tests/suite/schedule_test.go b/tests/suite/schedule_test.go new file mode 100644 index 0000000..20497a0 --- /dev/null +++ b/tests/suite/schedule_test.go @@ -0,0 +1,191 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "context" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/internal/api/stored" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" +) + +func Test_schedule(t *testing.T) { + t.Parallel() + + t.Run("if no counter, job should not be deleted and no counter created", func(t *testing.T) { + t.Parallel() + + client := etcd.EmbeddedBareClient(t) + + now := time.Now().UTC() + jobBytes1, err := proto.Marshal(&stored.Job{ + Begin: &stored.Job_DueTime{DueTime: timestamppb.New(now.Add(time.Hour))}, + PartitionId: 123, + Job: &api.Job{DueTime: ptr.Of(now.Add(time.Hour).Format(time.RFC3339))}, + }) + require.NoError(t, err) + _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes1)) + require.NoError(t, err) + + jobBytes2, err := proto.Marshal(&stored.Job{ + Begin: &stored.Job_DueTime{DueTime: timestamppb.New(now)}, + PartitionId: 123, + Job: &api.Job{DueTime: ptr.Of(now.Format(time.RFC3339))}, + }) + require.NoError(t, err) + _, err = client.Put(context.Background(), "abc/jobs/2", string(jobBytes2)) + require.NoError(t, err) + + resp, err := client.Get(context.Background(), "abc/jobs", clientv3.WithPrefix()) + require.NoError(t, err) + assert.Len(t, resp.Kvs, 2) + + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + Client: client, + }) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 1, cron.Triggered()) + }, 5*time.Second, 10*time.Millisecond) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + resp, err = client.Get(context.Background(), "abc/jobs", clientv3.WithPrefix()) + require.NoError(t, err) + assert.Len(c, resp.Kvs, 1) + }, 5*time.Second, 10*time.Millisecond) + + cron.Close() + + resp, err = client.Get(context.Background(), "abc/jobs/1") + require.NoError(t, err) + require.Len(t, resp.Kvs, 1) + assert.Equal(t, string(jobBytes1), string(resp.Kvs[0].Value)) + + resp, err = client.Get(context.Background(), "abc/counters", clientv3.WithPrefix()) + require.NoError(t, err) + require.Empty(t, resp.Kvs) + + assert.Equal(t, 1, cron.Triggered()) + }) + + t.Run("if schedule is not done, job and counter should not be deleted", func(t *testing.T) { + t.Parallel() + + client := etcd.EmbeddedBareClient(t) + + future := time.Now().UTC().Add(time.Hour) + jobBytes, err := proto.Marshal(&stored.Job{ + Begin: &stored.Job_DueTime{ + DueTime: timestamppb.New(future), + }, + PartitionId: 123, + Job: &api.Job{ + DueTime: ptr.Of(future.Format(time.RFC3339)), + }, + }) + require.NoError(t, err) + counterBytes, err := proto.Marshal(&stored.Counter{ + LastTrigger: nil, + Count: 0, + JobPartitionId: 123, + }) + require.NoError(t, err) + + _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) + require.NoError(t, err) + _, err = client.Put(context.Background(), "abc/counters/1", string(counterBytes)) + require.NoError(t, err) + + now := time.Now().UTC() + jobBytes2, err := proto.Marshal(&stored.Job{ + Begin: &stored.Job_DueTime{DueTime: timestamppb.New(now)}, + Job: &api.Job{DueTime: ptr.Of(now.Format(time.RFC3339))}, + }) + require.NoError(t, err) + _, err = client.Put(context.Background(), "abc/jobs/2", string(jobBytes2)) + require.NoError(t, err) + + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + Client: client, + }) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 1, cron.Triggered()) + }, 5*time.Second, 10*time.Millisecond) + + resp, err := client.Get(context.Background(), "abc/jobs/1") + require.NoError(t, err) + require.Len(t, resp.Kvs, 1) + assert.Equal(t, string(jobBytes), string(resp.Kvs[0].Value)) + + resp, err = client.Get(context.Background(), "abc/counters/1") + require.NoError(t, err) + require.Len(t, resp.Kvs, 1) + assert.Equal(t, string(counterBytes), string(resp.Kvs[0].Value)) + + resp, err = client.Get(context.Background(), "abc/jobs", clientv3.WithPrefix()) + require.NoError(t, err) + assert.Len(t, resp.Kvs, 1) + }) + + t.Run("if schedule is done, expect job and counter to be deleted", func(t *testing.T) { + t.Parallel() + + client := etcd.EmbeddedBareClient(t) + + now := time.Now().UTC() + jobBytes, err := proto.Marshal(&stored.Job{ + Begin: &stored.Job_DueTime{ + DueTime: timestamppb.New(now), + }, + PartitionId: 123, + Job: &api.Job{ + DueTime: ptr.Of(now.Format(time.RFC3339)), + }, + }) + require.NoError(t, err) + counterBytes, err := proto.Marshal(&stored.Counter{ + LastTrigger: timestamppb.New(now), + Count: 1, + JobPartitionId: 123, + }) + require.NoError(t, err) + + _, err = client.Put(context.Background(), "abc/jobs/1", string(jobBytes)) + require.NoError(t, err) + _, err = client.Put(context.Background(), "abc/counters/1", string(counterBytes)) + require.NoError(t, err) + + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + Client: client, + }) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + resp, err := client.Get(context.Background(), "abc/jobs/1") + require.NoError(t, err) + assert.Empty(c, resp.Kvs) + resp, err = client.Get(context.Background(), "abc/counters/1") + require.NoError(t, err) + assert.Empty(c, resp.Kvs) + }, 5*time.Second, 10*time.Millisecond) + + assert.Equal(t, 0, cron.Triggered()) + }) +} diff --git a/tests/suite/undeliverable_test.go b/tests/suite/undeliverable_test.go new file mode 100644 index 0000000..70b19e1 --- /dev/null +++ b/tests/suite/undeliverable_test.go @@ -0,0 +1,637 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "context" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/internal/api/stored" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" + "github.com/diagridio/go-etcd-cron/tests/framework/etcd" +) + +func Test_undeliverable(t *testing.T) { + t.Parallel() + + t.Run("single: jobs which are marked as undeliverable, should be triggered when their prefix is registered", func(t *testing.T) { + t.Parallel() + + var got []string + var lock sync.Mutex + ret := api.TriggerResponseResult_UNDELIVERABLE + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(req *api.TriggerRequest) *api.TriggerResponse { + lock.Lock() + defer lock.Unlock() + got = append(got, req.GetName()) + return &api.TriggerResponse{Result: ret} + }, + }) + + job := &api.Job{ + Schedule: ptr.Of("@every 1h"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + } + names := make([]string, 100) + for i := range 100 { + names[i] = "abc" + strconv.Itoa(i) + require.NoError(t, cron.API().Add(cron.Context(), names[i], job)) + } + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 100, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + assert.ElementsMatch(t, names, got) + + lock.Lock() + ret = api.TriggerResponseResult_SUCCESS + lock.Unlock() + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 200, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + //nolint:makezero + assert.ElementsMatch(t, append(names, names...), got) + }) + + t.Run("multiple: jobs which are marked as undeliverable, should be triggered when their prefix is registered", func(t *testing.T) { + t.Parallel() + + var got []string + var lock sync.Mutex + ret := api.TriggerResponseResult_UNDELIVERABLE + cron := integration.New(t, integration.Options{ + PartitionTotal: 4, + TriggerFn: func(req *api.TriggerRequest) *api.TriggerResponse { + lock.Lock() + defer lock.Unlock() + got = append(got, req.GetName()) + return &api.TriggerResponse{Result: ret} + }, + }) + + job := &api.Job{ + Schedule: ptr.Of("@every 1h"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + } + names := make([]string, 100) + for i := range 100 { + names[i] = "abc" + strconv.Itoa(i) + require.NoError(t, cron.API().Add(cron.Context(), names[i], job)) + } + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 100, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + assert.ElementsMatch(t, names, got) + + lock.Lock() + ret = api.TriggerResponseResult_SUCCESS + lock.Unlock() + + for _, api := range cron.AllCrons() { + cancel, err := api.DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 200, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + //nolint:makezero + assert.ElementsMatch(t, append(names, names...), got) + }) + + t.Run("single: some jobs should be re-enqueued based on the prefix", func(t *testing.T) { + t.Parallel() + + var got []string + var lock sync.Mutex + ret := api.TriggerResponseResult_UNDELIVERABLE + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(req *api.TriggerRequest) *api.TriggerResponse { + lock.Lock() + defer lock.Unlock() + got = append(got, req.GetName()) + return &api.TriggerResponse{Result: ret} + }, + }) + + job := &api.Job{ + Schedule: ptr.Of("@every 1h"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + } + require.NoError(t, cron.API().Add(cron.Context(), "abc1", job)) + require.NoError(t, cron.API().Add(cron.Context(), "abc2", job)) + require.NoError(t, cron.API().Add(cron.Context(), "def3", job)) + require.NoError(t, cron.API().Add(cron.Context(), "def4", job)) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 4, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + assert.ElementsMatch(t, []string{"abc1", "abc2", "def3", "def4"}, got) + + lock.Lock() + ret = api.TriggerResponseResult_SUCCESS + lock.Unlock() + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 6, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + assert.ElementsMatch(t, []string{ + "abc1", "abc2", "def3", "def4", + "abc1", "abc2", + }, got) + }) + + t.Run("multiple: some jobs should be re-enqueued based on the prefix", func(t *testing.T) { + t.Parallel() + + var got []string + var lock sync.Mutex + ret := api.TriggerResponseResult_UNDELIVERABLE + cron := integration.New(t, integration.Options{ + PartitionTotal: 4, + TriggerFn: func(req *api.TriggerRequest) *api.TriggerResponse { + lock.Lock() + defer lock.Unlock() + got = append(got, req.GetName()) + return &api.TriggerResponse{Result: ret} + }, + }) + + job := &api.Job{ + Schedule: ptr.Of("@every 1h"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + } + require.NoError(t, cron.API().Add(cron.Context(), "abc1", job)) + require.NoError(t, cron.API().Add(cron.Context(), "abc2", job)) + require.NoError(t, cron.API().Add(cron.Context(), "def3", job)) + require.NoError(t, cron.API().Add(cron.Context(), "def4", job)) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 4, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + assert.ElementsMatch(t, []string{"abc1", "abc2", "def3", "def4"}, got) + + lock.Lock() + ret = api.TriggerResponseResult_SUCCESS + lock.Unlock() + + for _, api := range cron.AllCrons() { + cancel, err := api.DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 6, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + assert.ElementsMatch(t, []string{ + "abc1", "abc2", "def3", "def4", + "abc1", "abc2", + }, got) + }) + + t.Run("should redeliver immediately if prefix added during trigger", func(t *testing.T) { + t.Parallel() + + var inTrigger atomic.Uint32 + cntCh := make(chan struct{}) + var ret atomic.Value + ret.Store(api.TriggerResponseResult_UNDELIVERABLE) + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + inTrigger.Add(1) + <-cntCh + return &api.TriggerResponse{Result: ret.Load().(api.TriggerResponseResult)} + }, + }) + + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + })) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(1), inTrigger.Load()) + }, time.Second*10, time.Millisecond*10) + + resp, err := cron.API().Get(cron.Context(), "abc1") + require.NoError(t, err) + assert.NotNil(t, resp) + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + cntCh <- struct{}{} + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.GreaterOrEqual(c, inTrigger.Load(), uint32(2)) + }, time.Second*10, time.Millisecond*10) + + ret.Store(api.TriggerResponseResult_SUCCESS) + cntCh <- struct{}{} + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + resp, err := cron.API().Get(cron.Context(), "abc1") + require.NoError(t, err) + assert.Nil(c, resp) + }, time.Second*10, time.Millisecond*10) + }) + + t.Run("ignore prefix if return SUCCESS", func(t *testing.T) { + t.Parallel() + + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} + }, + }) + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + + require.NoError(t, cron.API().Add(cron.Context(), "def1", &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + Repeats: ptr.Of(uint32(2)), + Schedule: ptr.Of("@every 1s"), + })) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 2, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + + resp, err := cron.API().Get(cron.Context(), "def1") + require.NoError(t, err) + assert.Nil(t, resp) + }) + + t.Run("ignore prefix if return FAILURE", func(t *testing.T) { + t.Parallel() + + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + return &api.TriggerResponse{Result: api.TriggerResponseResult_FAILED} + }, + }) + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + + require.NoError(t, cron.API().Add(cron.Context(), "def1", &api.Job{ + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + Repeats: ptr.Of(uint32(2)), + Schedule: ptr.Of("@every 1s"), + FailurePolicy: &api.FailurePolicy{Policy: new(api.FailurePolicy_Drop)}, + })) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 2, cron.Triggered()) + resp, err := cron.API().Get(cron.Context(), "def1") + require.NoError(t, err) + assert.Nil(c, resp) + }, time.Second*10, time.Millisecond*10) + }) + + t.Run("single: load jobs from db which are undeliverable should be re-tried when deliverable", func(t *testing.T) { + t.Parallel() + + client := etcd.EmbeddedBareClient(t) + + jobBytes, err := proto.Marshal(&stored.Job{ + Begin: &stored.Job_DueTime{DueTime: timestamppb.New(time.Now())}, + PartitionId: 123, + Job: &api.Job{DueTime: ptr.Of(time.Now().Format(time.RFC3339))}, + }) + require.NoError(t, err) + _, err = client.Put(context.Background(), "abc/jobs/helloworld", string(jobBytes)) + require.NoError(t, err) + + var inTrigger atomic.Uint32 + cntCh := make(chan struct{}) + var ret atomic.Value + ret.Store(api.TriggerResponseResult_UNDELIVERABLE) + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + Client: client, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + inTrigger.Add(1) + <-cntCh + return &api.TriggerResponse{Result: ret.Load().(api.TriggerResponseResult)} + }, + }) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(1), inTrigger.Load()) + }, time.Second*10, time.Millisecond*10) + cntCh <- struct{}{} + <-time.After(time.Second) + assert.Equal(t, uint32(1), inTrigger.Load()) + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "hello") + require.NoError(t, err) + t.Cleanup(cancel) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(2), inTrigger.Load()) + }, time.Second*10, time.Millisecond*10) + ret.Store(api.TriggerResponseResult_SUCCESS) + cntCh <- struct{}{} + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 2, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + }) + + t.Run("multiple: load jobs from db which are undeliverable should be re-tried when deliverable", func(t *testing.T) { + t.Parallel() + + client := etcd.EmbeddedBareClient(t) + + jobBytes, err := proto.Marshal(&stored.Job{ + Begin: &stored.Job_DueTime{DueTime: timestamppb.New(time.Now())}, + PartitionId: 123, + Job: &api.Job{DueTime: ptr.Of(time.Now().Format(time.RFC3339))}, + }) + require.NoError(t, err) + _, err = client.Put(context.Background(), "abc/jobs/helloworld", string(jobBytes)) + require.NoError(t, err) + + var inTrigger atomic.Uint32 + cntCh := make(chan struct{}) + var ret atomic.Value + ret.Store(api.TriggerResponseResult_UNDELIVERABLE) + cron := integration.New(t, integration.Options{ + PartitionTotal: 5, + Client: client, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + inTrigger.Add(1) + <-cntCh + return &api.TriggerResponse{Result: ret.Load().(api.TriggerResponseResult)} + }, + }) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(1), inTrigger.Load()) + }, time.Second*10, time.Millisecond*10) + cntCh <- struct{}{} + <-time.After(time.Second) + assert.Equal(t, uint32(1), inTrigger.Load()) + + for _, api := range cron.AllCrons() { + cancel, err := api.DeliverablePrefixes(cron.Context(), "hello") + require.NoError(t, err) + t.Cleanup(cancel) + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(2), inTrigger.Load()) + }, time.Second*10, time.Millisecond*10) + ret.Store(api.TriggerResponseResult_SUCCESS) + + cntCh <- struct{}{} + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, 2, cron.Triggered()) + }, time.Second*10, time.Millisecond*10) + }) + + t.Run("keep delivering undeliverable until cancel called", func(t *testing.T) { + t.Parallel() + + var inTrigger atomic.Uint32 + cntCh := make(chan struct{}) + var ret atomic.Value + ret.Store(api.TriggerResponseResult_UNDELIVERABLE) + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + inTrigger.Add(1) + <-cntCh + return &api.TriggerResponse{Result: ret.Load().(api.TriggerResponseResult)} + }, + }) + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{ + Schedule: ptr.Of("@every 1h"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + })) + + for i := range uint32(10) { + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, i+1, inTrigger.Load()) + }, time.Second*10, time.Millisecond*10) + cntCh <- struct{}{} + } + trigger := inTrigger.Load() + cancel() + <-time.After(time.Second) + assert.Equal(t, trigger, inTrigger.Load()) + }) + + t.Run("Deleting a staged job should not be triggered once it has been marked as deliverable", func(t *testing.T) { + t.Parallel() + + var triggered []string + var lock sync.Mutex + var i int + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(req *api.TriggerRequest) *api.TriggerResponse { + lock.Lock() + defer lock.Unlock() + i++ + triggered = append(triggered, req.GetName()) + if len(triggered) <= 2 { + return &api.TriggerResponse{Result: api.TriggerResponseResult_UNDELIVERABLE} + } + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} + }, + }) + + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{ + Schedule: ptr.Of("@every 1s"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + })) + require.NoError(t, cron.API().Add(cron.Context(), "xyz1", &api.Job{ + Schedule: ptr.Of("@every 1s"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + })) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + lock.Lock() + defer lock.Unlock() + assert.Equal(c, []string{"abc1", "xyz1"}, triggered) + }, time.Second*10, time.Millisecond*10) + + require.NoError(t, cron.API().Delete(cron.Context(), "abc1")) + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + time.Sleep(time.Second * 2) + assert.Equal(t, []string{"abc1", "xyz1"}, triggered) + }) + + t.Run("Deleting prefixes staged jobs should not be triggered once it has been marked as deliverable", func(t *testing.T) { + t.Parallel() + + var triggered []string + var lock sync.Mutex + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(req *api.TriggerRequest) *api.TriggerResponse { + lock.Lock() + defer lock.Unlock() + triggered = append(triggered, req.GetName()) + if len(triggered) < 4 { + return &api.TriggerResponse{Result: api.TriggerResponseResult_UNDELIVERABLE} + } + return &api.TriggerResponse{Result: api.TriggerResponseResult_SUCCESS} + }, + }) + + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{ + Schedule: ptr.Of("@every 1s"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + })) + require.NoError(t, cron.API().Add(cron.Context(), "def1", &api.Job{ + Schedule: ptr.Of("@every 1s"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + })) + require.NoError(t, cron.API().Add(cron.Context(), "xyz1", &api.Job{ + Schedule: ptr.Of("@every 1s"), + DueTime: ptr.Of(time.Now().Format(time.RFC3339)), + })) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + lock.Lock() + defer lock.Unlock() + assert.Equal(c, []string{"abc1", "def1", "xyz1"}, triggered) + }, time.Second*10, time.Millisecond*10) + + require.NoError(t, cron.API().DeletePrefixes(cron.Context(), "abc", "def")) + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc", "def") + require.NoError(t, err) + t.Cleanup(cancel) + time.Sleep(time.Second * 2) + assert.Equal(t, []string{"abc1", "def1", "xyz1"}, triggered) + }) + + t.Run("Re-scheduling the job should not trigger the old staged job when prefix is added", func(t *testing.T) { + t.Parallel() + + var ret atomic.Value + var triggered atomic.Uint32 + ret.Store(api.TriggerResponseResult_UNDELIVERABLE) + + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + triggered.Add(1) + return &api.TriggerResponse{Result: ret.Load().(api.TriggerResponseResult)} + }, + }) + + dueTime := ptr.Of(time.Now().Format(time.RFC3339)) + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{DueTime: dueTime})) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(1), triggered.Load()) + }, time.Second*10, time.Millisecond*10) + + ret.Store(api.TriggerResponseResult_SUCCESS) + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{DueTime: dueTime})) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(2), triggered.Load()) + }, time.Second*10, time.Millisecond*10) + + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + time.Sleep(time.Second * 2) + assert.Equal(t, uint32(2), triggered.Load()) + }) + + t.Run("Re-scheduling the job after multiple puts should not trigger the old staged job when prefix is added", func(t *testing.T) { + t.Parallel() + + var ret atomic.Value + var triggered atomic.Uint32 + ret.Store(api.TriggerResponseResult_UNDELIVERABLE) + + cron := integration.New(t, integration.Options{ + PartitionTotal: 1, + TriggerFn: func(*api.TriggerRequest) *api.TriggerResponse { + triggered.Add(1) + return &api.TriggerResponse{Result: ret.Load().(api.TriggerResponseResult)} + }, + }) + + dueTime := ptr.Of(time.Now().Format(time.RFC3339)) + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{DueTime: dueTime})) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(1), triggered.Load()) + }, time.Second*10, time.Millisecond*10) + + ret.Store(api.TriggerResponseResult_SUCCESS) + cancel, err := cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(2), triggered.Load()) + }, time.Second*10, time.Millisecond*10) + resp, err := cron.API().Get(cron.Context(), "abc1") + require.NoError(t, err) + assert.Nil(t, resp) + cancel() + + ret.Store(api.TriggerResponseResult_UNDELIVERABLE) + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{DueTime: dueTime})) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(3), triggered.Load()) + }, time.Second*10, time.Millisecond*10) + + ret.Store(api.TriggerResponseResult_SUCCESS) + require.NoError(t, cron.API().Add(cron.Context(), "abc1", &api.Job{DueTime: dueTime})) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, uint32(4), triggered.Load()) + }, time.Second*10, time.Millisecond*10) + resp, err = cron.API().Get(cron.Context(), "abc1") + require.NoError(t, err) + assert.Nil(t, resp) + + cancel, err = cron.API().DeliverablePrefixes(cron.Context(), "abc") + require.NoError(t, err) + t.Cleanup(cancel) + time.Sleep(time.Second * 2) + assert.Equal(t, uint32(4), triggered.Load()) + }) +} diff --git a/tests/suite/upsert_test.go b/tests/suite/upsert_test.go new file mode 100644 index 0000000..390f9a6 --- /dev/null +++ b/tests/suite/upsert_test.go @@ -0,0 +1,42 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "context" + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_upsert(t *testing.T) { + t.Parallel() + + cron := integration.NewBase(t, 1) + + job := &api.Job{ + DueTime: ptr.Of(time.Now().Add(time.Hour).Format(time.RFC3339)), + } + require.NoError(t, cron.API().Add(cron.Context(), "def", job)) + job = &api.Job{ + DueTime: ptr.Of(time.Now().Add(time.Second).Format(time.RFC3339)), + } + require.NoError(t, cron.API().Add(cron.Context(), "def", job)) + + assert.Eventually(t, func() bool { + return cron.Triggered() == 1 + }, 5*time.Second, 1*time.Second) + + resp, err := cron.Client().Get(context.Background(), "abc/jobs/def") + require.NoError(t, err) + assert.Empty(t, resp.Kvs) +} diff --git a/tests/suite/zeroduetime_test.go b/tests/suite/zeroduetime_test.go new file mode 100644 index 0000000..6938d9d --- /dev/null +++ b/tests/suite/zeroduetime_test.go @@ -0,0 +1,46 @@ +/* +Copyright (c) 2024 Diagrid Inc. +Licensed under the MIT License. +*/ + +package suite + +import ( + "testing" + "time" + + "github.com/dapr/kit/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/diagridio/go-etcd-cron/api" + "github.com/diagridio/go-etcd-cron/tests/framework/cron/integration" +) + +func Test_zeroDueTime(t *testing.T) { + t.Parallel() + + cron := integration.NewBase(t, 1) + + require.NoError(t, cron.API().Add(cron.Context(), "yoyo", &api.Job{ + Schedule: ptr.Of("@every 1h"), + DueTime: ptr.Of("0s"), + })) + assert.Eventually(t, func() bool { + return cron.Triggered() == 1 + }, 3*time.Second, time.Millisecond*10) + + require.NoError(t, cron.API().Add(cron.Context(), "yoyo2", &api.Job{ + Schedule: ptr.Of("@every 1h"), + DueTime: ptr.Of("1s"), + })) + assert.Eventually(t, func() bool { + return cron.Triggered() == 2 + }, 3*time.Second, time.Millisecond*10) + + require.NoError(t, cron.API().Add(cron.Context(), "yoyo3", &api.Job{ + Schedule: ptr.Of("@every 1h"), + })) + <-time.After(2 * time.Second) + assert.Equal(t, 2, cron.Triggered()) +}