diff --git a/server/auth/range_perm_cache.go b/server/auth/range_perm_cache.go index bae07ef5242..2ebe5439b58 100644 --- a/server/auth/range_perm_cache.go +++ b/server/auth/range_perm_cache.go @@ -20,7 +20,7 @@ import ( "go.uber.org/zap" ) -func getMergedPerms(tx AuthBatchTx, userName string) *unifiedRangePermissions { +func getMergedPerms(tx AuthReadTx, userName string) *unifiedRangePermissions { user := tx.UnsafeGetUser(userName) if user == nil { return nil @@ -103,7 +103,7 @@ func checkKeyPoint(lg *zap.Logger, cachedPerms *unifiedRangePermissions, key []b return false } -func (as *authStore) isRangeOpPermitted(tx AuthBatchTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool { +func (as *authStore) isRangeOpPermitted(tx AuthReadTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool { // assumption: tx is Lock()ed _, ok := as.rangePermCache[userName] if !ok { diff --git a/server/auth/store.go b/server/auth/store.go index 408b235babd..762caecd780 100644 --- a/server/auth/store.go +++ b/server/auth/store.go @@ -196,6 +196,7 @@ type TokenProvider interface { type AuthBackend interface { CreateAuthBuckets() ForceCommit() + ReadTx() AuthReadTx BatchTx() AuthBatchTx GetUser(string) *authpb.User @@ -345,7 +346,7 @@ func (as *authStore) CheckPassword(username, password string) (uint64, error) { // CompareHashAndPassword is very expensive, so we use closures // to avoid putting it in the critical section of the tx lock. revision, err := func() (uint64, error) { - tx := as.be.BatchTx() + tx := as.be.ReadTx() tx.Lock() defer tx.Unlock() @@ -855,7 +856,7 @@ func (as *authStore) isOpPermitted(userName string, revision uint64, key, rangeE return ErrAuthOldRevision } - tx := as.be.BatchTx() + tx := as.be.ReadTx() tx.Lock() defer tx.Unlock() @@ -897,7 +898,10 @@ func (as *authStore) IsAdminPermitted(authInfo *AuthInfo) error { return ErrUserEmpty } - u := as.be.GetUser(authInfo.Username) + tx := as.be.ReadTx() + tx.Lock() + defer tx.Unlock() + u := tx.UnsafeGetUser(authInfo.Username) if u == nil { return ErrUserNotFound @@ -935,6 +939,8 @@ func NewAuthStore(lg *zap.Logger, be AuthBackend, tp TokenProvider, bcryptCost i be.CreateAuthBuckets() tx := be.BatchTx() + // We should call LockWithoutHook here, but the txPostLockHoos isn't set + // to EtcdServer yet, so it's OK. tx.Lock() enabled := tx.UnsafeReadAuthEnabled() as := &authStore{ diff --git a/server/auth/store_mock_test.go b/server/auth/store_mock_test.go index d49f8dd333f..39c3f6d139a 100644 --- a/server/auth/store_mock_test.go +++ b/server/auth/store_mock_test.go @@ -36,6 +36,10 @@ func (b *backendMock) CreateAuthBuckets() { func (b *backendMock) ForceCommit() { } +func (b *backendMock) ReadTx() AuthReadTx { + return &txMock{be: b} +} + func (b *backendMock) BatchTx() AuthBatchTx { return &txMock{be: b} } diff --git a/server/etcdserver/cindex/cindex.go b/server/etcdserver/cindex/cindex.go index 7ec1b121283..6367967f875 100644 --- a/server/etcdserver/cindex/cindex.go +++ b/server/etcdserver/cindex/cindex.go @@ -89,7 +89,7 @@ func (ci *consistentIndex) UnsafeConsistentIndex() uint64 { return index } - v, term := schema.UnsafeReadConsistentIndex(ci.be.BatchTx()) + v, term := schema.UnsafeReadConsistentIndex(ci.be.ReadTx()) ci.SetConsistentIndex(v, term) return v } diff --git a/server/etcdserver/server.go b/server/etcdserver/server.go index dc36b5cc4b8..73eaa3a7182 100644 --- a/server/etcdserver/server.go +++ b/server/etcdserver/server.go @@ -343,7 +343,6 @@ func NewServer(cfg config.ServerConfig) (srv *EtcdServer, err error) { srv.applyV2 = NewApplierV2(cfg.Logger, srv.v2store, srv.cluster) srv.be = b.storage.backend.be - srv.be.SetTxPostLockHook(srv.getTxPostLockHook()) srv.beHooks = b.storage.backend.beHooks minTTL := time.Duration((3*cfg.ElectionTicks)/2) * heartbeat @@ -404,6 +403,10 @@ func NewServer(cfg config.ServerConfig) (srv *EtcdServer, err error) { }) } + // Set the hook after EtcdServer finishes the initialization to avoid + // the hook being called during the initialization process. + srv.be.SetTxPostLockHook(srv.getTxPostLockHook()) + // TODO: move transport initialization near the definition of remote tr := &rafthttp.Transport{ Logger: cfg.Logger, diff --git a/server/storage/schema/auth.go b/server/storage/schema/auth.go index fc334a8bcf9..3956ca782f9 100644 --- a/server/storage/schema/auth.go +++ b/server/storage/schema/auth.go @@ -60,15 +60,25 @@ func (abe *authBackend) ForceCommit() { abe.be.ForceCommit() } +func (abe *authBackend) ReadTx() auth.AuthReadTx { + return &authReadTx{tx: abe.be.ReadTx(), lg: abe.lg} +} + func (abe *authBackend) BatchTx() auth.AuthBatchTx { return &authBatchTx{tx: abe.be.BatchTx(), lg: abe.lg} } +type authReadTx struct { + tx backend.ReadTx + lg *zap.Logger +} + type authBatchTx struct { tx backend.BatchTx lg *zap.Logger } +var _ auth.AuthReadTx = (*authReadTx)(nil) var _ auth.AuthBatchTx = (*authBatchTx)(nil) func (atx *authBatchTx) UnsafeSaveAuthEnabled(enabled bool) { @@ -86,6 +96,24 @@ func (atx *authBatchTx) UnsafeSaveAuthRevision(rev uint64) { } func (atx *authBatchTx) UnsafeReadAuthEnabled() bool { + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeReadAuthEnabled() +} + +func (atx *authBatchTx) UnsafeReadAuthRevision() uint64 { + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeReadAuthRevision() +} + +func (atx *authBatchTx) Lock() { + atx.tx.Lock() +} + +func (atx *authBatchTx) Unlock() { + atx.tx.Unlock() +} + +func (atx *authReadTx) UnsafeReadAuthEnabled() bool { _, vs := atx.tx.UnsafeRange(Auth, AuthEnabledKeyName, nil, 0) if len(vs) == 1 { if bytes.Equal(vs[0], authEnabled) { @@ -95,7 +123,7 @@ func (atx *authBatchTx) UnsafeReadAuthEnabled() bool { return false } -func (atx *authBatchTx) UnsafeReadAuthRevision() uint64 { +func (atx *authReadTx) UnsafeReadAuthRevision() uint64 { _, vs := atx.tx.UnsafeRange(Auth, AuthRevisionKeyName, nil, 0) if len(vs) != 1 { // this can happen in the initialization phase @@ -104,10 +132,10 @@ func (atx *authBatchTx) UnsafeReadAuthRevision() uint64 { return binary.BigEndian.Uint64(vs[0]) } -func (atx *authBatchTx) Lock() { - atx.tx.Lock() +func (atx *authReadTx) Lock() { + atx.tx.RLock() } -func (atx *authBatchTx) Unlock() { - atx.tx.Unlock() +func (atx *authReadTx) Unlock() { + atx.tx.RUnlock() } diff --git a/server/storage/schema/auth_roles.go b/server/storage/schema/auth_roles.go index 541e37b7191..dfda7ce5b7b 100644 --- a/server/storage/schema/auth_roles.go +++ b/server/storage/schema/auth_roles.go @@ -32,6 +32,40 @@ func (abe *authBackend) GetRole(roleName string) *authpb.Role { } func (atx *authBatchTx) UnsafeGetRole(roleName string) *authpb.Role { + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeGetRole(roleName) +} + +func (abe *authBackend) GetAllRoles() []*authpb.Role { + tx := abe.BatchTx() + tx.Lock() + defer tx.Unlock() + return tx.UnsafeGetAllRoles() +} + +func (atx *authBatchTx) UnsafeGetAllRoles() []*authpb.Role { + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeGetAllRoles() +} + +func (atx *authBatchTx) UnsafePutRole(role *authpb.Role) { + b, err := role.Marshal() + if err != nil { + atx.lg.Panic( + "failed to marshal 'authpb.Role'", + zap.String("role-name", string(role.Name)), + zap.Error(err), + ) + } + + atx.tx.UnsafePut(AuthRoles, role.Name, b) +} + +func (atx *authBatchTx) UnsafeDeleteRole(rolename string) { + atx.tx.UnsafeDelete(AuthRoles, []byte(rolename)) +} + +func (atx *authReadTx) UnsafeGetRole(roleName string) *authpb.Role { _, vs := atx.tx.UnsafeRange(AuthRoles, []byte(roleName), nil, 0) if len(vs) == 0 { return nil @@ -45,14 +79,7 @@ func (atx *authBatchTx) UnsafeGetRole(roleName string) *authpb.Role { return role } -func (abe *authBackend) GetAllRoles() []*authpb.Role { - tx := abe.BatchTx() - tx.Lock() - defer tx.Unlock() - return tx.UnsafeGetAllRoles() -} - -func (atx *authBatchTx) UnsafeGetAllRoles() []*authpb.Role { +func (atx *authReadTx) UnsafeGetAllRoles() []*authpb.Role { _, vs := atx.tx.UnsafeRange(AuthRoles, []byte{0}, []byte{0xff}, -1) if len(vs) == 0 { return nil @@ -69,20 +96,3 @@ func (atx *authBatchTx) UnsafeGetAllRoles() []*authpb.Role { } return roles } - -func (atx *authBatchTx) UnsafePutRole(role *authpb.Role) { - b, err := role.Marshal() - if err != nil { - atx.lg.Panic( - "failed to marshal 'authpb.Role'", - zap.String("role-name", string(role.Name)), - zap.Error(err), - ) - } - - atx.tx.UnsafePut(AuthRoles, role.Name, b) -} - -func (atx *authBatchTx) UnsafeDeleteRole(rolename string) { - atx.tx.UnsafeDelete(AuthRoles, []byte(rolename)) -} diff --git a/server/storage/schema/auth_users.go b/server/storage/schema/auth_users.go index f385afa5122..c3e7a92ff39 100644 --- a/server/storage/schema/auth_users.go +++ b/server/storage/schema/auth_users.go @@ -27,6 +27,35 @@ func (abe *authBackend) GetUser(username string) *authpb.User { } func (atx *authBatchTx) UnsafeGetUser(username string) *authpb.User { + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeGetUser(username) +} + +func (abe *authBackend) GetAllUsers() []*authpb.User { + tx := abe.BatchTx() + tx.Lock() + defer tx.Unlock() + return tx.UnsafeGetAllUsers() +} + +func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User { + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeGetAllUsers() +} + +func (atx *authBatchTx) UnsafePutUser(user *authpb.User) { + b, err := user.Marshal() + if err != nil { + atx.lg.Panic("failed to unmarshal 'authpb.User'", zap.Error(err)) + } + atx.tx.UnsafePut(AuthUsers, user.Name, b) +} + +func (atx *authBatchTx) UnsafeDeleteUser(username string) { + atx.tx.UnsafeDelete(AuthUsers, []byte(username)) +} + +func (atx *authReadTx) UnsafeGetUser(username string) *authpb.User { _, vs := atx.tx.UnsafeRange(AuthUsers, []byte(username), nil, 0) if len(vs) == 0 { return nil @@ -44,14 +73,7 @@ func (atx *authBatchTx) UnsafeGetUser(username string) *authpb.User { return user } -func (abe *authBackend) GetAllUsers() []*authpb.User { - tx := abe.BatchTx() - tx.Lock() - defer tx.Unlock() - return tx.UnsafeGetAllUsers() -} - -func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User { +func (atx *authReadTx) UnsafeGetAllUsers() []*authpb.User { _, vs := atx.tx.UnsafeRange(AuthUsers, []byte{0}, []byte{0xff}, -1) if len(vs) == 0 { return nil @@ -68,15 +90,3 @@ func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User { } return users } - -func (atx *authBatchTx) UnsafePutUser(user *authpb.User) { - b, err := user.Marshal() - if err != nil { - atx.lg.Panic("failed to unmarshal 'authpb.User'", zap.Error(err)) - } - atx.tx.UnsafePut(AuthUsers, user.Name, b) -} - -func (atx *authBatchTx) UnsafeDeleteUser(username string) { - atx.tx.UnsafeDelete(AuthUsers, []byte(username)) -}