From 21134024883ca16740392ebae6bf6b97757ff4ed Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 17 Jan 2024 20:13:42 +0300 Subject: [PATCH] Refactor posture checks run on sources and updated the validation func --- management/server/policy.go | 70 +++++++++++++------------------------ 1 file changed, 25 insertions(+), 45 deletions(-) diff --git a/management/server/policy.go b/management/server/policy.go index ce09db08e1..92d1e11802 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -2,7 +2,6 @@ package server import ( _ "embed" - "fmt" "strconv" "strings" @@ -221,20 +220,13 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []* continue } - // if peer validation fails, the peer should not be able to connect to the policy peer's - // we return an empty list of peers and firewall rule for that policy - err := a.validatePostureChecksOnPeer(policy.SourcePostureChecks, peerID) - if err != nil { - return nil, nil - } - for _, rule := range policy.Rules { if !rule.Enabled { continue } - sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID) - destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID) + sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks) + destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil) sourcePeers = additions.ValidatePeers(sourcePeers) destinationPeers = additions.ValidatePeers(destinationPeers) @@ -278,10 +270,11 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in } return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - validGroupPeers := a.getValidatedPeersByPostureChecks(groupPeers) - - isAll := (len(all.Peers) - 1) == len(validGroupPeers) - for _, peer := range validGroupPeers { + isAll := (len(all.Peers) - 1) == len(groupPeers) + for _, peer := range groupPeers { + if peer == nil { + continue + } if _, ok := peersExists[peer.ID]; !ok { peers = append(peers, peer) @@ -495,8 +488,9 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // getAllPeersFromGroups for given peer ID and list of groups // -// Returns list of peers and boolean indicating if peer is in any of the groups -func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([]*nbpeer.Peer, bool) { +// Returns list of peers from the provided groups that pass the posture checks +// if the sourcePostureChecksIDs is set. +func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { peerInGroups := false filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) for _, g := range groups { @@ -511,6 +505,12 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([] continue } + // validate the peer + isValid := account.validatePostureChecksOnPeer(sourcePostureChecksIDs, peer.ID) + if !isValid { + continue + } + if peer.ID == peerID { peerInGroups = true continue @@ -523,10 +523,10 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([] } // validatePostureChecksOnPeer validates the posture checks on a peer -func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) error { +func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) bool { peer, ok := a.Peers[peerID] if !ok && peer == nil { - return fmt.Errorf("peer %s does not exists", peerID) + return false } for _, postureChecksID := range sourcePostureChecksID { @@ -536,37 +536,17 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe } for _, check := range postureChecks.Checks { - if err := check.Check(*peer); err != nil { - return fmt.Errorf("an error occurred on check %s: %s", check.Name(), err.Error()) - } - } - } - - return nil -} - -// getValidatedPeersByPostureChecks returns a slice of valid peers based on applied policy posture checks -func (a *Account) getValidatedPeersByPostureChecks(groupPeers []*nbpeer.Peer) []*nbpeer.Peer { - validPeers := make([]*nbpeer.Peer, 0) - for _, peer := range groupPeers { - if peer == nil { - continue - } - - isValidPeer := true - for _, policy := range a.Policies { - err := a.validatePostureChecksOnPeer(policy.SourcePostureChecks, peer.ID) - if err != nil { - isValidPeer = false - break + isValid, err := check.Check(*peer) + if !isValid { + if err != nil { + log.Debugf("an error occured check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) + } + return false } - } - if isValidPeer { - validPeers = append(validPeers, peer) } } - return validPeers + return true } func getPostureChecks(account *Account, postureChecksID string) *posture.Checks {