From b920c7003531dfaddaf0b3cc45d0452d90d27c77 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Thu, 21 Nov 2024 16:19:09 -0700 Subject: [PATCH] bug: don't create multiple runs when rerunning workflows --- apiclient/client.go | 10 +- apiclient/types/tasks.go | 26 ++ pkg/api/handlers/assistants.go | 12 +- pkg/api/handlers/tasks.go | 293 ++++++++++++++++++ pkg/api/handlers/threads.go | 30 +- pkg/api/request.go | 12 +- pkg/api/router/router.go | 13 + .../workflowexecution/workflowexecution.go | 2 +- .../handlers/workflowstep/invoke.go | 18 +- pkg/events/events.go | 81 +++-- pkg/invoke/invoker.go | 6 +- pkg/invoke/workflow.go | 15 +- pkg/storage/apis/otto.otto8.ai/v1/workflow.go | 33 +- .../apis/otto.otto8.ai/v1/workflowstep.go | 23 ++ 14 files changed, 508 insertions(+), 66 deletions(-) create mode 100644 apiclient/types/tasks.go create mode 100644 pkg/api/handlers/tasks.go diff --git a/apiclient/client.go b/apiclient/client.go index eedcb6de..9f6c1132 100644 --- a/apiclient/client.go +++ b/apiclient/client.go @@ -152,10 +152,11 @@ func toStream[T any](resp *http.Response) chan T { go func() { defer resp.Body.Close() defer close(ch) + var eventName string lines := bufio.NewScanner(resp.Body) for lines.Scan() { var obj T - if data, ok := strings.CutPrefix(lines.Text(), "data: "); ok { + if data, ok := strings.CutPrefix(lines.Text(), "data: "); ok && eventName == "" || eventName == "message" { if log.IsDebug() { log.Fields("data", data).Debugf("Received data") } @@ -169,6 +170,13 @@ func toStream[T any](resp *http.Response) chan T { ch <- obj } } + } else if event, ok := strings.CutPrefix(lines.Text(), "event: "); ok { + if log.IsDebug() { + log.Fields("event", event).Debugf("Received event") + } + eventName = event + } else if strings.TrimSpace(lines.Text()) == "" { + eventName = "" } } }() diff --git a/apiclient/types/tasks.go b/apiclient/types/tasks.go new file mode 100644 index 00000000..ed4251bf --- /dev/null +++ b/apiclient/types/tasks.go @@ -0,0 +1,26 @@ +package types + +type Task struct { + Metadata + TaskManifest +} + +type TaskList List[Task] + +type TaskManifest struct { + Name string `json:"name"` + Description string `json:"description"` + Steps []TaskStep `json:"steps"` +} + +type TaskStep struct { + ID string `json:"id,omitempty"` + If *TaskIf `json:"if,omitempty"` + Step string `json:"step,omitempty"` +} + +type TaskIf struct { + Condition string `json:"condition,omitempty"` + Steps []TaskStep `json:"steps,omitempty"` + Else []TaskStep `json:"else,omitempty"` +} diff --git a/pkg/api/handlers/assistants.go b/pkg/api/handlers/assistants.go index 5459193e..4170ec74 100644 --- a/pkg/api/handlers/assistants.go +++ b/pkg/api/handlers/assistants.go @@ -34,9 +34,9 @@ func NewAssistantHandler(invoker *invoke.Invoker, events *events.Emitter, gptScr } } -func getAgent(req api.Context, id string) (*v1.Agent, error) { +func getAssistant(req api.Context, id string) (*v1.Agent, error) { var agent v1.Agent - if err := alias.Get(req.Context(), req.Storage, &agent, req.Namespace(), id); err != nil { + if err := alias.Get(req.Context(), req.Storage, &agent, "", id); err != nil { return nil, err } return &agent, nil @@ -47,7 +47,7 @@ func (a *AssistantHandler) Invoke(req api.Context) error { id = req.PathValue("id") ) - agent, err := getAgent(req, id) + agent, err := getAssistant(req, id) if err != nil { return err } @@ -132,7 +132,7 @@ func getUserThread(req api.Context, agentID string) (*v1.Thread, error) { return &thread, nil } - agent, err := getAgent(req, agentID) + agent, err := getAssistant(req, agentID) if err != nil { return nil, err } @@ -373,7 +373,7 @@ func (a *AssistantHandler) AddTool(req api.Context) error { tool = req.PathValue("tool") ) - agent, err := getAgent(req, id) + agent, err := getAssistant(req, id) if err != nil { return err } @@ -430,7 +430,7 @@ func (a *AssistantHandler) Tools(req api.Context) error { id = req.PathValue("id") ) - agent, err := getAgent(req, id) + agent, err := getAssistant(req, id) if err != nil { return err } diff --git a/pkg/api/handlers/tasks.go b/pkg/api/handlers/tasks.go new file mode 100644 index 00000000..228aad69 --- /dev/null +++ b/pkg/api/handlers/tasks.go @@ -0,0 +1,293 @@ +package handlers + +import ( + "net/http" + "slices" + + "github.com/otto8-ai/otto8/apiclient/types" + "github.com/otto8-ai/otto8/pkg/api" + "github.com/otto8-ai/otto8/pkg/events" + "github.com/otto8-ai/otto8/pkg/invoke" + v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1" + "github.com/otto8-ai/otto8/pkg/system" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + kclient "sigs.k8s.io/controller-runtime/pkg/client" +) + +type TaskHandler struct { + invoker *invoke.Invoker + events *events.Emitter +} + +func NewTaskHandler(invoker *invoke.Invoker, events *events.Emitter) *TaskHandler { + return &TaskHandler{ + invoker: invoker, + events: events, + } +} + +func (t *TaskHandler) Events(req api.Context) error { + var ( + follow = req.URL.Query().Get("follow") == "true" + ) + + workflow, err := t.getTask(req) + if err != nil { + return err + } + + var thread v1.Thread + if err := req.Get(&thread, req.PathValue("thread_id")); kclient.IgnoreNotFound(err) != nil { + return err + } + + if thread.Spec.WorkflowName != workflow.Name { + return types.NewErrHttp(http.StatusForbidden, "thread does not belong to the task") + } + + _, events, err := t.events.Watch(req.Context(), req.Namespace(), events.WatchOptions{ + History: true, + MaxRuns: 100, + ThreadName: thread.Name, + Follow: true, + FollowWorkflowExecutions: follow, + }) + if err != nil { + return err + } + + return req.WriteEvents(events) +} + +func (t *TaskHandler) Run(req api.Context) error { + var ( + threadID = req.Request.URL.Query().Get("thread") + stepID = req.Request.URL.Query().Get("step") + ) + + workflow, err := t.getTask(req) + if err != nil { + return err + } + + resp, err := t.invoker.Workflow(req.Context(), req.Storage, workflow, "", invoke.WorkflowOptions{ + ThreadName: threadID, + StepID: stepID, + }) + if err != nil { + return err + } + + return req.WriteCreated(map[string]any{ + "threadID": resp.Thread.Name, + }) +} + +func (t *TaskHandler) Delete(req api.Context) error { + workflow, err := t.getTask(req) + if err != nil { + if apierrors.IsNotFound(err) { + return nil + } + return err + } + + return req.Delete(workflow) +} + +func (t *TaskHandler) Update(req api.Context) error { + workflow, err := t.getTask(req) + if err != nil { + return err + } + + _, manifest, err := t.getAssistantAndManifestFromRequest(req) + if err != nil { + return err + } + + workflow.Spec.Manifest = manifest + if err := req.Update(workflow); err != nil { + return err + } + + return req.Write(convertTask(*workflow)) +} + +func (t *TaskHandler) getAssistantAndManifestFromRequest(req api.Context) (*v1.Agent, types.WorkflowManifest, error) { + assistantID := req.PathValue("assistant_id") + + assistant, err := getAssistant(req, assistantID) + if err != nil { + return nil, types.WorkflowManifest{}, err + } + + thread, err := getUserThread(req, assistantID) + if err != nil { + return nil, types.WorkflowManifest{}, err + } + + var manifest types.TaskManifest + if err := req.Read(&manifest); err != nil { + return nil, types.WorkflowManifest{}, err + } + + if manifest.Name == "" { + manifest.Name = "New Task" + } + + return assistant, toWorkflowManifest(assistant, thread, manifest), nil +} + +func (t *TaskHandler) Create(req api.Context) error { + assistant, workflowManifest, err := t.getAssistantAndManifestFromRequest(req) + if err != nil { + return err + } + + workflow := v1.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: system.WorkflowPrefix, + Namespace: req.Namespace(), + }, + Spec: v1.WorkflowSpec{ + AgentName: assistant.Name, + UserID: req.User.GetUID(), + Manifest: workflowManifest, + }, + } + + if err := req.Create(&workflow); err != nil { + return err + } + + return req.WriteCreated(convertTask(workflow)) +} + +func toWorkflowManifest(agent *v1.Agent, thread *v1.Thread, manifest types.TaskManifest) types.WorkflowManifest { + workflowManifest := types.WorkflowManifest{ + AgentManifest: agent.Spec.Manifest, + } + + for _, tool := range thread.Spec.Manifest.Tools { + if !slices.Contains(workflowManifest.Tools, tool) { + workflowManifest.Tools = append(workflowManifest.Tools, tool) + } + } + + workflowManifest.Steps = toWorkflowSteps(manifest.Steps) + workflowManifest.Name = manifest.Name + workflowManifest.Description = manifest.Description + return workflowManifest +} + +func toWorkflowSteps(steps []types.TaskStep) []types.Step { + workflowSteps := make([]types.Step, 0, len(steps)) + for _, step := range steps { + workflowSteps = append(workflowSteps, types.Step{ + ID: step.ID, + Step: step.Step, + If: toWorkflowIf(step.If), + }) + } + return workflowSteps +} + +func toWorkflowIf(ifStep *types.TaskIf) *types.If { + if ifStep == nil { + return nil + } + return &types.If{ + Condition: ifStep.Condition, + Steps: toWorkflowSteps(ifStep.Steps), + Else: toWorkflowSteps(ifStep.Else), + } +} + +func (t *TaskHandler) Get(req api.Context) error { + task, err := t.getTask(req) + if err != nil { + return err + } + + return req.Write(convertTask(*task)) +} + +func (t *TaskHandler) getTask(req api.Context) (*v1.Workflow, error) { + assistantID := req.PathValue("assistant_id") + + var workflow v1.Workflow + if err := req.Get(&workflow, req.PathValue("id")); err != nil { + return nil, err + } + + assistant, err := getAssistant(req, assistantID) + if err != nil { + return nil, err + } + + if workflow.Spec.AgentName != assistant.Name || workflow.Spec.UserID != req.User.GetUID() { + return nil, types.NewErrHttp(http.StatusForbidden, "task does not belong to the user") + } + + return &workflow, nil +} + +func (t *TaskHandler) List(req api.Context) error { + assistant, err := getAssistant(req, req.PathValue("assistant_id")) + if err != nil { + return err + } + + var workflows v1.WorkflowList + if err := req.List(&workflows, kclient.MatchingFields{ + "spec.agentName": assistant.Name, + "spec.userID": req.User.GetUID(), + }); err != nil { + return err + } + + taskList := types.TaskList{Items: make([]types.Task, 0, len(workflows.Items))} + + for _, workflow := range workflows.Items { + taskList.Items = append(taskList.Items, convertTask(workflow)) + } + + return req.Write(taskList) +} + +func convertTask(workflow v1.Workflow) types.Task { + task := types.Task{ + Metadata: MetadataFrom(&workflow), + TaskManifest: types.TaskManifest{ + Name: workflow.Spec.Manifest.Name, + Description: workflow.Spec.Manifest.Description, + }, + } + task.Steps = toTaskSteps(workflow.Spec.Manifest.Steps) + return task +} + +func toTaskSteps(steps []types.Step) []types.TaskStep { + taskSteps := make([]types.TaskStep, 0, len(steps)) + for _, step := range steps { + taskSteps = append(taskSteps, types.TaskStep{ + ID: step.ID, + Step: step.Step, + If: toIf(step.If), + }) + } + return taskSteps +} + +func toIf(ifStep *types.If) *types.TaskIf { + if ifStep == nil { + return nil + } + return &types.TaskIf{ + Condition: ifStep.Condition, + Steps: toTaskSteps(ifStep.Steps), + Else: toTaskSteps(ifStep.Else), + } +} diff --git a/pkg/api/handlers/threads.go b/pkg/api/handlers/threads.go index bdaea33e..f63b9662 100644 --- a/pkg/api/handlers/threads.go +++ b/pkg/api/handlers/threads.go @@ -59,13 +59,14 @@ func convertThread(thread v1.Thread) types.Thread { func (a *ThreadHandler) Events(req api.Context) error { var ( - id = req.PathValue("id") - follow = req.URL.Query().Get("follow") == "true" - runID = req.URL.Query().Get("runID") - maxRunString = req.URL.Query().Get("maxRuns") - maxRuns int - err error - waitForThread = req.URL.Query().Get("waitForThread") == "true" + id = req.PathValue("id") + follow = req.URL.Query().Get("follow") == "true" + followWorkflows = req.URL.Query().Get("followWorflows") == "true" + runID = req.URL.Query().Get("runID") + maxRunString = req.URL.Query().Get("maxRuns") + maxRuns int + err error + waitForThread = req.URL.Query().Get("waitForThread") == "true" ) if runID == "" { @@ -82,13 +83,14 @@ func (a *ThreadHandler) Events(req api.Context) error { } _, events, err := a.events.Watch(req.Context(), req.Namespace(), events.WatchOptions{ - Follow: follow, - History: runID == "", - LastRunName: strings.TrimSuffix(runID, ":after"), - MaxRuns: maxRuns, - After: strings.HasSuffix(runID, ":after"), - ThreadName: id, - WaitForThread: waitForThread, + Follow: follow, + FollowWorkflowExecutions: followWorkflows, + History: runID == "", + LastRunName: strings.TrimSuffix(runID, ":after"), + MaxRuns: maxRuns, + After: strings.HasSuffix(runID, ":after"), + ThreadName: id, + WaitForThread: waitForThread, }) if err != nil { return err diff --git a/pkg/api/request.go b/pkg/api/request.go index 953969d2..dd142381 100644 --- a/pkg/api/request.go +++ b/pkg/api/request.go @@ -136,7 +136,7 @@ func (r *Context) WriteDataEvent(obj any) error { } } if _, ok := obj.(EventClose); ok { - _, err := r.ResponseWriter.Write([]byte("event: close\ndata: \n\n")) + _, err := r.ResponseWriter.Write([]byte("event: close\ndata: {}\n\n")) return err } data, err := json.Marshal(obj) @@ -200,11 +200,13 @@ func Watch[T client.Object](r Context, list client.ObjectList) (<-chan T, error) return resp, nil } -func (r *Context) List(obj client.ObjectList) error { +func (r *Context) List(obj client.ObjectList, opts ...client.ListOption) error { namespace := r.Namespace() - return r.Storage.List(r.Request.Context(), obj, &client.ListOptions{ - Namespace: namespace, - }) + return r.Storage.List(r.Request.Context(), obj, slices.Concat([]client.ListOption{ + &client.ListOptions{ + Namespace: namespace, + }, + }, opts)...) } func (r *Context) Delete(obj client.Object) error { diff --git a/pkg/api/router/router.go b/pkg/api/router/router.go index 7c36469f..c5812f2e 100644 --- a/pkg/api/router/router.go +++ b/pkg/api/router/router.go @@ -13,6 +13,7 @@ func Router(services *services.Services) (http.Handler, error) { agents := handlers.NewAgentHandler(services.GPTClient, services.ServerURL) assistants := handlers.NewAssistantHandler(services.Invoker, services.Events, services.GPTClient) + tasks := handlers.NewTaskHandler(services.Invoker, services.Events) workflows := handlers.NewWorkflowHandler(services.GPTClient, services.ServerURL, services.Invoker) invoker := handlers.NewInvokeHandler(services.Invoker) threads := handlers.NewThreadHandler(services.GPTClient, services.Events) @@ -44,17 +45,29 @@ func Router(services *services.Services) (http.Handler, error) { mux.HandleFunc("DELETE /api/assistants/{id}/credentials/{cred_id}", assistants.DeleteCredential) mux.HandleFunc("GET /api/assistants/{id}/events", assistants.Events) mux.HandleFunc("POST /api/assistants/{id}/invoke", assistants.Invoke) + // Assistant tools mux.HandleFunc("GET /api/assistants/{id}/tools", assistants.Tools) mux.HandleFunc("DELETE /api/assistants/{id}/tools/{tool}", assistants.RemoveTool) mux.HandleFunc("PUT /api/assistants/{id}/tools/{tool}", assistants.AddTool) + // Assistant files mux.HandleFunc("GET /api/assistants/{id}/files", assistants.Files) mux.HandleFunc("GET /api/assistants/{id}/file/{file...}", assistants.GetFile) mux.HandleFunc("POST /api/assistants/{id}/files/{file...}", assistants.UploadFile) mux.HandleFunc("DELETE /api/assistants/{id}/files/{file...}", assistants.DeleteFile) + // Assistant knowledge files mux.HandleFunc("GET /api/assistants/{id}/knowledge", assistants.Knowledge) mux.HandleFunc("POST /api/assistants/{id}/knowledge/{file}", assistants.UploadKnowledge) mux.HandleFunc("DELETE /api/assistants/{id}/knowledge/{file...}", assistants.DeleteKnowledge) + // Tasks + mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks", tasks.List) + mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}", tasks.Get) + mux.HandleFunc("POST /api/assistants/{assistant_id}/tasks", tasks.Create) + mux.HandleFunc("PUT /api/assistants/{assistant_id}/tasks/{id}", tasks.Update) + mux.HandleFunc("DELETE /api/assistants/{assistant_id}/tasks/{id}", tasks.Delete) + mux.HandleFunc("POST /api/assistants/{assistant_id}/tasks/{id}/run", tasks.Run) + mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}/threads/{thread_id}/events", tasks.Events) + // Agent files mux.HandleFunc("GET /api/agents/{id}/files", agents.ListFiles) mux.HandleFunc("POST /api/agents/{id}/files/{file}", agents.UploadFile) diff --git a/pkg/controller/handlers/workflowexecution/workflowexecution.go b/pkg/controller/handlers/workflowexecution/workflowexecution.go index e3d0a291..2e85ef5b 100644 --- a/pkg/controller/handlers/workflowexecution/workflowexecution.go +++ b/pkg/controller/handlers/workflowexecution/workflowexecution.go @@ -90,7 +90,7 @@ func (h *Handler) Run(req router.Request, _ router.Response) error { if newState.IsBlocked() { we.Status.State = newState we.Status.Error = output - return nil + return apply.New(req.Client).Apply(req.Ctx, req.Object, steps...) } if newState == types.WorkflowStateComplete { diff --git a/pkg/controller/handlers/workflowstep/invoke.go b/pkg/controller/handlers/workflowstep/invoke.go index 5cfb6f69..1c1a7555 100644 --- a/pkg/controller/handlers/workflowstep/invoke.go +++ b/pkg/controller/handlers/workflowstep/invoke.go @@ -3,9 +3,11 @@ package workflowstep import ( "github.com/gptscript-ai/go-gptscript" "github.com/otto8-ai/nah/pkg/router" + "github.com/otto8-ai/nah/pkg/uncached" "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/pkg/invoke" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1" + "k8s.io/client-go/util/retry" ) func (h *Handler) RunInvoke(req router.Request, resp router.Response) error { @@ -38,11 +40,17 @@ func (h *Handler) RunInvoke(req router.Request, resp router.Response) error { } defer invokeResp.Close() - step.Status.ThreadName = invokeResp.Thread.Name - step.Status.RunNames = []string{invokeResp.Run.Name} - - // Ignored error updating - _ = client.Status().Update(ctx, step) + err = retry.RetryOnConflict(retry.DefaultBackoff, func() error { + if err := client.Get(ctx, router.Key(step.Namespace, step.Name), uncached.Get(step)); err != nil { + return err + } + step.Status.ThreadName = invokeResp.Thread.Name + step.Status.RunNames = []string{invokeResp.Run.Name} + return client.Status().Update(ctx, step) + }) + if err != nil { + return err + } run = *invokeResp.Run } else { diff --git a/pkg/events/events.go b/pkg/events/events.go index 66852670..c6655504 100644 --- a/pkg/events/events.go +++ b/pkg/events/events.go @@ -15,14 +15,18 @@ import ( "github.com/otto8-ai/nah/pkg/router" "github.com/otto8-ai/nah/pkg/typed" "github.com/otto8-ai/otto8/apiclient/types" + "github.com/otto8-ai/otto8/logger" "github.com/otto8-ai/otto8/pkg/gz" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1" "github.com/otto8-ai/otto8/pkg/system" + "github.com/otto8-ai/otto8/pkg/wait" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" kclient "sigs.k8s.io/controller-runtime/pkg/client" ) +var log = logger.Package() + type Emitter struct { client kclient.WithWatch liveStates map[string][]liveState @@ -47,15 +51,16 @@ type liveState struct { } type WatchOptions struct { - History bool - LastRunName string - MaxRuns int - After bool - ThreadName string - ThreadResourceVersion string - Follow bool - Run *v1.Run - WaitForThread bool + History bool + LastRunName string + MaxRuns int + After bool + ThreadName string + ThreadResourceVersion string + Follow bool + FollowWorkflowExecutions bool + Run *v1.Run + WaitForThread bool } type callFramePrintState struct { @@ -440,7 +445,7 @@ func (e *Emitter) streamEvents(ctx context.Context, run v1.Run, opts WatchOption } } - nextRun, err := e.findNextRun(ctx, run, opts.Follow) + nextRun, err := e.findNextRun(ctx, run, opts) if err != nil { return err } @@ -477,7 +482,38 @@ func (e *Emitter) getThreadID(ctx context.Context, namespace, runName, workflowN return "", fmt.Errorf("no thread found for run %s and workflow %s", runName, workflowName) } -func (e *Emitter) isWorkflowDone(ctx context.Context, run v1.Run) (chan struct{}, func(), error) { +func (e *Emitter) getNextWorkflowRun(ctx context.Context, run v1.Run) (*v1.Run, error) { + var runName string + _, err := wait.For(ctx, e.client, &v1.Thread{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: run.Namespace, + Name: run.Spec.ThreadName, + }, + }, func(thread *v1.Thread) bool { + if thread.Status.CurrentRunName != "" && thread.Status.CurrentRunName != run.Name { + runName = thread.Status.CurrentRunName + return true + } + if thread.Status.LastRunName != "" && thread.Status.LastRunName != run.Name { + runName = thread.Status.LastRunName + return true + } + return false + }, wait.Option{ + Timeout: 15 * time.Minute, + }) + if err != nil { + return nil, err + } + + var nextRun v1.Run + if err := e.client.Get(ctx, router.Key(run.Namespace, runName), &nextRun); err != nil { + return nil, err + } + return &nextRun, nil +} + +func (e *Emitter) isWorkflowDone(ctx context.Context, run v1.Run, opts WatchOptions) (<-chan *v1.Run, func(), error) { if run.Spec.WorkflowExecutionName == "" { return nil, func() {}, nil } @@ -488,7 +524,7 @@ func (e *Emitter) isWorkflowDone(ctx context.Context, run v1.Run) (chan struct{} return nil, nil, err } - result := make(chan struct{}) + result := make(chan *v1.Run, 1) cancel := func() { w.Stop() go func() { @@ -498,11 +534,19 @@ func (e *Emitter) isWorkflowDone(ctx context.Context, run v1.Run) (chan struct{} } go func() { + defer close(result) defer cancel() for event := range w.ResultChan() { if wfe, ok := event.Object.(*v1.WorkflowExecution); ok { if wfe.Status.State.IsTerminal() || wfe.Status.State.IsBlocked() { - close(result) + if opts.FollowWorkflowExecutions { + next, err := e.getNextWorkflowRun(ctx, run) + if err != nil { + log.Errorf("failed to get next workflow run for last run %q: %v", run.Name, err) + } else { + result <- next + } + } return } } @@ -512,7 +556,7 @@ func (e *Emitter) isWorkflowDone(ctx context.Context, run v1.Run) (chan struct{} return result, cancel, nil } -func (e *Emitter) findNextRun(ctx context.Context, run v1.Run, follow bool) (*v1.Run, error) { +func (e *Emitter) findNextRun(ctx context.Context, run v1.Run, opts WatchOptions) (*v1.Run, error) { var ( runs v1.RunList criteria = []kclient.ListOption{ @@ -521,8 +565,7 @@ func (e *Emitter) findNextRun(ctx context.Context, run v1.Run, follow bool) (*v1 } ) - if !follow { - // If this isn't a workflow we are done at this point if follow is requested + if !opts.Follow { return nil, nil } @@ -547,7 +590,7 @@ func (e *Emitter) findNextRun(ctx context.Context, run v1.Run, follow bool) (*v1 } }() - isWorkflowDone, cancel, err := e.isWorkflowDone(ctx, run) + isWorkflowDone, cancel, err := e.isWorkflowDone(ctx, run, opts) if err != nil { return nil, err } @@ -562,8 +605,8 @@ func (e *Emitter) findNextRun(ctx context.Context, run v1.Run, follow bool) (*v1 if run, ok := event.Object.(*v1.Run); ok { return run, nil } - case <-isWorkflowDone: - return nil, nil + case run := <-isWorkflowDone: + return run, nil } } diff --git a/pkg/invoke/invoker.go b/pkg/invoke/invoker.go index 64a8fb82..102ca52a 100644 --- a/pkg/invoke/invoker.go +++ b/pkg/invoke/invoker.go @@ -13,6 +13,7 @@ import ( "github.com/gptscript-ai/go-gptscript" "github.com/otto8-ai/nah/pkg/router" + "github.com/otto8-ai/nah/pkg/uncached" "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/logger" "github.com/otto8-ai/otto8/pkg/events" @@ -402,14 +403,15 @@ func (i *Invoker) createRun(ctx context.Context, c kclient.WithWatch, thread *v1 if !thread.Spec.SystemTask { err = retry.RetryOnConflict(retry.DefaultBackoff, func() error { - if err := c.Get(ctx, kclient.ObjectKeyFromObject(thread), thread); err != nil { + if err := c.Get(ctx, kclient.ObjectKeyFromObject(thread), uncached.Get(thread)); err != nil { return err } thread.Status.CurrentRunName = run.Name return c.Status().Update(ctx, thread) }) if err != nil { - return nil, err + // Don't return error it's not critical, and will mostly likely make caller loose track of this + log.Errorf("failed to update thread %q for run %q: %v", thread.Name, run.Name, err) } } diff --git a/pkg/invoke/workflow.go b/pkg/invoke/workflow.go index aaf4edd4..70900c53 100644 --- a/pkg/invoke/workflow.go +++ b/pkg/invoke/workflow.go @@ -162,29 +162,24 @@ func (i *Invoker) deleteSteps(ctx context.Context, c kclient.Client, thread v1.T steps v1.WorkflowStepList ) - if err := c.List(ctx, &steps, kclient.InNamespace(thread.Namespace)); err != nil { + if err := c.List(ctx, &steps, kclient.InNamespace(thread.Namespace), kclient.MatchingFields{ + "spec.workflowExecutionName": thread.Spec.WorkflowExecutionName, + }); err != nil { return err } if len(steps.Items) == 0 { - return types.NewErrNotFound("step not found: %s", stepID) + return nil } - var deleted bool for _, step := range steps.Items { - if step.Status.State == types.WorkflowStateError || - step.Spec.WorkflowExecutionName == thread.Spec.WorkflowExecutionName && stepMatches(step.Spec.Step.ID, stepID) { + if step.Status.State == types.WorkflowStateError || stepMatches(step.Spec.Step.ID, stepID) { if err := c.Delete(ctx, &step); kclient.IgnoreNotFound(err) != nil { return err } - deleted = true } } - if !deleted { - return types.NewErrNotFound("step not found: %s", stepID) - } - return nil } diff --git a/pkg/storage/apis/otto.otto8.ai/v1/workflow.go b/pkg/storage/apis/otto.otto8.ai/v1/workflow.go index 6420d0f2..3eee8bcb 100644 --- a/pkg/storage/apis/otto.otto8.ai/v1/workflow.go +++ b/pkg/storage/apis/otto.otto8.ai/v1/workflow.go @@ -1,13 +1,17 @@ package v1 import ( + "slices" + + "github.com/otto8-ai/nah/pkg/fields" "github.com/otto8-ai/otto8/apiclient/types" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) var ( - _ Aliasable = (*Workflow)(nil) - _ AliasScoped = (*Workflow)(nil) + _ Aliasable = (*Workflow)(nil) + _ AliasScoped = (*Workflow)(nil) + _ fields.Fields = (*Workflow)(nil) ) // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object @@ -20,6 +24,27 @@ type Workflow struct { Status WorkflowStatus `json:"status,omitempty"` } +func (in *Workflow) Has(field string) (exists bool) { + return slices.Contains(in.FieldNames(), field) +} + +func (in *Workflow) Get(field string) (value string) { + switch field { + case "spec.agentName": + return in.Spec.AgentName + case "spec.userID": + return in.Spec.UserID + } + return "" +} + +func (in *Workflow) FieldNames() []string { + return []string{ + "spec.agentName", + "spec.userID", + } +} + func (in *Workflow) GetAliasName() string { return in.Spec.Manifest.Alias } @@ -37,7 +62,9 @@ func (in *Workflow) GetAliasScope() string { } type WorkflowSpec struct { - Manifest types.WorkflowManifest `json:"manifest,omitempty"` + AgentName string `json:"agentName,omitempty"` + UserID string `json:"userID,omitempty"` + Manifest types.WorkflowManifest `json:"manifest,omitempty"` } type WorkflowStatus struct { diff --git a/pkg/storage/apis/otto.otto8.ai/v1/workflowstep.go b/pkg/storage/apis/otto.otto8.ai/v1/workflowstep.go index 2962b9fc..ac2122a0 100644 --- a/pkg/storage/apis/otto.otto8.ai/v1/workflowstep.go +++ b/pkg/storage/apis/otto.otto8.ai/v1/workflowstep.go @@ -1,10 +1,15 @@ package v1 import ( + "slices" + + "github.com/otto8-ai/nah/pkg/fields" "github.com/otto8-ai/otto8/apiclient/types" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +var _ fields.Fields = (*WorkflowStep)(nil) + // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object type WorkflowStep struct { @@ -15,6 +20,24 @@ type WorkflowStep struct { Status WorkflowStepStatus `json:"status,omitempty"` } +func (in *WorkflowStep) Has(field string) (exists bool) { + return slices.Contains(in.FieldNames(), field) +} + +func (in *WorkflowStep) Get(field string) (value string) { + switch field { + case "spec.workflowExecutionName": + return in.Spec.WorkflowExecutionName + } + return "" +} + +func (in *WorkflowStep) FieldNames() []string { + return []string{ + "spec.workflowExecutionName", + } +} + func (in *WorkflowStep) IsGenerationInSync() bool { return in.Spec.WorkflowGeneration == in.Status.WorkflowGeneration }