From 06d898fc60fabbe44a9cc09567a3ca82aea0abb9 Mon Sep 17 00:00:00 2001 From: Vedant Radhakrishnan Date: Fri, 1 Sep 2023 15:20:19 -0700 Subject: [PATCH 1/2] Code change to support public 'SetPrincipal' method, with a security check to ensure single use --- server/auth/auth.go | 50 ++++++++++++++++++++++++++++++++++++++++ server/auth/auth_test.go | 48 ++++++++++++++++++++++++++++++++++++++ server/decorators.go | 25 ++++++++++++++------ 3 files changed, 116 insertions(+), 7 deletions(-) diff --git a/server/auth/auth.go b/server/auth/auth.go index 4acfe32..d6ea77d 100644 --- a/server/auth/auth.go +++ b/server/auth/auth.go @@ -8,11 +8,61 @@ import ( "io/ioutil" "net/http" "strings" + "sync" "time" + "github.com/gorilla/context" "github.com/pinterest/knox" ) +type contextKey int + +const ( + principalCtxKey contextKey = iota +) + +type PrincipalContext interface { + SetCurrentPrincipal(principal knox.Principal) + GetCurrentPrincipal() knox.Principal +} + +type principalContext struct { + principalCtxKey contextKey + request *http.Request + once sync.Once + invocationCount int +} + +func NewPrincipalContext(request *http.Request) PrincipalContext { + return &principalContext{ + principalCtxKey, + request, + sync.Once{}, + 0, + } +} + +func (ctx *principalContext) GetCurrentPrincipal() knox.Principal { + if rv := context.Get(ctx.request, principalCtxKey); rv != nil { + return rv.(knox.Principal) + } + return nil +} + +func (ctx *principalContext) SetCurrentPrincipal(principal knox.Principal) { + if ctx.invocationCount > 0 { + panic("SetPrincipal was called more than once during the lifetime of the HTTP request") + } + ctx.setPrincipalInner(ctx.request, principal) +} + +func (ctx *principalContext) setPrincipalInner(httpRequest *http.Request, principal knox.Principal) { + ctx.once.Do(func() { + context.Set(httpRequest, ctx.principalCtxKey, principal) + ctx.invocationCount += 1 + }) +} + // Provider is used for authenticating requests via the authentication decorator. type Provider interface { Name() string diff --git a/server/auth/auth_test.go b/server/auth/auth_test.go index 1773857..4828ce7 100644 --- a/server/auth/auth_test.go +++ b/server/auth/auth_test.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/base64" "net/http" + "reflect" "strings" "time" @@ -13,6 +14,53 @@ import ( "github.com/pinterest/knox" ) +func TestPrincipalContext(t *testing.T) { + originalPrincipal := NewUser("test", []string{"returntrue"}).(user) + newPrincipal := NewUser("hacker", []string{"returntrue"}).(user) + + req, err := http.NewRequest("GET", "http://localhost/", nil) + if err != nil { + t.Fatal(err.Error()) + } + + ctx := NewPrincipalContext(req) + + currentPrincipal := ctx.GetCurrentPrincipal() + if currentPrincipal != nil { + t.Error("Current principal was expected to be null") + } + + ctx.SetCurrentPrincipal(originalPrincipal) + currentPrincipal = ctx.GetCurrentPrincipal().(user) + + if !reflect.DeepEqual(currentPrincipal, originalPrincipal) { + t.Errorf( + "Current principal was expected to be user: '%s'. Instead got: '%s'", + originalPrincipal.GetID(), + currentPrincipal.GetID(), + ) + } + + panicRecovery := func() { + recoveryValue := recover() + if recoveryValue == nil { + t.Error("Expected 'SetCurrentPrincipal' to panic on second call") + } + + currentPrincipal = ctx.GetCurrentPrincipal() + if !reflect.DeepEqual(currentPrincipal, originalPrincipal) { + t.Errorf( + "Current principal was expected to be user: '%s'. Instead got: '%s'", + originalPrincipal.GetID(), + currentPrincipal.GetID(), + ) + } + } + defer panicRecovery() + + ctx.SetCurrentPrincipal(newPrincipal) +} + func TestUserCanAccess(t *testing.T) { u := NewUser("test", []string{"returntrue"}) a1 := knox.Access{ID: "test", AccessType: knox.Write, Type: knox.User} diff --git a/server/decorators.go b/server/decorators.go index 1204ad9..89549b3 100644 --- a/server/decorators.go +++ b/server/decorators.go @@ -35,14 +35,16 @@ func setAPIError(r *http.Request, val *HTTPError) { // GetPrincipal gets the principal authenticated through the authentication decorator func GetPrincipal(r *http.Request) knox.Principal { - if rv := context.Get(r, principalContext); rv != nil { - return rv.(knox.Principal) - } - return nil + ctx := getOrInitializePrincipalContext(r) + return ctx.GetCurrentPrincipal() } -func setPrincipal(r *http.Request, val knox.Principal) { - context.Set(r, principalContext, val) +// SetPrincipal sets the principal authenticated through the authentication decorator. +// For security reasons, this method will only set the Principal in the context for +// the first invocation. Subsequent invocations WILL cause a panic. +func SetPrincipal(r *http.Request, val knox.Principal) { + ctx := getOrInitializePrincipalContext(r) + ctx.SetCurrentPrincipal(val) } // GetParams gets the parameters for the request through the parameters context. @@ -68,6 +70,15 @@ func setDB(r *http.Request, val KeyManager) { context.Set(r, dbContext, val) } +func getOrInitializePrincipalContext(r *http.Request) auth.PrincipalContext { + if ctx := context.Get(r, principalContext); ctx != nil { + return ctx.(auth.PrincipalContext) + } + ctx := auth.NewPrincipalContext(r) + context.Set(r, principalContext, ctx) + return ctx +} + // GetRouteID gets the short form function name for the route being called. Used for logging/metrics. func GetRouteID(r *http.Request) string { if rv := context.Get(r, idContext); rv != nil { @@ -223,7 +234,7 @@ func Authentication(providers []auth.Provider) func(http.HandlerFunc) http.Handl return } - setPrincipal(r, knox.NewPrincipalMux(defaultPrincipal, allPrincipals)) + SetPrincipal(r, knox.NewPrincipalMux(defaultPrincipal, allPrincipals)) f(w, r) return } From ede2f972efd4fca9e7c8fb38916b448ffd02652b Mon Sep 17 00:00:00 2001 From: Vedant Radhakrishnan Date: Fri, 1 Sep 2023 16:02:37 -0700 Subject: [PATCH 2/2] Make writeErr, and writeData methods public --- server/api.go | 14 +++++++++----- server/api_test.go | 2 +- server/decorators.go | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/server/api.go b/server/api.go index bdccfc3..dce3c6d 100644 --- a/server/api.go +++ b/server/api.go @@ -105,7 +105,7 @@ func GetRouter( m := NewKeyManager(cryptor, db) - r.NotFoundHandler = setupRoute("404", m)(decorator(writeErr(errF(knox.NotFoundCode, "")))) + r.NotFoundHandler = setupRoute("404", m)(decorator(WriteErr(errF(knox.NotFoundCode, "")))) for _, route := range allRoutes { addRoute(r, route, decorator, m) @@ -226,7 +226,9 @@ type Route struct { Parameters []Parameter } -func writeErr(apiErr *HTTPError) http.HandlerFunc { +// WriteErr returns a function that can encode error information and set an +// HTTP error response code in the specified HTTP response writer +func WriteErr(apiErr *HTTPError) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { resp := new(knox.Response) hostname, err := os.Hostname() @@ -249,7 +251,9 @@ func writeErr(apiErr *HTTPError) http.HandlerFunc { } } -func writeData(w http.ResponseWriter, data interface{}) { +// WriteData returns a function that can write arbitrary data to the specified +// HTTP response writer +func WriteData(w http.ResponseWriter, data interface{}) { r := new(knox.Response) r.Message = "" r.Code = knox.OKCode @@ -275,9 +279,9 @@ func (r Route) ServeHTTP(w http.ResponseWriter, req *http.Request) { data, err := r.Handler(db, principal, ps) if err != nil { - writeErr(err)(w, req) + WriteErr(err)(w, req) } else { - writeData(w, data) + WriteData(w, data) } } diff --git a/server/api_test.go b/server/api_test.go index d43c3f0..e5314a0 100644 --- a/server/api_test.go +++ b/server/api_test.go @@ -174,7 +174,7 @@ func checkinternalServerErrorResponse(t *testing.T, w *httptest.ResponseRecorder func TestErrorHandler(t *testing.T) { testErr := errF(knox.InternalServerErrorCode, "") - handler := writeErr(testErr) + handler := WriteErr(testErr) w := httptest.NewRecorder() handler(w, nil) diff --git a/server/decorators.go b/server/decorators.go index 89549b3..1df68e8 100644 --- a/server/decorators.go +++ b/server/decorators.go @@ -230,7 +230,7 @@ func Authentication(providers []auth.Provider) func(http.HandlerFunc) http.Handl } } if defaultPrincipal == nil { - writeErr(errF(knox.UnauthenticatedCode, errReturned.Error()))(w, r) + WriteErr(errF(knox.UnauthenticatedCode, errReturned.Error()))(w, r) return }