From 978a45bcf1aa6f72b6e9c23c3b1e51c2a9e53eec Mon Sep 17 00:00:00 2001 From: Hidetake Iwata Date: Mon, 27 Aug 2018 14:49:25 +0900 Subject: [PATCH] Refactor --- auth/authcode.go | 113 ++++++++++++++++++++++------------------------- auth/oidc.go | 62 +++++++++++++++----------- auth/state.go | 15 +++++++ cli/cli.go | 13 ++++-- 4 files changed, 112 insertions(+), 91 deletions(-) create mode 100644 auth/state.go diff --git a/auth/authcode.go b/auth/authcode.go index a4df7b65..84e44006 100644 --- a/auth/authcode.go +++ b/auth/authcode.go @@ -2,71 +2,55 @@ package auth import ( "context" - "crypto/rand" - "encoding/binary" "fmt" "log" "net/http" "github.com/pkg/browser" - "golang.org/x/oauth2" ) -// BrowserAuthCodeFlow is a flow to get a token by browser interaction. -type BrowserAuthCodeFlow struct { - oauth2.Config - Port int // HTTP server port +type authCodeFlow struct { + Config *oauth2.Config + ServerPort int // HTTP server port SkipOpenBrowser bool // skip opening browser if true } -// GetToken returns a token. -func (f *BrowserAuthCodeFlow) GetToken(ctx context.Context) (*oauth2.Token, error) { - f.Config.RedirectURL = fmt.Sprintf("http://localhost:%d/", f.Port) - state, err := generateState() - if err != nil { - return nil, fmt.Errorf("Could not generate state parameter: %s", err) - } - log.Printf("Open http://localhost:%d for authorization", f.Port) - if !f.SkipOpenBrowser { - browser.OpenURL(fmt.Sprintf("http://localhost:%d/", f.Port)) - } - code, err := f.getCode(ctx, &f.Config, state) +func (f *authCodeFlow) getToken(ctx context.Context) (*oauth2.Token, error) { + code, err := f.getAuthCode(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("Could not get an auth code: %s", err) } token, err := f.Config.Exchange(ctx, code) if err != nil { - return nil, fmt.Errorf("Could not exchange oauth code: %s", err) + return nil, fmt.Errorf("Could not exchange token: %s", err) } return token, nil } -func generateState() (string, error) { - var n uint64 - if err := binary.Read(rand.Reader, binary.LittleEndian, &n); err != nil { - return "", err +func (f *authCodeFlow) getAuthCode(ctx context.Context) (string, error) { + state, err := generateState() + if err != nil { + return "", fmt.Errorf("Could not generate state parameter: %s", err) } - return fmt.Sprintf("%x", n), nil -} - -func (f *BrowserAuthCodeFlow) getCode(ctx context.Context, config *oauth2.Config, state string) (string, error) { codeCh := make(chan string) + defer close(codeCh) errCh := make(chan error) + defer close(errCh) server := http.Server{ - Addr: fmt.Sprintf("localhost:%d", f.Port), - Handler: &handler{ - AuthCodeURL: config.AuthCodeURL(state), - Callback: func(code string, actualState string, err error) { - switch { - case err != nil: - errCh <- err - case actualState != state: - errCh <- fmt.Errorf("OAuth state did not match, should be %s but %s", state, actualState) - default: + Addr: fmt.Sprintf("localhost:%d", f.ServerPort), + Handler: &authCodeHandler{ + authCodeURL: f.Config.AuthCodeURL(state), + gotCode: func(code string, gotState string) { + if gotState == state { codeCh <- code + } else { + errCh <- fmt.Errorf("State does not match, wants %s but %s", state, gotState) } }, + gotError: func(err error) { + errCh <- err + }, }, } go func() { @@ -74,6 +58,12 @@ func (f *BrowserAuthCodeFlow) getCode(ctx context.Context, config *oauth2.Config errCh <- err } }() + go func() { + log.Printf("Open http://localhost:%d for authorization", f.ServerPort) + if !f.SkipOpenBrowser { + browser.OpenURL(fmt.Sprintf("http://localhost:%d/", f.ServerPort)) + } + }() select { case err := <-errCh: server.Shutdown(ctx) @@ -81,33 +71,36 @@ func (f *BrowserAuthCodeFlow) getCode(ctx context.Context, config *oauth2.Config case code := <-codeCh: server.Shutdown(ctx) return code, nil + case <-ctx.Done(): + server.Shutdown(ctx) + return "", ctx.Err() } } -type handler struct { - AuthCodeURL string - Callback func(code string, state string, err error) +type authCodeHandler struct { + authCodeURL string + gotCode func(code string, state string) + gotError func(err error) } -func (s *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *authCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s", r.Method, r.RequestURI) - switch r.URL.Path { - case "/": - code := r.URL.Query().Get("code") - state := r.URL.Query().Get("state") - errorCode := r.URL.Query().Get("error") - errorDescription := r.URL.Query().Get("error_description") - switch { - case code != "": - s.Callback(code, state, nil) - w.Header().Add("Content-Type", "text/html") - fmt.Fprintf(w, `OK`) - case errorCode != "": - s.Callback("", "", fmt.Errorf("OAuth Error: %s %s", errorCode, errorDescription)) - http.Error(w, "OAuth Error", 500) - default: - http.Redirect(w, r, s.AuthCodeURL, 302) - } + m := r.Method + p := r.URL.Path + q := r.URL.Query() + switch { + case m == "GET" && p == "/" && q.Get("error") != "": + h.gotError(fmt.Errorf("OAuth Error: %s %s", q.Get("error"), q.Get("error_description"))) + http.Error(w, "OAuth Error", 500) + + case m == "GET" && p == "/" && q.Get("code") != "": + h.gotCode(q.Get("code"), q.Get("state")) + w.Header().Add("Content-Type", "text/html") + fmt.Fprintf(w, `OK`) + + case m == "GET" && p == "/": + http.Redirect(w, r, h.authCodeURL, 302) + default: http.Error(w, "Not Found", 404) } diff --git a/auth/oidc.go b/auth/oidc.go index ad3b3112..5da689b7 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -3,6 +3,8 @@ package auth import ( "context" "fmt" + "log" + "net/http" oidc "github.com/coreos/go-oidc" "golang.org/x/oauth2" @@ -12,50 +14,56 @@ import ( type TokenSet struct { IDToken string RefreshToken string - Claims *Claims } -// Claims represents properties in the ID token. -type Claims struct { - Email string `json:"email"` +// Config represents OIDC configuration. +type Config struct { + Issuer string + ClientID string + ClientSecret string + ExtraScopes []string // Additional scopes + Client *http.Client // HTTP client for oidc and oauth2 + ServerPort int // HTTP server port + SkipOpenBrowser bool // skip opening browser if true } -// GetTokenSet retrieves a token from the OIDC provider. -func GetTokenSet(ctx context.Context, issuer string, clientID string, clientSecret string, skipOpenBrowser bool) (*TokenSet, error) { - provider, err := oidc.NewProvider(ctx, issuer) +// GetTokenSet retrives a token from the OIDC provider and returns a TokenSet. +func (c *Config) GetTokenSet(ctx context.Context) (*TokenSet, error) { + if c.Client != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, c.Client) + } + provider, err := oidc.NewProvider(ctx, c.Issuer) if err != nil { - return nil, fmt.Errorf("Could not access OIDC issuer: %s", err) + return nil, fmt.Errorf("Could not discovery the OIDC issuer: %s", err) + } + oauth2Config := &oauth2.Config{ + Endpoint: provider.Endpoint(), + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + Scopes: append(c.ExtraScopes, oidc.ScopeOpenID), + RedirectURL: fmt.Sprintf("http://localhost:%d/", c.ServerPort), } - flow := BrowserAuthCodeFlow{ - Port: 8000, - SkipOpenBrowser: skipOpenBrowser, - Config: oauth2.Config{ - Endpoint: provider.Endpoint(), - ClientID: clientID, - ClientSecret: clientSecret, - Scopes: []string{oidc.ScopeOpenID, "email"}, - }, + flow := &authCodeFlow{ + ServerPort: c.ServerPort, + SkipOpenBrowser: c.SkipOpenBrowser, + Config: oauth2Config, } - token, err := flow.GetToken(ctx) + token, err := flow.getToken(ctx) if err != nil { return nil, fmt.Errorf("Could not get a token: %s", err) } - rawIDToken, ok := token.Extra("id_token").(string) + idToken, ok := token.Extra("id_token").(string) if !ok { return nil, fmt.Errorf("id_token is missing in the token response: %s", token) } - verifier := provider.Verifier(&oidc.Config{ClientID: clientID}) - idToken, err := verifier.Verify(ctx, rawIDToken) + verifier := provider.Verifier(&oidc.Config{ClientID: c.ClientID}) + verifiedIDToken, err := verifier.Verify(ctx, idToken) if err != nil { return nil, fmt.Errorf("Could not verify the id_token: %s", err) } - var claims Claims - if err := idToken.Claims(&claims); err != nil { - return nil, fmt.Errorf("Could not extract claims from the token response: %s", err) - } + log.Printf("Got token for subject=%s", verifiedIDToken.Subject) return &TokenSet{ - IDToken: rawIDToken, + IDToken: idToken, RefreshToken: token.RefreshToken, - Claims: &claims, }, nil } diff --git a/auth/state.go b/auth/state.go new file mode 100644 index 00000000..d2b75b26 --- /dev/null +++ b/auth/state.go @@ -0,0 +1,15 @@ +package auth + +import ( + "crypto/rand" + "encoding/binary" + "fmt" +) + +func generateState() (string, error) { + var n uint64 + if err := binary.Read(rand.Reader, binary.LittleEndian, &n); err != nil { + return "", err + } + return fmt.Sprintf("%x", n), nil +} diff --git a/cli/cli.go b/cli/cli.go index 808a01a0..fc82528c 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -10,7 +10,6 @@ import ( "github.com/int128/kubelogin/kubeconfig" flags "github.com/jessevdk/go-flags" homedir "github.com/mitchellh/go-homedir" - "golang.org/x/oauth2" ) // Parse parses command line arguments and returns a CLI instance. @@ -67,9 +66,15 @@ func (c *CLI) Run(ctx context.Context) error { if err != nil { return fmt.Errorf("Could not configure TLS: %s", err) } - client := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} - ctx = context.WithValue(ctx, oauth2.HTTPClient, client) - token, err := auth.GetTokenSet(ctx, authProvider.IDPIssuerURL(), authProvider.ClientID(), authProvider.ClientSecret(), c.SkipOpenBrowser) + authConfig := &auth.Config{ + Issuer: authProvider.IDPIssuerURL(), + ClientID: authProvider.ClientID(), + ClientSecret: authProvider.ClientSecret(), + Client: &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}, + ServerPort: 8000, + SkipOpenBrowser: c.SkipOpenBrowser, + } + token, err := authConfig.GetTokenSet(ctx) if err != nil { return fmt.Errorf("Authentication error: %s", err) }