Skip to content

Commit

Permalink
Refactor posture checks run on sources and updated the validation func
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga committed Jan 17, 2024
1 parent d0e4743 commit 2113402
Showing 1 changed file with 25 additions and 45 deletions.
70 changes: 25 additions & 45 deletions management/server/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package server

import (
_ "embed"
"fmt"
"strconv"
"strings"

Expand Down Expand Up @@ -221,20 +220,13 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []*
continue
}

// if peer validation fails, the peer should not be able to connect to the policy peer's
// we return an empty list of peers and firewall rule for that policy
err := a.validatePostureChecksOnPeer(policy.SourcePostureChecks, peerID)
if err != nil {
return nil, nil
}

for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}

sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID)
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID)
sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks)
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil)
sourcePeers = additions.ValidatePeers(sourcePeers)
destinationPeers = additions.ValidatePeers(destinationPeers)

Expand Down Expand Up @@ -278,10 +270,11 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
}

return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
validGroupPeers := a.getValidatedPeersByPostureChecks(groupPeers)

isAll := (len(all.Peers) - 1) == len(validGroupPeers)
for _, peer := range validGroupPeers {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
continue
}

if _, ok := peersExists[peer.ID]; !ok {
peers = append(peers, peer)
Expand Down Expand Up @@ -495,8 +488,9 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {

// getAllPeersFromGroups for given peer ID and list of groups
//
// Returns list of peers and boolean indicating if peer is in any of the groups
func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([]*nbpeer.Peer, bool) {
// Returns list of peers from the provided groups that pass the posture checks
// if the sourcePostureChecksIDs is set.
func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
Expand All @@ -511,6 +505,12 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([]
continue
}

// validate the peer
isValid := account.validatePostureChecksOnPeer(sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}

if peer.ID == peerID {
peerInGroups = true
continue
Expand All @@ -523,10 +523,10 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([]
}

// validatePostureChecksOnPeer validates the posture checks on a peer
func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) error {
func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) bool {
peer, ok := a.Peers[peerID]
if !ok && peer == nil {
return fmt.Errorf("peer %s does not exists", peerID)
return false
}

for _, postureChecksID := range sourcePostureChecksID {
Expand All @@ -536,37 +536,17 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe
}

for _, check := range postureChecks.Checks {
if err := check.Check(*peer); err != nil {
return fmt.Errorf("an error occurred on check %s: %s", check.Name(), err.Error())
}
}
}

return nil
}

// getValidatedPeersByPostureChecks returns a slice of valid peers based on applied policy posture checks
func (a *Account) getValidatedPeersByPostureChecks(groupPeers []*nbpeer.Peer) []*nbpeer.Peer {
validPeers := make([]*nbpeer.Peer, 0)
for _, peer := range groupPeers {
if peer == nil {
continue
}

isValidPeer := true
for _, policy := range a.Policies {
err := a.validatePostureChecksOnPeer(policy.SourcePostureChecks, peer.ID)
if err != nil {
isValidPeer = false
break
isValid, err := check.Check(*peer)
if !isValid {
if err != nil {
log.Debugf("an error occured check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())

Check failure on line 542 in management/server/policy.go

View workflow job for this annotation

GitHub Actions / codespell

occured ==> occurred

Check failure on line 542 in management/server/policy.go

View workflow job for this annotation

GitHub Actions / lint (macos-latest)

`occured` is a misspelling of `occurred` (misspell)

Check failure on line 542 in management/server/policy.go

View workflow job for this annotation

GitHub Actions / lint (windows-latest)

`occured` is a misspelling of `occurred` (misspell)

Check failure on line 542 in management/server/policy.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

`occured` is a misspelling of `occurred` (misspell)
}
return false
}
}

if isValidPeer {
validPeers = append(validPeers, peer)
}
}
return validPeers
return true
}

func getPostureChecks(account *Account, postureChecksID string) *posture.Checks {
Expand Down

0 comments on commit 2113402

Please sign in to comment.