Skip to content

Commit

Permalink
feat: display rules per region (#8)
Browse files Browse the repository at this point in the history
**Issue #, if available:**

## Description of changes:

By adding the ability to list rules per region, we make it possible to
render documentation per region. This is useful when you have a firewall
per region.

**Checklist**

<!--- Leave unchecked if your change doesn't seem to apply -->

* [x] Update tests
* [ ] Update docs
* [x] PR title follows [conventional commit
semantics](https://www.conventionalcommits.org/en/v1.0.0-beta.2/#commit-message-for-a-fix-using-an-optional-issue-number)

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
Nr18 authored Jul 28, 2023
1 parent 15dfe97 commit b089566
Show file tree
Hide file tree
Showing 14 changed files with 201 additions and 140 deletions.
47 changes: 27 additions & 20 deletions aws_network_firewall/account.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,54 @@
from __future__ import annotations
from typing import List, Optional

import itertools
from typing import List, Union
from landingzone_organization import Account as LandingZoneAccount
from aws_network_firewall.cidr_ranges import CidrRanges
from aws_network_firewall.destination import Destination
from aws_network_firewall.rule import Rule
from aws_network_firewall.rule_set import RuleSet
from aws_network_firewall.source import Source


class Account(LandingZoneAccount):
__rules: List[Rule]
__rules: RuleSet
__cidr_ranges: CidrRanges

def __init__(
self, name: str, account_id: str, cidr_ranges: CidrRanges, rules: List[Rule]
) -> None:
super().__init__(name, account_id)
self.__cidr_ranges = cidr_ranges
self.__rules = list(map(self.__enrich_rule, rules))
self.__rules = RuleSet(rules=list(map(self.__enrich_rule, rules)))

def __enrich_rule(self, rule: Rule) -> Rule:
list(
map(
lambda source: source.resolve_region_cidr_ranges(self.__cidr_ranges),
rule.sources,
)
)
list(
map(
lambda destination: destination.resolve_region_cidr_ranges(
self.__cidr_ranges
),
rule.destinations,
)
)
cidr_range = self.__cidr_ranges.by_region(rule.region)

def update_cidr_if_not_set(entry: Source) -> None:
if cidr_range and not entry.cidr:
entry.cidr = cidr_range.value

list(map(update_cidr_if_not_set, rule.sources))

return rule

@property
def rules(self) -> List[Rule]:
def regions(self) -> List[str]:
return list(set(filter(None, map(lambda rule: rule.region, self.rules.all))))

def rules_by_region(self, region: str) -> RuleSet:
return RuleSet(
rules=list(filter(lambda rule: region == rule.region, self.rules.all))
)

@property
def rules(self) -> RuleSet:
return self.__rules

@property
def inspection_rules(self) -> List[Rule]:
return list(filter(lambda rule: rule.is_inspection_rule, self.rules))
return list(filter(lambda rule: rule.is_inspection_rule, self.rules.all))

@property
def egress_rules(self) -> List[Rule]:
return list(filter(lambda rule: rule.is_egress_rule, self.rules))
return list(filter(lambda rule: rule.is_egress_rule, self.rules.all))
8 changes: 0 additions & 8 deletions aws_network_firewall/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from dataclasses import dataclass
from typing import Optional

from aws_network_firewall.cidr_ranges import CidrRanges


@dataclass
class Destination:
Expand All @@ -15,11 +13,5 @@ class Destination:
protocol: str
port: Optional[int]
endpoint: Optional[str]
region: Optional[str]
cidr: Optional[str]
message: Optional[str]

def resolve_region_cidr_ranges(self, ranges: CidrRanges) -> None:
if self.region and not self.cidr:
cidr = ranges.by_region(self.region)
self.cidr = cidr.value if cidr else None
1 change: 1 addition & 0 deletions aws_network_firewall/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Rule:
workload: str
name: str
type: str
region: str
description: str
sources: List[Source]
destinations: List[Destination]
Expand Down
31 changes: 31 additions & 0 deletions aws_network_firewall/rule_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import List

from aws_network_firewall.rule import Rule


class RuleSet:
__rules: List[Rule]

def __init__(self, rules: List[Rule]) -> None:
self.__rules = rules

def __len__(self) -> int:
return len(self.all)

def __iter__(self):
for value in self.all:
yield value

@property
def all(self) -> List[Rule]:
return self.__rules

@property
def inspection_rules(self) -> List[Rule]:
return list(filter(lambda rule: rule.is_inspection_rule, self.all))

@property
def egress_rules(self) -> List[Rule]:
return list(filter(lambda rule: rule.is_egress_rule, self.all))
3 changes: 1 addition & 2 deletions aws_network_firewall/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def source_resolver(entry: dict) -> Source:
return Source(
description=entry["Description"],
cidr=entry.get("Cidr"),
region=entry.get("Region"),
)


Expand All @@ -26,7 +25,6 @@ def destination_resolver(entry: dict) -> Destination:
protocol=entry["Protocol"],
port=entry.get("Port"),
endpoint=entry.get("Endpoint"),
region=entry.get("Region"),
cidr=entry.get("Cidr"),
message=entry.get("Message"),
)
Expand All @@ -36,6 +34,7 @@ def rule_resolver(workload: str, entry: dict) -> Rule:
return Rule(
workload=workload,
type=entry["Type"],
region=entry["Region"],
name=entry["Name"],
description=entry["Description"],
sources=list(map(source_resolver, entry["Sources"])),
Expand Down
12 changes: 6 additions & 6 deletions aws_network_firewall/schemas/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ definitions:
required:
- Name
- Type
- Region
- Description
- Sources
- Destinations
Expand All @@ -59,6 +60,8 @@ definitions:
enum: [ "Egress", "Inspection" ]
Description:
type: string
Region:
type: string
Sources:
type: array
items:
Expand All @@ -78,8 +81,6 @@ definitions:
type: string
Cidr:
type: string
Region:
type: string
examples:
- Description: Allow access from `10.0.0.0/8` to the defined destinations.
Cidr: 10.0.0.0/8
Expand All @@ -93,12 +94,11 @@ definitions:
- Description
- Protocol
anyOf:
- required: ["Endpoint", "Cidr"]
- required: ["Endpoint"]
not: { required: ["Region", "Cidr"] }
not: { required: ["Cidr"] }
- required: ["Cidr"]
not: { required: ["Endpoint", "Region"] }
- required: ["Region"]
not: { required: ["Endpoint", "Cidr"] }
not: { required: ["Endpoint"] }
# Port is not required when Protocol is ICMP
properties:
Description:
Expand Down
8 changes: 0 additions & 8 deletions aws_network_firewall/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from dataclasses import dataclass
from typing import Optional

from aws_network_firewall.cidr_ranges import CidrRanges


@dataclass
class Source:
Expand All @@ -13,9 +11,3 @@ class Source:

description: str
cidr: Optional[str]
region: Optional[str]

def resolve_region_cidr_ranges(self, ranges: CidrRanges) -> None:
if self.region and not self.cidr:
cidr = ranges.by_region(self.region)
self.cidr = cidr.value if cidr else None
125 changes: 104 additions & 21 deletions tests/test_account.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from aws_network_firewall.account import Account
from aws_network_firewall.cidr_range import CidrRange
from aws_network_firewall.cidr_ranges import CidrRanges
Expand All @@ -6,44 +8,58 @@
from aws_network_firewall.source import Source


def generate_rule(type: str) -> Rule:
def generate_account(rules: List[Rule]) -> Account:
return Account(
name="my-account",
account_id="123412341234",
cidr_ranges=CidrRanges(
cidr_ranges=[CidrRange(region="eu-west-1", value="10.0.0.0/8")]
),
rules=rules,
)


def generate_rule(type: str, region: str) -> Rule:
return Rule(
workload="my-workload",
name="my-rule",
region=region,
type=type,
description="My description",
sources=[Source(description="my source", cidr="10.0.0.0/24", region=None)],
sources=[Source(description="my source", cidr="10.0.0.0/24")],
destinations=[
Destination(
description="my destination",
protocol="TCP",
port=443,
cidr=None,
endpoint=None,
region=None,
message=None,
)
],
)


outbound_xebia = Destination(
description="Allow outbound connectivity to xebia.com",
protocol="TCP",
port=443,
cidr=None,
endpoint="xebia.com",
message=None,
)


def test_no_rules() -> None:
rules = []
account = Account(
name="my-account",
account_id="123412341234",
cidr_ranges=CidrRanges(
cidr_ranges=[CidrRange(region="eu-west-1", value="10.0.0.0/24")]
),
rules=rules,
)
account = generate_account(rules=rules)
assert len(account.rules) == 0
assert len(account.egress_rules) == 0
assert len(account.inspection_rules) == 0


def test_inspection_rules() -> None:
rules = [generate_rule(Rule.INSPECTION)]
rules = [generate_rule(Rule.INSPECTION, region="eu-west-1")]
account = Account(
name="my-account",
account_id="123412341234",
Expand All @@ -58,15 +74,82 @@ def test_inspection_rules() -> None:


def test_egress_rules() -> None:
rules = [generate_rule(Rule.EGRESS)]
account = Account(
name="my-account",
account_id="123412341234",
cidr_ranges=CidrRanges(
cidr_ranges=[CidrRange(region="eu-west-1", value="10.0.0.0/8")]
),
rules=rules,
)
rules = [generate_rule(Rule.EGRESS, region="eu-west-1")]
account = generate_account(rules=rules)
assert len(account.rules) == 1
assert len(account.egress_rules) == 1
assert len(account.inspection_rules) == 0


def test_rules_resolve_single_region_egress() -> None:
rules = [generate_rule(Rule.EGRESS, region="eu-west-1")]
account = generate_account(rules=rules)
assert len(account.rules) == 1
assert len(account.egress_rules) == 1
assert len(account.inspection_rules) == 0
assert "eu-west-1" in account.regions


def test_rules_resolve_2_regions_egress() -> None:
rules = [
generate_rule(Rule.EGRESS, region="eu-west-1"),
generate_rule(Rule.EGRESS, region="eu-central-1"),
]
account = generate_account(rules=rules)
assert len(account.rules) == 2
assert len(account.egress_rules) == 2
assert len(account.inspection_rules) == 0
assert "eu-west-1" in account.regions
assert "eu-central-1" in account.regions

rules = account.rules_by_region("eu-west-1")
assert len(rules) == 1
assert len(rules.egress_rules) == 1
assert len(rules.inspection_rules) == 0

rules = account.rules_by_region("eu-central-1")
assert len(rules) == 1
assert len(rules.egress_rules) == 1
assert len(rules.inspection_rules) == 0


def test_rules_resolve_single_source_region_inspection() -> None:
rules = [generate_rule(Rule.INSPECTION, region="eu-west-1")]
account = generate_account(rules=rules)
assert len(account.rules) == 1
assert len(account.egress_rules) == 0
assert len(account.inspection_rules) == 1
assert "eu-west-1" in account.regions

rules = account.rules_by_region("eu-west-1")
assert len(rules) == 1
assert len(rules.egress_rules) == 0
assert len(rules.inspection_rules) == 1

rules = account.rules_by_region("eu-central-1")
assert len(rules) == 0
assert len(rules.egress_rules) == 0
assert len(rules.inspection_rules) == 0


def test_rules_resolve_2_source_regions_inspection() -> None:
rules = [
generate_rule(Rule.INSPECTION, region="eu-west-1"),
generate_rule(Rule.INSPECTION, region="eu-central-1"),
]
account = generate_account(rules=rules)
assert len(account.rules) == 2
assert len(account.egress_rules) == 0
assert len(account.inspection_rules) == 2
assert "eu-west-1" in account.regions
assert "eu-central-1" in account.regions

rules = account.rules_by_region("eu-west-1")
assert len(rules) == 1
assert len(rules.egress_rules) == 0
assert len(rules.inspection_rules) == 1

rules = account.rules_by_region("eu-central-1")
assert len(rules) == 1
assert len(rules.egress_rules) == 0
assert len(rules.inspection_rules) == 1
Loading

0 comments on commit b089566

Please sign in to comment.