diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 9af81145..cf76d2ab 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -17,8 +17,8 @@ jobs: - name: Calculate coverage run: | go test -v -covermode=atomic -coverprofile=cover.out.raw -coverpkg=./... ./... - # remove mocks from coverage calculation - grep -v mock_ cover.out.raw > cover.out + # remove generated code from coverage calculation + grep -Ev 'internal/mock|_enumer.go' cover.out.raw > cover.out - name: Generage coverage badge uses: vladopajic/go-test-coverage@bcd064e5ceef1ccec5441519eb054263b6a44787 # v2.8.2 with: diff --git a/.gitignore b/.gitignore index dc0c8d2f..dd43b2de 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /dist /cover.out +/cover.out.raw diff --git a/Makefile b/Makefile index 64d11a6f..37854662 100644 --- a/Makefile +++ b/Makefile @@ -23,5 +23,6 @@ fuzz: mod-tidy generate .PHONY: cover cover: mod-tidy generate - go test -v -covermode=atomic -coverprofile=cover.out -coverpkg=./... ./... + go test -v -covermode=atomic -coverprofile=cover.out.raw -coverpkg=./... ./... + grep -Ev 'internal/mock|_enumer.go' cover.out.raw > cover.out go tool cover -html=cover.out diff --git a/go.mod b/go.mod index f6c708fe..8c6e7728 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/prometheus/client_golang v1.18.0 github.com/zitadel/oidc/v3 v3.11.1 go.opentelemetry.io/otel v1.23.1 + go.uber.org/mock v0.4.0 golang.org/x/crypto v0.19.0 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.17.0 diff --git a/go.sum b/go.sum index e858cdc0..49c78d2c 100644 --- a/go.sum +++ b/go.sum @@ -169,6 +169,8 @@ go.opentelemetry.io/otel/metric v1.23.1 h1:PQJmqJ9u2QaJLBOELl1cxIdPcpbwzbkjfEyel go.opentelemetry.io/otel/metric v1.23.1/go.mod h1:mpG2QPlAfnK8yNhNJAxDZruU9Y1/HubbC+KyH8FaCWI= go.opentelemetry.io/otel/trace v1.23.1 h1:4LrmmEd8AU2rFvU1zegmvqW7+kWarxtNOPyeL6HmYY8= go.opentelemetry.io/otel/trace v1.23.1/go.mod h1:4IpnpJFwr1mo/6HL8XIPJaE9y0+u1KcVmuW7dwFSVrI= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/internal/mock/ssh.go b/internal/mock/ssh.go new file mode 100644 index 00000000..430cc30d --- /dev/null +++ b/internal/mock/ssh.go @@ -0,0 +1,540 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/gliderlabs/ssh (interfaces: Session,Context) +// +// Generated by this command: +// +// mockgen -package=mock -destination=ssh.go -write_generate_directive github.com/gliderlabs/ssh Session,Context +// + +// Package mock is a generated GoMock package. +package mock + +import ( + io "io" + net "net" + reflect "reflect" + time "time" + + ssh "github.com/gliderlabs/ssh" + gomock "go.uber.org/mock/gomock" +) + +//go:generate mockgen -package=mock -destination=ssh.go -write_generate_directive github.com/gliderlabs/ssh Session,Context + +// MockSession is a mock of Session interface. +type MockSession struct { + ctrl *gomock.Controller + recorder *MockSessionMockRecorder +} + +// MockSessionMockRecorder is the mock recorder for MockSession. +type MockSessionMockRecorder struct { + mock *MockSession +} + +// NewMockSession creates a new mock instance. +func NewMockSession(ctrl *gomock.Controller) *MockSession { + mock := &MockSession{ctrl: ctrl} + mock.recorder = &MockSessionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSession) EXPECT() *MockSessionMockRecorder { + return m.recorder +} + +// Break mocks base method. +func (m *MockSession) Break(arg0 chan<- bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Break", arg0) +} + +// Break indicates an expected call of Break. +func (mr *MockSessionMockRecorder) Break(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Break", reflect.TypeOf((*MockSession)(nil).Break), arg0) +} + +// Close mocks base method. +func (m *MockSession) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockSessionMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSession)(nil).Close)) +} + +// CloseWrite mocks base method. +func (m *MockSession) CloseWrite() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseWrite") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseWrite indicates an expected call of CloseWrite. +func (mr *MockSessionMockRecorder) CloseWrite() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWrite", reflect.TypeOf((*MockSession)(nil).CloseWrite)) +} + +// Command mocks base method. +func (m *MockSession) Command() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Command") + ret0, _ := ret[0].([]string) + return ret0 +} + +// Command indicates an expected call of Command. +func (mr *MockSessionMockRecorder) Command() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Command", reflect.TypeOf((*MockSession)(nil).Command)) +} + +// Context mocks base method. +func (m *MockSession) Context() ssh.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(ssh.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockSessionMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSession)(nil).Context)) +} + +// Environ mocks base method. +func (m *MockSession) Environ() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Environ") + ret0, _ := ret[0].([]string) + return ret0 +} + +// Environ indicates an expected call of Environ. +func (mr *MockSessionMockRecorder) Environ() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Environ", reflect.TypeOf((*MockSession)(nil).Environ)) +} + +// Exit mocks base method. +func (m *MockSession) Exit(arg0 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exit", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exit indicates an expected call of Exit. +func (mr *MockSessionMockRecorder) Exit(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exit", reflect.TypeOf((*MockSession)(nil).Exit), arg0) +} + +// LocalAddr mocks base method. +func (m *MockSession) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockSessionMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockSession)(nil).LocalAddr)) +} + +// Permissions mocks base method. +func (m *MockSession) Permissions() ssh.Permissions { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Permissions") + ret0, _ := ret[0].(ssh.Permissions) + return ret0 +} + +// Permissions indicates an expected call of Permissions. +func (mr *MockSessionMockRecorder) Permissions() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Permissions", reflect.TypeOf((*MockSession)(nil).Permissions)) +} + +// Pty mocks base method. +func (m *MockSession) Pty() (ssh.Pty, <-chan ssh.Window, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Pty") + ret0, _ := ret[0].(ssh.Pty) + ret1, _ := ret[1].(<-chan ssh.Window) + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 +} + +// Pty indicates an expected call of Pty. +func (mr *MockSessionMockRecorder) Pty() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pty", reflect.TypeOf((*MockSession)(nil).Pty)) +} + +// PublicKey mocks base method. +func (m *MockSession) PublicKey() ssh.PublicKey { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PublicKey") + ret0, _ := ret[0].(ssh.PublicKey) + return ret0 +} + +// PublicKey indicates an expected call of PublicKey. +func (mr *MockSessionMockRecorder) PublicKey() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublicKey", reflect.TypeOf((*MockSession)(nil).PublicKey)) +} + +// RawCommand mocks base method. +func (m *MockSession) RawCommand() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RawCommand") + ret0, _ := ret[0].(string) + return ret0 +} + +// RawCommand indicates an expected call of RawCommand. +func (mr *MockSessionMockRecorder) RawCommand() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RawCommand", reflect.TypeOf((*MockSession)(nil).RawCommand)) +} + +// Read mocks base method. +func (m *MockSession) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockSessionMockRecorder) Read(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockSession)(nil).Read), arg0) +} + +// RemoteAddr mocks base method. +func (m *MockSession) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockSessionMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockSession)(nil).RemoteAddr)) +} + +// SendRequest mocks base method. +func (m *MockSession) SendRequest(arg0 string, arg1 bool, arg2 []byte) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SendRequest indicates an expected call of SendRequest. +func (mr *MockSessionMockRecorder) SendRequest(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockSession)(nil).SendRequest), arg0, arg1, arg2) +} + +// Signals mocks base method. +func (m *MockSession) Signals(arg0 chan<- ssh.Signal) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Signals", arg0) +} + +// Signals indicates an expected call of Signals. +func (mr *MockSessionMockRecorder) Signals(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signals", reflect.TypeOf((*MockSession)(nil).Signals), arg0) +} + +// Stderr mocks base method. +func (m *MockSession) Stderr() io.ReadWriter { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stderr") + ret0, _ := ret[0].(io.ReadWriter) + return ret0 +} + +// Stderr indicates an expected call of Stderr. +func (mr *MockSessionMockRecorder) Stderr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stderr", reflect.TypeOf((*MockSession)(nil).Stderr)) +} + +// Subsystem mocks base method. +func (m *MockSession) Subsystem() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Subsystem") + ret0, _ := ret[0].(string) + return ret0 +} + +// Subsystem indicates an expected call of Subsystem. +func (mr *MockSessionMockRecorder) Subsystem() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subsystem", reflect.TypeOf((*MockSession)(nil).Subsystem)) +} + +// User mocks base method. +func (m *MockSession) User() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "User") + ret0, _ := ret[0].(string) + return ret0 +} + +// User indicates an expected call of User. +func (mr *MockSessionMockRecorder) User() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "User", reflect.TypeOf((*MockSession)(nil).User)) +} + +// Write mocks base method. +func (m *MockSession) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockSessionMockRecorder) Write(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSession)(nil).Write), arg0) +} + +// MockContext is a mock of Context interface. +type MockContext struct { + ctrl *gomock.Controller + recorder *MockContextMockRecorder +} + +// MockContextMockRecorder is the mock recorder for MockContext. +type MockContextMockRecorder struct { + mock *MockContext +} + +// NewMockContext creates a new mock instance. +func NewMockContext(ctrl *gomock.Controller) *MockContext { + mock := &MockContext{ctrl: ctrl} + mock.recorder = &MockContextMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockContext) EXPECT() *MockContextMockRecorder { + return m.recorder +} + +// ClientVersion mocks base method. +func (m *MockContext) ClientVersion() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClientVersion") + ret0, _ := ret[0].(string) + return ret0 +} + +// ClientVersion indicates an expected call of ClientVersion. +func (mr *MockContextMockRecorder) ClientVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientVersion", reflect.TypeOf((*MockContext)(nil).ClientVersion)) +} + +// Deadline mocks base method. +func (m *MockContext) Deadline() (time.Time, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Deadline") + ret0, _ := ret[0].(time.Time) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Deadline indicates an expected call of Deadline. +func (mr *MockContextMockRecorder) Deadline() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Deadline", reflect.TypeOf((*MockContext)(nil).Deadline)) +} + +// Done mocks base method. +func (m *MockContext) Done() <-chan struct{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Done") + ret0, _ := ret[0].(<-chan struct{}) + return ret0 +} + +// Done indicates an expected call of Done. +func (mr *MockContextMockRecorder) Done() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockContext)(nil).Done)) +} + +// Err mocks base method. +func (m *MockContext) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockContextMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockContext)(nil).Err)) +} + +// LocalAddr mocks base method. +func (m *MockContext) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockContextMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockContext)(nil).LocalAddr)) +} + +// Lock mocks base method. +func (m *MockContext) Lock() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Lock") +} + +// Lock indicates an expected call of Lock. +func (mr *MockContextMockRecorder) Lock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockContext)(nil).Lock)) +} + +// Permissions mocks base method. +func (m *MockContext) Permissions() *ssh.Permissions { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Permissions") + ret0, _ := ret[0].(*ssh.Permissions) + return ret0 +} + +// Permissions indicates an expected call of Permissions. +func (mr *MockContextMockRecorder) Permissions() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Permissions", reflect.TypeOf((*MockContext)(nil).Permissions)) +} + +// RemoteAddr mocks base method. +func (m *MockContext) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockContextMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockContext)(nil).RemoteAddr)) +} + +// ServerVersion mocks base method. +func (m *MockContext) ServerVersion() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServerVersion") + ret0, _ := ret[0].(string) + return ret0 +} + +// ServerVersion indicates an expected call of ServerVersion. +func (mr *MockContextMockRecorder) ServerVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServerVersion", reflect.TypeOf((*MockContext)(nil).ServerVersion)) +} + +// SessionID mocks base method. +func (m *MockContext) SessionID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SessionID") + ret0, _ := ret[0].(string) + return ret0 +} + +// SessionID indicates an expected call of SessionID. +func (mr *MockContextMockRecorder) SessionID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionID", reflect.TypeOf((*MockContext)(nil).SessionID)) +} + +// SetValue mocks base method. +func (m *MockContext) SetValue(arg0, arg1 any) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetValue", arg0, arg1) +} + +// SetValue indicates an expected call of SetValue. +func (mr *MockContextMockRecorder) SetValue(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetValue", reflect.TypeOf((*MockContext)(nil).SetValue), arg0, arg1) +} + +// Unlock mocks base method. +func (m *MockContext) Unlock() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Unlock") +} + +// Unlock indicates an expected call of Unlock. +func (mr *MockContextMockRecorder) Unlock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockContext)(nil).Unlock)) +} + +// User mocks base method. +func (m *MockContext) User() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "User") + ret0, _ := ret[0].(string) + return ret0 +} + +// User indicates an expected call of User. +func (mr *MockContextMockRecorder) User() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "User", reflect.TypeOf((*MockContext)(nil).User)) +} + +// Value mocks base method. +func (m *MockContext) Value(arg0 any) any { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Value", arg0) + ret0, _ := ret[0].(any) + return ret0 +} + +// Value indicates an expected call of Value. +func (mr *MockContextMockRecorder) Value(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Value", reflect.TypeOf((*MockContext)(nil).Value), arg0) +} diff --git a/internal/mock/sshserver_sessionhandler.go b/internal/mock/sshserver_sessionhandler.go new file mode 100644 index 00000000..e87a6b29 --- /dev/null +++ b/internal/mock/sshserver_sessionhandler.go @@ -0,0 +1,87 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../sshserver/sessionhandler.go +// +// Generated by this command: +// +// mockgen -source=../sshserver/sessionhandler.go -package=mock -destination=sshserver_sessionhandler.go -write_generate_directive +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + io "io" + reflect "reflect" + + ssh "github.com/gliderlabs/ssh" + gomock "go.uber.org/mock/gomock" +) + +//go:generate mockgen -source=../sshserver/sessionhandler.go -package=mock -destination=sshserver_sessionhandler.go -write_generate_directive + +// MockK8SAPIService is a mock of K8SAPIService interface. +type MockK8SAPIService struct { + ctrl *gomock.Controller + recorder *MockK8SAPIServiceMockRecorder +} + +// MockK8SAPIServiceMockRecorder is the mock recorder for MockK8SAPIService. +type MockK8SAPIServiceMockRecorder struct { + mock *MockK8SAPIService +} + +// NewMockK8SAPIService creates a new mock instance. +func NewMockK8SAPIService(ctrl *gomock.Controller) *MockK8SAPIService { + mock := &MockK8SAPIService{ctrl: ctrl} + mock.recorder = &MockK8SAPIServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockK8SAPIService) EXPECT() *MockK8SAPIServiceMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockK8SAPIService) Exec(arg0 context.Context, arg1, arg2, arg3 string, arg4 []string, arg5 io.ReadWriter, arg6 io.Writer, arg7 bool, arg8 <-chan ssh.Window) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockK8SAPIServiceMockRecorder) Exec(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockK8SAPIService)(nil).Exec), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) +} + +// FindDeployment mocks base method. +func (m *MockK8SAPIService) FindDeployment(arg0 context.Context, arg1, arg2 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindDeployment", arg0, arg1, arg2) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindDeployment indicates an expected call of FindDeployment. +func (mr *MockK8SAPIServiceMockRecorder) FindDeployment(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindDeployment", reflect.TypeOf((*MockK8SAPIService)(nil).FindDeployment), arg0, arg1, arg2) +} + +// Logs mocks base method. +func (m *MockK8SAPIService) Logs(arg0 context.Context, arg1, arg2, arg3 string, arg4 bool, arg5 int64, arg6 io.ReadWriter) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logs", arg0, arg1, arg2, arg3, arg4, arg5, arg6) + ret0, _ := ret[0].(error) + return ret0 +} + +// Logs indicates an expected call of Logs. +func (mr *MockK8SAPIServiceMockRecorder) Logs(arg0, arg1, arg2, arg3, arg4, arg5, arg6 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logs", reflect.TypeOf((*MockK8SAPIService)(nil).Logs), arg0, arg1, arg2, arg3, arg4, arg5, arg6) +} diff --git a/internal/sshserver/helper_test.go b/internal/sshserver/helper_test.go index 598faa70..d92afc5c 100644 --- a/internal/sshserver/helper_test.go +++ b/internal/sshserver/helper_test.go @@ -2,11 +2,13 @@ package sshserver // ParseConnectionParams exposes the private parseConnectionParams for testing // only. -func ParseConnectionParams(args []string) (string, string, string, []string) { - return parseConnectionParams(args) -} +var ParseConnectionParams = parseConnectionParams // ParseLogsArg exposes the private parseLogsArg for testing only. -func ParseLogsArg(service, logs string, args []string) (bool, int64, error) { - return parseLogsArg(service, logs, args) -} +var ParseLogsArg = parseLogsArg + +// SessionHandler exposes the private sessionHandler for testing only. +var SessionHandler = sessionHandler + +// CtxKey exposes the private ctxKey for testing only. +type CtxKey = ctxKey diff --git a/internal/sshserver/serve.go b/internal/sshserver/serve.go index 84389251..fadedc8a 100644 --- a/internal/sshserver/serve.go +++ b/internal/sshserver/serve.go @@ -15,8 +15,14 @@ import ( gossh "golang.org/x/crypto/ssh" ) +// default server shutdown timeout once the top-level context is cancelled +// (e.g. via signal) +const shutdownTimeout = 8 * time.Second + // disableSHA1Kex returns a ServerConfig which relies on default for everything // except key exchange algorithms. There it removes the SHA1 based algorithms. +// +// This works around https://github.com/golang/go/issues/59593 func disableSHA1Kex(_ ssh.Context) *gossh.ServerConfig { c := gossh.ServerConfig{} c.Config.KeyExchanges = []string{ @@ -30,9 +36,16 @@ func disableSHA1Kex(_ ssh.Context) *gossh.ServerConfig { return &c } -// Serve contains the main ssh session logic -func Serve(ctx context.Context, log *slog.Logger, nc *nats.EncodedConn, - l net.Listener, c *k8s.Client, hostKeys [][]byte, logAccessEnabled bool) error { +// Serve implements the ssh server logic. +func Serve( + ctx context.Context, + log *slog.Logger, + nc *nats.EncodedConn, + l net.Listener, + c *k8s.Client, + hostKeys [][]byte, + logAccessEnabled bool, +) error { srv := ssh.Server{ Handler: sessionHandler(log, c, false, logAccessEnabled), SubsystemHandlers: map[string]ssh.SubsystemHandler{ @@ -48,9 +61,8 @@ func Serve(ctx context.Context, log *slog.Logger, nc *nats.EncodedConn, } go func() { // As soon as the top level context is cancelled, shut down the server. - // Give an 8 second deadline to do this. <-ctx.Done() - shutCtx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + shutCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() if err := srv.Shutdown(shutCtx); err != nil { log.Warn("couldn't shutdown cleanly", slog.Any("error", err)) diff --git a/internal/sshserver/sessionhandler.go b/internal/sshserver/sessionhandler.go index 0effdb2e..ad028deb 100644 --- a/internal/sshserver/sessionhandler.go +++ b/internal/sshserver/sessionhandler.go @@ -3,6 +3,7 @@ package sshserver import ( "context" "fmt" + "io" "log/slog" "strings" "time" @@ -14,6 +15,14 @@ import ( "k8s.io/utils/exec" ) +// K8SAPIService provides methods for querying the Kubernetes API. +type K8SAPIService interface { + Exec(context.Context, string, string, string, []string, io.ReadWriter, + io.Writer, bool, <-chan ssh.Window) error + FindDeployment(context.Context, string, string) (string, error) + Logs(context.Context, string, string, string, bool, int64, io.ReadWriter) error +} + var ( sessionTotal = promauto.NewCounter(prometheus.CounterOpts{ Name: "sshportal_sessions_total", @@ -46,7 +55,7 @@ func getSSHIntent(sftp bool, cmd []string) []string { // handler is that the command is set to sftp-server. This implies that the // target container must have a sftp-server binary installed for sftp to work. // There is no support for a built-in sftp server. -func sessionHandler(log *slog.Logger, c *k8s.Client, +func sessionHandler(log *slog.Logger, c K8SAPIService, sftp, logAccessEnabled bool) ssh.Handler { return func(s ssh.Session) { sessionTotal.Inc() @@ -211,7 +220,7 @@ func startClientKeepalive(ctx context.Context, cancel context.CancelFunc, } func doLogs(ctx ssh.Context, log *slog.Logger, s ssh.Session, deployment, - container string, follow bool, tailLines int64, c *k8s.Client) { + container string, follow bool, tailLines int64, c K8SAPIService) { // Wrap the ssh.Context so we can cancel goroutines started from this // function without affecting the SSH session. childCtx, cancel := context.WithCancel(ctx) @@ -244,7 +253,7 @@ func doLogs(ctx ssh.Context, log *slog.Logger, s ssh.Session, deployment, } func doExec(ctx ssh.Context, log *slog.Logger, s ssh.Session, deployment, - container string, cmd []string, c *k8s.Client, pty bool, + container string, cmd []string, c K8SAPIService, pty bool, winch <-chan ssh.Window) { err := c.Exec(ctx, s.User(), deployment, container, cmd, s, s.Stderr(), pty, winch) diff --git a/internal/sshserver/sessionhandler_test.go b/internal/sshserver/sessionhandler_test.go new file mode 100644 index 00000000..36cdbe61 --- /dev/null +++ b/internal/sshserver/sessionhandler_test.go @@ -0,0 +1,170 @@ +package sshserver_test + +import ( + "context" + "log/slog" + "os" + "testing" + + "github.com/gliderlabs/ssh" + "github.com/uselagoon/ssh-portal/internal/mock" + "github.com/uselagoon/ssh-portal/internal/sshserver" + "go.uber.org/mock/gomock" +) + +func TestExec(t *testing.T) { + log := slog.New(slog.NewJSONHandler(os.Stderr, nil)) + var testCases = map[string]struct { + user string + deployment string + rawCommand []string + command []string + sftp bool + logAccessEnabled bool + pty bool + }{ + "bare interactive shell": { + user: "project-test", + deployment: "cli", + rawCommand: nil, + command: []string{"sh"}, + sftp: false, + logAccessEnabled: false, + pty: true, + }, + "non-interactive id command": { + user: "project-test", + deployment: "cli", + rawCommand: []string{"id"}, + command: []string{"sh", "-c", "id"}, + sftp: false, + logAccessEnabled: false, + pty: false, + }, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + // set up mocks + ctrl := gomock.NewController(tt) + k8sService := mock.NewMockK8SAPIService(ctrl) + sshSession := mock.NewMockSession(ctrl) + sshContext := mock.NewMockContext(ctrl) + // configure callback + callback := sshserver.SessionHandler( + log, + k8sService, + tc.sftp, + tc.logAccessEnabled, + ) + // configure mocks + sshSession.EXPECT().Context().Return(sshContext) + sshContext.EXPECT().SessionID().Return("test_session_id") + sshSession.EXPECT().Command().Return(tc.rawCommand).AnyTimes() + sshSession.EXPECT().Subsystem().Return("") + sshSession.EXPECT().User().Return(tc.user).AnyTimes() + k8sService.EXPECT().FindDeployment( + sshContext, + tc.user, + tc.deployment, + ).Return(tc.deployment, nil) + sshContext.EXPECT().Value(sshserver.CtxKey(0)).Return(0) + sshContext.EXPECT().Value(sshserver.CtxKey(1)).Return("test") + sshContext.EXPECT().Value(sshserver.CtxKey(2)).Return(0) + sshContext.EXPECT().Value(sshserver.CtxKey(3)).Return("project") + sshContext.EXPECT().Value(sshserver.CtxKey(4)).Return("fingerprint") + winch := make(<-chan ssh.Window) + sshSession.EXPECT().Pty().Return(ssh.Pty{}, winch, tc.pty) + sshSession.EXPECT().Stderr().Return(os.Stderr) + k8sService.EXPECT().Exec( + sshContext, + tc.user, + tc.deployment, + "", + tc.command, + sshSession, + os.Stderr, + tc.pty, + winch, + ).Return(nil) + // execute callback + callback(sshSession) + }) + } +} + +func TestLogs(t *testing.T) { + log := slog.New(slog.NewJSONHandler(os.Stderr, nil)) + var testCases = map[string]struct { + user string + deployment string + rawCommand []string + command []string + sftp bool + logAccessEnabled bool + pty bool + follow bool + taillines int64 + }{ + "nginx logs": { + user: "project-test", + deployment: "nginx", + rawCommand: []string{"service=nginx", "logs=tailLines=10"}, + command: nil, + sftp: false, + logAccessEnabled: true, + pty: false, + follow: false, + taillines: 10, + }, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + // set up mocks + ctrl := gomock.NewController(tt) + k8sService := mock.NewMockK8SAPIService(ctrl) + sshSession := mock.NewMockSession(ctrl) + sshContext := mock.NewMockContext(ctrl) + // configure callback + callback := sshserver.SessionHandler( + log, + k8sService, + tc.sftp, + tc.logAccessEnabled, + ) + // configure mocks + sshSession.EXPECT().Context().Return(sshContext) + sshContext.EXPECT().SessionID().Return("test_session_id") + sshSession.EXPECT().Command().Return(tc.rawCommand).AnyTimes() + sshSession.EXPECT().Subsystem().Return("") + sshSession.EXPECT().User().Return(tc.user).AnyTimes() + k8sService.EXPECT().FindDeployment( + sshContext, + tc.user, + tc.deployment, + ).Return(tc.deployment, nil) + sshContext.EXPECT().Value(sshserver.CtxKey(0)).Return(0) + sshContext.EXPECT().Value(sshserver.CtxKey(1)).Return("test") + sshContext.EXPECT().Value(sshserver.CtxKey(2)).Return(0) + sshContext.EXPECT().Value(sshserver.CtxKey(3)).Return("project") + sshContext.EXPECT().Value(sshserver.CtxKey(4)).Return("fingerprint") + + // this call is executed by context.WithCancel() + sshContext.EXPECT().Value(gomock.Any()).Return(nil).Times(4) + + sshContext.EXPECT().Done().Return(make(<-chan struct{})).AnyTimes() + childCtx, cancel := context.WithCancel(sshContext) + defer cancel() + k8sService.EXPECT().Logs( + childCtx, + tc.user, + tc.deployment, + "", + tc.follow, + tc.taillines, + sshSession, + ).Return(nil) + // execute callback + callback(sshSession) + }) + } +}