From 8e6732402867d147cdecb43b8382750bbe5c12ab Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Wed, 16 Oct 2024 08:38:35 -0400 Subject: [PATCH] feat: add authorization for the API layer (#204) Signed-off-by: Donnie Adams --- pkg/api/authn/anonymous.go | 22 ++++ pkg/api/authn/authn.go | 29 +++++ pkg/api/authn/noauth.go | 21 ++++ pkg/api/authz/authz.go | 92 ++++++++++++++++ pkg/api/handlers/agent.go | 9 +- pkg/api/handlers/webhooks.go | 9 +- pkg/api/handlers/workflows.go | 9 +- pkg/api/request.go | 13 ++- pkg/api/router/router.go | 167 ++++++++++++++--------------- pkg/api/{ => server}/server.go | 63 +++++------ pkg/gateway/client/auth.go | 7 +- pkg/gateway/server/authprovider.go | 31 +++--- pkg/gateway/server/middleware.go | 21 ---- pkg/gateway/server/router.go | 91 ++++++++++++++++ pkg/gateway/server/routes.go | 92 ---------------- pkg/gateway/server/token.go | 97 ++++++----------- pkg/services/config.go | 48 ++------- 17 files changed, 459 insertions(+), 362 deletions(-) create mode 100644 pkg/api/authn/anonymous.go create mode 100644 pkg/api/authn/authn.go create mode 100644 pkg/api/authn/noauth.go create mode 100644 pkg/api/authz/authz.go rename pkg/api/{ => server}/server.go (53%) create mode 100644 pkg/gateway/server/router.go delete mode 100644 pkg/gateway/server/routes.go diff --git a/pkg/api/authn/anonymous.go b/pkg/api/authn/anonymous.go new file mode 100644 index 000000000..918e994f0 --- /dev/null +++ b/pkg/api/authn/anonymous.go @@ -0,0 +1,22 @@ +package authn + +import ( + "net/http" + + "github.com/otto8-ai/otto8/pkg/api/authz" + "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/apiserver/pkg/authentication/user" +) + +type Anonymous struct { +} + +func (n Anonymous) AuthenticateRequest(*http.Request) (*authenticator.Response, bool, error) { + return &authenticator.Response{ + User: &user.DefaultInfo{ + UID: "anonymous", + Name: "anonymous", + Groups: []string{authz.UnauthenticatedGroup}, + }, + }, true, nil +} diff --git a/pkg/api/authn/authn.go b/pkg/api/authn/authn.go new file mode 100644 index 000000000..52ed63c82 --- /dev/null +++ b/pkg/api/authn/authn.go @@ -0,0 +1,29 @@ +package authn + +import ( + "net/http" + + "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/apiserver/pkg/authentication/user" +) + +type Authenticator struct { + authenticator authenticator.Request +} + +func NewAuthenticator(authenticator authenticator.Request) *Authenticator { + return &Authenticator{ + authenticator: authenticator, + } +} + +func (a *Authenticator) Authenticate(req *http.Request) (user.Info, error) { + resp, ok, err := a.authenticator.AuthenticateRequest(req) + if err != nil { + return nil, err + } + if !ok { + panic("authentication should always succeed") + } + return resp.User, nil +} diff --git a/pkg/api/authn/noauth.go b/pkg/api/authn/noauth.go new file mode 100644 index 000000000..cda5a158b --- /dev/null +++ b/pkg/api/authn/noauth.go @@ -0,0 +1,21 @@ +package authn + +import ( + "net/http" + + "github.com/otto8-ai/otto8/pkg/api/authz" + "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/apiserver/pkg/authentication/user" +) + +type NoAuth struct { +} + +func (n NoAuth) AuthenticateRequest(*http.Request) (*authenticator.Response, bool, error) { + return &authenticator.Response{ + User: &user.DefaultInfo{ + Name: "nobody", + Groups: []string{authz.AdminGroup, authz.AuthenticatedGroup}, + }, + }, true, nil +} diff --git a/pkg/api/authz/authz.go b/pkg/api/authz/authz.go new file mode 100644 index 000000000..2b3013cc8 --- /dev/null +++ b/pkg/api/authz/authz.go @@ -0,0 +1,92 @@ +package authz + +import ( + "net/http" + "slices" + + "k8s.io/apiserver/pkg/authentication/user" +) + +const ( + AdminGroup = "admin" + AuthenticatedGroup = "authenticated" + UnauthenticatedGroup = "unauthenticated" + + // anyGroup is an internal group that allows access to any group + anyGroup = "*" +) + +type Authorizer struct { + rules []rule +} + +func NewAuthorizer() *Authorizer { + return &Authorizer{ + rules: defaultRules(), + } +} + +func (a *Authorizer) Authorize(req *http.Request, user user.Info) bool { + userGroups := user.GetGroups() + for _, r := range a.rules { + if r.group == anyGroup || slices.Contains(userGroups, r.group) { + if _, pattern := r.mux.Handler(req); pattern != "" { + return true + } + } + } + + return false +} + +type rule struct { + group string + mux *http.ServeMux +} + +func defaultRules() []rule { + var ( + rules []rule + f = (*fake)(nil) + ) + + // Build admin mux, admins can assess any URL + adminMux := http.NewServeMux() + adminMux.Handle("/", f) + + rules = append(rules, rule{ + group: AdminGroup, + mux: adminMux, + }) + + // Build mux that anyone can access + anyMux := http.NewServeMux() + anyMux.Handle("POST /api/webhooks/{id}", f) + + anyMux.Handle("GET /api/token-request/{id}", f) + anyMux.Handle("POST /api/token-request", f) + anyMux.Handle("GET /api/token-request/{id}/{service}", f) + + anyMux.Handle("GET /api/auth-providers", f) + anyMux.Handle("GET /api/auth-providers/{slug}", f) + + anyMux.Handle("GET /api/oauth/start/{id}/{service}", f) + anyMux.Handle("/api/oauth/redirect/{service}", f) + + anyMux.Handle("GET /api/app-oauth/authorize/{id}", f) + anyMux.Handle("GET /api/app-oauth/refresh/{id}", f) + anyMux.Handle("GET /api/app-oauth/callback/{id}", f) + anyMux.Handle("GET /api/app-oauth/get-token", f) + + rules = append(rules, rule{ + group: anyGroup, + mux: anyMux, + }) + + return rules +} + +// fake is a fake handler that does nothing +type fake struct{} + +func (f *fake) ServeHTTP(http.ResponseWriter, *http.Request) {} diff --git a/pkg/api/handlers/agent.go b/pkg/api/handlers/agent.go index 2964821ed..057f10a58 100644 --- a/pkg/api/handlers/agent.go +++ b/pkg/api/handlers/agent.go @@ -8,6 +8,7 @@ import ( "github.com/gptscript-ai/go-gptscript" "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/pkg/api" + "github.com/otto8-ai/otto8/pkg/api/server" "github.com/otto8-ai/otto8/pkg/render" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.gptscript.ai/v1" "github.com/otto8-ai/otto8/pkg/system" @@ -48,7 +49,7 @@ func (a *AgentHandler) Update(req api.Context) error { return err } - return req.Write(convertAgent(agent, api.GetURLPrefix(req))) + return req.Write(convertAgent(agent, server.GetURLPrefix(req))) } func (a *AgentHandler) Delete(req api.Context) error { @@ -84,7 +85,7 @@ func (a *AgentHandler) Create(req api.Context) error { } req.WriteHeader(http.StatusCreated) - return req.Write(convertAgent(agent, api.GetURLPrefix(req))) + return req.Write(convertAgent(agent, server.GetURLPrefix(req))) } func convertAgent(agent v1.Agent, prefix string) *types.Agent { @@ -109,7 +110,7 @@ func (a *AgentHandler) ByID(req api.Context) error { return err } - return req.Write(convertAgent(agent, api.GetURLPrefix(req))) + return req.Write(convertAgent(agent, server.GetURLPrefix(req))) } func (a *AgentHandler) List(req api.Context) error { @@ -120,7 +121,7 @@ func (a *AgentHandler) List(req api.Context) error { var resp types.AgentList for _, agent := range agentList.Items { - resp.Items = append(resp.Items, *convertAgent(agent, api.GetURLPrefix(req))) + resp.Items = append(resp.Items, *convertAgent(agent, server.GetURLPrefix(req))) } return req.Write(resp) diff --git a/pkg/api/handlers/webhooks.go b/pkg/api/handlers/webhooks.go index 38004486d..8dda0413e 100644 --- a/pkg/api/handlers/webhooks.go +++ b/pkg/api/handlers/webhooks.go @@ -12,6 +12,7 @@ import ( "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/pkg/api" + "github.com/otto8-ai/otto8/pkg/api/server" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.gptscript.ai/v1" "github.com/otto8-ai/otto8/pkg/system" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -53,7 +54,7 @@ func (a *WebhookHandler) Update(req api.Context) error { return err } - return req.Write(convertWebhook(wh, api.GetURLPrefix(req))) + return req.Write(convertWebhook(wh, server.GetURLPrefix(req))) } func (a *WebhookHandler) Delete(req api.Context) error { @@ -98,7 +99,7 @@ func (a *WebhookHandler) Create(req api.Context) error { } req.WriteHeader(http.StatusCreated) - return req.Write(convertWebhook(wh, api.GetURLPrefix(req))) + return req.Write(convertWebhook(wh, server.GetURLPrefix(req))) } func convertWebhook(webhook v1.Webhook, urlPrefix string) *types.Webhook { @@ -125,7 +126,7 @@ func (a *WebhookHandler) ByID(req api.Context) error { return err } - return req.Write(convertWebhook(wh, api.GetURLPrefix(req))) + return req.Write(convertWebhook(wh, server.GetURLPrefix(req))) } func (a *WebhookHandler) List(req api.Context) error { @@ -136,7 +137,7 @@ func (a *WebhookHandler) List(req api.Context) error { var resp types.WebhookList for _, wh := range webhookList.Items { - resp.Items = append(resp.Items, *convertWebhook(wh, api.GetURLPrefix(req))) + resp.Items = append(resp.Items, *convertWebhook(wh, server.GetURLPrefix(req))) } return req.Write(resp) diff --git a/pkg/api/handlers/workflows.go b/pkg/api/handlers/workflows.go index 1a72f7ef3..617e46cce 100644 --- a/pkg/api/handlers/workflows.go +++ b/pkg/api/handlers/workflows.go @@ -8,6 +8,7 @@ import ( "github.com/gptscript-ai/go-gptscript" "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/pkg/api" + "github.com/otto8-ai/otto8/pkg/api/server" "github.com/otto8-ai/otto8/pkg/controller/handlers/workflow" "github.com/otto8-ai/otto8/pkg/render" v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.gptscript.ai/v1" @@ -50,7 +51,7 @@ func (a *WorkflowHandler) Update(req api.Context) error { return err } - return req.Write(convertWorkflow(wf, api.GetURLPrefix(req))) + return req.Write(convertWorkflow(wf, server.GetURLPrefix(req))) } func (a *WorkflowHandler) Delete(req api.Context) error { @@ -87,7 +88,7 @@ func (a *WorkflowHandler) Create(req api.Context) error { } req.WriteHeader(http.StatusCreated) - return req.Write(convertWorkflow(workflow, api.GetURLPrefix(req))) + return req.Write(convertWorkflow(workflow, server.GetURLPrefix(req))) } func convertWorkflow(workflow v1.Workflow, prefix string) *types.Workflow { @@ -112,7 +113,7 @@ func (a *WorkflowHandler) ByID(req api.Context) error { return err } - return req.Write(convertWorkflow(workflow, api.GetURLPrefix(req))) + return req.Write(convertWorkflow(workflow, server.GetURLPrefix(req))) } func (a *WorkflowHandler) List(req api.Context) error { @@ -123,7 +124,7 @@ func (a *WorkflowHandler) List(req api.Context) error { var resp types.WorkflowList for _, workflow := range workflowList.Items { - resp.Items = append(resp.Items, *convertWorkflow(workflow, api.GetURLPrefix(req))) + resp.Items = append(resp.Items, *convertWorkflow(workflow, server.GetURLPrefix(req))) } return req.Write(resp) diff --git a/pkg/api/request.go b/pkg/api/request.go index 86481ac41..07da1853f 100644 --- a/pkg/api/request.go +++ b/pkg/api/request.go @@ -10,9 +10,9 @@ import ( "strings" "time" - "github.com/acorn-io/baaah/pkg/router" "github.com/gptscript-ai/go-gptscript" "github.com/otto8-ai/otto8/apiclient/types" + "github.com/otto8-ai/otto8/pkg/api/authz" "github.com/otto8-ai/otto8/pkg/storage" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" @@ -31,6 +31,11 @@ type Context struct { User user.Info } +type ( + HandlerFunc func(Context) error + Middleware func(HandlerFunc) HandlerFunc +) + func (r *Context) IsStreamRequested() bool { return r.Accepts("text/event-stream") } @@ -195,7 +200,7 @@ func (r *Context) Delete(obj client.Object) error { func (r *Context) Get(obj client.Object, name string) error { namespace := r.Namespace() - err := r.Storage.Get(r.Request.Context(), router.Key(namespace, name), obj) + err := r.Storage.Get(r.Request.Context(), client.ObjectKey{Namespace: namespace, Name: name}, obj) if apierrors.IsNotFound(err) { gvk, _ := r.Storage.GroupVersionKindFor(obj) return types.NewErrHttp(http.StatusNotFound, fmt.Sprintf("%s %s not found", strings.ToLower(gvk.Kind), name)) @@ -216,11 +221,11 @@ func (r *Context) Namespace() string { } func (r *Context) UserIsAdmin() bool { - return slices.Contains(r.User.GetGroups(), "admin") + return slices.Contains(r.User.GetGroups(), authz.AdminGroup) } func (r *Context) UserIsAuthenticated() bool { - return slices.Contains(r.User.GetGroups(), "system:authenticated") + return slices.Contains(r.User.GetGroups(), authz.AuthenticatedGroup) } func (r *Context) UserID() uint { diff --git a/pkg/api/router/router.go b/pkg/api/router/router.go index 147147143..ceb993d3d 100644 --- a/pkg/api/router/router.go +++ b/pkg/api/router/router.go @@ -9,8 +9,7 @@ import ( ) func Router(services *services.Services) (http.Handler, error) { - w := services.APIServer.Wrap - mux := http.NewServeMux() + mux := services.APIServer agents := handlers.NewAgentHandler(services.WorkspaceClient, "directory") workflows := handlers.NewWorkflowHandler(services.WorkspaceClient, "directory") @@ -22,118 +21,118 @@ func Router(services *services.Services) (http.Handler, error) { cronJobs := handlers.NewCronJobHandler() // Agents - mux.Handle("GET /api/agents", w(agents.List)) - mux.Handle("GET /api/agents/{id}", w(agents.ByID)) - mux.Handle("GET /api/agents/{id}/script", w(agents.Script)) - mux.Handle("GET /api/agents/{id}/script.gpt", w(agents.Script)) - mux.Handle("GET /api/agents/{id}/script/tool.gpt", w(agents.Script)) - mux.Handle("POST /api/agents", w(agents.Create)) - mux.Handle("PUT /api/agents/{id}", w(agents.Update)) - mux.Handle("DELETE /api/agents/{id}", w(agents.Delete)) + mux.HandleFunc("GET /api/agents", agents.List) + mux.HandleFunc("GET /api/agents/{id}", agents.ByID) + mux.HandleFunc("GET /api/agents/{id}/script", agents.Script) + mux.HandleFunc("GET /api/agents/{id}/script.gpt", agents.Script) + mux.HandleFunc("GET /api/agents/{id}/script/tool.gpt", agents.Script) + mux.HandleFunc("POST /api/agents", agents.Create) + mux.HandleFunc("PUT /api/agents/{id}", agents.Update) + mux.HandleFunc("DELETE /api/agents/{id}", agents.Delete) // Agent files - mux.Handle("GET /api/agents/{id}/files", w(agents.Files)) - mux.Handle("POST /api/agents/{id}/files/{file}", w(agents.UploadFile)) - mux.Handle("DELETE /api/agents/{id}/files/{file}", w(agents.DeleteFile)) + mux.HandleFunc("GET /api/agents/{id}/files", agents.Files) + mux.HandleFunc("POST /api/agents/{id}/files/{file}", agents.UploadFile) + mux.HandleFunc("DELETE /api/agents/{id}/files/{file}", agents.DeleteFile) // Agent knowledge files - mux.Handle("GET /api/agents/{id}/knowledge", w(agents.Knowledge)) - mux.Handle("POST /api/agents/{id}/knowledge/{file}", w(agents.UploadKnowledge)) - mux.Handle("DELETE /api/agents/{id}/knowledge/{file...}", w(agents.DeleteKnowledge)) + mux.HandleFunc("GET /api/agents/{id}/knowledge", agents.Knowledge) + mux.HandleFunc("POST /api/agents/{id}/knowledge/{file}", agents.UploadKnowledge) + mux.HandleFunc("DELETE /api/agents/{id}/knowledge/{file...}", agents.DeleteKnowledge) - mux.Handle("POST /api/agents/{agent_id}/remote-knowledge-sources", w(agents.CreateRemoteKnowledgeSource)) - mux.Handle("GET /api/agents/{agent_id}/remote-knowledge-sources", w(agents.GetRemoteKnowledgeSources)) - mux.Handle("PATCH /api/agents/{agent_id}/remote-knowledge-sources/{id}", w(agents.ReSyncRemoteKnowledgeSource)) - mux.Handle("PUT /api/agents/{agent_id}/remote-knowledge-sources/{id}", w(agents.UpdateRemoteKnowledgeSource)) - mux.Handle("DELETE /api/agents/{agent_id}/remote-knowledge-sources/{id}", w(agents.DeleteRemoteKnowledgeSource)) + mux.HandleFunc("POST /api/agents/{agent_id}/remote-knowledge-sources", agents.CreateRemoteKnowledgeSource) + mux.HandleFunc("GET /api/agents/{agent_id}/remote-knowledge-sources", agents.GetRemoteKnowledgeSources) + mux.HandleFunc("PATCH /api/agents/{agent_id}/remote-knowledge-sources/{id}", agents.ReSyncRemoteKnowledgeSource) + mux.HandleFunc("PUT /api/agents/{agent_id}/remote-knowledge-sources/{id}", agents.UpdateRemoteKnowledgeSource) + mux.HandleFunc("DELETE /api/agents/{agent_id}/remote-knowledge-sources/{id}", agents.DeleteRemoteKnowledgeSource) // Workflows - mux.Handle("GET /api/workflows", w(workflows.List)) - mux.Handle("GET /api/workflows/{id}", w(workflows.ByID)) - mux.Handle("GET /api/workflows/{id}/script", w(workflows.Script)) - mux.Handle("GET /api/workflows/{id}/script.gpt", w(workflows.Script)) - mux.Handle("GET /api/workflows/{id}/script/tool.gpt", w(workflows.Script)) - mux.Handle("POST /api/workflows", w(workflows.Create)) - mux.Handle("PUT /api/workflows/{id}", w(workflows.Update)) - mux.Handle("DELETE /api/workflows/{id}", w(workflows.Delete)) + mux.HandleFunc("GET /api/workflows", workflows.List) + mux.HandleFunc("GET /api/workflows/{id}", workflows.ByID) + mux.HandleFunc("GET /api/workflows/{id}/script", workflows.Script) + mux.HandleFunc("GET /api/workflows/{id}/script.gpt", workflows.Script) + mux.HandleFunc("GET /api/workflows/{id}/script/tool.gpt", workflows.Script) + mux.HandleFunc("POST /api/workflows", workflows.Create) + mux.HandleFunc("PUT /api/workflows/{id}", workflows.Update) + mux.HandleFunc("DELETE /api/workflows/{id}", workflows.Delete) // Workflow files - mux.Handle("GET /api/workflows/{id}/files", w(workflows.Files)) - mux.Handle("POST /api/workflows/{id}/files/{file}", w(workflows.UploadFile)) - mux.Handle("DELETE /api/workflows/{id}/files/{file}", w(workflows.DeleteFile)) + mux.HandleFunc("GET /api/workflows/{id}/files", workflows.Files) + mux.HandleFunc("POST /api/workflows/{id}/files/{file}", workflows.UploadFile) + mux.HandleFunc("DELETE /api/workflows/{id}/files/{file}", workflows.DeleteFile) // Invoker - mux.Handle("POST /api/invoke/{id}", w(invoker.Invoke)) - mux.Handle("POST /api/invoke/{id}/threads/{thread}", w(invoker.Invoke)) + mux.HandleFunc("POST /api/invoke/{id}", invoker.Invoke) + mux.HandleFunc("POST /api/invoke/{id}/threads/{thread}", invoker.Invoke) // Threads - mux.Handle("GET /api/threads", w(threads.List)) - mux.Handle("GET /api/threads/{id}", w(threads.ByID)) - mux.Handle("GET /api/threads/{id}/events", w(threads.Events)) - mux.Handle("DELETE /api/threads/{id}", w(threads.Delete)) - mux.Handle("PUT /api/threads/{id}", w(threads.Update)) - mux.Handle("GET /api/agents/{agent}/threads", w(threads.List)) + mux.HandleFunc("GET /api/threads", threads.List) + mux.HandleFunc("GET /api/threads/{id}", threads.ByID) + mux.HandleFunc("GET /api/threads/{id}/events", threads.Events) + mux.HandleFunc("DELETE /api/threads/{id}", threads.Delete) + mux.HandleFunc("PUT /api/threads/{id}", threads.Update) + mux.HandleFunc("GET /api/agents/{agent}/threads", threads.List) // Thread files - mux.Handle("GET /api/threads/{id}/files", w(threads.Files)) - mux.Handle("POST /api/threads/{id}/files/{file}", w(threads.UploadFile)) - mux.Handle("DELETE /api/threads/{id}/files/{file}", w(threads.DeleteFile)) + mux.HandleFunc("GET /api/threads/{id}/files", threads.Files) + mux.HandleFunc("POST /api/threads/{id}/files/{file}", threads.UploadFile) + mux.HandleFunc("DELETE /api/threads/{id}/files/{file}", threads.DeleteFile) // Thread knowledge files - mux.Handle("GET /api/threads/{id}/knowledge", w(threads.Knowledge)) - mux.Handle("POST /api/threads/{id}/knowledge/{file}", w(threads.UploadKnowledge)) - mux.Handle("DELETE /api/threads/{id}/knowledge/{file...}", w(threads.DeleteKnowledge)) + mux.HandleFunc("GET /api/threads/{id}/knowledge", threads.Knowledge) + mux.HandleFunc("POST /api/threads/{id}/knowledge/{file}", threads.UploadKnowledge) + mux.HandleFunc("DELETE /api/threads/{id}/knowledge/{file...}", threads.DeleteKnowledge) // ToolRefs - mux.Handle("GET /api/toolreferences", w(toolRefs.List)) - mux.Handle("GET /api/toolreferences/{id}", w(toolRefs.ByID)) - mux.Handle("POST /api/toolreferences", w(toolRefs.Create)) - mux.Handle("DELETE /api/toolreferences/{id}", w(toolRefs.Delete)) - mux.Handle("PUT /api/toolreferences/{id}", w(toolRefs.Update)) + mux.HandleFunc("GET /api/toolreferences", toolRefs.List) + mux.HandleFunc("GET /api/toolreferences/{id}", toolRefs.ByID) + mux.HandleFunc("POST /api/toolreferences", toolRefs.Create) + mux.HandleFunc("DELETE /api/toolreferences/{id}", toolRefs.Delete) + mux.HandleFunc("PUT /api/toolreferences/{id}", toolRefs.Update) // Runs - mux.Handle("GET /api/runs", w(runs.List)) - mux.Handle("GET /api/runs/{id}", w(runs.ByID)) - mux.Handle("DELETE /api/runs/{id}", w(runs.Delete)) - mux.Handle("GET /api/runs/{id}/debug", w(runs.Debug)) - mux.Handle("GET /api/runs/{id}/events", w(runs.Events)) - mux.Handle("GET /api/threads/{thread}/runs", w(runs.List)) - mux.Handle("GET /api/agents/{agent}/runs", w(runs.List)) - mux.Handle("GET /api/agents/{agent}/threads/{thread}/runs", w(runs.List)) - mux.Handle("GET /api/workflows/{workflow}/runs", w(runs.List)) - mux.Handle("GET /api/workflows/{workflow}/threads/{thread}/runs", w(runs.List)) + mux.HandleFunc("GET /api/runs", runs.List) + mux.HandleFunc("GET /api/runs/{id}", runs.ByID) + mux.HandleFunc("DELETE /api/runs/{id}", runs.Delete) + mux.HandleFunc("GET /api/runs/{id}/debug", runs.Debug) + mux.HandleFunc("GET /api/runs/{id}/events", runs.Events) + mux.HandleFunc("GET /api/threads/{thread}/runs", runs.List) + mux.HandleFunc("GET /api/agents/{agent}/runs", runs.List) + mux.HandleFunc("GET /api/agents/{agent}/threads/{thread}/runs", runs.List) + mux.HandleFunc("GET /api/workflows/{workflow}/runs", runs.List) + mux.HandleFunc("GET /api/workflows/{workflow}/threads/{thread}/runs", runs.List) // Credentials - mux.Handle("GET /api/threads/{context}/credentials", w(handlers.ListCredentials)) - mux.Handle("GET /api/agents/{context}/credentials", w(handlers.ListCredentials)) - mux.Handle("GET /api/workflows/{context}/credentials", w(handlers.ListCredentials)) - mux.Handle("GET /api/credentials", w(handlers.ListCredentials)) - mux.Handle("DELETE /api/threads/{context}/credentials/{id}", w(handlers.DeleteCredential)) - mux.Handle("DELETE /api/agents/{context}/credentials/{id}", w(handlers.DeleteCredential)) - mux.Handle("DELETE /api/workflows/{context}/credentials/{id}", w(handlers.DeleteCredential)) - mux.Handle("DELETE /api/credentials/{id}", w(handlers.DeleteCredential)) + mux.HandleFunc("GET /api/threads/{context}/credentials", handlers.ListCredentials) + mux.HandleFunc("GET /api/agents/{context}/credentials", handlers.ListCredentials) + mux.HandleFunc("GET /api/workflows/{context}/credentials", handlers.ListCredentials) + mux.HandleFunc("GET /api/credentials", handlers.ListCredentials) + mux.HandleFunc("DELETE /api/threads/{context}/credentials/{id}", handlers.DeleteCredential) + mux.HandleFunc("DELETE /api/agents/{context}/credentials/{id}", handlers.DeleteCredential) + mux.HandleFunc("DELETE /api/workflows/{context}/credentials/{id}", handlers.DeleteCredential) + mux.HandleFunc("DELETE /api/credentials/{id}", handlers.DeleteCredential) // Webhooks - mux.Handle("POST /api/webhooks", w(webhooks.Create)) - mux.Handle("GET /api/webhooks", w(webhooks.List)) - mux.Handle("GET /api/webhooks/{id}", w(webhooks.ByID)) - mux.Handle("DELETE /api/webhooks/{id}", w(webhooks.Delete)) - mux.Handle("PUT /api/webhooks/{id}", w(webhooks.Update)) - mux.Handle("POST /api/webhooks/{id}", w(webhooks.Execute)) + mux.HandleFunc("POST /api/webhooks", webhooks.Create) + mux.HandleFunc("GET /api/webhooks", webhooks.List) + mux.HandleFunc("GET /api/webhooks/{id}", webhooks.ByID) + mux.HandleFunc("DELETE /api/webhooks/{id}", webhooks.Delete) + mux.HandleFunc("PUT /api/webhooks/{id}", webhooks.Update) + mux.HandleFunc("POST /api/webhooks/{id}", webhooks.Execute) // CronJobs - mux.Handle("POST /api/cronjobs", w(cronJobs.Create)) - mux.Handle("GET /api/cronjobs", w(cronJobs.List)) - mux.Handle("GET /api/cronjobs/{id}", w(cronJobs.ByID)) - mux.Handle("DELETE /api/cronjobs/{id}", w(cronJobs.Delete)) - mux.Handle("PUT /api/cronjobs/{id}", w(cronJobs.Update)) - mux.Handle("POST /api/cronjobs/{id}", w(cronJobs.Execute)) + mux.HandleFunc("POST /api/cronjobs", cronJobs.Create) + mux.HandleFunc("GET /api/cronjobs", cronJobs.List) + mux.HandleFunc("GET /api/cronjobs/{id}", cronJobs.ByID) + mux.HandleFunc("DELETE /api/cronjobs/{id}", cronJobs.Delete) + mux.HandleFunc("PUT /api/cronjobs/{id}", cronJobs.Update) + mux.HandleFunc("POST /api/cronjobs/{id}", cronJobs.Execute) // Gateway APIs - services.GatewayServer.AddRoutes(w, mux) + services.GatewayServer.AddRoutes(services.APIServer) // UI - mux.Handle("/", services.ProxyServer.Wrap(ui.Handler(services.DevUIPort))) + services.APIServer.HTTPHandle("/", services.ProxyServer.Wrap(ui.Handler(services.DevUIPort))) - return mux, nil + return services.APIServer, nil } diff --git a/pkg/api/server.go b/pkg/api/server/server.go similarity index 53% rename from pkg/api/server.go rename to pkg/api/server/server.go index c919009df..952821701 100644 --- a/pkg/api/server.go +++ b/pkg/api/server/server.go @@ -1,4 +1,4 @@ -package api +package server import ( "errors" @@ -6,56 +6,59 @@ import ( "github.com/gptscript-ai/go-gptscript" "github.com/otto8-ai/otto8/apiclient/types" - "github.com/otto8-ai/otto8/pkg/gateway/client" - "github.com/otto8-ai/otto8/pkg/jwt" + "github.com/otto8-ai/otto8/pkg/api" + "github.com/otto8-ai/otto8/pkg/api/authn" + "github.com/otto8-ai/otto8/pkg/api/authz" "github.com/otto8-ai/otto8/pkg/storage" apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apiserver/pkg/authentication/authenticator" - "k8s.io/apiserver/pkg/authentication/user" ) type Server struct { storageClient storage.Client gptClient *gptscript.GPTScript - gatewayClient *client.Client - tokenService *jwt.TokenService - authenticator authenticator.Request + authenticator *authn.Authenticator + authorizer *authz.Authorizer + + mux *http.ServeMux } -func NewServer(storageClient storage.Client, gptClient *gptscript.GPTScript, gatewayClient *client.Client, tokenService *jwt.TokenService, authn authenticator.Request) *Server { +func NewServer(storageClient storage.Client, gptClient *gptscript.GPTScript, authn *authn.Authenticator, authz *authz.Authorizer) *Server { return &Server{ storageClient: storageClient, gptClient: gptClient, - gatewayClient: gatewayClient, - tokenService: tokenService, authenticator: authn, + authorizer: authz, + + mux: http.NewServeMux(), } } -type ( - HandlerFunc func(Context) error - Middleware func(HandlerFunc) HandlerFunc -) +func (s *Server) HandleFunc(pattern string, f api.HandlerFunc) { + s.mux.HandleFunc(pattern, s.wrap(f)) +} -func (s *Server) getUser(req *http.Request) (user.Info, error) { - resp, ok, err := s.authenticator.AuthenticateRequest(req) - if err != nil { - return nil, err - } - if !ok { - panic("authentication should always succeed") - } - return resp.User, nil +func (s *Server) HTTPHandle(pattern string, f http.Handler) { + s.mux.Handle(pattern, f) } -func (s *Server) Wrap(f HandlerFunc) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - user, err := s.getUser(req) +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mux.ServeHTTP(w, r) +} + +func (s *Server) wrap(f api.HandlerFunc) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + user, err := s.authenticator.Authenticate(req) if err != nil { http.Error(rw, err.Error(), http.StatusUnauthorized) return } - err = f(Context{ + + if !s.authorizer.Authorize(req, user) { + http.Error(rw, "forbidden", http.StatusForbidden) + return + } + + err = f(api.Context{ ResponseWriter: rw, Request: req, GPTClient: s.gptClient, @@ -70,10 +73,10 @@ func (s *Server) Wrap(f HandlerFunc) http.Handler { } else if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) } - }) + } } -func GetURLPrefix(req Context) string { +func GetURLPrefix(req api.Context) string { if req.Request.TLS == nil { return "http://" + req.Request.Host } diff --git a/pkg/gateway/client/auth.go b/pkg/gateway/client/auth.go index 3b8f1436a..1755aa0a5 100644 --- a/pkg/gateway/client/auth.go +++ b/pkg/gateway/client/auth.go @@ -6,6 +6,7 @@ import ( "slices" types2 "github.com/otto8-ai/otto8/apiclient/types" + "github.com/otto8-ai/otto8/pkg/api/authz" "github.com/otto8-ai/otto8/pkg/gateway/types" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" @@ -41,15 +42,15 @@ func (u UserDecorator) AuthenticateRequest(req *http.Request) (*authenticator.Re } groups := resp.User.GetGroups() - if gatewayUser.Role == types2.RoleAdmin && !slices.Contains(groups, "admin") { - groups = append(groups, "admin") + if gatewayUser.Role == types2.RoleAdmin && !slices.Contains(groups, authz.AdminGroup) { + groups = append(groups, authz.AdminGroup) } resp.User = &user.DefaultInfo{ Name: gatewayUser.Username, UID: fmt.Sprintf("%d", gatewayUser.ID), Extra: resp.User.GetExtra(), - Groups: append(groups, "system:authenticated"), + Groups: append(groups, authz.AuthenticatedGroup), } return resp, true, nil } diff --git a/pkg/gateway/server/authprovider.go b/pkg/gateway/server/authprovider.go index f88ba6f7e..85966a1c7 100644 --- a/pkg/gateway/server/authprovider.go +++ b/pkg/gateway/server/authprovider.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + types2 "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/pkg/api" kcontext "github.com/otto8-ai/otto8/pkg/gateway/context" ktime "github.com/otto8-ai/otto8/pkg/gateway/time" @@ -86,13 +87,10 @@ func (s *Server) updateAuthProvider(apiContext api.Context) error { return nil } -func (s *Server) getAuthProviders(w http.ResponseWriter, r *http.Request) { - logger := kcontext.GetLogger(r.Context()) +func (s *Server) getAuthProviders(apiContext api.Context) error { var authProviders []types.AuthProvider - if err := s.db.WithContext(r.Context()).Find(&authProviders).Error; err != nil { - logger.DebugContext(r.Context(), "failed to query auth providers", "error", err) - writeError(r.Context(), logger, w, http.StatusInternalServerError, err) - return + if err := s.db.WithContext(apiContext.Context()).Find(&authProviders).Error; err != nil { + return types2.NewErrHttp(http.StatusInternalServerError, err.Error()) } resp := make([]authProviderResponse, len(authProviders)) @@ -104,31 +102,30 @@ func (s *Server) getAuthProviders(w http.ResponseWriter, r *http.Request) { } } - writeResponse(r.Context(), logger, w, resp) + return apiContext.Write(resp) } -func (s *Server) getAuthProvider(w http.ResponseWriter, r *http.Request) { - logger := kcontext.GetLogger(r.Context()) - slug := r.PathValue("slug") +func (s *Server) getAuthProvider(apiContext api.Context) error { + slug := apiContext.PathValue("slug") if slug == "" { - writeError(r.Context(), logger, w, http.StatusBadRequest, errors.New("id path parameter is required")) - return + return types2.NewErrHttp(http.StatusBadRequest, "id path parameter is required") } oauthProvider := new(types.AuthProvider) - if err := s.db.WithContext(r.Context()).Where("slug = ?", slug).Find(oauthProvider).Error; err != nil { + if err := s.db.WithContext(apiContext.Context()).Where("slug = ?", slug).Find(oauthProvider).Error; err != nil { status := http.StatusInternalServerError if errors.Is(err, gorm.ErrRecordNotFound) { status = http.StatusNotFound } - logger.DebugContext(r.Context(), "failed to query auth providers", "error", err) - writeError(r.Context(), logger, w, status, fmt.Errorf("failed to query auth provider: %v", err)) - return + return types2.NewErrHttp(status, fmt.Sprintf("failed to query auth provider: %v", err)) } oauthProvider.ClientSecret = "" - writeResponse(r.Context(), logger, w, authProviderResponse{AuthProvider: *oauthProvider, RedirectURL: oauthProvider.RedirectURL(s.baseURL)}) + return apiContext.Write(authProviderResponse{ + AuthProvider: *oauthProvider, + RedirectURL: oauthProvider.RedirectURL(s.baseURL), + }) } func (s *Server) deleteAuthProvider(apiContext api.Context) error { diff --git a/pkg/gateway/server/middleware.go b/pkg/gateway/server/middleware.go index 0ddaef028..a361f0808 100644 --- a/pkg/gateway/server/middleware.go +++ b/pkg/gateway/server/middleware.go @@ -6,33 +6,12 @@ import ( "runtime/debug" "time" - types2 "github.com/otto8-ai/otto8/apiclient/types" "github.com/otto8-ai/otto8/pkg/api" "github.com/otto8-ai/otto8/pkg/gateway/context" "github.com/otto8-ai/otto8/pkg/gateway/log" "github.com/otto8-ai/otto8/pkg/gateway/types" ) -func (s *Server) auth(mustBeAdmin bool) api.Middleware { - return func(next api.HandlerFunc) api.HandlerFunc { - return func(apiContext api.Context) error { - if !apiContext.UserIsAuthenticated() { - return types2.NewErrHttp(http.StatusUnauthorized, "unauthenticated") - } - if mustBeAdmin && !apiContext.UserIsAdmin() { - return types2.NewErrHttp(http.StatusForbidden, "must be admin") - } - return next(apiContext) - } - } -} - -func (s *Server) authFunc(role types2.Role) api.Middleware { - return func(next api.HandlerFunc) api.HandlerFunc { - return s.auth(role.HasRole(types2.RoleAdmin))(next) - } -} - func (s *Server) monitor(next api.HandlerFunc) api.HandlerFunc { return func(apiContext api.Context) error { logger := context.GetLogger(apiContext.Context()) diff --git a/pkg/gateway/server/router.go b/pkg/gateway/server/router.go new file mode 100644 index 000000000..707272cd0 --- /dev/null +++ b/pkg/gateway/server/router.go @@ -0,0 +1,91 @@ +package server + +import ( + _ "embed" + "net/http" + "net/http/httputil" + + "github.com/otto8-ai/otto8/pkg/api" + "github.com/otto8-ai/otto8/pkg/api/server" + kcontext "github.com/otto8-ai/otto8/pkg/gateway/context" + "github.com/otto8-ai/otto8/pkg/gateway/types" +) + +func (s *Server) AddRoutes(mux *server.Server) { + wrap := func(h api.HandlerFunc) api.HandlerFunc { + return apply(h, addRequestID, addLogger, logRequest, contentType("application/json")) + } + // All the routes served by the API will start with `/api` + mux.HandleFunc("GET /api/me", wrap(s.getCurrentUser)) + mux.HandleFunc("GET /api/users", wrap(s.getUsers)) + mux.HandleFunc("GET /api/users/{username}", wrap(s.getUser)) + mux.HandleFunc("PATCH /api/users/{username}", wrap(s.updateUser)) + mux.HandleFunc("DELETE /api/users/{username}", wrap(s.deleteUser)) + + mux.HandleFunc("POST /api/token-request", s.tokenRequest) + mux.HandleFunc("GET /api/token-request/{id}", s.checkForToken) + mux.HandleFunc("GET /api/token-request/{id}/{service}", s.redirectForTokenRequest) + + mux.HandleFunc("GET /api/tokens", wrap(s.getTokens)) + mux.HandleFunc("DELETE /api/tokens/{id}", wrap(s.deleteToken)) + mux.HandleFunc("POST /api/tokens", wrap(s.newToken)) + + mux.HTTPHandle("GET /api/supported-auth-types", http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) { + writeResponse(r.Context(), kcontext.GetLogger(r.Context()), writer, types.SupportedAuthTypeConfigs()) + })) + mux.HTTPHandle("GET /api/supported-oauth-app-types", http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) { + writeResponse(r.Context(), kcontext.GetLogger(r.Context()), writer, types.SupportedOAuthAppTypeConfigs()) + })) + + mux.HandleFunc("POST /api/auth-providers", wrap(s.createAuthProvider)) + mux.HandleFunc("PATCH /api/auth-providers/{slug}", wrap(s.updateAuthProvider)) + mux.HandleFunc("DELETE /api/auth-providers/{slug}", wrap(s.deleteAuthProvider)) + mux.HandleFunc("GET /api/auth-providers", s.getAuthProviders) + mux.HandleFunc("GET /api/auth-providers/{slug}", s.getAuthProvider) + mux.HandleFunc("POST /api/auth-providers/{slug}/disable", wrap(s.disableAuthProvider)) + mux.HandleFunc("POST /api/auth-providers/{slug}/enable", wrap(s.enableAuthProvider)) + + mux.HandleFunc("POST /api/llm-providers", wrap(s.createLLMProvider)) + mux.HandleFunc("PATCH /api/llm-providers/{slug}", wrap(s.updateLLMProvider)) + mux.HandleFunc("DELETE /api/llm-providers/{slug}", wrap(s.deleteLLMProvider)) + mux.HandleFunc("GET /api/llm-providers", wrap(s.getLLMProviders)) + mux.HandleFunc("GET /api/llm-providers/{slug}", wrap(s.getLLMProvider)) + mux.HandleFunc("POST /api/llm-providers/{slug}/disable", wrap(s.disableLLMProvider)) + mux.HandleFunc("POST /api/llm-providers/{slug}/enable", wrap(s.enableLLMProvider)) + + mux.HandleFunc("POST /api/models", wrap(s.createModel)) + mux.HandleFunc("PATCH /api/models/{id}", wrap(s.updateModel)) + mux.HandleFunc("DELETE /api/models/{id}", wrap(s.deleteModel)) + mux.HandleFunc("GET /api/models", wrap(s.getModels)) + mux.HandleFunc("GET /api/models/{id}", wrap(s.getModel)) + mux.HandleFunc("POST /api/models/{id}/disable", wrap(s.disableModel)) + mux.HandleFunc("POST /api/models/{id}/enable", wrap(s.enableModel)) + + mux.HandleFunc("GET /api/oauth/start/{id}/{service}", wrap(s.oauth)) + mux.HandleFunc("/api/oauth/redirect/{service}", wrap(s.redirect)) + + // CRUD routes for OAuth Apps (integrations with other service such as Microsoft 365) + mux.HandleFunc("GET /api/oauth-apps", wrap(s.listOAuthApps)) + mux.HandleFunc("GET /api/oauth-apps/{id}", wrap(s.oauthAppByID)) + mux.HandleFunc("POST /api/oauth-apps", wrap(s.createOAuthApp)) + mux.HandleFunc("PATCH /api/oauth-apps/{id}", wrap(s.updateOAuthApp)) + mux.HandleFunc("DELETE /api/oauth-apps/{id}", wrap(s.deleteOAuthApp)) + + // Routes for OAuth authorization code flow + mux.HandleFunc("GET /api/app-oauth/authorize/{id}", wrap(s.authorizeOAuthApp)) + mux.HandleFunc("GET /api/app-oauth/refresh/{id}", wrap(s.refreshOAuthApp)) + mux.HandleFunc("GET /api/app-oauth/callback/{id}", wrap(s.callbackOAuthApp)) + + // Route for credential tools to get their OAuth tokens + mux.HandleFunc("GET /api/app-oauth/get-token", wrap(s.getTokenOAuthApp)) + + // Handle the proxy to the LLM provider. + mux.HandleFunc("/llm/{provider}/{path...}", apply(httpToApiHandlerFunc(&httputil.ReverseProxy{ + Rewrite: s.proxyToProvider, + ErrorHandler: s.proxyError, + }), addRequestID, addLogger, logRequest, s.monitor)) + mux.HandleFunc("/llm/{provider}", apply(httpToApiHandlerFunc(&httputil.ReverseProxy{ + Rewrite: s.proxyToProvider, + ErrorHandler: s.proxyError, + }), addRequestID, addLogger, logRequest, s.monitor)) +} diff --git a/pkg/gateway/server/routes.go b/pkg/gateway/server/routes.go deleted file mode 100644 index 0dd0c27a2..000000000 --- a/pkg/gateway/server/routes.go +++ /dev/null @@ -1,92 +0,0 @@ -package server - -import ( - _ "embed" - "net/http" - "net/http/httputil" - - types2 "github.com/otto8-ai/otto8/apiclient/types" - "github.com/otto8-ai/otto8/pkg/api" - kcontext "github.com/otto8-ai/otto8/pkg/gateway/context" - "github.com/otto8-ai/otto8/pkg/gateway/types" -) - -func (s *Server) AddRoutes(w func(api.HandlerFunc) http.Handler, mux *http.ServeMux) { - wrap := func(h api.HandlerFunc) http.Handler { - return w(apply(h, addRequestID, addLogger, logRequest, contentType("application/json"))) - } - // All the routes served by the API will start with `/api` - mux.Handle("GET /api/me", wrap(s.authFunc(types2.RoleBasic)(s.getCurrentUser))) - mux.Handle("GET /api/users", wrap(s.authFunc(types2.RoleAdmin)(s.getUsers))) - mux.Handle("GET /api/users/{username}", wrap(s.authFunc(types2.RoleAdmin)(s.getUser))) - // Any user can update their own username, admins can update any user - mux.Handle("PATCH /api/users/{username}", wrap(s.authFunc(types2.RoleBasic)(s.updateUser))) - mux.Handle("DELETE /api/users/{username}", wrap(s.authFunc(types2.RoleAdmin)(s.deleteUser))) - - mux.HandleFunc("POST /api/token-request", s.tokenRequest) - mux.HandleFunc("GET /api/token-request/{id}", s.checkForToken) - mux.HandleFunc("GET /api/token-request/{id}/{service}", s.redirectForTokenRequest) - - mux.Handle("GET /api/tokens", wrap(s.authFunc(types2.RoleBasic)(s.getTokens))) - mux.Handle("DELETE /api/tokens/{id}", wrap(s.authFunc(types2.RoleBasic)(s.deleteToken))) - mux.Handle("POST /api/tokens", wrap(s.authFunc(types2.RoleBasic)(s.newToken))) - - mux.HandleFunc("GET /api/supported-auth-types", func(writer http.ResponseWriter, r *http.Request) { - writeResponse(r.Context(), kcontext.GetLogger(r.Context()), writer, types.SupportedAuthTypeConfigs()) - }) - mux.HandleFunc("GET /api/supported-oauth-app-types", func(writer http.ResponseWriter, r *http.Request) { - writeResponse(r.Context(), kcontext.GetLogger(r.Context()), writer, types.SupportedOAuthAppTypeConfigs()) - }) - - mux.Handle("POST /api/auth-providers", wrap(s.authFunc(types2.RoleAdmin)(s.createAuthProvider))) - mux.Handle("PATCH /api/auth-providers/{slug}", wrap(s.authFunc(types2.RoleAdmin)(s.updateAuthProvider))) - mux.Handle("DELETE /api/auth-providers/{slug}", wrap(s.authFunc(types2.RoleAdmin)(s.deleteAuthProvider))) - mux.HandleFunc("GET /api/auth-providers", s.getAuthProviders) - mux.HandleFunc("GET /api/auth-providers/{slug}", s.getAuthProvider) - mux.Handle("POST /api/auth-providers/{slug}/disable", wrap(s.authFunc(types2.RoleAdmin)(s.disableAuthProvider))) - mux.Handle("POST /api/auth-providers/{slug}/enable", wrap(s.authFunc(types2.RoleAdmin)(s.enableAuthProvider))) - - mux.Handle("POST /api/llm-providers", wrap(s.authFunc(types2.RoleAdmin)(s.createLLMProvider))) - mux.Handle("PATCH /api/llm-providers/{slug}", wrap(s.authFunc(types2.RoleAdmin)(s.updateLLMProvider))) - mux.Handle("DELETE /api/llm-providers/{slug}", wrap(s.authFunc(types2.RoleAdmin)(s.deleteLLMProvider))) - mux.Handle("GET /api/llm-providers", wrap(s.authFunc(types2.RoleBasic)(s.getLLMProviders))) - mux.Handle("GET /api/llm-providers/{slug}", wrap(s.authFunc(types2.RoleBasic)(s.getLLMProvider))) - mux.Handle("POST /api/llm-providers/{slug}/disable", wrap(s.authFunc(types2.RoleAdmin)(s.disableLLMProvider))) - mux.Handle("POST /api/llm-providers/{slug}/enable", wrap(s.authFunc(types2.RoleAdmin)(s.enableLLMProvider))) - - mux.Handle("POST /api/models", wrap(s.authFunc(types2.RoleAdmin)(s.createModel))) - mux.Handle("PATCH /api/models/{id}", wrap(s.authFunc(types2.RoleAdmin)(s.updateModel))) - mux.Handle("DELETE /api/models/{id}", wrap(s.authFunc(types2.RoleAdmin)(s.deleteModel))) - mux.Handle("GET /api/models", wrap(s.authFunc(types2.RoleBasic)(s.getModels))) - mux.Handle("GET /api/models/{id}", wrap(s.authFunc(types2.RoleBasic)(s.getModel))) - mux.Handle("POST /api/models/{id}/disable", wrap(s.authFunc(types2.RoleAdmin)(s.disableModel))) - mux.Handle("POST /api/models/{id}/enable", wrap(s.authFunc(types2.RoleAdmin)(s.enableModel))) - - mux.Handle("GET /api/oauth/start/{id}/{service}", wrap(s.oauth)) - mux.Handle("/api/oauth/redirect/{service}", wrap(s.redirect)) - - // CRUD routes for OAuth Apps (integrations with other service such as Microsoft 365) - mux.Handle("GET /api/oauth-apps", wrap(s.authFunc(types2.RoleBasic)(s.listOAuthApps))) - mux.Handle("GET /api/oauth-apps/{id}", wrap(s.authFunc(types2.RoleBasic)(s.oauthAppByID))) - mux.Handle("POST /api/oauth-apps", wrap(s.authFunc(types2.RoleAdmin)(s.createOAuthApp))) - mux.Handle("PATCH /api/oauth-apps/{id}", wrap(s.authFunc(types2.RoleAdmin)(s.updateOAuthApp))) - mux.Handle("DELETE /api/oauth-apps/{id}", wrap(s.authFunc(types2.RoleAdmin)(s.deleteOAuthApp))) - - // Routes for OAuth authorization code flow - mux.Handle("GET /api/app-oauth/authorize/{id}", wrap(s.authorizeOAuthApp)) - mux.Handle("GET /api/app-oauth/refresh/{id}", wrap(s.refreshOAuthApp)) - mux.Handle("GET /api/app-oauth/callback/{id}", wrap(s.callbackOAuthApp)) - - // Route for credential tools to get their OAuth tokens - mux.Handle("GET /api/app-oauth/get-token", wrap(s.getTokenOAuthApp)) - - // Handle the proxy to the LLM provider. - mux.Handle("/api/llm/{provider}/{path...}", w(s.auth(false)(apply(httpToApiHandlerFunc(&httputil.ReverseProxy{ - Rewrite: s.proxyToProvider, - ErrorHandler: s.proxyError, - }), addRequestID, addLogger, logRequest, s.monitor)))) - mux.Handle("/api/llm/{provider}", w(s.auth(false)(apply(httpToApiHandlerFunc(&httputil.ReverseProxy{ - Rewrite: s.proxyToProvider, - ErrorHandler: s.proxyError, - }), addRequestID, addLogger, logRequest, s.monitor)))) -} diff --git a/pkg/gateway/server/token.go b/pkg/gateway/server/token.go index 21287527a..e196b8360 100644 --- a/pkg/gateway/server/token.go +++ b/pkg/gateway/server/token.go @@ -38,24 +38,18 @@ type refreshTokenResponse struct { } func (s *Server) getTokens(apiContext api.Context) error { - logger := kcontext.GetLogger(apiContext.Context()) - var tokens []types.AuthToken if err := s.db.WithContext(apiContext.Context()).Where("user_id = ?", apiContext.UserID()).Find(&tokens).Error; err != nil { - writeError(apiContext.Context(), logger, apiContext.ResponseWriter, http.StatusInternalServerError, fmt.Errorf("error getting tokens: %v", err)) - return nil + return types2.NewErrHttp(http.StatusInternalServerError, fmt.Sprintf("error getting tokens: %v", err)) } - writeResponse(apiContext.Context(), logger, apiContext.ResponseWriter, tokens) - return nil + return apiContext.Write(tokens) } func (s *Server) deleteToken(apiContext api.Context) error { - logger := kcontext.GetLogger(apiContext.Context()) id := apiContext.PathValue("id") if id == "" { - writeError(apiContext.Context(), logger, apiContext.ResponseWriter, http.StatusBadRequest, fmt.Errorf("id path parameter is required")) - return nil + return types2.NewErrHttp(http.StatusBadRequest, "id path parameter is required") } if err := s.db.WithContext(apiContext.Context()).Where("user_id = ? AND id = ?", apiContext.UserID(), id).Delete(new(types.AuthToken)).Error; err != nil { @@ -64,12 +58,10 @@ func (s *Server) deleteToken(apiContext api.Context) error { status = http.StatusNotFound err = fmt.Errorf("not found") } - writeError(apiContext.Context(), logger, apiContext.ResponseWriter, status, fmt.Errorf("error deleting token: %v", err)) - return nil + return types2.NewErrHttp(status, fmt.Sprintf("error deleting token: %v", err)) } - writeResponse(apiContext.Context(), logger, apiContext.ResponseWriter, map[string]any{"deleted": true}) - return nil + return apiContext.Write(map[string]any{"deleted": true}) } type createTokenRequest struct { @@ -77,12 +69,10 @@ type createTokenRequest struct { } func (s *Server) newToken(apiContext api.Context) error { - logger := kcontext.GetLogger(apiContext.Context()) authProviderID := apiContext.AuthProviderID() userID := apiContext.UserID() if authProviderID <= 0 || userID <= 0 { - writeError(apiContext.Context(), logger, apiContext.ResponseWriter, http.StatusForbidden, fmt.Errorf("forbidden")) - return nil + return types2.NewErrHttp(http.StatusForbidden, "forbidden") } var customExpiration time.Duration @@ -90,23 +80,20 @@ func (s *Server) newToken(apiContext api.Context) error { request := new(createTokenRequest) err := apiContext.Read(request) if err != nil { - writeError(apiContext.Context(), logger, apiContext.ResponseWriter, http.StatusBadRequest, fmt.Errorf("invalid create create token request body: %v", err)) - return nil + return types2.NewErrHttp(http.StatusBadRequest, fmt.Sprintf("invalid create create token request body: %v", err)) } if request.ExpiresIn != "" { customExpiration, err = ktime.ParseDuration(request.ExpiresIn) if err != nil { - writeError(apiContext.Context(), logger, apiContext.ResponseWriter, http.StatusBadRequest, fmt.Errorf("invalid expiresIn duration: %v", err)) - return nil + return types2.NewErrHttp(http.StatusBadRequest, fmt.Sprintf("invalid expiresIn duration: %v", err)) } } } randBytes := make([]byte, randomTokenLength+tokenIDLength) if _, err := rand.Read(randBytes); err != nil { - writeError(apiContext.Context(), logger, apiContext.ResponseWriter, http.StatusInternalServerError, fmt.Errorf("error refreshing token: %v", err)) - return nil + return types2.NewErrHttp(http.StatusInternalServerError, fmt.Sprintf("error refreshing token: %v", err)) } id := randBytes[:tokenIDLength] @@ -131,24 +118,19 @@ func (s *Server) newToken(apiContext api.Context) error { tkn.AuthProviderID = provider.ID return tx.Create(tkn).Error }); err != nil { - writeError(apiContext.Context(), logger, apiContext.ResponseWriter, http.StatusInternalServerError, fmt.Errorf("error refreshing token: %v", err)) - return nil + return types2.NewErrHttp(http.StatusInternalServerError, fmt.Sprintf("error refreshing token: %v", err)) } - writeResponse(apiContext.Context(), logger, apiContext.ResponseWriter, refreshTokenResponse{ + return apiContext.Write(refreshTokenResponse{ Token: publicToken(id, token), ExpiresAt: tkn.ExpiresAt, }) - - return nil } -func (s *Server) tokenRequest(w http.ResponseWriter, r *http.Request) { - logger := kcontext.GetLogger(r.Context()) +func (s *Server) tokenRequest(apiContext api.Context) error { reqObj := new(tokenRequestRequest) - if err := json.NewDecoder(r.Body).Decode(reqObj); err != nil { - writeError(r.Context(), logger, w, http.StatusBadRequest, fmt.Errorf("invalid token request body: %v", err)) - return + if err := json.NewDecoder(apiContext.Request.Body).Decode(reqObj); err != nil { + return types2.NewErrHttp(http.StatusBadRequest, fmt.Sprintf("invalid token request body: %v", err)) } tokenReq := &types.TokenRequest{ @@ -157,7 +139,7 @@ func (s *Server) tokenRequest(w http.ResponseWriter, r *http.Request) { } oauthProvider := new(types.AuthProvider) - if err := s.db.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error { + if err := s.db.WithContext(apiContext.Context()).Transaction(func(tx *gorm.DB) error { if reqObj.ServiceName != "" { // Ensure the OAuth provider exists, if one was provided. if err := tx.Where("service_name = ?", reqObj.ServiceName).Where("disabled IS NULL OR disabled != ?", true).First(oauthProvider).Error; err != nil { @@ -167,31 +149,25 @@ func (s *Server) tokenRequest(w http.ResponseWriter, r *http.Request) { return tx.Create(tokenReq).Error }); err != nil { - logger.DebugContext(r.Context(), "failed to create token", "error", err) if errors.Is(err, gorm.ErrDuplicatedKey) { - writeError(r.Context(), logger, w, http.StatusConflict, fmt.Errorf("token request already exists")) - } else { - writeError(r.Context(), logger, w, http.StatusInternalServerError, err) + return types2.NewErrHttp(http.StatusConflict, "token request already exists") } - return + return types2.NewErrHttp(http.StatusInternalServerError, err.Error()) } if reqObj.ServiceName != "" { - writeResponse(r.Context(), logger, w, map[string]any{"token-path": fmt.Sprintf("%s/oauth/start/%s/%s", s.baseURL, reqObj.ID, oauthProvider.Slug)}) - return + return apiContext.Write(map[string]any{"token-path": fmt.Sprintf("%s/oauth/start/%s/%s", s.baseURL, reqObj.ID, oauthProvider.Slug)}) } - - writeResponse(r.Context(), logger, w, map[string]any{"token-path": fmt.Sprintf("%s/login?id=%s", s.uiURL, reqObj.ID)}) + return apiContext.Write(map[string]any{"token-path": fmt.Sprintf("%s/login?id=%s", s.uiURL, reqObj.ID)}) } -func (s *Server) redirectForTokenRequest(w http.ResponseWriter, r *http.Request) { - logger := kcontext.GetLogger(r.Context()) - id := r.PathValue("id") - service := r.PathValue("service") +func (s *Server) redirectForTokenRequest(apiContext api.Context) error { + id := apiContext.PathValue("id") + service := apiContext.PathValue("service") oauthProvider := new(types.AuthProvider) tokenReq := new(types.TokenRequest) - if err := s.db.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error { + if err := s.db.WithContext(apiContext.Context()).Transaction(func(tx *gorm.DB) error { // Ensure the OAuth provider exists, if one was provided. if err := tx.Where("slug = ?", service).Where("disabled IS NULL OR disabled != ?", true).First(oauthProvider).Error; err != nil { return fmt.Errorf("failed to find oauth provider %q: %v", service, err) @@ -199,23 +175,19 @@ func (s *Server) redirectForTokenRequest(w http.ResponseWriter, r *http.Request) return tx.Where("id = ?", id).First(tokenReq).Error }); err != nil { - logger.DebugContext(r.Context(), "failed to create token", "error", err) if errors.Is(err, gorm.ErrRecordNotFound) { - writeError(r.Context(), logger, w, http.StatusNotFound, fmt.Errorf("token or service not found")) - } else { - writeError(r.Context(), logger, w, http.StatusInternalServerError, err) + return types2.NewErrNotFound("token or service not found") } - return + return types2.NewErrHttp(http.StatusInternalServerError, err.Error()) } - writeResponse(r.Context(), logger, w, map[string]any{"token-path": fmt.Sprintf("%s/oauth/start/%s/%s", s.baseURL, tokenReq.ID, oauthProvider.Slug)}) + return apiContext.Write(map[string]any{"token-path": fmt.Sprintf("%s/oauth/start/%s/%s", s.baseURL, tokenReq.ID, oauthProvider.Slug)}) } -func (s *Server) checkForToken(w http.ResponseWriter, r *http.Request) { - logger := kcontext.GetLogger(r.Context()) +func (s *Server) checkForToken(apiContext api.Context) error { tr := new(types.TokenRequest) - if err := s.db.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error { - if err := tx.Where("id = ?", r.PathValue("id")).First(tr).Error; err != nil { + if err := s.db.WithContext(apiContext.Context()).Transaction(func(tx *gorm.DB) error { + if err := tx.Where("id = ?", apiContext.PathValue("id")).First(tr).Error; err != nil { return err } @@ -224,16 +196,17 @@ func (s *Server) checkForToken(w http.ResponseWriter, r *http.Request) { } return nil }); err != nil || tr.ID == "" { - logger.DebugContext(r.Context(), "failed to check token retrieved", "error", err) - writeError(r.Context(), logger, w, http.StatusNotFound, fmt.Errorf("not found")) - return + return types2.NewErrNotFound("not found") } if tr.Error != "" { - writeResponse(r.Context(), logger, w, map[string]any{"error": tr.Error}) + return apiContext.Write(map[string]any{"error": tr.Error}) } - writeResponse(r.Context(), logger, w, refreshTokenResponse{Token: tr.Token, ExpiresAt: tr.ExpiresAt}) + return apiContext.Write(refreshTokenResponse{ + Token: tr.Token, + ExpiresAt: tr.ExpiresAt, + }) } func (s *Server) createState(ctx context.Context, id string) (string, string, error) { diff --git a/pkg/services/config.go b/pkg/services/config.go index 77f7d20e7..5604adf20 100644 --- a/pkg/services/config.go +++ b/pkg/services/config.go @@ -3,7 +3,6 @@ package services import ( "context" "fmt" - "net/http" "os" "path/filepath" @@ -14,11 +13,13 @@ import ( "github.com/gptscript-ai/go-gptscript" "github.com/gptscript-ai/gptscript/pkg/sdkserver" "github.com/otto8-ai/otto8/pkg/aihelper" - "github.com/otto8-ai/otto8/pkg/api" + "github.com/otto8-ai/otto8/pkg/api/authn" + "github.com/otto8-ai/otto8/pkg/api/authz" + "github.com/otto8-ai/otto8/pkg/api/server" "github.com/otto8-ai/otto8/pkg/events" "github.com/otto8-ai/otto8/pkg/gateway/client" "github.com/otto8-ai/otto8/pkg/gateway/db" - "github.com/otto8-ai/otto8/pkg/gateway/server" + gserver "github.com/otto8-ai/otto8/pkg/gateway/server" "github.com/otto8-ai/otto8/pkg/invoke" "github.com/otto8-ai/otto8/pkg/jwt" "github.com/otto8-ai/otto8/pkg/proxy" @@ -31,15 +32,13 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/request/union" - "k8s.io/apiserver/pkg/authentication/user" - // Setup baaah logging _ "github.com/acorn-io/baaah/pkg/logrus" ) type ( AuthConfig proxy.Config - GatewayConfig server.Options + GatewayConfig gserver.Options ) type Config struct { @@ -63,12 +62,12 @@ type Services struct { GPTClient *gptscript.GPTScript Invoker *invoke.Invoker TokenServer *jwt.TokenService - APIServer *api.Server + APIServer *server.Server WorkspaceClient *wclient.Client AIHelper *aihelper.AIHelper Started chan struct{} ProxyServer *proxy.Proxy - GatewayServer *server.Server + GatewayServer *gserver.Server } func newGPTScript(ctx context.Context) (*gptscript.GPTScript, error) { @@ -92,31 +91,6 @@ func newGPTScript(ctx context.Context) (*gptscript.GPTScript, error) { }) } -type noAuth struct { -} - -func (n noAuth) AuthenticateRequest(*http.Request) (*authenticator.Response, bool, error) { - return &authenticator.Response{ - User: &user.DefaultInfo{ - Name: "nobody", - Groups: []string{"admin", "system:authenticated"}, - }, - }, true, nil -} - -type anonymous struct { -} - -func (n anonymous) AuthenticateRequest(*http.Request) (*authenticator.Response, bool, error) { - return &authenticator.Response{ - User: &user.DefaultInfo{ - UID: "anonymous", - Name: "anonymous", - Groups: []string{"system:unauthenticated"}, - }, - }, true, nil -} - func New(ctx context.Context, config Config) (*Services, error) { system.SetBinToSelf() @@ -152,7 +126,7 @@ func New(ctx context.Context, config Config) (*Services, error) { return nil, err } - gatewayServer, err := server.New(ctx, gatewayDB, config.AuthAdminEmails, server.Options(config.GatewayConfig)) + gatewayServer, err := gserver.New(ctx, gatewayDB, config.AuthAdminEmails, gserver.Options(config.GatewayConfig)) if err != nil { return nil, err } @@ -180,7 +154,7 @@ func New(ctx context.Context, config Config) (*Services, error) { // Add gateway user info authenticators = client.NewUserDecorator(authenticators, gatewayClient) // Add anonymous user authenticator - authenticators = union.New(authenticators, anonymous{}) + authenticators = union.New(authenticators, authn.Anonymous{}) } else { // "Authentication Disabled" flow @@ -188,7 +162,7 @@ func New(ctx context.Context, config Config) (*Services, error) { authenticators = client.NewUserDecorator(authenticators, gatewayClient) // Add no auth authenticator - authenticators = union.New(authenticators, noAuth{}) + authenticators = union.New(authenticators, authn.NoAuth{}) } var ( @@ -207,7 +181,7 @@ func New(ctx context.Context, config Config) (*Services, error) { StorageClient: storageClient, Router: r, GPTClient: c, - APIServer: api.NewServer(storageClient, c, gatewayClient, tokenServer, authenticators), + APIServer: server.NewServer(storageClient, c, authn.NewAuthenticator(authenticators), authz.NewAuthorizer()), TokenServer: tokenServer, WorkspaceClient: workspaceClient, Invoker: invoke.NewInvoker(storageClient, c, tokenServer, workspaceClient, events),