Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add active subscription to context #69

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ type AuthenticatedUser struct {
OrganizationIDs []string
// AuthenticationType is the type of authentication used to authenticate the user (JWT, PAT, API Token)
AuthenticationType AuthenticationType
// ActiveSubscription is the active subscription for the user
ActiveSubscription bool
}

// GetContextName returns the name of the context key
Expand Down Expand Up @@ -310,7 +312,7 @@ func AddOrganizationIDToContext(ctx context.Context, orgID string) error {
return addOrganizationIDsToEchoContext(ec, orgID)
}

// getOrganizationIDsFromEchoContext appends an authorized organization ID to the echo context
// addOrganizationIDsToEchoContext appends an authorized organization ID to the echo context
func addOrganizationIDsToEchoContext(c echo.Context, orgID string) error {
if v := c.Get(ContextAuthenticatedUser.name); v != nil {
a, ok := v.(*AuthenticatedUser)
Expand Down Expand Up @@ -370,3 +372,53 @@ func GetRefreshTokenContext(c context.Context) (string, error) {

return token, nil
}

// AddSubscriptionToContext appends a subscription to the context
func AddSubscriptionToContext(ctx context.Context, subscription bool) error {
ec, err := echocontext.EchoContextFromContext(ctx)
if err != nil {
return err
}

return addSubscriptionToEchoContext(ec, subscription)
}

// addSubscriptionToEchoContext appends a subscription to the echo context
func addSubscriptionToEchoContext(c echo.Context, subscription bool) error {
if v := c.Get(ContextAuthenticatedUser.name); v != nil {
a, ok := v.(*AuthenticatedUser)
if !ok {
return ErrNoAuthUser
}

a.ActiveSubscription = subscription

return nil
}

return ErrNoAuthUser
}

// getSubscriptionFromContext returns the active subscription from the echo context
func getSubscriptionFromContext(c echo.Context) (bool, error) {
if v := c.Get(ContextAuthenticatedUser.name); v != nil {
a, ok := v.(*AuthenticatedUser)
if !ok {
return false, ErrNoAuthUser
}

return a.ActiveSubscription, nil
}

return false, nil
}

// GetSubscriptionFromContext returns the active subscription from the context
func GetSubscriptionFromContext(ctx context.Context) (bool, error) {
ec, err := echocontext.EchoContextFromContext(ctx)
if err != nil {
return false, err
}

return getSubscriptionFromContext(ec)
}
58 changes: 58 additions & 0 deletions auth/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,61 @@ func TestGetOrganizationIDFromContext(t *testing.T) {
})
}
}

func TestGetSubscriptionFromContext(t *testing.T) {
validsubscription := true
invalidsubscription := false

ec := echocontext.NewTestEchoContext()

basicContext := context.WithValue(ec.Request().Context(), echocontext.EchoContextKey, ec)

ec.SetRequest(ec.Request().WithContext(basicContext))

invalidCtx, err := auth.NewTestContextWithValidUser(ulids.Null.String())
if err != nil {
t.Fatal()
}

validCtx, err := auth.NewTestContextWithValidUser(ulids.New().String())
if err != nil {
t.Fatal()
}

if err := auth.AddSubscriptionToContext(validCtx, true); err != nil {
t.Fatal(err)
}

testCases := []struct {
name string
ctx context.Context
expect bool
}{
{
name: "happy path",
ctx: invalidCtx,
expect: invalidsubscription,
},
{
name: "MITB BABBYYYYY",
ctx: validCtx,
expect: validsubscription,
},
}

for _, tc := range testCases {
t.Run("Get "+tc.name, func(t *testing.T) {
got, err := auth.GetSubscriptionFromContext(tc.ctx)

assert.NoError(t, err)

if tc.expect == validsubscription {
assert.Equal(t, validsubscription, got)
}

if tc.expect == invalidsubscription {
assert.Equal(t, invalidsubscription, got)
}
Comment on lines +180 to +186
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to simplify this to

Suggested change
if tc.expect == validsubscription {
assert.Equal(t, validsubscription, got)
}
if tc.expect == invalidsubscription {
assert.Equal(t, invalidsubscription, got)
}
assert.Equal(t, tc.expect, got)

})
}
}
12 changes: 2 additions & 10 deletions auth/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,26 @@ import (
var (
// ErrNoClaims is returned when no claims are found on the request context
ErrNoClaims = errors.New("no claims found on the request context")

// ErrNoUserInfo is returned when no user info is found on the request context
ErrNoUserInfo = errors.New("no user info found on the request context")

// ErrNoAuthUser is returned when no authenticated user is found on the request context
ErrNoAuthUser = errors.New("could not identify authenticated user in request")

// ErrUnverifiedUser is returned when the user is not verified
ErrUnverifiedUser = errors.New("user is not verified")

// ErrParseBearer is returned when the bearer token could not be parsed from the authorization header
ErrParseBearer = errors.New("could not parse bearer token from authorization header")

// ErrNoAuthorization is returned when no authorization header is found in the request
ErrNoAuthorization = errors.New("no authorization header in request")

// ErrNoRequest is returned when no request is found on the context
ErrNoRequest = errors.New("no request found on the context")

// ErrNoRefreshToken is returned when no refresh token is found on the request
ErrNoRefreshToken = errors.New("no refresh token available on request")

// ErrRefreshDisabled is returned when re-authentication with refresh tokens is disabled
ErrRefreshDisabled = errors.New("re-authentication with refresh tokens disabled")

// ErrUnableToConstructValidator is returned when the validator cannot be constructed
ErrUnableToConstructValidator = errors.New("unable to construct validator")

// ErrPasswordTooWeak is returned when the password is too weak
ErrPasswordTooWeak = errors.New("password is too weak: use a combination of upper and lower case letters, numbers, and special characters")
// ErrCouldNotFetchSubscription is returned when the subscription could not be fetched
ErrCouldNotFetchSubscription = errors.New("could not fetch subscription")
)
34 changes: 33 additions & 1 deletion auth/test_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (

"github.com/golang-jwt/jwt/v5"
echo "github.com/theopenlane/echox"

"github.com/theopenlane/echox/middleware/echocontext"
"github.com/theopenlane/utils/ulids"

"github.com/theopenlane/iam/tokens"
)
Expand Down Expand Up @@ -99,6 +99,7 @@ func NewTestEchoContextWithOrgID(sub, orgID string) (echo.Context, error) {
return ec, nil
}

// NewTestContextWithOrgID creates a context with a fake orgID for testing purposes only (why all caps jeez keep it down)
func NewTestContextWithOrgID(sub, orgID string) (context.Context, error) {
ec, err := NewTestEchoContextWithOrgID(sub, orgID)
if err != nil {
Expand All @@ -111,3 +112,34 @@ func NewTestContextWithOrgID(sub, orgID string) (context.Context, error) {

return reqCtx, nil
}

// NewTestEchoContextWithOrgID creates an echo context with a fake orgID for testing purposes ONLY
func NewTestEchoContextWithSubscription(subscription bool) (echo.Context, error) {
ec := echocontext.NewTestEchoContext()

claims := newValidClaimsOrgID(ulids.New().String(), ulids.New().String())

SetAuthenticatedUserContext(ec, &AuthenticatedUser{
SubjectID: claims.UserID,
OrganizationID: claims.OrgID,
OrganizationIDs: []string{claims.OrgID},
AuthenticationType: "jwt",
ActiveSubscription: subscription,
})

return ec, nil
}

// NewTestContextWithOrgID creates a context with a fake orgID for testing purposes only (why all caps jeez keep it down)
func NewTestContextWithSubscription(subscription bool) (context.Context, error) {
ec, err := NewTestEchoContextWithSubscription(subscription)
if err != nil {
return nil, err
}

reqCtx := context.WithValue(ec.Request().Context(), echocontext.EchoContextKey, ec)

ec.SetRequest(ec.Request().WithContext(reqCtx))

return reqCtx, nil
}