From 6242cf4988f007d29491c4295ded76c92b01c419 Mon Sep 17 00:00:00 2001 From: Joris Conijn Date: Wed, 26 Jul 2023 17:38:49 +0200 Subject: [PATCH] fix: missing egress rules (#5) --- aws_network_firewall/rule.py | 12 +++++------ aws_network_firewall/suricata/host.py | 5 +++-- tests/test_rule.py | 29 +++++++++++++++------------ 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/aws_network_firewall/rule.py b/aws_network_firewall/rule.py index f3220ec..7e7939f 100644 --- a/aws_network_firewall/rule.py +++ b/aws_network_firewall/rule.py @@ -35,7 +35,7 @@ def is_egress_rule(self) -> bool: @property def __suricata_source(self) -> List[SuricataHost]: def convert_source(source: Source) -> Optional[SuricataHost]: - return SuricataHost(address=source.cidr) if source.cidr else None + return SuricataHost(address=source.cidr, port=0) if source.cidr else None return list(filter(None, map(convert_source, self.sources))) @@ -76,15 +76,15 @@ def __resolve_options(self, destination: Destination) -> List[SuricataOption]: SuricataOption(name="sid", value="XXX", quoted_value=False), ] - def __resolve_rule(self, destination: Destination) -> Optional[SuricataRule]: - if not destination.cidr: - return None - + def __resolve_rule(self, destination: Destination) -> SuricataRule: return SuricataRule( action="pass", protocol=destination.protocol, sources=self.__suricata_source, - destination=SuricataHost(address=destination.cidr, port=destination.port), + destination=SuricataHost( + address=destination.cidr if destination.cidr else "", + port=destination.port if destination.port else 0, + ), options=self.__resolve_options(destination), ) diff --git a/aws_network_firewall/suricata/host.py b/aws_network_firewall/suricata/host.py index e4cc7f7..935e241 100644 --- a/aws_network_firewall/suricata/host.py +++ b/aws_network_firewall/suricata/host.py @@ -10,11 +10,12 @@ class Host: Understands a source and/or destination defenition """ - address: str = "any" - port: Optional[int] = None + address: str + port: int def __post_init__(self): self.port = "any" if not self.port else self.port + self.address = "any" if not self.address else self.address def __str__(self): return f"{self.address} {self.port}" diff --git a/tests/test_rule.py b/tests/test_rule.py index 1754a42..1f481fc 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -104,7 +104,7 @@ def test_rule_with_tcp_cidr() -> None: ) -def test_rule_no_cidr() -> None: +def test_icmp_rule() -> None: rule = Rule( workload="my-workload", name="my-rule", @@ -114,38 +114,41 @@ def test_rule_no_cidr() -> None: destinations=[ Destination( description="my destination", - protocol="TCP", - port=443, - cidr=None, + protocol="ICMP", + port=None, + cidr="10.0.1.0/24", endpoint=None, region=None, ) ], ) - assert "" == str(rule) + assert ( + 'pass icmp 10.0.0.0/24 any <> 10.0.1.0/24 any (msg: "my-workload | my-rule"; rev: 1; sid: XXX;)' + == str(rule) + ) -def test_icmp_rule() -> None: +def test_egress_tls_rule() -> None: rule = Rule( workload="my-workload", name="my-rule", - type=Rule.INSPECTION, + type=Rule.EGRESS, description="My description", - sources=[Source(description="my source", cidr="10.0.0.0/24", region=None)], + sources=[Source(description="my source", cidr=None, region="eu-west-1")], destinations=[ Destination( description="my destination", - protocol="ICMP", - port=None, - cidr="10.0.1.0/24", - endpoint=None, + protocol="TLS", + port=443, + cidr=None, + endpoint="xebia.com", region=None, ) ], ) assert ( - 'pass icmp 10.0.0.0/24 any <> 10.0.1.0/24 any (msg: "my-workload | my-rule"; rev: 1; sid: XXX;)' + 'pass tls any -> any 443 (tls.sni; tls.version: 1.2; tls.version: 1.3; content: "xebia.com"; nocase; startswith; endswith; msg: "my-workload | my-rule"; rev: 1; sid: XXX;)' == str(rule) )