diff --git a/management/server/account.go b/management/server/account.go index d5e8c8cf8b..da32038527 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,7 +76,7 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) - GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) @@ -478,12 +478,12 @@ func (a *Account) GetPeerNetworkMap( } nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, RoutesFirewallRules: routesFirewallRules, } @@ -843,55 +843,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } -// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. -// Returns true if there are changes in the JWT group membership. -func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { - user, ok := a.Users[userID] - if !ok { - return false - } - +// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. +// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, +// newly groups to create and an error if any occurred. +func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { existedGroupsByName := make(map[string]*nbgroup.Group) - for _, group := range a.Groups { + for _, group := range groups { existedGroupsByName[group.Name] = group } - newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) - groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) + newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) + + groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { - return false + return false, nil, nil, nil } + newGroupsToCreate := make([]*nbgroup.Group, 0) + var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { group = &nbgroup.Group{ - ID: xid.New().String(), - Name: name, - Issued: nbgroup.GroupIssuedJWT, + ID: xid.New().String(), + AccountID: user.AccountID, + Name: name, + Issued: nbgroup.GroupIssuedJWT, } - a.Groups[group.ID] = group + newGroupsToCreate = append(newGroupsToCreate, group) } if group.Issued == nbgroup.GroupIssuedJWT { - newAutoGroups = append(newAutoGroups, group.ID) + newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } } for name, id := range jwtGroupsMap { if !slices.Contains(groupsToRemove, name) { - newAutoGroups = append(newAutoGroups, id) + newUserAutoGroups = append(newUserAutoGroups, id) continue } modified = true } - user.AutoGroups = newAutoGroups - return modified + return modified, newUserAutoGroups, newGroupsToCreate, nil } // UserGroupsAddToPeers adds groups to all peers of user @@ -1262,37 +1261,31 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. -// If an accountID is provided, it checks if the account exists and returns it. -// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// GetAccountIDByUserID retrieves the account ID based on the userID provided. +// If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. // Returns the account ID or an error if none is found or created. -func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { - if accountID != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) - } - return accountID, nil +func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) { + if userID == "" { + return "", status.Errorf(status.NotFound, "no valid userID provided") } - if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) - if err != nil { - return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) - } + accountID, err := am.Store.GetAccountIDByUserID(userID) + if err != nil { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + if err != nil { + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { - return "", err + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + return "", err + } + return account.Id, nil } - - return account.Id, nil + return "", err } - - return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") + return accountID, nil } func isNil(i idp.Manager) bool { @@ -1796,6 +1789,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } + if user.AccountID != accountID { + return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) + } + if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { @@ -1803,7 +1800,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai } } - if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { return "", "", err } @@ -1812,7 +1809,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err @@ -1823,69 +1820,136 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set") return nil } - // TODO: Remove GetAccount after refactoring account peer's update - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer func() { + if unlockPeer != nil { + unlockPeer() + } + }() - // Update the account if group membership changes - if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) + var addNewGroups []string + var removeOldGroups []string + var hasChanges bool + var user *User + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } - if settings.GroupsPropagationEnabled { - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) } - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) + changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames) + if err != nil { + return fmt.Errorf("error getting JWT groups changes: %w", err) + } + + hasChanges = changed + // skip update if no changes + if !changed { return nil } + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + addNewGroups = difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + + user.AutoGroups = updatedAutoGroups + if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + return fmt.Errorf("error saving user: %w", err) + } + // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - } - - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) + groups, err = transaction.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } + + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user peers: %w", err) + } + + updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) + if err != nil { + return fmt.Errorf("error modifying user peers in groups: %w", err) + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return fmt.Errorf("error incrementing network serial: %w", err) } } + unlockPeer() + unlockPeer = nil + + return nil + }) + if err != nil { + return err + } + + if !hasChanges { + return nil + } - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) + for _, g := range addNewGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) + } + } + + for _, g := range removeOldGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) } } + if settings.GroupsPropagationEnabled { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + return nil } @@ -1916,7 +1980,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + if claims.AccountId != "" { + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) + if err != nil { + return "", err + } + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) + } + return claims.AccountId, nil + } + return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } else if claims.AccountId != "" { userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { @@ -2229,7 +2303,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac routes := make(map[route.ID]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - users[userID] = NewOwnerUser(userID) + + owner := NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } @@ -2297,12 +2375,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID + allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + for _, group := range allGroups { + allGroupsMap[group.ID] = group + } + for _, id := range autoGroups { - if group, ok := allGroups[id]; ok { + if group, ok := allGroupsMap[id]; ok { if group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { @@ -2310,5 +2393,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([ } } } + return newAutoGroups, jwtAutoGroups } diff --git a/management/server/account_test.go b/management/server/account_test.go index 198775bc33..c417e4bc89 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") initAccount, err := manager.Store.GetAccount(context.Background(), accountID) @@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - initAccount := newAccountWithId(context.Background(), "", userId, domain) + _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID := initAccount.Id - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ @@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } } -func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { +func TestAccountManager_GetAccountByUserID(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "") if err != nil { t.Fatal(err) } @@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { return } - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) - } + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + assert.NoError(t, err) + assert.True(t, exists, "expected to get existing account after creation using userid") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), "", "") if err == nil { - t.Errorf("expected an error when user and account IDs are empty") + t.Errorf("expected an error when user ID is empty") } } @@ -1669,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) @@ -1684,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1696,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1742,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1770,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1790,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1802,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1850,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ @@ -1861,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - require.NoError(t, err, "unable to get account by ID") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") @@ -2199,8 +2194,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } func TestAccount_SetJWTGroups(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + // create a new account account := &Account{ + Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2211,62 +2210,120 @@ func TestAccount_SetJWTGroups(t *testing.T) { Groups: map[string]*group.Group{ "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true}, + Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, Users: map[string]*User{ - "user1": {Id: "user1"}, - "user2": {Id: "user2"}, + "user1": {Id: "user1", AccountID: "accountID"}, + "user2": {Id: "user2", AccountID: "accountID"}, }, } + assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("empty jwt groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.False(t, updated, "account should not be updated") - assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) t.Run("jwt match existing api group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} + assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 2, "new group should be added") - assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("existed group not update", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group2"}) - assert.False(t, updated, "account should not be updated") - assert.Len(t, account.Groups, 2, "groups count should not be changed") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("add new group", func(t *testing.T) { - updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 3, "new group should be added") - assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + assert.NoError(t, err) + assert.Len(t, groups, 3, "new group3 should be added") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "new group should be added") }) t.Run("remove all JWT groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") + assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") }) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b399be8228..b6283a7e69 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,7 +27,7 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -194,14 +194,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { - if am.GetAccountIDByUserOrAccountIdFunc != nil { - return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.GetAccountIDByUserIdFunc(ctx, userId, domain) } return "", status.Errorf( codes.Unimplemented, - "method GetAccountIDByUserOrAccountID is not implemented", + "method GetAccountIDByUserID is not implemented", ) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index cce748a0f8..9e1ab27dcc 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "slices" "strings" "sync" "time" @@ -378,15 +379,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { Create(&usersToSave).Error } +// SaveUser saves the given user to the database. +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) + } + return nil +} + // SaveGroups saves the given list of groups to the database. -// It updates existing groups if a conflict occurs. -func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { - groupsToSave := make([]nbgroup.Group, 0, len(groups)) - for _, group := range groups { - group.AccountID = accountID - groupsToSave = append(groupsToSave, *group) +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + if len(groups) == 0 { + return nil + } + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) } - return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error + return nil } // DeleteHashedPAT2TokenIDIndex is noop in SqlStore @@ -1021,6 +1033,89 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// AddUserPeersToGroups adds the user's peers to specified groups in database. +func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + var userPeerIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). + Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) + if result.Error != nil { + return status.Errorf(status.Internal, "issue finding user peers") + } + + groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, gid := range groupIDs { + group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return err + } + + groupPeers := make(map[string]struct{}) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for _, pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = group.Peers[:0] + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + + groupsToUpdate = append(groupsToUpdate, group) + } + + return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) +} + +// RemoveUserPeersFromGroups removes the user's peers from specified groups in database. +func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + var userPeerIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). + Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) + if result.Error != nil { + return status.Errorf(status.Internal, "issue finding user peers") + } + + groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, gid := range groupIDs { + group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return err + } + + if group.Name == "All" { + continue + } + + update := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if !slices.Contains(userPeerIDs, pid) { + update = append(update, pid) + } + } + + group.Peers = update + groupsToUpdate = append(groupsToUpdate, group) + } + + return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) +} + +// GetUserPeers retrieves peers for a user. +func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { + return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) +} + func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account") @@ -1127,6 +1222,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// SaveGroup saves a group to the store. +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + } + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index dc07849d9b..4eed09c69b 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1185,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } + +func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + group := &nbgroup.Group{ + ID: "group-id", + AccountID: "account-id", + Name: "group-name", + Issued: "api", + Peers: nil, + } + err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error { + err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + t.Fatal("failed to save group") + return err + } + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + if err != nil { + t.Fatal("failed to get group") + return err + } + t.Logf("group: %v", group) + return nil + }) + assert.NoError(t, err) +} diff --git a/management/server/store.go b/management/server/store.go index 041c936ae5..50bc6afdfd 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -60,6 +60,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) SaveUsers(accountID string, users map[string]*User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error @@ -68,7 +69,8 @@ type Store interface { GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -82,6 +84,7 @@ type Store interface { AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error diff --git a/management/server/user.go b/management/server/user.go index 6d01561c6c..38a8ac0c40 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -8,14 +8,14 @@ import ( "time" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( @@ -1254,6 +1254,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil } +// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { + + if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { + return + } + + userPeerIDMap := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + userPeerIDMap[peer.ID] = struct{}{} + } + + for _, gid := range groupsToAdd { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + addUserPeersToGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + for _, gid := range groupsToRemove { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + removeUserPeersFromGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + return groupsToUpdate, nil +} + +// addUserPeersToGroup adds the user's peers to the group. +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + groupPeers := make(map[string]struct{}, len(group.Peers)) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = make([]string, 0, len(groupPeers)) + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } +} + +// removeUserPeersFromGroup removes user's peers from the group. +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + // skip removing peers from group All + if group.Name == "All" { + return + } + + updatedPeers := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if _, found := userPeerIDs[pid]; !found { + updatedPeers = append(updatedPeers, pid) + } + } + + group.Peers = updatedPeers +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index ec0a106957..1a5704551b 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -813,10 +813,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") - assert.NoError(t, err) - - acc, err := am.Store.GetAccount(context.Background(), accID) + acc, err := am.Store.GetAccount(context.Background(), account.Id) assert.NoError(t, err) for _, id := range tc.expectedDeleted {