diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c index 1e74851614e8..ef140cf5f096 100644 --- a/net/mptcp/sockopt.c +++ b/net/mptcp/sockopt.c @@ -1498,6 +1498,10 @@ int mptcp_set_rcvlowat(struct sock *sk, int val) struct mptcp_subflow_context *subflow; int space, cap; + /* bpf can land here with a wrong sk type */ + if (sk->sk_protocol == IPPROTO_TCP) + return -EINVAL; + if (sk->sk_userlocks & SOCK_RCVBUF_LOCK) cap = sk->sk_rcvbuf >> 1; else diff --git a/tools/testing/selftests/bpf/network_helpers.c b/tools/testing/selftests/bpf/network_helpers.c index efc99056e8a6..c595bca4f97e 100644 --- a/tools/testing/selftests/bpf/network_helpers.c +++ b/tools/testing/selftests/bpf/network_helpers.c @@ -511,7 +511,6 @@ struct arg { int fd; unsigned total_bytes; int stop; - bool again; }; static void *server(void *arg) @@ -539,8 +538,6 @@ static void *server(void *arg) MIN(a->total_bytes - bytes, sizeof(batch)), 0); if (nr_sent == -1 && errno == EINTR) continue; - if (nr_sent == -1 && a->again && errno == EAGAIN) - continue; if (nr_sent == -1) { err = -errno; break; @@ -561,15 +558,13 @@ static void *server(void *arg) } void send_recv_data(int lfd, int fd, - unsigned total_bytes, - bool again) + unsigned total_bytes) { ssize_t nr_recv = 0, bytes = 0; pthread_t srv_thread; struct arg arg = { .fd = lfd, .total_bytes = total_bytes, - .again = again, }; void *thread_ret; char batch[1500]; @@ -587,8 +582,6 @@ void send_recv_data(int lfd, int fd, MIN(total_bytes - bytes, sizeof(batch)), 0); if (nr_recv == -1 && errno == EINTR) continue; - if (nr_recv == -1 && again && errno == EAGAIN) - continue; if (nr_recv == -1) break; bytes += nr_recv; diff --git a/tools/testing/selftests/bpf/network_helpers.h b/tools/testing/selftests/bpf/network_helpers.h index 992c9dce3386..ec83b25f1e46 100644 --- a/tools/testing/selftests/bpf/network_helpers.h +++ b/tools/testing/selftests/bpf/network_helpers.h @@ -73,8 +73,7 @@ struct nstoken *open_netns(const char *name); void close_netns(struct nstoken *token); int send_byte(int fd); void send_recv_data(int lfd, int fd, - unsigned total_bytes, - bool again); + unsigned total_bytes); static __u16 csum_fold(__u32 csum) { diff --git a/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c b/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c index f67bf931cf6a..162eee699975 100644 --- a/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c +++ b/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c @@ -72,7 +72,7 @@ static void do_test(const char *tcp_ca, const struct bpf_map *sk_stg_map) goto done; } - send_recv_data(lfd, fd, total_bytes, false); + send_recv_data(lfd, fd, total_bytes); done: close(lfd); diff --git a/tools/testing/selftests/bpf/prog_tests/mptcp.c b/tools/testing/selftests/bpf/prog_tests/mptcp.c index 12f9dd7acba6..a636d1201941 100644 --- a/tools/testing/selftests/bpf/prog_tests/mptcp.c +++ b/tools/testing/selftests/bpf/prog_tests/mptcp.c @@ -21,6 +21,7 @@ #define ADDR_1 "10.0.1.1" #define ADDR_2 "10.0.1.2" #define PORT_1 10001 +#define TIMEOUT_TEST 60 #ifndef IPPROTO_MPTCP #define IPPROTO_MPTCP 262 @@ -305,6 +306,15 @@ static int ss_search(char *src, char *keyword) return _ss_search(src, ADDR_1, "dport", keyword); } +static int set_nonblock(int fd) +{ + int flags = O_NONBLOCK; + + if (fcntl(fd, flags) < 0) + return -1; + return 0; +} + static void run_mptcp_subflow(int cgroup_fd, struct mptcp_subflow *skel) { int server_fd, client_fd, prog_fd, err; @@ -321,15 +331,18 @@ static void run_mptcp_subflow(int cgroup_fd, struct mptcp_subflow *skel) if (!ASSERT_OK(err, "prog_attach")) return; - server_fd = start_mptcp_server(AF_INET, ADDR_1, PORT_1, 0); + server_fd = start_mptcp_server(AF_INET, ADDR_1, PORT_1, TIMEOUT_TEST); if (!ASSERT_GE(server_fd, 0, "start_mptcp_server")) return; - client_fd = connect_to_fd(server_fd, 0); + client_fd = connect_to_fd(server_fd, TIMEOUT_TEST); if (!ASSERT_GE(client_fd, 0, "connect to fd")) goto close_server; - send_recv_data(server_fd, client_fd, total_bytes, true); + if (set_nonblock(server_fd)) + goto close_server; + + send_recv_data(server_fd, client_fd, total_bytes * 20); ASSERT_OK(ss_search(ADDR_1, "fwmark:0x1"), "ss_search fwmark:0x1"); ASSERT_OK(ss_search(ADDR_2, "fwmark:0x2"), "ss_search fwmark:0x2"); @@ -407,20 +420,23 @@ static void send_data_and_verify(char *msg, int addr1, int addr2) int server_fd, client_fd; unsigned int delta_ms; - server_fd = start_mptcp_server(AF_INET, ADDR_1, PORT_1, 0); + server_fd = start_mptcp_server(AF_INET, ADDR_1, PORT_1, TIMEOUT_TEST); if (!ASSERT_NEQ(server_fd, -1, "start_mptcp_server")) return; - client_fd = connect_to_fd(server_fd, 0); + client_fd = connect_to_fd(server_fd, TIMEOUT_TEST); if (!ASSERT_NEQ(client_fd, -1, "connect_to_fd")) { close(server_fd); return; } if (clock_gettime(CLOCK_MONOTONIC, &start) < 0) - return; + goto close_server; - send_recv_data(server_fd, client_fd, total_bytes, true); + if (set_nonblock(server_fd)) + goto close_server; + + send_recv_data(server_fd, client_fd, total_bytes * 20); if (clock_gettime(CLOCK_MONOTONIC, &end) < 0) return; @@ -438,6 +454,7 @@ static void send_data_and_verify(char *msg, int addr1, int addr2) ASSERT_GT(has_bytes_sent(ADDR_2), 0, "Shouldn't have bytes_sent on addr2"); close(client_fd); +close_server: close(server_fd); }