Skip to content

Commit

Permalink
Merge pull request #36 from kostrse/rate-limit-session
Browse files Browse the repository at this point in the history
Rate limit sessions for throttling control
  • Loading branch information
aangelisc authored Mar 28, 2024
2 parents c955531 + 9fe63a8 commit f459ab1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 4 deletions.
27 changes: 23 additions & 4 deletions azhttpclient/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func AzureMiddleware(authOpts *AuthOptions, credentials azcredentials.AzureCrede
return httpclient.NamedMiddlewareFunc(azureMiddlewareName, func(clientOpts httpclient.Options, next http.RoundTripper) http.RoundTripper {
var err error
var tokenProvider aztokenprovider.AzureTokenProvider = nil
var sessionProvider *userSessionProvider = nil

if tokenProviderFactory, ok := authOpts.customProviders[credentials.AzureAuthType()]; ok && tokenProviderFactory != nil {
tokenProvider, err = tokenProviderFactory(authOpts.settings, credentials)
Expand All @@ -27,31 +28,49 @@ func AzureMiddleware(authOpts *AuthOptions, credentials azcredentials.AzureCrede
return errorResponse(err)
}

if authOpts.rateLimitSession {
sessionProvider, err = newSessionProvider()
if err != nil {
return errorResponse(err)
}
}

if len(authOpts.scopes) == 0 {
err = errors.New("scopes not configured")
return errorResponse(err)
}

return applyAzureAuth(tokenProvider, authOpts.scopes, authOpts.endpoints, next)
return applyAzureAuth(tokenProvider, sessionProvider, authOpts.scopes, authOpts.endpoints, next)
})
}

func applyAzureAuth(tokenProvider aztokenprovider.AzureTokenProvider, scopes []string,
endpoints *azendpoint.EndpointAllowlist, next http.RoundTripper) http.RoundTripper {
func applyAzureAuth(tokenProvider aztokenprovider.AzureTokenProvider, sessionProvider *userSessionProvider,
scopes []string, endpoints *azendpoint.EndpointAllowlist, next http.RoundTripper) http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
reqContext := req.Context()

if endpoints != nil {
endpoint := azendpoint.Endpoint(*req.URL)
if !endpoints.IsAllowed(endpoint) {
return nil, fmt.Errorf("request to endpoint '%s' is not allowed by the datasource", endpoint.String())
}
}

token, err := tokenProvider.GetAccessToken(req.Context(), scopes)
token, err := tokenProvider.GetAccessToken(reqContext, scopes)
if err != nil {
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

if sessionProvider != nil {
sessionId, err := sessionProvider.GetSessionId(reqContext)
if err != nil {
return nil, fmt.Errorf("failed to obtain user session: %w", err)
} else if sessionId != "" {
req.Header.Set("x-ms-ratelimit-id", sessionId)
}
}

return next.RoundTrip(req)
})
}
Expand Down
5 changes: 5 additions & 0 deletions azhttpclient/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type AuthOptions struct {
endpoints *azendpoint.EndpointAllowlist
scopes []string
userIdentitySupported bool
rateLimitSession bool
customProviders map[string]AzureTokenProviderFactory
}

Expand Down Expand Up @@ -50,6 +51,10 @@ func (opts *AuthOptions) AllowUserIdentity() {
opts.userIdentitySupported = true
}

func (opts *AuthOptions) AddRateLimitSession(enable bool) {
opts.rateLimitSession = enable
}

func (opts *AuthOptions) AddTokenProvider(authType string, factory AzureTokenProviderFactory) {
if factory == nil {
return
Expand Down
71 changes: 71 additions & 0 deletions azhttpclient/session_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package azhttpclient

import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"sync"

"github.com/grafana/grafana-azure-sdk-go/azusercontext"
)

var (
once sync.Once
processSeed []byte
processSeedOk bool
)

type userSessionProvider struct {
seed []byte
}

func newSessionProvider() (*userSessionProvider, error) {
// Session anonymized with an in-memory seed generated for the process instance
seed, err := perProcessSeed()
if err != nil {
return nil, errors.New("failed to initialize the user session provider")
}

return &userSessionProvider{
seed,
}, nil
}

func perProcessSeed() ([]byte, error) {
once.Do(func() {
seed := make([]byte, 32)
_, err := rand.Read(seed)
if err == nil {
processSeed = seed
processSeedOk = true
}
})

if !processSeedOk {
return nil, errors.New("failed to generate seed")
}
return processSeed, nil
}

func (p *userSessionProvider) GetSessionId(ctx context.Context) (string, error) {
if ctx == nil {
err := fmt.Errorf("parameter 'ctx' cannot be nil")
return "", err
}

currentUser, ok := azusercontext.GetCurrentUser(ctx)
if !ok {
err := fmt.Errorf("user context not configured")
return "", err
}

hash := sha256.New()
_, _ = hash.Write(p.seed)
_, _ = hash.Write([]byte(currentUser.User.Login))
sessionId := base64.URLEncoding.EncodeToString(hash.Sum(nil))

return sessionId, nil
}

0 comments on commit f459ab1

Please sign in to comment.