diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 98ae550f9e..2424ffd33d 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -27,6 +27,7 @@ import ( "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/vitess/go/mysql" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -242,6 +243,7 @@ func Serve( var remoteSrv *remotesrv.Server if serverConfig.RemotesapiPort() != nil { + port := *serverConfig.RemotesapiPort() if remoteSrvSqlCtx, err := sqlEngine.NewDefaultContext(ctx); err == nil { listenaddr := fmt.Sprintf(":%d", port) @@ -251,7 +253,11 @@ func Serve( HttpListenAddr: listenaddr, GrpcListenAddr: listenaddr, }) - args = sqle.WithUserPasswordAuth(args, remotesrv.UserAuth{User: serverConfig.User(), Password: serverConfig.Password()}) + + authen := newAuthenticator(remoteSrvSqlCtx, serverConfig, sqlEngine) + shouldAuth := len(serverConfig.User()) > 0 || len(serverConfig.Password()) > 0 + args = sqle.WithUserPasswordAuth(args, &authen, shouldAuth) + args.TLSConfig = serverConf.TLSConfig remoteSrv, err = remotesrv.NewServer(args) if err != nil { @@ -357,6 +363,30 @@ func Serve( return } +type remotesapiAuth struct { + ctx *sql.Context + serverConfig ServerConfig + sqlEngine *engine.SqlEngine +} + +func newAuthenticator(ctx *sql.Context, serverConfig ServerConfig, sqlEngine *engine.SqlEngine) remotesrv.Authenticator { + return &remotesapiAuth{ctx, serverConfig, sqlEngine} +} + +func (r *remotesapiAuth) Authenticate(ctx context.Context, creds *remotesrv.RequestCredentials) bool { + if creds == nil { + return true + } + if r.serverConfig.User() == creds.Username { + return r.serverConfig.Password() == creds.Password + } + + r.ctx.Session.SetClient(sql.Client{User: creds.Username, Address: "localhost", Capabilities: 0}) + + privOp := sql.NewDynamicPrivilegedOperation(plan.DynamicPrivilege_CloneAdmin) + return r.sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb.UserHasPrivileges(r.ctx, privOp) +} + func LoadClusterTLSConfig(cfg cluster.Config) (*tls.Config, error) { rcfg := cfg.RemotesAPIConfig() if rcfg.TLSKey() == "" && rcfg.TLSCert() == "" { diff --git a/go/libraries/doltcore/remotesrv/grpc.go b/go/libraries/doltcore/remotesrv/grpc.go index 024fc4abab..283a10fc28 100644 --- a/go/libraries/doltcore/remotesrv/grpc.go +++ b/go/libraries/doltcore/remotesrv/grpc.go @@ -504,6 +504,7 @@ func (rs *RemoteChunkStore) GetRepoMetadata(ctx context.Context, req *remotesapi if err := ValidateGetRepoMetadataRequest(req); err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } + repoPath := getRepoPath(req) logger = logger.WithField(RepoPathField, repoPath) defer func() { logger.Info("finished") }() diff --git a/go/libraries/doltcore/remotesrv/interceptors.go b/go/libraries/doltcore/remotesrv/interceptors.go index b17c591ba9..affab35663 100644 --- a/go/libraries/doltcore/remotesrv/interceptors.go +++ b/go/libraries/doltcore/remotesrv/interceptors.go @@ -16,9 +16,7 @@ package remotesrv import ( "context" - "crypto/subtle" "encoding/base64" - "fmt" "strings" "github.com/sirupsen/logrus" @@ -28,14 +26,19 @@ import ( "google.golang.org/grpc/status" ) -type UserAuth struct { - User string +type RequestCredentials struct { + Username string Password string } type ServerInterceptor struct { - Lgr *logrus.Entry - ExpectedUserAuth UserAuth + Lgr *logrus.Entry + Authenticator *Authenticator + ShouldAuth bool +} + +type Authenticator interface { + Authenticate(ctx context.Context, creds *RequestCredentials) bool } func (si *ServerInterceptor) Stream() grpc.StreamServerInterceptor { @@ -66,9 +69,10 @@ func (si *ServerInterceptor) Options() []grpc.ServerOption { } func (si *ServerInterceptor) authenticate(ctx context.Context) error { - if len(si.ExpectedUserAuth.User) == 0 && len(si.ExpectedUserAuth.Password) == 0 { + if !si.ShouldAuth || si.Authenticator == nil { return nil } + if md, ok := metadata.FromIncomingContext(ctx); ok { auths := md.Get("authorization") if len(auths) != 1 { @@ -86,14 +90,13 @@ func (si *ServerInterceptor) authenticate(ctx context.Context) error { si.Lgr.Infof("incoming request authorization header failed to decode: %v", err) return status.Error(codes.Unauthenticated, "unauthenticated") } - uExp := []byte(fmt.Sprintf("%s:%s", si.ExpectedUserAuth.User, si.ExpectedUserAuth.Password)) - compare := subtle.ConstantTimeCompare(uDec, uExp) - if compare == 0 { - - si.Lgr.Infof("incoming request authorization header failed to match") + userPass := strings.Split(string(uDec), ":") + authen := *si.Authenticator + if authed := authen.Authenticate(ctx, &RequestCredentials{Username: userPass[0], Password: userPass[1]}); !authed { return status.Error(codes.Unauthenticated, "unauthenticated") } return nil } + return status.Error(codes.Unauthenticated, "unauthenticated") } diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index 007a88ec91..a1feb56f65 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -45,7 +45,7 @@ const ( DbRevisionDelimiter = "/" ) -var ErrSessionNotPeristable = errors.New("session is not persistable") +var ErrSessionNotPersistable = errors.New("session is not persistable") // DoltSession is the sql.Session implementation used by dolt. It is accessible through a *sql.Context instance type DoltSession struct { @@ -1383,7 +1383,7 @@ func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession { // PersistGlobal implements sql.PersistableSession func (d *DoltSession) PersistGlobal(sysVarName string, value interface{}) error { if d.globalsConf == nil { - return ErrSessionNotPeristable + return ErrSessionNotPersistable } sysVar, _, err := validatePersistableSysVar(sysVarName) @@ -1399,7 +1399,7 @@ func (d *DoltSession) PersistGlobal(sysVarName string, value interface{}) error // RemovePersistedGlobal implements sql.PersistableSession func (d *DoltSession) RemovePersistedGlobal(sysVarName string) error { if d.globalsConf == nil { - return ErrSessionNotPeristable + return ErrSessionNotPersistable } sysVar, _, err := validatePersistableSysVar(sysVarName) @@ -1415,7 +1415,7 @@ func (d *DoltSession) RemovePersistedGlobal(sysVarName string) error { // RemoveAllPersistedGlobals implements sql.PersistableSession func (d *DoltSession) RemoveAllPersistedGlobals() error { if d.globalsConf == nil { - return ErrSessionNotPeristable + return ErrSessionNotPersistable } allVars := make([]string, d.globalsConf.Size()) @@ -1434,7 +1434,7 @@ func (d *DoltSession) RemoveAllPersistedGlobals() error { // RemoveAllPersistedGlobals implements sql.PersistableSession func (d *DoltSession) GetPersistedValue(k string) (interface{}, error) { if d.globalsConf == nil { - return nil, ErrSessionNotPeristable + return nil, ErrSessionNotPersistable } return getPersistedValue(d.globalsConf, k) @@ -1443,7 +1443,7 @@ func (d *DoltSession) GetPersistedValue(k string) (interface{}, error) { // SystemVariablesInConfig returns a list of System Variables associated with the session func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) { if d.globalsConf == nil { - return nil, ErrSessionNotPeristable + return nil, ErrSessionNotPersistable } sysVars, _, err := SystemVariablesInConfig(d.globalsConf) if err != nil { diff --git a/go/libraries/doltcore/sqle/remotesrv.go b/go/libraries/doltcore/sqle/remotesrv.go index c788695bf3..109fe40281 100644 --- a/go/libraries/doltcore/sqle/remotesrv.go +++ b/go/libraries/doltcore/sqle/remotesrv.go @@ -47,6 +47,7 @@ func (s remotesrvStore) Get(path, nbfVerStr string) (remotesrv.RemoteSrvStore, e return nil, err } } + sdb, ok := db.(dsess.SqlDatabase) if !ok { return nil, remotesrv.ErrUnimplemented @@ -67,10 +68,11 @@ func RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.ServerArgs) remotesrv. return args } -func WithUserPasswordAuth(args remotesrv.ServerArgs, userAuth remotesrv.UserAuth) remotesrv.ServerArgs { +func WithUserPasswordAuth(args remotesrv.ServerArgs, auth *remotesrv.Authenticator, shouldAuth bool) remotesrv.ServerArgs { si := remotesrv.ServerInterceptor{ - Lgr: args.Logger, - ExpectedUserAuth: userAuth, + Lgr: args.Logger, + Authenticator: auth, + ShouldAuth: shouldAuth, } args.Options = append(args.Options, si.Options()...) return args diff --git a/integration-tests/bats/sql-server-remotesrv.bats b/integration-tests/bats/sql-server-remotesrv.bats index aa70568790..c978eead1e 100644 --- a/integration-tests/bats/sql-server-remotesrv.bats +++ b/integration-tests/bats/sql-server-remotesrv.bats @@ -161,11 +161,10 @@ SQL [[ "$output" =~ "| 5 " ]] || false } -@test "sql-server-remotesrv: can read from sql-server with --remotesapi-port with clone/fetch/pull authentication" { +@test "sql-server-remotesrv: clone/fetch/pull from remotesapi port with authentication" { mkdir remote cd remote dolt init - dolt --privilege-file=privs.json sql -q "CREATE USER user IDENTIFIED BY 'pass0'" dolt sql -q 'create table vals (i int);' dolt sql -q 'insert into vals (i) values (1), (2), (3), (4), (5);' dolt add vals @@ -233,7 +232,86 @@ SQL [[ "$output" =~ "11" ]] || false } -@test "sql-server-remotesrv: dolt clone without authentication errors" { +@test "sql-server-remotesrv: clone/fetch/pull from remotesapi port with clone_admin authentication" { + mkdir remote + cd remote + dolt init + dolt sql -q 'create table vals (i int);' + dolt sql -q 'insert into vals (i) values (1), (2), (3), (4), (5);' + dolt add vals + dolt commit -m 'initial vals.' + + dolt sql-server --port 3307 -u user0 -p pass0 --remotesapi-port 50051 & + srv_pid=$! + sleep 2 # wait for server to start so we don't lock it out + + run dolt sql-client --port 3307 -u user0 -p pass0 <