Skip to content

Commit

Permalink
Code change to support public 'SetPrincipal' method (#98)
Browse files Browse the repository at this point in the history
* Code change to support public 'SetPrincipal' method, with a security check to ensure single use

* Make writeErr, and writeData methods public
  • Loading branch information
vedantr authored Sep 1, 2023
1 parent 2e70fd2 commit f6207c8
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 14 deletions.
14 changes: 9 additions & 5 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func GetRouter(

m := NewKeyManager(cryptor, db)

r.NotFoundHandler = setupRoute("404", m)(decorator(writeErr(errF(knox.NotFoundCode, ""))))
r.NotFoundHandler = setupRoute("404", m)(decorator(WriteErr(errF(knox.NotFoundCode, ""))))

for _, route := range allRoutes {
addRoute(r, route, decorator, m)
Expand Down Expand Up @@ -226,7 +226,9 @@ type Route struct {
Parameters []Parameter
}

func writeErr(apiErr *HTTPError) http.HandlerFunc {
// WriteErr returns a function that can encode error information and set an
// HTTP error response code in the specified HTTP response writer
func WriteErr(apiErr *HTTPError) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
resp := new(knox.Response)
hostname, err := os.Hostname()
Expand All @@ -249,7 +251,9 @@ func writeErr(apiErr *HTTPError) http.HandlerFunc {
}
}

func writeData(w http.ResponseWriter, data interface{}) {
// WriteData returns a function that can write arbitrary data to the specified
// HTTP response writer
func WriteData(w http.ResponseWriter, data interface{}) {
r := new(knox.Response)
r.Message = ""
r.Code = knox.OKCode
Expand All @@ -275,9 +279,9 @@ func (r Route) ServeHTTP(w http.ResponseWriter, req *http.Request) {
data, err := r.Handler(db, principal, ps)

if err != nil {
writeErr(err)(w, req)
WriteErr(err)(w, req)
} else {
writeData(w, data)
WriteData(w, data)
}
}

Expand Down
2 changes: 1 addition & 1 deletion server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func checkinternalServerErrorResponse(t *testing.T, w *httptest.ResponseRecorder

func TestErrorHandler(t *testing.T) {
testErr := errF(knox.InternalServerErrorCode, "")
handler := writeErr(testErr)
handler := WriteErr(testErr)

w := httptest.NewRecorder()
handler(w, nil)
Expand Down
50 changes: 50 additions & 0 deletions server/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,61 @@ import (
"io/ioutil"
"net/http"
"strings"
"sync"
"time"

"github.com/gorilla/context"
"github.com/pinterest/knox"
)

type contextKey int

const (
principalCtxKey contextKey = iota
)

type PrincipalContext interface {
SetCurrentPrincipal(principal knox.Principal)
GetCurrentPrincipal() knox.Principal
}

type principalContext struct {
principalCtxKey contextKey
request *http.Request
once sync.Once
invocationCount int
}

func NewPrincipalContext(request *http.Request) PrincipalContext {
return &principalContext{
principalCtxKey,
request,
sync.Once{},
0,
}
}

func (ctx *principalContext) GetCurrentPrincipal() knox.Principal {
if rv := context.Get(ctx.request, principalCtxKey); rv != nil {
return rv.(knox.Principal)
}
return nil
}

func (ctx *principalContext) SetCurrentPrincipal(principal knox.Principal) {
if ctx.invocationCount > 0 {
panic("SetPrincipal was called more than once during the lifetime of the HTTP request")
}
ctx.setPrincipalInner(ctx.request, principal)
}

func (ctx *principalContext) setPrincipalInner(httpRequest *http.Request, principal knox.Principal) {
ctx.once.Do(func() {
context.Set(httpRequest, ctx.principalCtxKey, principal)
ctx.invocationCount += 1
})
}

// Provider is used for authenticating requests via the authentication decorator.
type Provider interface {
Name() string
Expand Down
48 changes: 48 additions & 0 deletions server/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/x509"
"encoding/base64"
"net/http"
"reflect"
"strings"
"time"

Expand All @@ -13,6 +14,53 @@ import (
"github.com/pinterest/knox"
)

func TestPrincipalContext(t *testing.T) {
originalPrincipal := NewUser("test", []string{"returntrue"}).(user)
newPrincipal := NewUser("hacker", []string{"returntrue"}).(user)

req, err := http.NewRequest("GET", "http://localhost/", nil)
if err != nil {
t.Fatal(err.Error())
}

ctx := NewPrincipalContext(req)

currentPrincipal := ctx.GetCurrentPrincipal()
if currentPrincipal != nil {
t.Error("Current principal was expected to be null")
}

ctx.SetCurrentPrincipal(originalPrincipal)
currentPrincipal = ctx.GetCurrentPrincipal().(user)

if !reflect.DeepEqual(currentPrincipal, originalPrincipal) {
t.Errorf(
"Current principal was expected to be user: '%s'. Instead got: '%s'",
originalPrincipal.GetID(),
currentPrincipal.GetID(),
)
}

panicRecovery := func() {
recoveryValue := recover()
if recoveryValue == nil {
t.Error("Expected 'SetCurrentPrincipal' to panic on second call")
}

currentPrincipal = ctx.GetCurrentPrincipal()
if !reflect.DeepEqual(currentPrincipal, originalPrincipal) {
t.Errorf(
"Current principal was expected to be user: '%s'. Instead got: '%s'",
originalPrincipal.GetID(),
currentPrincipal.GetID(),
)
}
}
defer panicRecovery()

ctx.SetCurrentPrincipal(newPrincipal)
}

func TestUserCanAccess(t *testing.T) {
u := NewUser("test", []string{"returntrue"})
a1 := knox.Access{ID: "test", AccessType: knox.Write, Type: knox.User}
Expand Down
27 changes: 19 additions & 8 deletions server/decorators.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ func setAPIError(r *http.Request, val *HTTPError) {

// GetPrincipal gets the principal authenticated through the authentication decorator
func GetPrincipal(r *http.Request) knox.Principal {
if rv := context.Get(r, principalContext); rv != nil {
return rv.(knox.Principal)
}
return nil
ctx := getOrInitializePrincipalContext(r)
return ctx.GetCurrentPrincipal()
}

func setPrincipal(r *http.Request, val knox.Principal) {
context.Set(r, principalContext, val)
// SetPrincipal sets the principal authenticated through the authentication decorator.
// For security reasons, this method will only set the Principal in the context for
// the first invocation. Subsequent invocations WILL cause a panic.
func SetPrincipal(r *http.Request, val knox.Principal) {
ctx := getOrInitializePrincipalContext(r)
ctx.SetCurrentPrincipal(val)
}

// GetParams gets the parameters for the request through the parameters context.
Expand All @@ -68,6 +70,15 @@ func setDB(r *http.Request, val KeyManager) {
context.Set(r, dbContext, val)
}

func getOrInitializePrincipalContext(r *http.Request) auth.PrincipalContext {
if ctx := context.Get(r, principalContext); ctx != nil {
return ctx.(auth.PrincipalContext)
}
ctx := auth.NewPrincipalContext(r)
context.Set(r, principalContext, ctx)
return ctx
}

// GetRouteID gets the short form function name for the route being called. Used for logging/metrics.
func GetRouteID(r *http.Request) string {
if rv := context.Get(r, idContext); rv != nil {
Expand Down Expand Up @@ -219,11 +230,11 @@ func Authentication(providers []auth.Provider) func(http.HandlerFunc) http.Handl
}
}
if defaultPrincipal == nil {
writeErr(errF(knox.UnauthenticatedCode, errReturned.Error()))(w, r)
WriteErr(errF(knox.UnauthenticatedCode, errReturned.Error()))(w, r)
return
}

setPrincipal(r, knox.NewPrincipalMux(defaultPrincipal, allPrincipals))
SetPrincipal(r, knox.NewPrincipalMux(defaultPrincipal, allPrincipals))
f(w, r)
return
}
Expand Down

0 comments on commit f6207c8

Please sign in to comment.