Skip to content

Commit

Permalink
enhanced authBackend to support authReadTx
Browse files Browse the repository at this point in the history
  • Loading branch information
ahrtr committed Apr 2, 2022
1 parent dad9b9d commit a7ac307
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 57 deletions.
4 changes: 2 additions & 2 deletions server/auth/range_perm_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"go.uber.org/zap"
)

func getMergedPerms(tx AuthBatchTx, userName string) *unifiedRangePermissions {
func getMergedPerms(tx AuthReadTx, userName string) *unifiedRangePermissions {
user := tx.UnsafeGetUser(userName)
if user == nil {
return nil
Expand Down Expand Up @@ -103,7 +103,7 @@ func checkKeyPoint(lg *zap.Logger, cachedPerms *unifiedRangePermissions, key []b
return false
}

func (as *authStore) isRangeOpPermitted(tx AuthBatchTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool {
func (as *authStore) isRangeOpPermitted(tx AuthReadTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool {
// assumption: tx is Lock()ed
_, ok := as.rangePermCache[userName]
if !ok {
Expand Down
12 changes: 9 additions & 3 deletions server/auth/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ type TokenProvider interface {
type AuthBackend interface {
CreateAuthBuckets()
ForceCommit()
ReadTx() AuthReadTx
BatchTx() AuthBatchTx

GetUser(string) *authpb.User
Expand Down Expand Up @@ -345,7 +346,7 @@ func (as *authStore) CheckPassword(username, password string) (uint64, error) {
// CompareHashAndPassword is very expensive, so we use closures
// to avoid putting it in the critical section of the tx lock.
revision, err := func() (uint64, error) {
tx := as.be.BatchTx()
tx := as.be.ReadTx()
tx.Lock()
defer tx.Unlock()

Expand Down Expand Up @@ -855,7 +856,7 @@ func (as *authStore) isOpPermitted(userName string, revision uint64, key, rangeE
return ErrAuthOldRevision
}

tx := as.be.BatchTx()
tx := as.be.ReadTx()
tx.Lock()
defer tx.Unlock()

Expand Down Expand Up @@ -897,7 +898,10 @@ func (as *authStore) IsAdminPermitted(authInfo *AuthInfo) error {
return ErrUserEmpty
}

u := as.be.GetUser(authInfo.Username)
tx := as.be.ReadTx()
tx.Lock()
defer tx.Unlock()
u := tx.UnsafeGetUser(authInfo.Username)

if u == nil {
return ErrUserNotFound
Expand Down Expand Up @@ -935,6 +939,8 @@ func NewAuthStore(lg *zap.Logger, be AuthBackend, tp TokenProvider, bcryptCost i

be.CreateAuthBuckets()
tx := be.BatchTx()
// We should call LockWithoutHook here, but the txPostLockHoos isn't set
// to EtcdServer yet, so it's OK.
tx.Lock()
enabled := tx.UnsafeReadAuthEnabled()
as := &authStore{
Expand Down
4 changes: 4 additions & 0 deletions server/auth/store_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func (b *backendMock) CreateAuthBuckets() {
func (b *backendMock) ForceCommit() {
}

func (b *backendMock) ReadTx() AuthReadTx {
return &txMock{be: b}
}

func (b *backendMock) BatchTx() AuthBatchTx {
return &txMock{be: b}
}
Expand Down
2 changes: 1 addition & 1 deletion server/etcdserver/cindex/cindex.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (ci *consistentIndex) UnsafeConsistentIndex() uint64 {
return index
}

v, term := schema.UnsafeReadConsistentIndex(ci.be.BatchTx())
v, term := schema.UnsafeReadConsistentIndex(ci.be.ReadTx())
ci.SetConsistentIndex(v, term)
return v
}
Expand Down
5 changes: 4 additions & 1 deletion server/etcdserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ func NewServer(cfg config.ServerConfig) (srv *EtcdServer, err error) {
srv.applyV2 = NewApplierV2(cfg.Logger, srv.v2store, srv.cluster)

srv.be = b.storage.backend.be
srv.be.SetTxPostLockHook(srv.getTxPostLockHook())
srv.beHooks = b.storage.backend.beHooks
minTTL := time.Duration((3*cfg.ElectionTicks)/2) * heartbeat

Expand Down Expand Up @@ -404,6 +403,10 @@ func NewServer(cfg config.ServerConfig) (srv *EtcdServer, err error) {
})
}

// Set the hook after EtcdServer finishes the initialization to avoid
// the hook being called during the initialization process.
srv.be.SetTxPostLockHook(srv.getTxPostLockHook())

// TODO: move transport initialization near the definition of remote
tr := &rafthttp.Transport{
Logger: cfg.Logger,
Expand Down
38 changes: 33 additions & 5 deletions server/storage/schema/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,25 @@ func (abe *authBackend) ForceCommit() {
abe.be.ForceCommit()
}

func (abe *authBackend) ReadTx() auth.AuthReadTx {
return &authReadTx{tx: abe.be.ReadTx(), lg: abe.lg}
}

func (abe *authBackend) BatchTx() auth.AuthBatchTx {
return &authBatchTx{tx: abe.be.BatchTx(), lg: abe.lg}
}

type authReadTx struct {
tx backend.ReadTx
lg *zap.Logger
}

type authBatchTx struct {
tx backend.BatchTx
lg *zap.Logger
}

var _ auth.AuthReadTx = (*authReadTx)(nil)
var _ auth.AuthBatchTx = (*authBatchTx)(nil)

func (atx *authBatchTx) UnsafeSaveAuthEnabled(enabled bool) {
Expand All @@ -86,6 +96,24 @@ func (atx *authBatchTx) UnsafeSaveAuthRevision(rev uint64) {
}

func (atx *authBatchTx) UnsafeReadAuthEnabled() bool {
arx := &authReadTx{tx: atx.tx, lg: atx.lg}
return arx.UnsafeReadAuthEnabled()
}

func (atx *authBatchTx) UnsafeReadAuthRevision() uint64 {
arx := &authReadTx{tx: atx.tx, lg: atx.lg}
return arx.UnsafeReadAuthRevision()
}

func (atx *authBatchTx) Lock() {
atx.tx.Lock()
}

func (atx *authBatchTx) Unlock() {
atx.tx.Unlock()
}

func (atx *authReadTx) UnsafeReadAuthEnabled() bool {
_, vs := atx.tx.UnsafeRange(Auth, AuthEnabledKeyName, nil, 0)
if len(vs) == 1 {
if bytes.Equal(vs[0], authEnabled) {
Expand All @@ -95,7 +123,7 @@ func (atx *authBatchTx) UnsafeReadAuthEnabled() bool {
return false
}

func (atx *authBatchTx) UnsafeReadAuthRevision() uint64 {
func (atx *authReadTx) UnsafeReadAuthRevision() uint64 {
_, vs := atx.tx.UnsafeRange(Auth, AuthRevisionKeyName, nil, 0)
if len(vs) != 1 {
// this can happen in the initialization phase
Expand All @@ -104,10 +132,10 @@ func (atx *authBatchTx) UnsafeReadAuthRevision() uint64 {
return binary.BigEndian.Uint64(vs[0])
}

func (atx *authBatchTx) Lock() {
atx.tx.Lock()
func (atx *authReadTx) Lock() {
atx.tx.RLock()
}

func (atx *authBatchTx) Unlock() {
atx.tx.Unlock()
func (atx *authReadTx) Unlock() {
atx.tx.RUnlock()
}
60 changes: 35 additions & 25 deletions server/storage/schema/auth_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,40 @@ func (abe *authBackend) GetRole(roleName string) *authpb.Role {
}

func (atx *authBatchTx) UnsafeGetRole(roleName string) *authpb.Role {
arx := &authReadTx{tx: atx.tx, lg: atx.lg}
return arx.UnsafeGetRole(roleName)
}

func (abe *authBackend) GetAllRoles() []*authpb.Role {
tx := abe.BatchTx()
tx.Lock()
defer tx.Unlock()
return tx.UnsafeGetAllRoles()
}

func (atx *authBatchTx) UnsafeGetAllRoles() []*authpb.Role {
arx := &authReadTx{tx: atx.tx, lg: atx.lg}
return arx.UnsafeGetAllRoles()
}

func (atx *authBatchTx) UnsafePutRole(role *authpb.Role) {
b, err := role.Marshal()
if err != nil {
atx.lg.Panic(
"failed to marshal 'authpb.Role'",
zap.String("role-name", string(role.Name)),
zap.Error(err),
)
}

atx.tx.UnsafePut(AuthRoles, role.Name, b)
}

func (atx *authBatchTx) UnsafeDeleteRole(rolename string) {
atx.tx.UnsafeDelete(AuthRoles, []byte(rolename))
}

func (atx *authReadTx) UnsafeGetRole(roleName string) *authpb.Role {
_, vs := atx.tx.UnsafeRange(AuthRoles, []byte(roleName), nil, 0)
if len(vs) == 0 {
return nil
Expand All @@ -45,14 +79,7 @@ func (atx *authBatchTx) UnsafeGetRole(roleName string) *authpb.Role {
return role
}

func (abe *authBackend) GetAllRoles() []*authpb.Role {
tx := abe.BatchTx()
tx.Lock()
defer tx.Unlock()
return tx.UnsafeGetAllRoles()
}

func (atx *authBatchTx) UnsafeGetAllRoles() []*authpb.Role {
func (atx *authReadTx) UnsafeGetAllRoles() []*authpb.Role {
_, vs := atx.tx.UnsafeRange(AuthRoles, []byte{0}, []byte{0xff}, -1)
if len(vs) == 0 {
return nil
Expand All @@ -69,20 +96,3 @@ func (atx *authBatchTx) UnsafeGetAllRoles() []*authpb.Role {
}
return roles
}

func (atx *authBatchTx) UnsafePutRole(role *authpb.Role) {
b, err := role.Marshal()
if err != nil {
atx.lg.Panic(
"failed to marshal 'authpb.Role'",
zap.String("role-name", string(role.Name)),
zap.Error(err),
)
}

atx.tx.UnsafePut(AuthRoles, role.Name, b)
}

func (atx *authBatchTx) UnsafeDeleteRole(rolename string) {
atx.tx.UnsafeDelete(AuthRoles, []byte(rolename))
}
50 changes: 30 additions & 20 deletions server/storage/schema/auth_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,35 @@ func (abe *authBackend) GetUser(username string) *authpb.User {
}

func (atx *authBatchTx) UnsafeGetUser(username string) *authpb.User {
arx := &authReadTx{tx: atx.tx, lg: atx.lg}
return arx.UnsafeGetUser(username)
}

func (abe *authBackend) GetAllUsers() []*authpb.User {
tx := abe.BatchTx()
tx.Lock()
defer tx.Unlock()
return tx.UnsafeGetAllUsers()
}

func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User {
arx := &authReadTx{tx: atx.tx, lg: atx.lg}
return arx.UnsafeGetAllUsers()
}

func (atx *authBatchTx) UnsafePutUser(user *authpb.User) {
b, err := user.Marshal()
if err != nil {
atx.lg.Panic("failed to unmarshal 'authpb.User'", zap.Error(err))
}
atx.tx.UnsafePut(AuthUsers, user.Name, b)
}

func (atx *authBatchTx) UnsafeDeleteUser(username string) {
atx.tx.UnsafeDelete(AuthUsers, []byte(username))
}

func (atx *authReadTx) UnsafeGetUser(username string) *authpb.User {
_, vs := atx.tx.UnsafeRange(AuthUsers, []byte(username), nil, 0)
if len(vs) == 0 {
return nil
Expand All @@ -44,14 +73,7 @@ func (atx *authBatchTx) UnsafeGetUser(username string) *authpb.User {
return user
}

func (abe *authBackend) GetAllUsers() []*authpb.User {
tx := abe.BatchTx()
tx.Lock()
defer tx.Unlock()
return tx.UnsafeGetAllUsers()
}

func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User {
func (atx *authReadTx) UnsafeGetAllUsers() []*authpb.User {
_, vs := atx.tx.UnsafeRange(AuthUsers, []byte{0}, []byte{0xff}, -1)
if len(vs) == 0 {
return nil
Expand All @@ -68,15 +90,3 @@ func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User {
}
return users
}

func (atx *authBatchTx) UnsafePutUser(user *authpb.User) {
b, err := user.Marshal()
if err != nil {
atx.lg.Panic("failed to unmarshal 'authpb.User'", zap.Error(err))
}
atx.tx.UnsafePut(AuthUsers, user.Name, b)
}

func (atx *authBatchTx) UnsafeDeleteUser(username string) {
atx.tx.UnsafeDelete(AuthUsers, []byte(username))
}

0 comments on commit a7ac307

Please sign in to comment.