diff --git a/interface-definitions/include/policy/route-common.xml.i b/interface-definitions/include/policy/route-common.xml.i index 97795601eed..203be73e759 100644 --- a/interface-definitions/include/policy/route-common.xml.i +++ b/interface-definitions/include/policy/route-common.xml.i @@ -128,6 +128,24 @@ + + + VRF to forward packet with + + txt + VRF instance name + + + default + Forward into default global VRF + + + default + vrf name + + #include + + TCP Maximum Segment Size diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py index 664df28cc6d..c4116c8ce16 100644 --- a/python/vyos/firewall.py +++ b/python/vyos/firewall.py @@ -30,6 +30,7 @@ from vyos.utils.dict import dict_search_recursive from vyos.utils.process import cmd from vyos.utils.process import run +from vyos.utils.network import get_vrf_table_id # Conntrack def conntrack_required(conf): @@ -469,11 +470,20 @@ def parse_rule(rule_conf, hook, fw_name, rule_id, ip_name): if 'mark' in rule_conf['set']: mark = rule_conf['set']['mark'] output.append(f'meta mark set {mark}') + if 'vrf' in rule_conf['set']: + set_table = True + vrf_name = rule_conf['set']['vrf'] + # NOTE: VRF->table ID lookup depends on the VRF iface already existing. + if vrf_name == 'default': + table = '254' + else: + table = get_vrf_table_id(vrf_name) if 'table' in rule_conf['set']: set_table = True table = rule_conf['set']['table'] if table == 'main': table = '254' + if set_table: mark = 0x7FFFFFFF - int(table) output.append(f'meta mark set {mark}') if 'tcp_mss' in rule_conf['set']: diff --git a/python/vyos/utils/network.py b/python/vyos/utils/network.py index 829124b5721..398d7e07687 100644 --- a/python/vyos/utils/network.py +++ b/python/vyos/utils/network.py @@ -74,6 +74,9 @@ def get_vrf_members(vrf: str) -> list: pass return interfaces +def get_vrf_table_id(vrf: str): + return get_interface_config(vrf)['linkinfo']['info_data']['table'] + def get_interface_vrf(interface): """ Returns VRF of given interface """ from vyos.utils.dict import dict_search diff --git a/smoketest/scripts/cli/test_policy_route.py b/smoketest/scripts/cli/test_policy_route.py index 462fc24d0d5..797ab97704c 100755 --- a/smoketest/scripts/cli/test_policy_route.py +++ b/smoketest/scripts/cli/test_policy_route.py @@ -25,6 +25,8 @@ conn_mark_set = '111' table_mark_offset = 0x7fffffff table_id = '101' +vrf = 'PBRVRF' +vrf_table_id = '102' interface = 'eth0' interface_wc = 'ppp*' interface_ip = '172.16.10.1/24' @@ -39,11 +41,14 @@ def setUpClass(cls): cls.cli_set(cls, ['interfaces', 'ethernet', interface, 'address', interface_ip]) cls.cli_set(cls, ['protocols', 'static', 'table', table_id, 'route', '0.0.0.0/0', 'interface', interface]) + + cls.cli_set(cls, ['vrf', 'name', vrf, 'table', vrf_table_id]) @classmethod def tearDownClass(cls): cls.cli_delete(cls, ['interfaces', 'ethernet', interface, 'address', interface_ip]) cls.cli_delete(cls, ['protocols', 'static', 'table', table_id]) + cls.cli_delete(cls, ['vrf', 'name', vrf]) super(TestPolicyRoute, cls).tearDownClass() @@ -180,6 +185,50 @@ def test_pbr_table(self): self.verify_rules(ip_rule_search) + def test_pbr_vrf(self): + self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'protocol', 'tcp']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'destination', 'port', '8888']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'tcp', 'flags', 'syn']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'tcp', 'flags', 'not', 'ack']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'set', 'vrf', vrf]) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'protocol', 'tcp_udp']) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'destination', 'port', '8888']) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'set', 'vrf', vrf]) + + self.cli_set(['policy', 'route', 'smoketest', 'interface', interface]) + self.cli_set(['policy', 'route6', 'smoketest6', 'interface', interface]) + + self.cli_commit() + + mark_hex = "{0:#010x}".format(table_mark_offset - int(vrf_table_id)) + + # IPv4 + + nftables_search = [ + [f'iifname "{interface}"', 'jump VYOS_PBR_UD_smoketest'], + ['tcp flags syn / syn,ack', 'tcp dport 8888', 'meta mark set ' + mark_hex] + ] + + self.verify_nftables(nftables_search, 'ip vyos_mangle') + + # IPv6 + + nftables6_search = [ + [f'iifname "{interface}"', 'jump VYOS_PBR6_UD_smoketest'], + ['meta l4proto { tcp, udp }', 'th dport 8888', 'meta mark set ' + mark_hex] + ] + + self.verify_nftables(nftables6_search, 'ip6 vyos_mangle') + + # IP rule fwmark -> table + + ip_rule_search = [ + ['fwmark ' + hex(table_mark_offset - int(vrf_table_id)), 'lookup ' + vrf] + ] + + self.verify_rules(ip_rule_search) + + def test_pbr_matching_criteria(self): self.cli_set(['policy', 'route', 'smoketest', 'default-log']) self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'protocol', 'udp']) diff --git a/src/conf_mode/policy_route.py b/src/conf_mode/policy_route.py index c58fe1bce72..e97282860f9 100755 --- a/src/conf_mode/policy_route.py +++ b/src/conf_mode/policy_route.py @@ -25,6 +25,7 @@ from vyos.utils.dict import dict_search_args from vyos.utils.process import cmd from vyos.utils.process import run +from vyos.utils.network import get_vrf_table_id from vyos import ConfigError from vyos import airbag airbag.enable() @@ -83,6 +84,9 @@ def verify_rule(policy, name, rule_conf, ipv6, rule_id): if not tcp_flags or 'syn' not in tcp_flags: raise ConfigError(f'{name} rule {rule_id}: TCP SYN flag must be set to modify TCP-MSS') + if 'vrf' in rule_conf['set'] and 'table' in rule_conf['set']: + raise ConfigError(f'{name} rule {rule_id}: Cannot set both forwarding route table and VRF') + tcp_flags = dict_search_args(rule_conf, 'tcp', 'flags') if tcp_flags: if dict_search_args(rule_conf, 'protocol') != 'tcp': @@ -152,15 +156,26 @@ def apply_table_marks(policy): for name, pol_conf in policy[route].items(): if 'rule' in pol_conf: for rule_id, rule_conf in pol_conf['rule'].items(): + vrf_table_id = None set_table = dict_search_args(rule_conf, 'set', 'table') - if set_table: + set_vrf = dict_search_args(rule_conf, 'set', 'vrf') + if set_vrf: + if set_vrf == 'default': + vrf_table_id = '254' + else: + # str-cast so that tables uniqueness check works below. + vrf_table_id = str(get_vrf_table_id(set_vrf)) + elif set_table: if set_table == 'main': - set_table = '254' - if set_table in tables: + vrf_table_id = '254' + else: + vrf_table_id = set_table + if vrf_table_id is not None: + if vrf_table_id in tables: continue - tables.append(set_table) - table_mark = mark_offset - int(set_table) - cmd(f'{cmd_str} rule add pref {set_table} fwmark {table_mark} table {set_table}') + tables.append(vrf_table_id) + table_mark = mark_offset - int(vrf_table_id) + cmd(f'{cmd_str} rule add pref {vrf_table_id} fwmark {table_mark} table {vrf_table_id}') def cleanup_table_marks(): for cmd_str in ['ip', 'ip -6']: