Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code change to support public 'SetPrincipal' method #98

Merged
merged 2 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions server/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,61 @@ import (
"io/ioutil"
"net/http"
"strings"
"sync"
"time"

"github.com/gorilla/context"
"github.com/pinterest/knox"
)

type contextKey int

const (
principalCtxKey contextKey = iota
)

type PrincipalContext interface {
SetCurrentPrincipal(principal knox.Principal)
GetCurrentPrincipal() knox.Principal
}

type principalContext struct {
principalCtxKey contextKey
request *http.Request
once sync.Once
invocationCount int
}

func NewPrincipalContext(request *http.Request) PrincipalContext {
return &principalContext{
principalCtxKey,
request,
sync.Once{},
0,
}
}

func (ctx *principalContext) GetCurrentPrincipal() knox.Principal {
if rv := context.Get(ctx.request, principalCtxKey); rv != nil {
return rv.(knox.Principal)
}
return nil
}

func (ctx *principalContext) SetCurrentPrincipal(principal knox.Principal) {
if ctx.invocationCount > 0 {
panic("SetPrincipal was called more than once during the lifetime of the HTTP request")
}
krockpot marked this conversation as resolved.
Show resolved Hide resolved
ctx.setPrincipalInner(ctx.request, principal)
}

func (ctx *principalContext) setPrincipalInner(httpRequest *http.Request, principal knox.Principal) {
ctx.once.Do(func() {
context.Set(httpRequest, ctx.principalCtxKey, principal)
ctx.invocationCount += 1
})
}

// Provider is used for authenticating requests via the authentication decorator.
type Provider interface {
Name() string
Expand Down
48 changes: 48 additions & 0 deletions server/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/x509"
"encoding/base64"
"net/http"
"reflect"
"strings"
"time"

Expand All @@ -13,6 +14,53 @@ import (
"github.com/pinterest/knox"
)

func TestPrincipalContext(t *testing.T) {
originalPrincipal := NewUser("test", []string{"returntrue"}).(user)
newPrincipal := NewUser("hacker", []string{"returntrue"}).(user)

req, err := http.NewRequest("GET", "http://localhost/", nil)
if err != nil {
t.Fatal(err.Error())
}

ctx := NewPrincipalContext(req)

currentPrincipal := ctx.GetCurrentPrincipal()
if currentPrincipal != nil {
t.Error("Current principal was expected to be null")
}

ctx.SetCurrentPrincipal(originalPrincipal)
currentPrincipal = ctx.GetCurrentPrincipal().(user)

if !reflect.DeepEqual(currentPrincipal, originalPrincipal) {
t.Errorf(
"Current principal was expected to be user: '%s'. Instead got: '%s'",
originalPrincipal.GetID(),
currentPrincipal.GetID(),
)
}

panicRecovery := func() {
recoveryValue := recover()
if recoveryValue == nil {
t.Error("Expected 'SetCurrentPrincipal' to panic on second call")
}

currentPrincipal = ctx.GetCurrentPrincipal()
if !reflect.DeepEqual(currentPrincipal, originalPrincipal) {
t.Errorf(
"Current principal was expected to be user: '%s'. Instead got: '%s'",
originalPrincipal.GetID(),
currentPrincipal.GetID(),
)
}
}
defer panicRecovery()

ctx.SetCurrentPrincipal(newPrincipal)
}

func TestUserCanAccess(t *testing.T) {
u := NewUser("test", []string{"returntrue"})
a1 := knox.Access{ID: "test", AccessType: knox.Write, Type: knox.User}
Expand Down
25 changes: 18 additions & 7 deletions server/decorators.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ func setAPIError(r *http.Request, val *HTTPError) {

// GetPrincipal gets the principal authenticated through the authentication decorator
func GetPrincipal(r *http.Request) knox.Principal {
if rv := context.Get(r, principalContext); rv != nil {
return rv.(knox.Principal)
}
return nil
ctx := getOrInitializePrincipalContext(r)
return ctx.GetCurrentPrincipal()
}

func setPrincipal(r *http.Request, val knox.Principal) {
context.Set(r, principalContext, val)
// SetPrincipal sets the principal authenticated through the authentication decorator.
// For security reasons, this method will only set the Principal in the context for
// the first invocation. Subsequent invocations WILL cause a panic.
func SetPrincipal(r *http.Request, val knox.Principal) {
ctx := getOrInitializePrincipalContext(r)
ctx.SetCurrentPrincipal(val)
}

// GetParams gets the parameters for the request through the parameters context.
Expand All @@ -68,6 +70,15 @@ func setDB(r *http.Request, val KeyManager) {
context.Set(r, dbContext, val)
}

func getOrInitializePrincipalContext(r *http.Request) auth.PrincipalContext {
if ctx := context.Get(r, principalContext); ctx != nil {
return ctx.(auth.PrincipalContext)
}
ctx := auth.NewPrincipalContext(r)
context.Set(r, principalContext, ctx)
return ctx
}

// GetRouteID gets the short form function name for the route being called. Used for logging/metrics.
func GetRouteID(r *http.Request) string {
if rv := context.Get(r, idContext); rv != nil {
Expand Down Expand Up @@ -223,7 +234,7 @@ func Authentication(providers []auth.Provider) func(http.HandlerFunc) http.Handl
return
}

setPrincipal(r, knox.NewPrincipalMux(defaultPrincipal, allPrincipals))
SetPrincipal(r, knox.NewPrincipalMux(defaultPrincipal, allPrincipals))
f(w, r)
return
}
Expand Down
Loading