diff --git a/example_test.go b/example_test.go index 972d3ef..2c3c990 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,7 @@ package ssh_test import ( + "fmt" "io" "io/ioutil" @@ -21,6 +22,17 @@ func ExamplePasswordAuth() { ) } +func ExamplePasswordAuthE() { + ssh.ListenAndServe(":2222", nil, + ssh.PasswordAuthE(func(ctx ssh.Context, pass string) error { + if pass == "secret" { + return nil + } + return fmt.Errorf("password incorrect") + }), + ) +} + func ExampleNoPty() { ssh.ListenAndServe(":2222", nil, ssh.NoPty()) } diff --git a/options.go b/options.go index 303dcc3..d0591a6 100644 --- a/options.go +++ b/options.go @@ -14,6 +14,14 @@ func PasswordAuth(fn PasswordHandler) Option { } } +// PasswordAuthE returns a functional option that sets PasswordHandlerE on the server. +func PasswordAuthE(fn PasswordHandlerE) Option { + return func(srv *Server) error { + srv.PasswordHandlerE = fn + return nil + } +} + // PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. func PublicKeyAuth(fn PublicKeyHandler) Option { return func(srv *Server) error { diff --git a/options_test.go b/options_test.go index 23fca5a..2364ae7 100644 --- a/options_test.go +++ b/options_test.go @@ -69,6 +69,57 @@ func TestPasswordAuthBadPass(t *testing.T) { } } +func TestPasswordAuthE(t *testing.T) { + t.Parallel() + testUser := "testuser" + testPass := "testpass" + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, &gossh.ClientConfig{ + User: testUser, + Auth: []gossh.AuthMethod{ + gossh.Password(testPass), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }, PasswordAuthE(func(ctx Context, password string) error { + if ctx.User() != testUser { + t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) + } + if password != testPass { + t.Fatalf("user = %#v; want %#v", password, testPass) + } + return nil + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestPasswordAuthEBadPass(t *testing.T) { + t.Parallel() + l := newLocalListener() + srv := &Server{Handler: func(s Session) {}} + srv.SetOption(PasswordAuthE(func(ctx Context, password string) error { + return nil + })) + go srv.serveOnce(l) + _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + if !strings.Contains(err.Error(), "unable to authenticate") { + t.Fatal(err) + } + } +} + type wrappedConn struct { net.Conn written int32 diff --git a/server.go b/server.go index be4355e..e8ee7c5 100644 --- a/server.go +++ b/server.go @@ -40,6 +40,7 @@ type Server struct { KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler PasswordHandler PasswordHandler // password authentication handler + PasswordHandlerE PasswordHandlerE // password authentiication handler with error, if it is set, it overrides PasswordHandler PublicKeyHandler PublicKeyHandler // public key authentication handler PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling @@ -126,7 +127,7 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig { for _, signer := range srv.HostSigners { config.AddHostKey(signer) } - if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { + if srv.PasswordHandler == nil && srv.PasswordHandlerE == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { config.NoClientAuth = true } if srv.Version != "" { @@ -141,6 +142,13 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig { return ctx.Permissions().Permissions, nil } } + if srv.PasswordHandlerE != nil { + config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + err := srv.PasswordHandlerE(ctx, string(password)) + return ctx.Permissions().Permissions, err + } + } if srv.PublicKeyHandler != nil { config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { applyConnMetadata(ctx, conn) diff --git a/ssh.go b/ssh.go index fbeb150..825d96d 100644 --- a/ssh.go +++ b/ssh.go @@ -41,6 +41,9 @@ type PublicKeyHandler func(ctx Context, key PublicKey) bool // PasswordHandler is a callback for performing password authentication. type PasswordHandler func(ctx Context, password string) bool +// PasswordHandlerE is like PasswordHandler, but returns error +type PasswordHandlerE func(ctx Context, password string) error + // KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool