diff --git a/brokers/in-memory/broker.go b/brokers/in-memory/broker.go index 2467d07..312760f 100644 --- a/brokers/in-memory/broker.go +++ b/brokers/in-memory/broker.go @@ -27,13 +27,11 @@ func (r *Broker) Consume(ctx context.Context, work chan []byte, queue string) { ch, ok := r.queues[queue] r.mu.RUnlock() - // If the queue isn't found, make a queue. if !ok { ch = make(chan []byte, 100) r.mu.Lock() r.queues[queue] = ch r.mu.Unlock() - } for { diff --git a/chains.go b/chains.go index b1375d6..9748ad9 100644 --- a/chains.go +++ b/chains.go @@ -46,7 +46,7 @@ func NewChain(j []Job, opts ChainOpts) (Chain, error) { // Set the on success tasks as the i+1 task, // hence forming a "chain" of tasks. for i := 0; i < len(j)-1; i++ { - j[i].OnSuccess = &j[i+1] + j[i].OnSuccess = append(j[i].OnSuccess, &j[i+1]) } return Chain{Jobs: j, Opts: opts}, nil @@ -114,10 +114,10 @@ checkJobs: // to success. Otherwise update the current job and perform all the above checks. case StatusDone: c.PrevJobs = append(c.PrevJobs, currJob.ID) - if currJob.OnSuccessID == "" { + if len(currJob.OnSuccessIDs) == 0 { c.Status = StatusDone } else { - currJob, err = s.GetJob(ctx, currJob.OnSuccessID) + currJob, err = s.GetJob(ctx, currJob.OnSuccessIDs[0]) if err != nil { return ChainMessage{}, nil } diff --git a/jobs.go b/jobs.go index a4b78eb..063add1 100644 --- a/jobs.go +++ b/jobs.go @@ -21,10 +21,13 @@ const ( // It is the responsibility of the task handler to unmarshal (if required) the payload and process it in any manner. type Job struct { // If task is successful, the OnSuccess jobs are enqueued. - OnSuccess *Job + OnSuccess []*Job Task string Payload []byte + // If task fails, the OnError jobs are enqueued. + OnError []*Job + Opts JobOpts } @@ -42,15 +45,15 @@ type JobOpts struct { // Meta contains fields related to a job. These are updated when a task is consumed. type Meta struct { - ID string - OnSuccessID string - Status string - Queue string - Schedule string - MaxRetry uint32 - Retried uint32 - PrevErr string - ProcessedAt time.Time + ID string + OnSuccessIDs []string + Status string + Queue string + Schedule string + MaxRetry uint32 + Retried uint32 + PrevErr string + ProcessedAt time.Time // PrevJobResults contains any job result set by the previous job in a chain. // This will be nil if the previous job doesn't set the results on JobCtx. @@ -151,7 +154,7 @@ func (s *Server) enqueueWithMeta(ctx context.Context, t Job, meta Meta) (string, } // Set current jobs OnSuccess as next job - t.OnSuccess = &j + t.OnSuccess = append(t.OnSuccess, &j) // Set the next job's eta according to schedule j.Opts.ETA = sch.Next(t.Opts.ETA) } diff --git a/jobs_test.go b/jobs_test.go index 04410cd..911fcf9 100644 --- a/jobs_test.go +++ b/jobs_test.go @@ -233,6 +233,44 @@ func TestDeleteJob(t *testing.T) { } +func TestJobsOnError(t *testing.T) { + var ( + srv = newServer(t, taskName, MockHandler) + ) + + hasErrored := make(chan bool, 1) + + if err := srv.RegisterTask("error", func(b []byte, jc JobCtx) error { + t.Log("error task called") + hasErrored <- true + return nil + }, TaskOpts{ + Queue: "error_task", + Concurrency: 1, + }); err != nil { + t.Fatal(err) + } + + j := makeJob(t, taskName, true) + + errJob, _ := NewJob("error", []byte{}, JobOpts{ + Queue: "error_task", + }) + + j.OnError = append(j.OnError, &errJob) + + if _, err := srv.Enqueue(context.Background(), j); err != nil { + t.Fatalf("error enqueuing job: %v", err) + } + + go srv.Start(context.Background()) + + b := <-hasErrored + if !b { + t.Fatalf("error job didn't enqueue") + } +} + func makeJob(t *testing.T, taskName string, doErr bool) Job { j, err := json.Marshal(MockPayload{ShouldErr: doErr}) if err != nil { diff --git a/server.go b/server.go index 7b27afd..9933256 100644 --- a/server.go +++ b/server.go @@ -278,6 +278,7 @@ func (s *Server) process(ctx context.Context, w chan []byte) { s.log.Error("error unmarshalling task", "error", err) break } + // Fetch the registered task handler. task, err := s.getHandler(msg.Job.Task) if err != nil { @@ -365,6 +366,20 @@ func (s *Server) execJob(ctx context.Context, msg JobMessage, task Task) error { if task.opts.FailedCB != nil { task.opts.FailedCB(taskCtx, err) } + + // If there are jobs to enqueued after failure, enqueue them. + if msg.Job.OnError != nil { + // Extract OnErrorJob into a variable to get opts. + for _, j := range msg.Job.OnError { + nj := *j + meta := DefaultMeta(nj.Opts) + + if _, err = s.enqueueWithMeta(ctx, nj, meta); err != nil { + return fmt.Errorf("error enqueuing jobs after failure: %w", err) + } + } + } + // If we hit max retries, set the task status as failed. return s.statusFailed(ctx, msg) } @@ -376,19 +391,22 @@ func (s *Server) execJob(ctx context.Context, msg JobMessage, task Task) error { // If the task contains OnSuccess task (part of a chain), enqueue them. if msg.Job.OnSuccess != nil { - // Extract OnSuccessJob into a variable to get opts. - j := msg.Job.OnSuccess - nj := *j - meta := DefaultMeta(nj.Opts) - meta.PrevJobResult, err = s.GetResult(ctx, msg.ID) - if err != nil { - return fmt.Errorf("could not get result for id (%s) : %w", msg.ID, err) - } + for _, j := range msg.Job.OnSuccess { + // Extract OnSuccessJob into a variable to get opts. + nj := *j + meta := DefaultMeta(nj.Opts) + meta.PrevJobResult, err = s.GetResult(ctx, msg.ID) + if err != nil { + return err + } + + // Set the ID of the next job in the chain + onSuccessID, err := s.enqueueWithMeta(ctx, nj, meta) + if err != nil { + return err + } - // Set the ID of the next job in the chain - msg.OnSuccessID, err = s.enqueueWithMeta(ctx, nj, meta) - if err != nil { - return fmt.Errorf("could not enqueue job id (%s) : %w", msg.ID, err) + msg.OnSuccessIDs = append(msg.OnSuccessIDs, onSuccessID) } } diff --git a/server_test.go b/server_test.go index aeb9f3b..dfa9f86 100644 --- a/server_test.go +++ b/server_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "log/slog" + "os" "testing" "time" @@ -16,11 +17,13 @@ const ( ) func newServer(t *testing.T, taskName string, handler func([]byte, JobCtx) error) *Server { - lo := slog.Default().Handler() + lo := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelError, + })) srv, err := NewServer(ServerOpts{ Broker: rb.New(), Results: rr.New(), - Logger: lo, + Logger: lo.Handler(), }) if err != nil { t.Fatal(err)