Skip to content

Commit

Permalink
Support clone_admin grants for authenticating to remotesapi
Browse files Browse the repository at this point in the history
  • Loading branch information
tbantle22 committed Jul 21, 2023
1 parent 61acf32 commit e429857
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 26 deletions.
32 changes: 31 additions & 1 deletion go/cmd/dolt/commands/sqlserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/mysql"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -242,6 +243,7 @@ func Serve(

var remoteSrv *remotesrv.Server
if serverConfig.RemotesapiPort() != nil {

port := *serverConfig.RemotesapiPort()
if remoteSrvSqlCtx, err := sqlEngine.NewDefaultContext(ctx); err == nil {
listenaddr := fmt.Sprintf(":%d", port)
Expand All @@ -251,7 +253,11 @@ func Serve(
HttpListenAddr: listenaddr,
GrpcListenAddr: listenaddr,
})
args = sqle.WithUserPasswordAuth(args, remotesrv.UserAuth{User: serverConfig.User(), Password: serverConfig.Password()})

authen := newAuthenticator(remoteSrvSqlCtx, serverConfig, sqlEngine)
shouldAuth := len(serverConfig.User()) > 0 || len(serverConfig.Password()) > 0
args = sqle.WithUserPasswordAuth(args, &authen, shouldAuth)

args.TLSConfig = serverConf.TLSConfig
remoteSrv, err = remotesrv.NewServer(args)
if err != nil {
Expand Down Expand Up @@ -357,6 +363,30 @@ func Serve(
return
}

type remotesapiAuth struct {
ctx *sql.Context
serverConfig ServerConfig
sqlEngine *engine.SqlEngine
}

func newAuthenticator(ctx *sql.Context, serverConfig ServerConfig, sqlEngine *engine.SqlEngine) remotesrv.Authenticator {
return &remotesapiAuth{ctx, serverConfig, sqlEngine}
}

func (r *remotesapiAuth) Authenticate(ctx context.Context, creds *remotesrv.RequestCredentials) bool {
if creds == nil {
return true
}
if r.serverConfig.User() == creds.Username {
return r.serverConfig.Password() == creds.Password
}

r.ctx.Session.SetClient(sql.Client{User: creds.Username, Address: "localhost", Capabilities: 0})

privOp := sql.NewDynamicPrivilegedOperation(plan.DynamicPrivilege_CloneAdmin)
return r.sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb.UserHasPrivileges(r.ctx, privOp)
}

func LoadClusterTLSConfig(cfg cluster.Config) (*tls.Config, error) {
rcfg := cfg.RemotesAPIConfig()
if rcfg.TLSKey() == "" && rcfg.TLSCert() == "" {
Expand Down
1 change: 1 addition & 0 deletions go/libraries/doltcore/remotesrv/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ func (rs *RemoteChunkStore) GetRepoMetadata(ctx context.Context, req *remotesapi
if err := ValidateGetRepoMetadataRequest(req); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}

repoPath := getRepoPath(req)
logger = logger.WithField(RepoPathField, repoPath)
defer func() { logger.Info("finished") }()
Expand Down
27 changes: 15 additions & 12 deletions go/libraries/doltcore/remotesrv/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ package remotesrv

import (
"context"
"crypto/subtle"
"encoding/base64"
"fmt"
"strings"

"github.com/sirupsen/logrus"
Expand All @@ -28,14 +26,19 @@ import (
"google.golang.org/grpc/status"
)

type UserAuth struct {
User string
type RequestCredentials struct {
Username string
Password string
}

type ServerInterceptor struct {
Lgr *logrus.Entry
ExpectedUserAuth UserAuth
Lgr *logrus.Entry
Authenticator *Authenticator
ShouldAuth bool
}

type Authenticator interface {
Authenticate(ctx context.Context, creds *RequestCredentials) bool
}

func (si *ServerInterceptor) Stream() grpc.StreamServerInterceptor {
Expand Down Expand Up @@ -66,9 +69,10 @@ func (si *ServerInterceptor) Options() []grpc.ServerOption {
}

func (si *ServerInterceptor) authenticate(ctx context.Context) error {
if len(si.ExpectedUserAuth.User) == 0 && len(si.ExpectedUserAuth.Password) == 0 {
if !si.ShouldAuth || si.Authenticator == nil {
return nil
}

if md, ok := metadata.FromIncomingContext(ctx); ok {
auths := md.Get("authorization")
if len(auths) != 1 {
Expand All @@ -86,14 +90,13 @@ func (si *ServerInterceptor) authenticate(ctx context.Context) error {
si.Lgr.Infof("incoming request authorization header failed to decode: %v", err)
return status.Error(codes.Unauthenticated, "unauthenticated")
}
uExp := []byte(fmt.Sprintf("%s:%s", si.ExpectedUserAuth.User, si.ExpectedUserAuth.Password))
compare := subtle.ConstantTimeCompare(uDec, uExp)
if compare == 0 {

si.Lgr.Infof("incoming request authorization header failed to match")
userPass := strings.Split(string(uDec), ":")
authen := *si.Authenticator
if authed := authen.Authenticate(ctx, &RequestCredentials{Username: userPass[0], Password: userPass[1]}); !authed {
return status.Error(codes.Unauthenticated, "unauthenticated")
}
return nil
}

return status.Error(codes.Unauthenticated, "unauthenticated")
}
12 changes: 6 additions & 6 deletions go/libraries/doltcore/sqle/dsess/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const (
DbRevisionDelimiter = "/"
)

var ErrSessionNotPeristable = errors.New("session is not persistable")
var ErrSessionNotPersistable = errors.New("session is not persistable")

// DoltSession is the sql.Session implementation used by dolt. It is accessible through a *sql.Context instance
type DoltSession struct {
Expand Down Expand Up @@ -1383,7 +1383,7 @@ func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession {
// PersistGlobal implements sql.PersistableSession
func (d *DoltSession) PersistGlobal(sysVarName string, value interface{}) error {
if d.globalsConf == nil {
return ErrSessionNotPeristable
return ErrSessionNotPersistable
}

sysVar, _, err := validatePersistableSysVar(sysVarName)
Expand All @@ -1399,7 +1399,7 @@ func (d *DoltSession) PersistGlobal(sysVarName string, value interface{}) error
// RemovePersistedGlobal implements sql.PersistableSession
func (d *DoltSession) RemovePersistedGlobal(sysVarName string) error {
if d.globalsConf == nil {
return ErrSessionNotPeristable
return ErrSessionNotPersistable
}

sysVar, _, err := validatePersistableSysVar(sysVarName)
Expand All @@ -1415,7 +1415,7 @@ func (d *DoltSession) RemovePersistedGlobal(sysVarName string) error {
// RemoveAllPersistedGlobals implements sql.PersistableSession
func (d *DoltSession) RemoveAllPersistedGlobals() error {
if d.globalsConf == nil {
return ErrSessionNotPeristable
return ErrSessionNotPersistable
}

allVars := make([]string, d.globalsConf.Size())
Expand All @@ -1434,7 +1434,7 @@ func (d *DoltSession) RemoveAllPersistedGlobals() error {
// RemoveAllPersistedGlobals implements sql.PersistableSession
func (d *DoltSession) GetPersistedValue(k string) (interface{}, error) {
if d.globalsConf == nil {
return nil, ErrSessionNotPeristable
return nil, ErrSessionNotPersistable
}

return getPersistedValue(d.globalsConf, k)
Expand All @@ -1443,7 +1443,7 @@ func (d *DoltSession) GetPersistedValue(k string) (interface{}, error) {
// SystemVariablesInConfig returns a list of System Variables associated with the session
func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) {
if d.globalsConf == nil {
return nil, ErrSessionNotPeristable
return nil, ErrSessionNotPersistable
}
sysVars, _, err := SystemVariablesInConfig(d.globalsConf)
if err != nil {
Expand Down
8 changes: 5 additions & 3 deletions go/libraries/doltcore/sqle/remotesrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (s remotesrvStore) Get(path, nbfVerStr string) (remotesrv.RemoteSrvStore, e
return nil, err
}
}

sdb, ok := db.(dsess.SqlDatabase)
if !ok {
return nil, remotesrv.ErrUnimplemented
Expand All @@ -67,10 +68,11 @@ func RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.ServerArgs) remotesrv.
return args
}

func WithUserPasswordAuth(args remotesrv.ServerArgs, userAuth remotesrv.UserAuth) remotesrv.ServerArgs {
func WithUserPasswordAuth(args remotesrv.ServerArgs, auth *remotesrv.Authenticator, shouldAuth bool) remotesrv.ServerArgs {
si := remotesrv.ServerInterceptor{
Lgr: args.Logger,
ExpectedUserAuth: userAuth,
Lgr: args.Logger,
Authenticator: auth,
ShouldAuth: shouldAuth,
}
args.Options = append(args.Options, si.Options()...)
return args
Expand Down
86 changes: 82 additions & 4 deletions integration-tests/bats/sql-server-remotesrv.bats
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,10 @@ SQL
[[ "$output" =~ "| 5 " ]] || false
}

@test "sql-server-remotesrv: can read from sql-server with --remotesapi-port with clone/fetch/pull authentication" {
@test "sql-server-remotesrv: clone/fetch/pull from remotesapi port with authentication" {
mkdir remote
cd remote
dolt init
dolt --privilege-file=privs.json sql -q "CREATE USER user IDENTIFIED BY 'pass0'"
dolt sql -q 'create table vals (i int);'
dolt sql -q 'insert into vals (i) values (1), (2), (3), (4), (5);'
dolt add vals
Expand Down Expand Up @@ -233,7 +232,86 @@ SQL
[[ "$output" =~ "11" ]] || false
}

@test "sql-server-remotesrv: dolt clone without authentication errors" {
@test "sql-server-remotesrv: clone/fetch/pull from remotesapi port with clone_admin authentication" {
mkdir remote
cd remote
dolt init
dolt sql -q 'create table vals (i int);'
dolt sql -q 'insert into vals (i) values (1), (2), (3), (4), (5);'
dolt add vals
dolt commit -m 'initial vals.'

dolt sql-server --port 3307 -u user0 -p pass0 --remotesapi-port 50051 &
srv_pid=$!
sleep 2 # wait for server to start so we don't lock it out

run dolt sql-client --port 3307 -u user0 -p pass0 <<SQL
CREATE USER clone_admin_user@'%' IDENTIFIED BY 'pass1';
GRANT CLONE_ADMIN ON *.* TO clone_admin_user@'%';
select user from mysql.user;
SQL
[ $status -eq 0 ]
[[ $output =~ user0 ]] || false
[[ $output =~ clone_admin_user ]] || false

export DOLT_REMOTE_PASSWORD="pass1"
cd ../
dolt clone http://localhost:50051/remote repo1 -u clone_admin_user
cd repo1
run dolt ls
[[ "$output" =~ "vals" ]] || false
run dolt sql -q 'select count(*) from vals'
[[ "$output" =~ "5" ]] || false

dolt sql-client --port 3307 -u user0 -p pass0 <<SQL
use remote;
call dolt_checkout('-b', 'new_branch');
insert into vals (i) values (6), (7), (8), (9), (10);
call dolt_commit('-am', 'add some vals');
SQL

run dolt branch -v -a
[ "$status" -eq 0 ]
[[ "$output" =~ "remotes/origin/main" ]] || false
[[ ! "$output" =~ "remotes/origin/new_branch" ]] || false

# No auth fetch
run dolt fetch
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false

# # With auth fetch
run dolt fetch -u clone_admin_user
[[ "$status" -eq 0 ]] || false

run dolt branch -v -a
[ "$status" -eq 0 ]
[[ "$output" =~ "remotes/origin/main" ]] || false
[[ "$output" =~ "remotes/origin/new_branch" ]] || false

run dolt checkout new_branch
[[ "$status" -eq 0 ]] || false

dolt sql-client --port 3307 -u user0 -p pass0 <<SQL
use remote;
call dolt_checkout('new_branch');
insert into vals (i) values (11);
call dolt_commit('-am', 'add one val');
SQL

# No auth pull
run dolt pull
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false

# With auth pull
run dolt pull -u clone_admin_user
[[ "$status" -eq 0 ]] || false
run dolt sql -q 'select count(*) from vals;'
[[ "$output" =~ "11" ]] || false
}

@test "sql-server-remotesrv: dolt clone without authentication returns error" {
mkdir remote
cd remote
dolt init
Expand All @@ -254,7 +332,7 @@ SQL
[[ "$output" =~ "Unauthenticated" ]] || false
}

@test "sql-server-remotesrv: dolt clone with incorrect authentication errors" {
@test "sql-server-remotesrv: dolt clone with incorrect authentication returns error" {
mkdir remote
cd remote
dolt init
Expand Down

0 comments on commit e429857

Please sign in to comment.