Skip to content

Commit

Permalink
[v2.8] [backport] AzureAD support (#363)
Browse files Browse the repository at this point in the history
* import of rancher types

* fixes and cleanup

* removed unused func

bump rancher

added dashboard implementation

* added device auth flow

updated login request

removed unused import

removed prompt flag

* updated go.mod, removed pkce package

go mod fix

* added cluster specific kubeconfig

* go mod update

* bump of rancher client-go

* dropped rancher fork of client-go

* fix lint and errors

* added getClient func

- added getClient func to create once the HTTP client with the same TLS configuration.
- added tests for the getAuthProviders func

* insecureRequest

* removed personal references

* revert name

* fix merge

* fix deps

* updated go mod
  • Loading branch information
enrichman authored May 31, 2024
1 parent 2aa5932 commit 1d80411
Show file tree
Hide file tree
Showing 5 changed files with 499 additions and 127 deletions.
221 changes: 142 additions & 79 deletions cmd/kubectl_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
Expand All @@ -17,12 +18,14 @@ import (
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"time"

"github.com/rancher/cli/config"
"github.com/rancher/norman/types/convert"
apiv3 "github.com/rancher/rancher/pkg/apis/management.cattle.io/v3"
managementClient "github.com/rancher/rancher/pkg/client/generated/management/v3"
"github.com/tidwall/gjson"
"github.com/urfave/cli"
"golang.org/x/term"
)
Expand Down Expand Up @@ -61,6 +64,10 @@ var samlProviders = map[string]bool{
"shibbolethProvider": true,
}

var oauthProviders = map[string]bool{
"azureADProvider": true,
}

var supportedAuthProviders = map[string]bool{
"localProvider": true,
"freeIpaProvider": true,
Expand All @@ -73,6 +80,9 @@ var supportedAuthProviders = map[string]bool{
"keyCloakProvider": true,
"oktaProvider": true,
"shibbolethProvider": true,

// oauth providers
"azureADProvider": true,
}

func CredentialCommand() cli.Command {
Expand Down Expand Up @@ -295,30 +305,43 @@ func cacheCredential(ctx *cli.Context, cred *config.ExecCredential, id string) e
}

func loginAndGenerateCred(input *LoginInput) (*config.ExecCredential, error) {
if input.authProvider == "" {
provider, err := getAuthProvider(input.server)
if err != nil {
return nil, err
}
input.authProvider = provider
// setup a client with the provided TLS configuration
client, err := getClient(input.skipVerify, input.caCerts)
if err != nil {
return nil, err
}
tlsConfig, err := getTLSConfig(input)

authProviders, err := getAuthProviders(input.server)
if err != nil {
return nil, err
}

selectedProvider, err := selectAuthProvider(authProviders, input.authProvider)
if err != nil {
return nil, err
}
input.authProvider = selectedProvider.GetType()

token := managementClient.Token{}
if samlProviders[input.authProvider] {
token, err = samlAuth(input, tlsConfig)
token, err = samlAuth(input, client)
if err != nil {
return nil, err
}
} else if oauthProviders[input.authProvider] {
tokenPtr, err := oauthAuth(input, selectedProvider)
if err != nil {
return nil, err
}
token = *tokenPtr
} else {
customPrint(fmt.Sprintf("Enter credentials for %s \n", input.authProvider))
token, err = basicAuth(input, tlsConfig)
token, err = basicAuth(input)
if err != nil {
return nil, err
}
}

cred := &config.ExecCredential{
TypeMeta: config.TypeMeta{
Kind: "ExecCredential",
Expand All @@ -340,14 +363,14 @@ func loginAndGenerateCred(input *LoginInput) (*config.ExecCredential, error) {

}

func basicAuth(input *LoginInput, tlsConfig *tls.Config) (managementClient.Token, error) {
func basicAuth(input *LoginInput) (managementClient.Token, error) {
token := managementClient.Token{}
username, err := customPrompt("username", true)
username, err := customPrompt("Enter username: ", true)
if err != nil {
return token, err
}

password, err := customPrompt("password", false)
password, err := customPrompt("Enter password: ", false)
if err != nil {
return token, err
}
Expand Down Expand Up @@ -385,7 +408,7 @@ func basicAuth(input *LoginInput, tlsConfig *tls.Config) (managementClient.Token
return token, nil
}

func samlAuth(input *LoginInput, tlsConfig *tls.Config) (managementClient.Token, error) {
func samlAuth(input *LoginInput, client *http.Client) (managementClient.Token, error) {
token := managementClient.Token{}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
Expand Down Expand Up @@ -418,11 +441,7 @@ func samlAuth(input *LoginInput, tlsConfig *tls.Config) (managementClient.Token,
req.Header.Set("content-type", "application/json")
req.Header.Set("accept", "application/json")

tr := &http.Transport{
TLSClientConfig: tlsConfig,
}

client := &http.Client{Transport: tr, Timeout: 300 * time.Second}
client.Timeout = 300 * time.Second

loginRequest := fmt.Sprintf("%s/dashboard/auth/login?requestId=%s&publicKey=%s&responseType=%s",
input.server, id, encodedKey, responseType)
Expand Down Expand Up @@ -476,10 +495,9 @@ func samlAuth(input *LoginInput, tlsConfig *tls.Config) (managementClient.Token,
}
req.Header.Set("content-type", "application/json")
req.Header.Set("accept", "application/json")
tr := &http.Transport{
TLSClientConfig: tlsConfig,
}
client = &http.Client{Transport: tr, Timeout: 150 * time.Second}

client.Timeout = 150 * time.Second

res, err = client.Do(req)
if err != nil {
// log error and use the token if login succeeds
Expand All @@ -501,7 +519,11 @@ func samlAuth(input *LoginInput, tlsConfig *tls.Config) (managementClient.Token,
}
}

func getAuthProviders(server string) (map[string]string, error) {
type TypedProvider interface {
GetType() string
}

func getAuthProviders(server string) ([]TypedProvider, error) {
authProviders := fmt.Sprintf(authProviderURL, server)
customPrint(authProviders)

Expand All @@ -510,58 +532,84 @@ func getAuthProviders(server string) (map[string]string, error) {
return nil, err
}

data := map[string]interface{}{}
err = json.Unmarshal(response, &data)
if err != nil {
return nil, err
if !gjson.ValidBytes(response) {
return nil, errors.New("invalid JSON input")
}

providers := map[string]string{}
i := 0
for _, value := range convert.ToMapSlice(data["data"]) {
provider := convert.ToString(value["type"])
if provider != "" && supportedAuthProviders[provider] {
providers[fmt.Sprintf("%v", i)] = provider
i++
data := gjson.GetBytes(response, "data").Array()

supportedProviders := []TypedProvider{}
for _, provider := range data {
providerType := provider.Get("type").String()

if providerType != "" && supportedAuthProviders[providerType] {
var typedProvider TypedProvider

switch providerType {
case "azureADProvider":
typedProvider = &apiv3.AzureADProvider{}
case "localProvider":
typedProvider = &apiv3.LocalProvider{}
default:
typedProvider = &apiv3.AuthProvider{}
}

err = json.Unmarshal([]byte(provider.Raw), typedProvider)
if err != nil {
return nil, err
}
supportedProviders = append(supportedProviders, typedProvider)
}
}
return providers, err

return supportedProviders, err
}

func getAuthProvider(server string) (string, error) {
authProviders, err := getAuthProviders(server)
if err != nil || authProviders == nil {
return "", err
}
func selectAuthProvider(authProviders []TypedProvider, providerType string) (TypedProvider, error) {
if len(authProviders) == 0 {
return "", fmt.Errorf("no auth provider configured")
return nil, fmt.Errorf("no auth provider configured")
}

// if providerType was specified, look for it
if providerType != "" {
for _, p := range authProviders {
if p.GetType() == providerType {
return p, nil
}
}
return nil, fmt.Errorf("provider %s not found", providerType)
}

// otherwise ask to the user (if more than one)
if len(authProviders) == 1 {
return authProviders["0"], nil
return authProviders[0], nil
}
try := 0

var providers []string
for key, val := range authProviders {
providers = append(providers, fmt.Sprintf("%s - %s", key, val))
for i, val := range authProviders {
providers = append(providers, fmt.Sprintf("%d - %s", i, val.GetType()))
}

try := 0
for try < 3 {
provider, err := customPrompt(fmt.Sprintf("auth provider\n%v",
strings.Join(providers, "\n")), true)
customPrint(fmt.Sprintf("Auth providers:\n%v", strings.Join(providers, "\n")))
providerIndexStr, err := customPrompt("Select auth provider: ", true)
if err != nil {
try++
continue
}
if _, ok := authProviders[provider]; !ok {
customPrint("pick valid auth provider")

providerIndex, err := strconv.Atoi(providerIndexStr)
if err != nil || (providerIndex < 0 || providerIndex > len(providers)-1) {
customPrint("Pick a valid auth provider")
try++
continue
}
provider = authProviders[provider]
return provider, nil
}

return "", fmt.Errorf("invalid auth provider")
return authProviders[providerIndex], nil
}

return nil, fmt.Errorf("invalid auth provider")
}

func generateKey() (string, error) {
Expand All @@ -579,56 +627,72 @@ func generateKey() (string, error) {
return string(token), nil
}

func getTLSConfig(input *LoginInput) (*tls.Config, error) {
config := &tls.Config{}
if input.skipVerify || input.caCerts == "" {
config = &tls.Config{
InsecureSkipVerify: true,
}
// getClient return a client with the provided TLS configuration
func getClient(skipVerify bool, caCerts string) (*http.Client, error) {
tlsConfig, err := getTLSConfig(skipVerify, caCerts)
if err != nil {
return nil, err
}

// clone the DefaultTransport to get the default values
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = tlsConfig
return &http.Client{Transport: transport}, nil
}

func getTLSConfig(skipVerify bool, caCerts string) (*tls.Config, error) {
config := &tls.Config{
InsecureSkipVerify: skipVerify,
}

if caCerts == "" {
return config, nil
}

if input.caCerts != "" {
cert, err := loadAndVerifyCert(input.caCerts)
if err != nil {
return nil, err
}
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM([]byte(cert))
if !ok {
return nil, err
}
config.RootCAs = roots
// load custom certs
cert, err := loadAndVerifyCert(caCerts)
if err != nil {
return nil, err
}

roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM([]byte(cert))
if !ok {
return nil, err
}
config.RootCAs = roots

return config, nil
}

func request(method, url string, body io.Reader) ([]byte, error) {
var response []byte
var client *http.Client

req, err := http.NewRequest(method, url, body)
if err != nil {
return response, err
return nil, err
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},

client, err := getClient(true, "")
if err != nil {
return nil, err
}
client = &http.Client{Transport: tr}

res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()

response, err = io.ReadAll(res.Body)
if err != nil {
return nil, err
}
return response, nil
}

func customPrompt(field string, show bool) (result string, err error) {
fmt.Fprintf(os.Stderr, "Enter %s: ", field)
func customPrompt(msg string, show bool) (result string, err error) {
fmt.Fprint(os.Stderr, msg)
if show {
_, err = fmt.Fscan(os.Stdin, &result)
} else {
Expand All @@ -638,7 +702,6 @@ func customPrompt(field string, show bool) (result string, err error) {
fmt.Fprintf(os.Stderr, "\n")
}
return result, err

}

func customPrint(data interface{}) {
Expand Down
Loading

0 comments on commit 1d80411

Please sign in to comment.