From cda2eccaac4701cc251ead940cc7209337d2a201 Mon Sep 17 00:00:00 2001 From: Bastian Date: Thu, 22 Dec 2022 00:03:10 +0100 Subject: [PATCH] feat(authentication): add oauth2 device grant (#837) --- go.mod | 11 +- go.sum | 22 ++-- pkg/cmd/authentication.go | 7 ++ pkg/cmd/root.go | 1 + pkg/di/wire_gen.go | 13 +- pkg/oidc/client/client.go | 36 +++++- pkg/oidc/client/factory.go | 21 +++- pkg/oidc/client/mock_Interface.go | 119 ++++++++++++++++-- pkg/usecases/authentication/authentication.go | 12 +- .../authentication/devicecode/devicecode.go | 67 ++++++++++ .../devicecode/devicecode_test.go | 105 ++++++++++++++++ 11 files changed, 375 insertions(+), 39 deletions(-) create mode 100644 pkg/usecases/authentication/devicecode/devicecode.go create mode 100644 pkg/usecases/authentication/devicecode/devicecode_test.go diff --git a/go.mod b/go.mod index 96a8e9d0..8a830c39 100644 --- a/go.mod +++ b/go.mod @@ -10,14 +10,15 @@ require ( github.com/google/go-cmp v0.5.9 github.com/google/wire v0.5.0 github.com/int128/oauth2cli v1.14.0 + github.com/int128/oauth2dev v0.0.0-20221220102744-82ff1a972401 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/spf13/cobra v1.6.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.1 - golang.org/x/net v0.2.0 - golang.org/x/oauth2 v0.2.0 + golang.org/x/net v0.3.0 + golang.org/x/oauth2 v0.3.0 golang.org/x/sync v0.1.0 - golang.org/x/term v0.2.0 + golang.org/x/term v0.3.0 gopkg.in/yaml.v2 v2.4.0 k8s.io/apimachinery v0.25.4 k8s.io/client-go v0.25.4 @@ -46,8 +47,8 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.0 // indirect golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect - golang.org/x/sys v0.2.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/sys v0.3.0 // indirect + golang.org/x/text v0.5.0 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.28.0 // indirect diff --git a/go.sum b/go.sum index c0d4093d..efbbe7a1 100644 --- a/go.sum +++ b/go.sum @@ -221,6 +221,8 @@ github.com/int128/listener v1.1.0 h1:2Jb41DWLpkQ3I9bIdBzO8H/tNwMvyl/OBZWtCV5Pjuw github.com/int128/listener v1.1.0/go.mod h1:68WkmTN8PQtLzc9DucIaagAKeGVyMnyyKIkW4Xn47UA= github.com/int128/oauth2cli v1.14.0 h1:r63NoO10ybUXIXUQxih8WOmt5HQpJubdTmhWh22B9VE= github.com/int128/oauth2cli v1.14.0/go.mod h1:LIoVAzgAsS2tDDBc8yopkcgY5oZR0+MJAeECkCwtxhA= +github.com/int128/oauth2dev v0.0.0-20221220102744-82ff1a972401 h1:7uAt3uMNEZIbotbRl7DvylgiVjoFCCgpqrO8tP1w7Fs= +github.com/int128/oauth2dev v0.0.0-20221220102744-82ff1a972401/go.mod h1:gBQLN8PsWqaq3+2wAEfepDjUkQB3enABD+sfdKUKdtI= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -369,8 +371,8 @@ golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/net v0.3.0 h1:VWL6FNY2bEEmsGVKabSlHu5Irp34xmMRoqb/9lF9lxk= +golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -393,8 +395,8 @@ golang.org/x/oauth2 v0.0.0-20220309155454-6242fa91716a/go.mod h1:DAh4E804XQdzx2j golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= golang.org/x/oauth2 v0.0.0-20220608161450-d0670ef3b1eb/go.mod h1:jaDAt6Dkxork7LmZnYtzbRWj0W47D86a3TGe0YHBvmE= golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094/go.mod h1:h4gKUeWbJ4rQPri7E0u6Gs4e9Ri2zaLxzw5DI5XGrYg= -golang.org/x/oauth2 v0.2.0 h1:GtQkldQ9m7yvzCL1V+LrYow3Khe0eJH0w7RbX/VbaIU= -golang.org/x/oauth2 v0.2.0/go.mod h1:Cwn6afJ8jrQwYMxQDTpISoXmXW9I6qF6vDeuuoX3Ibs= +golang.org/x/oauth2 v0.3.0 h1:6l90koy8/LaBLmLu8jpHeHexzMwEita0zFfYlggy2F8= +golang.org/x/oauth2 v0.3.0/go.mod h1:rQrIauxkUhJ6CuwEXwymO2/eh4xz2ZWF1nBkcxS+tGk= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -470,12 +472,12 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.2.0 h1:z85xZCsEl7bi/KwbNADeBYoOP0++7W1ipu+aGnpwzRM= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= +golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= +golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -485,8 +487,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/pkg/cmd/authentication.go b/pkg/cmd/authentication.go index 333be198..e71ec7b9 100644 --- a/pkg/cmd/authentication.go +++ b/pkg/cmd/authentication.go @@ -7,6 +7,7 @@ import ( "github.com/int128/kubelogin/pkg/usecases/authentication" "github.com/int128/kubelogin/pkg/usecases/authentication/authcode" + "github.com/int128/kubelogin/pkg/usecases/authentication/devicecode" "github.com/int128/kubelogin/pkg/usecases/authentication/ropc" "github.com/spf13/pflag" ) @@ -47,6 +48,7 @@ var allGrantType = strings.Join([]string{ "authcode", "authcode-keyboard", "password", + "device-code", }, "|") func (o *authenticationOptions) addFlags(f *pflag.FlagSet) { @@ -97,6 +99,11 @@ func (o *authenticationOptions) grantOptionSet() (s authentication.GrantOptionSe Username: o.Username, Password: o.Password, } + case o.GrantType == "device-code": + s.DeviceCodeOption = &devicecode.Option{ + SkipOpenBrowser: o.SkipOpenBrowser, + BrowserCommand: o.BrowserCommand, + } default: err = fmt.Errorf("grant-type must be one of (%s)", allGrantType) } diff --git a/pkg/cmd/root.go b/pkg/cmd/root.go index a3c3371f..8f8ec061 100644 --- a/pkg/cmd/root.go +++ b/pkg/cmd/root.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "github.com/int128/kubelogin/pkg/infrastructure/logger" "github.com/int128/kubelogin/pkg/kubeconfig" "github.com/int128/kubelogin/pkg/usecases/standalone" diff --git a/pkg/di/wire_gen.go b/pkg/di/wire_gen.go index 9b37e236..151d7389 100644 --- a/pkg/di/wire_gen.go +++ b/pkg/di/wire_gen.go @@ -1,7 +1,8 @@ // Code generated by Wire. DO NOT EDIT. -//go:generate wire -//+build !wireinject +//go:generate go run github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject package di @@ -21,6 +22,7 @@ import ( "github.com/int128/kubelogin/pkg/tokencache/repository" "github.com/int128/kubelogin/pkg/usecases/authentication" "github.com/int128/kubelogin/pkg/usecases/authentication/authcode" + "github.com/int128/kubelogin/pkg/usecases/authentication/devicecode" "github.com/int128/kubelogin/pkg/usecases/authentication/ropc" "github.com/int128/kubelogin/pkg/usecases/credentialplugin" "github.com/int128/kubelogin/pkg/usecases/setup" @@ -30,6 +32,7 @@ import ( // Injectors from di.go: +// NewCmd returns an instance of infrastructure.Cmd. func NewCmd() cmd.Interface { clockReal := &clock.Real{} stdin := _wireFileValue @@ -45,6 +48,7 @@ var ( _wireOsFileValue = os.Stdout ) +// NewCmdForHeadless returns an instance of infrastructure.Cmd for headless testing. func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout stdio.Stdout, loggerInterface logger.Interface, browserInterface browser.Interface) cmd.Interface { loaderLoader := loader.Loader{} factory := &client.Factory{ @@ -67,6 +71,10 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout Reader: readerReader, Logger: loggerInterface, } + deviceCode := &devicecode.DeviceCode{ + Browser: browserInterface, + Logger: loggerInterface, + } authenticationAuthentication := &authentication.Authentication{ ClientFactory: factory, Logger: loggerInterface, @@ -74,6 +82,7 @@ func NewCmdForHeadless(clockInterface clock.Interface, stdin stdio.Stdin, stdout AuthCodeBrowser: authcodeBrowser, AuthCodeKeyboard: keyboard, ROPC: ropcROPC, + DeviceCode: deviceCode, } loader3 := &loader2.Loader{} writerWriter := &writer.Writer{} diff --git a/pkg/oidc/client/client.go b/pkg/oidc/client/client.go index d2e04f95..6f1af768 100644 --- a/pkg/oidc/client/client.go +++ b/pkg/oidc/client/client.go @@ -12,6 +12,7 @@ import ( "github.com/int128/kubelogin/pkg/oidc" "github.com/int128/kubelogin/pkg/pkce" "github.com/int128/oauth2cli" + "github.com/int128/oauth2dev" "golang.org/x/oauth2" ) @@ -20,6 +21,8 @@ type Interface interface { ExchangeAuthCode(ctx context.Context, in ExchangeAuthCodeInput) (*oidc.TokenSet, error) GetTokenByAuthCode(ctx context.Context, in GetTokenByAuthCodeInput, localServerReadyChan chan<- string) (*oidc.TokenSet, error) GetTokenByROPC(ctx context.Context, username, password string) (*oidc.TokenSet, error) + GetDeviceAuthorization(ctx context.Context) (*oauth2dev.AuthorizationResponse, error) + ExchangeDeviceCode(ctx context.Context, authResponse *oauth2dev.AuthorizationResponse) (*oidc.TokenSet, error) Refresh(ctx context.Context, refreshToken string) (*oidc.TokenSet, error) SupportedPKCEMethods() []string } @@ -52,12 +55,13 @@ type GetTokenByAuthCodeInput struct { } type client struct { - httpClient *http.Client - provider *gooidc.Provider - oauth2Config oauth2.Config - clock clock.Interface - logger logger.Interface - supportedPKCEMethods []string + httpClient *http.Client + provider *gooidc.Provider + oauth2Config oauth2.Config + clock clock.Interface + logger logger.Interface + supportedPKCEMethods []string + deviceAuthorizationEndpoint string } func (c *client) wrapContext(ctx context.Context) context.Context { @@ -151,6 +155,26 @@ func (c *client) GetTokenByROPC(ctx context.Context, username, password string) return c.verifyToken(ctx, token, "") } +// GetDeviceAuthorization initializes the device authorization code challenge +func (c *client) GetDeviceAuthorization(ctx context.Context) (*oauth2dev.AuthorizationResponse, error) { + ctx = c.wrapContext(ctx) + config := c.oauth2Config + config.Endpoint = oauth2.Endpoint{ + AuthURL: c.deviceAuthorizationEndpoint, + } + return oauth2dev.RetrieveCode(ctx, config) +} + +// ExchangeDeviceCode exchanges the device to an oidc.TokenSet +func (c *client) ExchangeDeviceCode(ctx context.Context, authResponse *oauth2dev.AuthorizationResponse) (*oidc.TokenSet, error) { + ctx = c.wrapContext(ctx) + tokenResponse, err := oauth2dev.PollToken(ctx, c.oauth2Config, *authResponse) + if err != nil { + return nil, fmt.Errorf("device-code: exchange failed: %w", err) + } + return c.verifyToken(ctx, tokenResponse, "") +} + // Refresh sends a refresh token request and returns a token set. func (c *client) Refresh(ctx context.Context, refreshToken string) (*oidc.TokenSet, error) { ctx = c.wrapContext(ctx) diff --git a/pkg/oidc/client/factory.go b/pkg/oidc/client/factory.go index 17007637..4a9ee332 100644 --- a/pkg/oidc/client/factory.go +++ b/pkg/oidc/client/factory.go @@ -63,6 +63,10 @@ func (f *Factory) New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsc if len(supportedPKCEMethods) == 0 && p.UsePKCE { supportedPKCEMethods = []string{pkce.MethodS256} } + deviceAuthorizationEndpoint, err := extractDeviceAuthorizationEndpoint(provider) + if err != nil { + return nil, fmt.Errorf("could not determine device authorization endpoint: %w", err) + } return &client{ httpClient: httpClient, provider: provider, @@ -72,9 +76,10 @@ func (f *Factory) New(ctx context.Context, p oidc.Provider, tlsClientConfig tlsc ClientSecret: p.ClientSecret, Scopes: append(p.ExtraScopes, gooidc.ScopeOpenID), }, - clock: f.Clock, - logger: f.Logger, - supportedPKCEMethods: supportedPKCEMethods, + clock: f.Clock, + logger: f.Logger, + supportedPKCEMethods: supportedPKCEMethods, + deviceAuthorizationEndpoint: deviceAuthorizationEndpoint, }, nil } @@ -87,3 +92,13 @@ func extractSupportedPKCEMethods(provider *gooidc.Provider) ([]string, error) { } return d.CodeChallengeMethodsSupported, nil } + +func extractDeviceAuthorizationEndpoint(provider *gooidc.Provider) (string, error) { + var d struct { + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` + } + if err := provider.Claims(&d); err != nil { + return "", fmt.Errorf("invalid discovery document: %w", err) + } + return d.DeviceAuthorizationEndpoint, nil +} diff --git a/pkg/oidc/client/mock_Interface.go b/pkg/oidc/client/mock_Interface.go index 3397328e..4f4debd0 100644 --- a/pkg/oidc/client/mock_Interface.go +++ b/pkg/oidc/client/mock_Interface.go @@ -5,8 +5,10 @@ package client import ( context "context" - oidc "github.com/int128/kubelogin/pkg/oidc" + oauth2dev "github.com/int128/oauth2dev" mock "github.com/stretchr/testify/mock" + + oidc "github.com/int128/kubelogin/pkg/oidc" ) // MockInterface is an autogenerated mock type for the Interface type @@ -51,8 +53,8 @@ type MockInterface_ExchangeAuthCode_Call struct { } // ExchangeAuthCode is a helper method to define mock.On call -// - ctx context.Context -// - in ExchangeAuthCodeInput +// - ctx context.Context +// - in ExchangeAuthCodeInput func (_e *MockInterface_Expecter) ExchangeAuthCode(ctx interface{}, in interface{}) *MockInterface_ExchangeAuthCode_Call { return &MockInterface_ExchangeAuthCode_Call{Call: _e.mock.On("ExchangeAuthCode", ctx, in)} } @@ -69,6 +71,53 @@ func (_c *MockInterface_ExchangeAuthCode_Call) Return(_a0 *oidc.TokenSet, _a1 er return _c } +// ExchangeDeviceCode provides a mock function with given fields: ctx, authResponse +func (_m *MockInterface) ExchangeDeviceCode(ctx context.Context, authResponse *oauth2dev.AuthorizationResponse) (*oidc.TokenSet, error) { + ret := _m.Called(ctx, authResponse) + + var r0 *oidc.TokenSet + if rf, ok := ret.Get(0).(func(context.Context, *oauth2dev.AuthorizationResponse) *oidc.TokenSet); ok { + r0 = rf(ctx, authResponse) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*oidc.TokenSet) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *oauth2dev.AuthorizationResponse) error); ok { + r1 = rf(ctx, authResponse) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockInterface_ExchangeDeviceCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ExchangeDeviceCode' +type MockInterface_ExchangeDeviceCode_Call struct { + *mock.Call +} + +// ExchangeDeviceCode is a helper method to define mock.On call +// - ctx context.Context +// - authResponse *oauth2dev.AuthorizationResponse +func (_e *MockInterface_Expecter) ExchangeDeviceCode(ctx interface{}, authResponse interface{}) *MockInterface_ExchangeDeviceCode_Call { + return &MockInterface_ExchangeDeviceCode_Call{Call: _e.mock.On("ExchangeDeviceCode", ctx, authResponse)} +} + +func (_c *MockInterface_ExchangeDeviceCode_Call) Run(run func(ctx context.Context, authResponse *oauth2dev.AuthorizationResponse)) *MockInterface_ExchangeDeviceCode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*oauth2dev.AuthorizationResponse)) + }) + return _c +} + +func (_c *MockInterface_ExchangeDeviceCode_Call) Return(_a0 *oidc.TokenSet, _a1 error) *MockInterface_ExchangeDeviceCode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + // GetAuthCodeURL provides a mock function with given fields: in func (_m *MockInterface) GetAuthCodeURL(in AuthCodeURLInput) string { ret := _m.Called(in) @@ -89,7 +138,7 @@ type MockInterface_GetAuthCodeURL_Call struct { } // GetAuthCodeURL is a helper method to define mock.On call -// - in AuthCodeURLInput +// - in AuthCodeURLInput func (_e *MockInterface_Expecter) GetAuthCodeURL(in interface{}) *MockInterface_GetAuthCodeURL_Call { return &MockInterface_GetAuthCodeURL_Call{Call: _e.mock.On("GetAuthCodeURL", in)} } @@ -106,6 +155,52 @@ func (_c *MockInterface_GetAuthCodeURL_Call) Return(_a0 string) *MockInterface_G return _c } +// GetDeviceAuthorization provides a mock function with given fields: ctx +func (_m *MockInterface) GetDeviceAuthorization(ctx context.Context) (*oauth2dev.AuthorizationResponse, error) { + ret := _m.Called(ctx) + + var r0 *oauth2dev.AuthorizationResponse + if rf, ok := ret.Get(0).(func(context.Context) *oauth2dev.AuthorizationResponse); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*oauth2dev.AuthorizationResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockInterface_GetDeviceAuthorization_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDeviceAuthorization' +type MockInterface_GetDeviceAuthorization_Call struct { + *mock.Call +} + +// GetDeviceAuthorization is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockInterface_Expecter) GetDeviceAuthorization(ctx interface{}) *MockInterface_GetDeviceAuthorization_Call { + return &MockInterface_GetDeviceAuthorization_Call{Call: _e.mock.On("GetDeviceAuthorization", ctx)} +} + +func (_c *MockInterface_GetDeviceAuthorization_Call) Run(run func(ctx context.Context)) *MockInterface_GetDeviceAuthorization_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockInterface_GetDeviceAuthorization_Call) Return(_a0 *oauth2dev.AuthorizationResponse, _a1 error) *MockInterface_GetDeviceAuthorization_Call { + _c.Call.Return(_a0, _a1) + return _c +} + // GetTokenByAuthCode provides a mock function with given fields: ctx, in, localServerReadyChan func (_m *MockInterface) GetTokenByAuthCode(ctx context.Context, in GetTokenByAuthCodeInput, localServerReadyChan chan<- string) (*oidc.TokenSet, error) { ret := _m.Called(ctx, in, localServerReadyChan) @@ -135,9 +230,9 @@ type MockInterface_GetTokenByAuthCode_Call struct { } // GetTokenByAuthCode is a helper method to define mock.On call -// - ctx context.Context -// - in GetTokenByAuthCodeInput -// - localServerReadyChan chan<- string +// - ctx context.Context +// - in GetTokenByAuthCodeInput +// - localServerReadyChan chan<- string func (_e *MockInterface_Expecter) GetTokenByAuthCode(ctx interface{}, in interface{}, localServerReadyChan interface{}) *MockInterface_GetTokenByAuthCode_Call { return &MockInterface_GetTokenByAuthCode_Call{Call: _e.mock.On("GetTokenByAuthCode", ctx, in, localServerReadyChan)} } @@ -183,9 +278,9 @@ type MockInterface_GetTokenByROPC_Call struct { } // GetTokenByROPC is a helper method to define mock.On call -// - ctx context.Context -// - username string -// - password string +// - ctx context.Context +// - username string +// - password string func (_e *MockInterface_Expecter) GetTokenByROPC(ctx interface{}, username interface{}, password interface{}) *MockInterface_GetTokenByROPC_Call { return &MockInterface_GetTokenByROPC_Call{Call: _e.mock.On("GetTokenByROPC", ctx, username, password)} } @@ -231,8 +326,8 @@ type MockInterface_Refresh_Call struct { } // Refresh is a helper method to define mock.On call -// - ctx context.Context -// - refreshToken string +// - ctx context.Context +// - refreshToken string func (_e *MockInterface_Expecter) Refresh(ctx interface{}, refreshToken interface{}) *MockInterface_Refresh_Call { return &MockInterface_Refresh_Call{Call: _e.mock.On("Refresh", ctx, refreshToken)} } diff --git a/pkg/usecases/authentication/authentication.go b/pkg/usecases/authentication/authentication.go index ebbe1fc1..61cf2bf0 100644 --- a/pkg/usecases/authentication/authentication.go +++ b/pkg/usecases/authentication/authentication.go @@ -11,6 +11,7 @@ import ( "github.com/int128/kubelogin/pkg/oidc/client" "github.com/int128/kubelogin/pkg/tlsclientconfig" "github.com/int128/kubelogin/pkg/usecases/authentication/authcode" + "github.com/int128/kubelogin/pkg/usecases/authentication/devicecode" "github.com/int128/kubelogin/pkg/usecases/authentication/ropc" ) @@ -21,6 +22,7 @@ var Set = wire.NewSet( wire.Struct(new(authcode.Browser), "*"), wire.Struct(new(authcode.Keyboard), "*"), wire.Struct(new(ropc.ROPC), "*"), + wire.Struct(new(devicecode.DeviceCode), "*"), ) type Interface interface { @@ -39,6 +41,7 @@ type GrantOptionSet struct { AuthCodeBrowserOption *authcode.BrowserOption AuthCodeKeyboardOption *authcode.KeyboardOption ROPCOption *ropc.Option + DeviceCodeOption *devicecode.Option } // Output represents an output DTO of the Authentication use-case. @@ -59,7 +62,6 @@ type Output struct { // If the Username is not set, it performs the authorization code flow. // Otherwise, it performs the resource owner password credentials flow. // If the Password is not set, it asks a password by the prompt. -// type Authentication struct { ClientFactory client.FactoryInterface Logger logger.Interface @@ -67,6 +69,7 @@ type Authentication struct { AuthCodeBrowser *authcode.Browser AuthCodeKeyboard *authcode.Keyboard ROPC *ropc.ROPC + DeviceCode *devicecode.DeviceCode } func (u *Authentication) Do(ctx context.Context, in Input) (*Output, error) { @@ -125,5 +128,12 @@ func (u *Authentication) Do(ctx context.Context, in Input) (*Output, error) { } return &Output{TokenSet: *tokenSet}, nil } + if in.GrantOptionSet.DeviceCodeOption != nil { + tokenSet, err := u.DeviceCode.Do(ctx, in.GrantOptionSet.DeviceCodeOption, oidcClient) + if err != nil { + return nil, fmt.Errorf("device-code error: %w", err) + } + return &Output{TokenSet: *tokenSet}, nil + } return nil, fmt.Errorf("any authorization grant must be set") } diff --git a/pkg/usecases/authentication/devicecode/devicecode.go b/pkg/usecases/authentication/devicecode/devicecode.go new file mode 100644 index 00000000..098d771c --- /dev/null +++ b/pkg/usecases/authentication/devicecode/devicecode.go @@ -0,0 +1,67 @@ +package devicecode + +import ( + "context" + "fmt" + + "github.com/int128/kubelogin/pkg/infrastructure/browser" + "github.com/int128/kubelogin/pkg/infrastructure/logger" + "github.com/int128/kubelogin/pkg/oidc" + "github.com/int128/kubelogin/pkg/oidc/client" +) + +type Option struct { + SkipOpenBrowser bool + BrowserCommand string +} + +// DeviceCode provides the oauth2 device code flow. +type DeviceCode struct { + Browser browser.Interface + Logger logger.Interface +} + +func (u *DeviceCode) Do(ctx context.Context, in *Option, oidcClient client.Interface) (*oidc.TokenSet, error) { + u.Logger.V(1).Infof("starting the oauth2 device code flow") + + authResponse, err := oidcClient.GetDeviceAuthorization(ctx) + if err != nil { + return nil, err + } + + if authResponse.VerificationURIComplete == "" { + u.Logger.Printf("Please enter the following code when asked in your browser: %s", authResponse.UserCode) + u.openURL(ctx, in, authResponse.VerificationURI) + } else { + u.openURL(ctx, in, authResponse.VerificationURIComplete) + } + + tokenSet, err := oidcClient.ExchangeDeviceCode(ctx, authResponse) + u.Logger.V(1).Infof("finished the oauth2 device code flow") + if err != nil { + return nil, fmt.Errorf("unable to exchange device code: %w", err) + } + return tokenSet, nil +} + +func (u *DeviceCode) openURL(ctx context.Context, o *Option, url string) { + if o != nil && o.SkipOpenBrowser { + u.Logger.Printf("Please visit the following URL in your browser: %s", url) + return + } + + u.Logger.V(1).Infof("opening %s in the browser", url) + if o != nil && o.BrowserCommand != "" { + if err := u.Browser.OpenCommand(ctx, url, o.BrowserCommand); err != nil { + u.Logger.Printf(`error: could not open the browser: %s + +Please visit the following URL in your browser manually: %s`, err, url) + } + return + } + if err := u.Browser.Open(url); err != nil { + u.Logger.Printf(`error: could not open the browser: %s + +Please visit the following URL in your browser manually: %s`, err, url) + } +} diff --git a/pkg/usecases/authentication/devicecode/devicecode_test.go b/pkg/usecases/authentication/devicecode/devicecode_test.go new file mode 100644 index 00000000..bc621743 --- /dev/null +++ b/pkg/usecases/authentication/devicecode/devicecode_test.go @@ -0,0 +1,105 @@ +package devicecode + +import ( + "context" + "errors" + "testing" + + "github.com/int128/kubelogin/pkg/infrastructure/browser" + "github.com/int128/kubelogin/pkg/oidc" + "github.com/int128/kubelogin/pkg/oidc/client" + "github.com/int128/kubelogin/pkg/testing/logger" + "github.com/int128/oauth2dev" + "github.com/stretchr/testify/mock" +) + +func TestDeviceCode(t *testing.T) { + mockBrowser := browser.NewMockInterface(t) + logger := logger.New(t) + mockClient := client.NewMockInterface(t) + + dc := &DeviceCode{ + Browser: mockBrowser, + Logger: logger, + } + + ctx := context.Background() + errTest := errors.New("test error") + + mockClient.EXPECT().GetDeviceAuthorization(ctx).Return(nil, errTest).Once() + _, err := dc.Do(ctx, &Option{}, mockClient) + if !errors.Is(err, errTest) { + t.Errorf("returned error is not the test error: %v", err) + } + + mockResponse := &oauth2dev.AuthorizationResponse{DeviceCode: "device-code-1", UserCode: "", VerificationURI: "", VerificationURIComplete: "https://example.com/verificationComplete?code=code123", VerificationURL: "", ExpiresIn: 2, Interval: 1} + mockClient.EXPECT().GetDeviceAuthorization(ctx).Return(&oauth2dev.AuthorizationResponse{ + Interval: 1, + ExpiresIn: 2, + VerificationURIComplete: "https://example.com/verificationComplete?code=code123", + DeviceCode: "device-code-1", + }, nil).Once() + mockBrowser.EXPECT().Open("https://example.com/verificationComplete?code=code123").Return(nil).Once() + mockClient.EXPECT().ExchangeDeviceCode(mock.Anything, mockResponse).Return(&oidc.TokenSet{ + IDToken: "test-id-token", + }, nil).Once() + ts, err := dc.Do(ctx, &Option{}, mockClient) + if err != nil { + t.Errorf("returned unexpected error: %v", err) + } + if ts.IDToken != "test-id-token" { + t.Errorf("wrong returned tokenset: %v", err) + } + + mockResponseWithoutComplete := &oauth2dev.AuthorizationResponse{DeviceCode: "device-code-1", UserCode: "", VerificationURI: "https://example.com/verificationComplete", VerificationURIComplete: "", VerificationURL: "", ExpiresIn: 2, Interval: 1} + mockClient.EXPECT().GetDeviceAuthorization(ctx).Return(&oauth2dev.AuthorizationResponse{ + Interval: 1, + ExpiresIn: 2, + VerificationURI: "https://example.com/verificationComplete", + DeviceCode: "device-code-1", + }, nil).Once() + mockBrowser.EXPECT().Open("https://example.com/verificationComplete").Return(nil).Once() + mockClient.EXPECT().ExchangeDeviceCode(mock.Anything, mockResponseWithoutComplete).Return(&oidc.TokenSet{ + IDToken: "test-id-token", + }, nil).Once() + ts, err = dc.Do(ctx, &Option{}, mockClient) + if err != nil { + t.Errorf("returned unexpected error: %v", err) + } + if ts.IDToken != "test-id-token" { + t.Errorf("wrong returned tokenset: %v", err) + } + + mockClient.EXPECT().GetDeviceAuthorization(ctx).Return(&oauth2dev.AuthorizationResponse{ + Interval: 1, + ExpiresIn: 2, + VerificationURIComplete: "https://example.com/verificationComplete?code=code123", + DeviceCode: "device-code-1", + }, nil).Once() + mockBrowser.EXPECT().Open("https://example.com/verificationComplete?code=code123").Return(nil).Once() + mockClient.EXPECT().ExchangeDeviceCode(mock.Anything, mockResponse).Return(nil, errTest).Once() + _, err = dc.Do(ctx, &Option{}, mockClient) + if err == nil { + t.Errorf("did not return error: %v", err) + } +} + +func TestOpenUrl(t *testing.T) { + ctx := context.Background() + browserMock := browser.NewMockInterface(t) + deviceCode := &DeviceCode{ + Browser: browserMock, + Logger: logger.New(t), + } + + const url = "https://example.com" + var testError = errors.New("test error") + + browserMock.EXPECT().Open(url).Return(testError).Once() + deviceCode.openURL(ctx, nil, url) + + deviceCode.openURL(ctx, &Option{SkipOpenBrowser: true}, url) + + browserMock.EXPECT().OpenCommand(ctx, url, "test-command").Return(testError).Once() + deviceCode.openURL(ctx, &Option{BrowserCommand: "test-command"}, url) +}