Skip to content

Commit

Permalink
feat: make tls version configurable per destination
Browse files Browse the repository at this point in the history
Also support multiple TLS versions for a single endpoint.
  • Loading branch information
Nr18 committed Jul 31, 2023
1 parent b089566 commit 0363251
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 33 deletions.
3 changes: 2 additions & 1 deletion aws_network_firewall/destination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from typing import Optional, List


@dataclass
Expand All @@ -15,3 +15,4 @@ class Destination:
endpoint: Optional[str]
cidr: Optional[str]
message: Optional[str]
tls_versions: List[str]
103 changes: 76 additions & 27 deletions aws_network_firewall/rule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from dataclasses import dataclass
from typing import List, Optional, ClassVar

Expand Down Expand Up @@ -41,25 +42,29 @@ def convert_source(source: Source) -> Optional[SuricataHost]:
return list(filter(None, map(convert_source, self.sources)))

@staticmethod
def __tls_endpoint_options(endpoint: str) -> List[SuricataOption]:
def __resolve_tls_options(
destination: Destination, tls_version: Optional[str]
) -> List[SuricataOption]:
options = [
SuricataOption(name="tls.sni"),
SuricataOption(name="tls.version", value="1.2", quoted_value=False),
# When using multiple tls versions you need 2 rules
# openssl 1.1.1 is needed for tls1.3
# SuricataOption(name="tls.version", value="1.3", quoted_value=False),
]

if endpoint.startswith("*"):
if tls_version:
options.append(
SuricataOption(
name="tls.version", value=tls_version, quoted_value=False
)
)

if destination.endpoint.startswith("*"): # type: ignore
options += [
SuricataOption(name="dotprefix"),
SuricataOption(name="content", value=endpoint[1:]),
SuricataOption(name="content", value=destination.endpoint[1:]), # type: ignore
SuricataOption(name="nocase"),
SuricataOption(name="endswith"),
]
else:
options += [
SuricataOption(name="content", value=endpoint),
SuricataOption(name="content", value=destination.endpoint),
SuricataOption(name="nocase"),
SuricataOption(name="startswith"),
SuricataOption(name="endswith"),
Expand All @@ -68,39 +73,83 @@ def __tls_endpoint_options(endpoint: str) -> List[SuricataOption]:
return options

def __resolve_options(self, destination: Destination) -> List[SuricataOption]:
options = []

if destination.protocol == "TLS" and destination.endpoint:
options = self.__tls_endpoint_options(destination.endpoint)

message = (
f"{destination.message} | {self.workload} | {self.name}"
if destination.message
else f"{self.workload} | {self.name}"
)

return options + [
return [
SuricataOption(name="msg", value=message),
SuricataOption(name="rev", value="1", quoted_value=False),
SuricataOption(name="sid", value="XXX", quoted_value=False),
]

def __resolve_rule(self, destination: Destination) -> SuricataRule:
return SuricataRule(
action="pass",
protocol=destination.protocol,
sources=self.__suricata_source,
destination=SuricataHost(
address=destination.cidr if destination.cidr else "",
port=destination.port if destination.port else 0,
),
options=self.__resolve_options(destination),
)
def __resolve_tls_version_rules(
self, destination: Destination
) -> List[SuricataRule]:
rules = []

for tls_version in destination.tls_versions:
rules.append(
SuricataRule(
action="pass",
protocol=destination.protocol,
sources=self.__suricata_source,
destination=SuricataHost(
address=destination.cidr if destination.cidr else "",
port=destination.port if destination.port else 0,
),
options=self.__resolve_tls_options(
destination=destination, tls_version=tls_version
)
+ self.__resolve_options(destination=destination),
)
)

return rules

def __resolve_tls_rules(self, destination: Destination) -> List[SuricataRule]:
if destination.tls_versions:
return self.__resolve_tls_version_rules(destination)

return [
SuricataRule(
action="pass",
protocol=destination.protocol,
sources=self.__suricata_source,
destination=SuricataHost(
address=destination.cidr if destination.cidr else "",
port=destination.port if destination.port else 0,
),
options=self.__resolve_tls_options(
destination=destination, tls_version=None
)
+ self.__resolve_options(destination=destination),
)
]

def __resolve_rule(self, destination: Destination) -> List[SuricataRule]:
if destination.protocol == "TLS" and destination.endpoint:
return self.__resolve_tls_rules(destination=destination)

return [
SuricataRule(
action="pass",
protocol=destination.protocol,
sources=self.__suricata_source,
destination=SuricataHost(
address=destination.cidr if destination.cidr else "",
port=destination.port if destination.port else 0,
),
options=self.__resolve_options(destination),
)
]

@property
def suricata_rules(self) -> List[SuricataRule]:
rules = list(filter(None, map(self.__resolve_rule, self.destinations)))
return rules
return list(itertools.chain.from_iterable(rules))

def __str__(self) -> str:
return "\n".join(map(str, self.suricata_rules))
1 change: 1 addition & 0 deletions aws_network_firewall/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def destination_resolver(entry: dict) -> Destination:
endpoint=entry.get("Endpoint"),
cidr=entry.get("Cidr"),
message=entry.get("Message"),
tls_versions=entry.get("TLSVersions", []),
)


Expand Down
2 changes: 2 additions & 0 deletions aws_network_firewall/schemas/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ definitions:
type: integer
Message:
type: string
TLSVersions:
enum: [ "1.2", "1.3" ]
examples:
- Description: Website of Xebia
Protocol: TLS
Expand Down
2 changes: 2 additions & 0 deletions tests/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def generate_rule(type: str, region: str) -> Rule:
cidr=None,
endpoint=None,
message=None,
tls_versions=[],
)
],
)
Expand All @@ -47,6 +48,7 @@ def generate_rule(type: str, region: str) -> Rule:
cidr=None,
endpoint="xebia.com",
message=None,
tls_versions=[],
)


Expand Down
2 changes: 2 additions & 0 deletions tests/test_destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ def test_destination_properties() -> None:
endpoint="xebia.com",
cidr="10.0.0.0/24",
message="Important Message",
tls_versions=["1.2"],
)
assert destination.description == "My Description"
assert destination.protocol == "TLS"
assert destination.port == 443
assert destination.endpoint == "xebia.com"
assert destination.cidr == "10.0.0.0/24"
assert destination.message == "Important Message"
assert destination.tls_versions == ["1.2"]
97 changes: 93 additions & 4 deletions tests/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,34 @@ def test_rule_with_tls_endpoint() -> None:
endpoint="xebia.com",
cidr="10.0.1.0/24",
message=None,
tls_versions=[],
)
],
)

assert (
'pass tls 10.0.0.0/24 any -> 10.0.1.0/24 443 (tls.sni; content:"xebia.com"; nocase; startswith; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)'
== str(rule)
)


def test_rule_with_tls_1_2_endpoint() -> None:
rule = Rule(
workload="my-workload",
name="my-rule",
region="eu-west-1",
type=Rule.INSPECTION,
description="My description",
sources=[Source(description="my source", cidr="10.0.0.0/24")],
destinations=[
Destination(
description="my destination",
protocol="TLS",
port=443,
endpoint="xebia.com",
cidr="10.0.1.0/24",
message=None,
tls_versions=["1.2"],
)
],
)
Expand All @@ -29,6 +57,61 @@ def test_rule_with_tls_endpoint() -> None:
)


def test_rule_with_tls_1_3_endpoint() -> None:
rule = Rule(
workload="my-workload",
name="my-rule",
region="eu-west-1",
type=Rule.INSPECTION,
description="My description",
sources=[Source(description="my source", cidr="10.0.0.0/24")],
destinations=[
Destination(
description="my destination",
protocol="TLS",
port=443,
endpoint="xebia.com",
cidr="10.0.1.0/24",
message=None,
tls_versions=["1.3"],
)
],
)

assert (
'pass tls 10.0.0.0/24 any -> 10.0.1.0/24 443 (tls.sni; tls.version:1.3; content:"xebia.com"; nocase; startswith; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)'
== str(rule)
)


def test_rule_with_tls_1_2_and_1_3_endpoint() -> None:
rule = Rule(
workload="my-workload",
name="my-rule",
region="eu-west-1",
type=Rule.INSPECTION,
description="My description",
sources=[Source(description="my source", cidr="10.0.0.0/24")],
destinations=[
Destination(
description="my destination",
protocol="TLS",
port=443,
endpoint="xebia.com",
cidr="10.0.1.0/24",
message=None,
tls_versions=["1.2", "1.3"],
)
],
)

assert (
'pass tls 10.0.0.0/24 any -> 10.0.1.0/24 443 (tls.sni; tls.version:1.2; content:"xebia.com"; nocase; startswith; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)\n'
+ 'pass tls 10.0.0.0/24 any -> 10.0.1.0/24 443 (tls.sni; tls.version:1.3; content:"xebia.com"; nocase; startswith; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)'
== str(rule)
)


def test_rule_with_tls_wildcard_endpoint() -> None:
rule = Rule(
workload="my-workload",
Expand All @@ -45,12 +128,13 @@ def test_rule_with_tls_wildcard_endpoint() -> None:
endpoint="*.xebia.com",
cidr="10.0.1.0/24",
message=None,
tls_versions=[],
)
],
)

assert (
'pass tls 10.0.0.0/24 any -> 10.0.1.0/24 443 (tls.sni; tls.version:1.2; dotprefix; content:".xebia.com"; nocase; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)'
'pass tls 10.0.0.0/24 any -> 10.0.1.0/24 443 (tls.sni; dotprefix; content:".xebia.com"; nocase; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)'
== str(rule)
)

Expand All @@ -71,12 +155,13 @@ def test_rule_with_tls_endpoint_non_standard_port() -> None:
endpoint="xebia.com",
cidr="10.0.1.0/24",
message=None,
tls_versions=[],
)
],
)

assert (
'pass tls 10.0.0.0/24 any -> 10.0.1.0/24 444 (tls.sni; tls.version:1.2; content:"xebia.com"; nocase; startswith; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)\n'
'pass tls 10.0.0.0/24 any -> 10.0.1.0/24 444 (tls.sni; content:"xebia.com"; nocase; startswith; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)\n'
+ 'pass tcp 10.0.0.0/24 any <> 10.0.1.0/24 444 (msg:"my-workload | my-rule | Pass non-established TCP for 3-way handshake"; flow:"not_established"; rev:1; sid:XXX;)'
== str(rule)
)
Expand All @@ -98,6 +183,7 @@ def test_rule_with_tcp_cidr() -> None:
cidr="10.0.1.0/24",
endpoint=None,
message=None,
tls_versions=[],
)
],
)
Expand All @@ -124,6 +210,7 @@ def test_icmp_rule() -> None:
cidr="10.0.1.0/24",
endpoint=None,
message=None,
tls_versions=[],
)
],
)
Expand All @@ -150,12 +237,13 @@ def test_egress_tls_rule() -> None:
cidr=None,
endpoint="xebia.com",
message=None,
tls_versions=[],
)
],
)

assert (
'pass tls any -> any 443 (tls.sni; tls.version:1.2; content:"xebia.com"; nocase; startswith; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)'
'pass tls any -> any 443 (tls.sni; content:"xebia.com"; nocase; startswith; endswith; msg:"my-workload | my-rule"; rev:1; sid:XXX;)'
== str(rule)
)

Expand All @@ -176,11 +264,12 @@ def test_egress_tls_rule_with_message() -> None:
cidr=None,
endpoint="xebia.com",
message="IMPORTANT BECAUSE ...",
tls_versions=[],
)
],
)

assert (
'pass tls any -> any 443 (tls.sni; tls.version:1.2; content:"xebia.com"; nocase; startswith; endswith; msg:"IMPORTANT BECAUSE ... | my-workload | my-rule"; rev:1; sid:XXX;)'
'pass tls any -> any 443 (tls.sni; content:"xebia.com"; nocase; startswith; endswith; msg:"IMPORTANT BECAUSE ... | my-workload | my-rule"; rev:1; sid:XXX;)'
== str(rule)
)
2 changes: 1 addition & 1 deletion tests/workloads/example-workload/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ xebia.com | None | TLS | 443 | My destination
Based on the above defined sources and destination the following firewall rules are required:

```
pass tls 192.168.0.0/21 any -> any 443 (tls.sni; tls.version:1.2; content:"xebia.com"; nocase; startswith; endswith; msg:"binxio-example-workload-development | My Rule name"; rev:1; sid:XXX;)
pass tls 192.168.0.0/21 any -> any 443 (tls.sni; content:"xebia.com"; nocase; startswith; endswith; msg:"binxio-example-workload-development | My Rule name"; rev:1; sid:XXX;)
```

Expand Down

0 comments on commit 0363251

Please sign in to comment.