Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flush VM flows after removal #369

Merged
merged 2 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/dp_flow.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ void dp_free_flow(struct dp_ref *ref);
void dp_free_network_nat_port(struct flow_value *cntrack);
void dp_remove_nat_flows(uint16_t port_id, int nat_type); // TODO create proper enum!
void dp_remove_neighnat_flows(uint32_t ipv4, uint32_t vni, uint16_t min_port, uint16_t max_port);
void dp_remove_vm_flows(uint16_t port_id, uint32_t ipv4, uint32_t vni);

hash_sig_t dp_get_conntrack_flow_hash_value(struct flow_key *key);

Expand Down
39 changes: 31 additions & 8 deletions src/dp_flow.c
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,13 @@ void dp_process_aged_flows_non_offload(void)
}
}

static __rte_always_inline void dp_remove_flow(struct flow_value *flow_val)
{
if (offload_mode_enabled)
dp_rte_flow_remove(flow_val);
dp_age_out_flow(flow_val);
}

void dp_remove_nat_flows(uint16_t port_id, int nat_type)
{
struct flow_value *flow_val = NULL;
Expand All @@ -438,11 +445,8 @@ void dp_remove_nat_flows(uint16_t port_id, int nat_type)
return;
}
// NAT/VIP are in 1:1 relation to a VM (port_id), no need to check IP:port
if (flow_val->created_port_id == port_id && flow_val->nf_info.nat_type == nat_type) {
if (offload_mode_enabled)
dp_rte_flow_remove(flow_val);
dp_age_out_flow(flow_val);
}
if (flow_val->created_port_id == port_id && flow_val->nf_info.nat_type == nat_type)
dp_remove_flow(flow_val);
}
}

Expand All @@ -461,13 +465,32 @@ void dp_remove_neighnat_flows(uint32_t ipv4, uint32_t vni, uint16_t min_port, ui
if (next_key->vni == vni && next_key->ip_dst == ipv4
&& next_key->port_dst >= min_port && next_key->port_dst < max_port
) {
if (offload_mode_enabled)
dp_rte_flow_remove(flow_val);
dp_age_out_flow(flow_val);
dp_remove_flow(flow_val);
}
}
}

void dp_remove_vm_flows(uint16_t port_id, uint32_t ipv4, uint32_t vni)
{
struct flow_value *flow_val = NULL;
const struct flow_key *next_key;
uint32_t iter = 0;
int ret;

while ((ret = rte_hash_iterate(ipv4_flow_tbl, (const void **)&next_key, (void **)&flow_val, &iter)) != -ENOENT) {
if (DP_FAILED(ret)) {
DPS_LOG_ERR("Iterating flow table failed while removing VM flows", DP_LOG_RET(ret));
return;
}
if (flow_val->created_port_id == port_id
|| (next_key->vni == vni && flow_val->flow_key[DP_FLOW_DIR_ORG].ip_dst == ipv4)
) {
dp_remove_flow(flow_val);
}
}
}


hash_sig_t dp_get_conntrack_flow_hash_value(struct flow_key *key)
{
//It is not necessary to first test if this key exists, since for now, this function
Expand Down
6 changes: 6 additions & 0 deletions src/grpc/dp_grpc_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,17 @@ static int dp_process_delete_interface(struct dp_grpc_responder *responder)
struct dpgrpc_iface_id *request = &responder->request.del_iface;

int port_id;
uint32_t ipv4;
uint32_t vni;
int ret = DP_GRPC_OK;

port_id = dp_get_portid_with_vm_handle(request->iface_id);
if (DP_FAILED(port_id))
return DP_GRPC_ERR_NOT_FOUND;

ipv4 = dp_get_dhcp_range_ip4(port_id);
vni = dp_get_vm_vni(port_id);

dp_del_vnf_with_vnf_key(dp_get_vm_ul_ip6(port_id));
if (DP_FAILED(dp_port_stop(port_id)))
ret = DP_GRPC_ERR_PORT_STOP;
Expand All @@ -570,6 +575,7 @@ static int dp_process_delete_interface(struct dp_grpc_responder *responder)
#ifdef ENABLE_VIRTSVC
dp_virtsvc_del_vm(port_id);
#endif
dp_remove_vm_flows(port_id, ipv4, vni);
return ret;
}

Expand Down
107 changes: 68 additions & 39 deletions test/tcp_tester.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
from config import *
from helpers import *

class TCPTester:
class _TCPTester:
TCP_RESET_REQUEST = "Resetme"
TCP_NORMAL_REQUEST = "Hello"
TCP_NORMAL_RESPONSE = "Same to you"

def __init__(self, client_vm, client_port, client_ul_ipv6, pf_name, server_ip, server_port, client_pkt_check=None, server_pkt_check=None, encaped=True):
self.client_vm = client_vm
def __init__(self, client_tap, client_mac, client_ip, client_port,
server_tap, server_mac, server_ip, server_port,
client_pkt_check=None, server_pkt_check=None):
self.client_tap = client_tap
self.client_mac = client_mac
self.client_ip = client_ip
self.client_port = client_port
self.client_ul_ipv6 = client_ul_ipv6
self.pf_name = pf_name
self.server_tap = server_tap
self.server_mac = server_mac
self.server_ip = server_ip
self.server_port = server_port
self.client_pkt_check = client_pkt_check
self.server_pkt_check = server_pkt_check
self.encaped = encaped

def reset(self):
self.tcp_sender_seq = 100
self.tcp_receiver_seq = 200
self.tcp_used_port = 0


def get_ip_layer_response(self, pkt):
if self.encaped:
return IPv6(dst=self.client_ul_ipv6, src=pkt[IPv6].dst, nh=4) / IP(dst=pkt[IP].src, src=pkt[IP].dst)
else:
return IPv6(dst=self.client_ul_ipv6, src=pkt[IPv6].dst, nh=6)
def get_server_l3_reply(self, pkt):
raise NotImplementedError("This base implementation needs to be overriden")

def get_server_packet(self):
pkt = sniff_packet(self.pf_name, is_tcp_pkt)
pkt = sniff_packet(self.server_tap, is_tcp_pkt)
assert self.tcp_used_port == 0 or pkt[TCP].sport == self.tcp_used_port, \
f"Dp-service port changed during communication {pkt[TCP].sport} vs {self.tcp_used_port}"
self.tcp_used_port = pkt[TCP].sport
Expand All @@ -56,29 +55,26 @@ def reply_tcp(self):
if "F" in pkt[TCP].flags:
flags += "F"

reply_pkt = (Ether(dst=pkt[Ether].src, src=pkt[Ether].dst, type=0x86DD) /
self.get_ip_layer_response(pkt) /
reply_pkt = (self.get_server_l3_reply(pkt) /
TCP(dport=pkt[TCP].sport, sport=pkt[TCP].dport, seq=self.tcp_receiver_seq, flags=flags, ack=pkt[TCP].seq+1, options=[("NOP", None)]))
delayed_sendp(reply_pkt, self.pf_name)
delayed_sendp(reply_pkt, self.server_tap)

if flags != "A":
self.tcp_receiver_seq += 1

# Application-level reply
if pkt[TCP].payload != None and len(pkt[TCP].payload) > 0:
if pkt[TCP].payload == Raw(TCPTester.TCP_RESET_REQUEST):
reply_pkt = (Ether(dst=pkt[Ether].src, src=pkt[Ether].dst, type=0x86DD) /
self.get_ip_layer_response(pkt) /
if pkt[TCP].payload == Raw(_TCPTester.TCP_RESET_REQUEST):
reply_pkt = (self.get_server_l3_reply(pkt) /
TCP(dport=pkt[TCP].sport, sport=pkt[TCP].dport, seq=self.tcp_receiver_seq, flags="R"))
delayed_sendp(reply_pkt, self.pf_name)
delayed_sendp(reply_pkt, self.server_tap)
return
elif pkt[TCP].payload == Raw(TCPTester.TCP_NORMAL_REQUEST):
reply_pkt = (Ether(dst=pkt[Ether].src, src=pkt[Ether].dst, type=0x86DD) /
self.get_ip_layer_response(pkt) /
TCP(dport=pkt[TCP].sport, sport=pkt[TCP].dport, seq=self.tcp_receiver_seq, flags="") /
Raw(TCPTester.TCP_NORMAL_RESPONSE))
delayed_sendp(reply_pkt, self.pf_name)
self.tcp_receiver_seq += len(TCPTester.TCP_NORMAL_RESPONSE)
elif pkt[TCP].payload == Raw(_TCPTester.TCP_NORMAL_REQUEST):
reply_pkt = (self.get_server_l3_reply(pkt) /
TCP(dport=pkt[TCP].sport, sport=pkt[TCP].dport, seq=self.tcp_receiver_seq, flags="") /
Raw(_TCPTester.TCP_NORMAL_RESPONSE))
delayed_sendp(reply_pkt, self.server_tap)
self.tcp_receiver_seq += len(_TCPTester.TCP_NORMAL_RESPONSE)
# and continue with ACK

# Await ACK
Expand All @@ -88,12 +84,12 @@ def reply_tcp(self):


def get_client_packet(self):
pkt = sniff_packet(self.client_vm.tap, is_tcp_pkt)
pkt = sniff_packet(self.client_tap, is_tcp_pkt)
assert pkt[IP].src == self.server_ip, \
"Got answer from wrong server IP"
assert pkt[TCP].sport == self.server_port, \
"Got answer from wrong server TCP port"
assert pkt[IP].dst == self.client_vm.ip, \
assert pkt[IP].dst == self.client_ip, \
"Got answer back to wrong client VM IP"
assert pkt[TCP].dport == self.client_port, \
"Got answer back to wrong client VM TCP port"
Expand All @@ -105,12 +101,12 @@ def request_tcp(self, flags, payload=None):
server_thread = threading.Thread(target=self.reply_tcp)
server_thread.start()

tcp_pkt = (Ether(dst=PF0.mac, src=self.client_vm.mac, type=0x0800) /
IP(dst=self.server_ip, src=self.client_vm.ip) /
tcp_pkt = (Ether(dst=self.server_mac, src=self.client_mac, type=0x0800) /
IP(dst=self.server_ip, src=self.client_ip) /
TCP(dport=self.server_port, sport=self.client_port, seq=self.tcp_sender_seq, flags=flags, options=[("NOP", None)]))
if payload != None:
tcp_pkt /= Raw(payload)
delayed_sendp(tcp_pkt, self.client_vm.tap)
delayed_sendp(tcp_pkt, self.client_tap)

# No reaction to ACK expected
if flags == "A":
Expand All @@ -136,20 +132,20 @@ def request_tcp(self, flags, payload=None):
else:
pkt = self.get_client_packet()
if "R" in pkt[TCP].flags:
assert payload is not None and payload == TCPTester.TCP_RESET_REQUEST, \
assert payload is not None and payload == _TCPTester.TCP_RESET_REQUEST, \
"Unexpected connection reset"
self.reset()
return
else:
assert pkt[TCP].payload == Raw(TCPTester.TCP_NORMAL_RESPONSE), \
assert pkt[TCP].payload == Raw(_TCPTester.TCP_NORMAL_RESPONSE), \
"Bad answer from server"
reply_seq += len(payload)

# send ACK
tcp_pkt = (Ether(dst=PF0.mac, src=self.client_vm.mac, type=0x0800) /
IP(dst=self.server_ip, src=self.client_vm.ip) /
tcp_pkt = (Ether(dst=self.server_mac, src=self.client_mac, type=0x0800) /
IP(dst=self.server_ip, src=self.client_ip) /
TCP(dport=self.server_port, sport=self.client_port, flags="A", seq=self.tcp_sender_seq, ack=reply_seq))
delayed_sendp(tcp_pkt, self.client_vm.tap)
delayed_sendp(tcp_pkt, self.client_tap)

server_thread.join(timeout=1)
assert not server_thread.is_alive(), \
Expand All @@ -162,17 +158,50 @@ def communicate(self):
# 3-way handshake
self.request_tcp("S")
# data
self.request_tcp("", payload=TCPTester.TCP_NORMAL_REQUEST)
self.request_tcp("", payload=_TCPTester.TCP_NORMAL_REQUEST)
# close connection
self.request_tcp("F")

# Helper function to start, send data, and make the server send RST
def request_rst(self):
self.reset()
self.request_tcp("S")
self.request_tcp("", payload=TCPTester.TCP_RESET_REQUEST)
self.request_tcp("", payload=_TCPTester.TCP_RESET_REQUEST)

# Helper function to create a dangling connection
def leave_open(self):
self.reset()
self.request_tcp("S")


class TCPTesterLocal(_TCPTester):
def __init__(self, client_vm, client_port, server_vm, server_port, client_pkt_check=None, server_pkt_check=None):
super().__init__(client_vm.tap, client_vm.mac, client_vm.ip, client_port,
server_vm.tap, server_vm.mac, server_vm.ip, server_port,
client_pkt_check=client_pkt_check, server_pkt_check=server_pkt_check)
# VM-VM local communication, stay in IPv4
def get_server_l3_reply(self, pkt):
return (Ether(dst=pkt[Ether].src, src=pkt[Ether].dst, type=0x0800) /
IP(dst=pkt[IP].src, src=pkt[IP].dst))

class TCPTesterVirtsvc(_TCPTester):
def __init__(self, client_vm, client_port, pf_spec, server_ip, server_port, client_pkt_check=None, server_pkt_check=None):
super().__init__(client_vm.tap, client_vm.mac, client_vm.ip, client_port,
pf_spec.tap, pf_spec.mac, server_ip, server_port,
client_pkt_check=client_pkt_check, server_pkt_check=server_pkt_check)
# Virtual-service communication, no tunnel, replace header with IPv6
def get_server_l3_reply(self, pkt):
return (Ether(dst=pkt[Ether].src, src=pkt[Ether].dst, type=0x86DD) /
IPv6(dst=router_ul_ipv6, src=pkt[IPv6].dst, nh=6))

class TCPTesterPublic(_TCPTester):
def __init__(self, client_vm, client_port, nat_ul_ipv6, pf_spec, server_ip, server_port, client_pkt_check=None, server_pkt_check=None):
super().__init__(client_vm.tap, client_vm.mac, client_vm.ip, client_port,
pf_spec.tap, pf_spec.mac, server_ip, server_port,
client_pkt_check=client_pkt_check, server_pkt_check=server_pkt_check)
self.nat_ul_ipv6 = nat_ul_ipv6
# Underlay communication, use IP-IP tunnel
def get_server_l3_reply(self, pkt):
return (Ether(dst=pkt[Ether].src, src=pkt[Ether].dst, type=0x86DD) /
IPv6(dst=self.nat_ul_ipv6, src=pkt[IPv6].dst, nh=4) /
IP(dst=pkt[IP].src, src=pkt[IP].dst))
18 changes: 6 additions & 12 deletions test/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import time

from helpers import *
from tcp_tester import TCPTester
from tcp_tester import TCPTesterLocal
from tcp_tester import TCPTesterPublic

nat_only_port = 1024

Expand All @@ -26,24 +27,21 @@ def test_nat_table_flush(prepare_ipv4, grpc_client):
# NAT with one port
nat_ul_ipv6 = grpc_client.addnat(VM1.name, nat_vip, nat_only_port, nat_only_port+1)

tester = TCPTester(client_vm=VM1, client_port=12345, client_ul_ipv6=nat_ul_ipv6, pf_name=PF0.tap,
server_ip=public_ip, server_port=443,
server_pkt_check=tcp_server_nat_pkt_check)
tester = TCPTesterPublic(VM1, 12345, nat_ul_ipv6, PF0, public_ip, 443, server_pkt_check=tcp_server_nat_pkt_check)
tester.communicate()

# Re-create the NAT with a different port range
grpc_client.delnat(VM1.name)
nat_only_port = 1025
nat_ul_ipv6 = grpc_client.addnat(VM1.name, nat_vip, nat_only_port, nat_only_port+1)
tester.client_ul_ipv6 = nat_ul_ipv6
tester.nat_ul_ipv6 = nat_ul_ipv6

# Keep the client port the same, this will cause an established flow to re-use the old NAT port
tester.communicate()

grpc_client.delnat(VM1.name)



def send_bounce_pkt_to_pf(ipv6_nat):
bouce_pkt = (Ether(dst=ipv6_multicast_mac, src=PF0.mac, type=0x86DD) /
IPv6(dst=ipv6_nat, src=router_ul_ipv6, nh=4) /
Expand All @@ -57,9 +55,7 @@ def test_neighnat_table_flush(prepare_ipv4, grpc_client):

global nat_only_port
nat_only_port = nat_local_min_port
tester = TCPTester(client_vm=VM1, client_port=12345, client_ul_ipv6=nat_ul_ipv6, pf_name=PF0.tap,
server_ip=public_ip, server_port=443,
server_pkt_check=tcp_server_nat_pkt_check)
tester = TCPTesterPublic(VM1, 12345, nat_ul_ipv6, PF0, public_ip, 443, server_pkt_check=tcp_server_nat_pkt_check)
tester.communicate()

grpc_client.addneighnat(nat_vip, vni1, nat_neigh_min_port, nat_neigh_max_port, neigh_vni1_ul_ipv6)
Expand Down Expand Up @@ -96,9 +92,7 @@ def test_neighnat_table_flush(prepare_ipv4, grpc_client):
f"Packet still being relayed!"

nat_only_port = nat_neigh_min_port
tester = TCPTester(client_vm=VM1, client_port=12345, client_ul_ipv6=nat_ul_ipv6, pf_name=PF0.tap,
server_ip=public_ip, server_port=443,
server_pkt_check=tcp_server_nat_pkt_check)
tester = TCPTesterPublic(VM1, 12345, nat_ul_ipv6, PF0, public_ip, 443, server_pkt_check=tcp_server_nat_pkt_check)
tester.communicate()

grpc_client.delnat(VM1.name)
10 changes: 3 additions & 7 deletions test/test_virtsvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from helpers import *
from tcp_tester import TCPTester
from tcp_tester import TCPTesterVirtsvc

udp_used_port = 0

Expand Down Expand Up @@ -63,14 +63,10 @@ def test_virtsvc_tcp(request, prepare_ipv4, port_redundancy):
if not request.config.getoption("--virtsvc"):
pytest.skip("Virtual services not enabled")

tester = TCPTester(client_vm=VM1, client_port=12345, client_ul_ipv6=router_ul_ipv6, pf_name=PF0.tap,
server_ip=virtsvc_tcp_virtual_ip, server_port=virtsvc_tcp_virtual_port,
server_pkt_check=tcp_server_virtsvc_pkt_check,
encaped=False)
tester = TCPTesterVirtsvc(VM1, 12345, PF0, virtsvc_tcp_virtual_ip, virtsvc_tcp_virtual_port, server_pkt_check=tcp_server_virtsvc_pkt_check)
tester.communicate()

# port number chosen so that they cause the right redirection
if port_redundancy:
tester.client_port = 54321
tester.pf_name = PF1.tap
tester = TCPTesterVirtsvc(VM1, 54321, PF1, virtsvc_tcp_virtual_ip, virtsvc_tcp_virtual_port, server_pkt_check=tcp_server_virtsvc_pkt_check)
tester.communicate()
Loading
Loading