diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c index 2026a9a36f80..2b648a2fc83c 100644 --- a/net/mptcp/sockopt.c +++ b/net/mptcp/sockopt.c @@ -567,6 +567,7 @@ static bool mptcp_supported_sockopt(int level, int optname) case TCP_FASTOPEN_CONNECT: case TCP_FASTOPEN_KEY: case TCP_FASTOPEN_NO_COOKIE: + case TCP_ULP: return true; } @@ -579,6 +580,42 @@ static bool mptcp_supported_sockopt(int level, int optname) return false; } +static int mptcp_setsockopt_sol_tcp_ulp(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + char name[TCP_CA_NAME_MAX]; + int ret; + + if (optlen < 1) + return -EINVAL; + + ret = strncpy_from_sockptr(name, optval, + min_t(long, TCP_ULP_NAME_MAX - 1, optlen)); + if (ret < 0) + return -EFAULT; + name[ret] = 0; + + ret = 0; + lock_sock(sk); + sockopt_seq_inc(msk); + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + int err; + + lock_sock(ssk); + err = tcp_set_ulp(ssk, name); + if (err < 0 && ret == 0) + ret = err; + subflow->setsockopt_seq = msk->setsockopt_seq; + release_sock(ssk); + } + + release_sock(sk); + return ret; +} + static int mptcp_setsockopt_sol_tcp_congestion(struct mptcp_sock *msk, sockptr_t optval, unsigned int optlen) { @@ -806,7 +843,7 @@ static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname, switch (optname) { case TCP_ULP: - return -EOPNOTSUPP; + return mptcp_setsockopt_sol_tcp_ulp(msk, optval, optlen); case TCP_CONGESTION: return mptcp_setsockopt_sol_tcp_congestion(msk, optval, optlen); case TCP_DEFER_ACCEPT: diff --git a/tools/testing/selftests/bpf/network_helpers.c b/tools/testing/selftests/bpf/network_helpers.c index 764c41043181..019b9d4e2c7f 100644 --- a/tools/testing/selftests/bpf/network_helpers.c +++ b/tools/testing/selftests/bpf/network_helpers.c @@ -616,6 +616,7 @@ struct send_recv_arg { int fd; uint32_t bytes; int stop; + int (*cb)(int fd); }; static void *send_recv_server(void *arg) @@ -638,6 +639,11 @@ static void *send_recv_server(void *arg) goto done; } + if (a->cb && a->cb(fd)) { + err = -errno; + goto done; + } + while (bytes < a->bytes && !READ_ONCE(a->stop)) { nr_sent = send(fd, &batch, MIN(a->bytes - bytes, sizeof(batch)), 0); @@ -666,13 +672,15 @@ static void *send_recv_server(void *arg) return NULL; } -int send_recv_data(int lfd, int fd, uint32_t total_bytes) +int send_recv_data(int lfd, int fd, uint32_t total_bytes, + int (*post_accept_cb)(int fd)) { ssize_t nr_recv = 0, bytes = 0; struct send_recv_arg arg = { .fd = lfd, .bytes = total_bytes, .stop = 0, + .cb = post_accept_cb, }; pthread_t srv_thread; void *thread_ret; diff --git a/tools/testing/selftests/bpf/network_helpers.h b/tools/testing/selftests/bpf/network_helpers.h index 622ca0041609..0ee54d337cad 100644 --- a/tools/testing/selftests/bpf/network_helpers.h +++ b/tools/testing/selftests/bpf/network_helpers.h @@ -87,7 +87,8 @@ struct nstoken *open_netns(const char *name); void close_netns(struct nstoken *token); struct nstoken *create_netns(const char *name); void cleanup_netns(struct nstoken *token); -int send_recv_data(int lfd, int fd, uint32_t total_bytes); +int send_recv_data(int lfd, int fd, uint32_t total_bytes, + int (*post_accept_cb)(int fd)); int unshare_netns(void); static __u16 csum_fold(__u32 csum) diff --git a/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c b/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c index 499e80bf673b..8d04e15c153a 100644 --- a/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c +++ b/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c @@ -50,7 +50,7 @@ static void do_test(const struct network_helper_opts *opts) if (!ASSERT_NEQ(fd, -1, "connect_to_fd_opts")) goto done; - ASSERT_OK(send_recv_data(lfd, fd, total_bytes), "send_recv_data"); + ASSERT_OK(send_recv_data(lfd, fd, total_bytes, NULL), "send_recv_data"); done: close(lfd); diff --git a/tools/testing/selftests/bpf/prog_tests/mptcp.c b/tools/testing/selftests/bpf/prog_tests/mptcp.c index 7655d342d886..e8a0d9169dca 100644 --- a/tools/testing/selftests/bpf/prog_tests/mptcp.c +++ b/tools/testing/selftests/bpf/prog_tests/mptcp.c @@ -4,6 +4,7 @@ #include #include +#include #include #include "cgroup_helpers.h" #include "network_helpers.h" @@ -41,6 +42,10 @@ #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1) #endif +#ifndef TCP_ULP +#define TCP_ULP 31 +#endif + #ifndef TCP_CA_NAME_MAX #define TCP_CA_NAME_MAX 16 #endif @@ -436,6 +441,100 @@ static void test_subflow(void) close(cgroup_fd); } +static int sockmap_init_ktls(int fd) +{ + struct tls12_crypto_info_aes_gcm_128 tls_tx = { + .info = { + .version = TLS_1_2_VERSION, + .cipher_type = TLS_CIPHER_AES_GCM_128, + }, + }; + struct tls12_crypto_info_aes_gcm_128 tls_rx = { + .info = { + .version = TLS_1_2_VERSION, + .cipher_type = TLS_CIPHER_AES_GCM_128, + }, + }; + int so_buf = 6553500; + int err; + + err = setsockopt(fd, SOL_TCP, TCP_ULP, "tls", sizeof("tls")); + if (!ASSERT_OK(err, "setsockopt TCP_ULP")) + return err; + err = setsockopt(fd, SOL_TLS, TLS_TX, (void *)&tls_tx, sizeof(tls_tx)); + if (!ASSERT_OK(err, "setsockopt TLS_TX")) + return err; + err = setsockopt(fd, SOL_TLS, TLS_RX, (void *)&tls_rx, sizeof(tls_rx)); + if (!ASSERT_OK(err, "setsockopt TLS_RX")) + return err; + err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &so_buf, sizeof(so_buf)); + if (!ASSERT_OK(err, "setsockopt SO_SNDBUF")) + return err; + err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &so_buf, sizeof(so_buf)); + if (!ASSERT_OK(err, "setsockopt SO_RCVBUF")) + return err; + + return 0; +} + +static int has_bytes_sent(char *dst) +{ + return _ss_search(ADDR_1, dst, "sport", "bytes_sent:"); +} + +static void run_ktls(void) +{ + int server_fd, client_fd; + + server_fd = start_server(AF_INET, SOCK_STREAM, ADDR_1, PORT_1, 0); + //server_fd = start_mptcp_server(AF_INET, ADDR_1, PORT_1, 0); + if (!ASSERT_GE(server_fd, 0, "start_mptcp_server")) + return; + + client_fd = connect_to_fd(server_fd, 0); + if (!ASSERT_GE(client_fd, 0, "connect to fd")) + goto fail; + + if (!ASSERT_OK(sockmap_init_ktls(client_fd), "init_ktls client_fd")) + goto fail; + + if (!ASSERT_OK(send_recv_data(server_fd, client_fd, + total_bytes, sockmap_init_ktls), + "send_recv_data")) + goto fail; + + CHECK(has_bytes_sent(ADDR_1), "ktls", "should have bytes_sent on addr1\n"); + //CHECK(has_bytes_sent(ADDR_2), "ktls", "should have bytes_sent on addr2\n"); + + close(client_fd); +fail: + close(server_fd); +} + +static void test_ktls(void) +{ + struct nstoken *nstoken; + int cgroup_fd; + + cgroup_fd = test__join_cgroup("/mptcp_ktls"); + if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup: mptcp_ktls")) + return; + + nstoken = create_netns(NS_TEST); + if (!ASSERT_OK_PTR(nstoken, "create_netns: mptcp_ktls")) + goto close_cgroup; + + if (!ASSERT_OK(endpoint_init("subflow"), "endpoint_init")) + goto close_netns; + + run_ktls(); + +close_netns: + cleanup_netns(nstoken); +close_cgroup: + close(cgroup_fd); +} + static struct nstoken *sched_init(char *flags, char *sched) { struct nstoken *nstoken; @@ -455,11 +554,6 @@ static struct nstoken *sched_init(char *flags, char *sched) return NULL; } -static int has_bytes_sent(char *dst) -{ - return _ss_search(ADDR_1, dst, "sport", "bytes_sent:"); -} - static void send_data_and_verify(char *sched, bool addr1, bool addr2) { struct timespec start, end; @@ -477,7 +571,7 @@ static void send_data_and_verify(char *sched, bool addr1, bool addr2) if (clock_gettime(CLOCK_MONOTONIC, &start) < 0) goto fail; - if (!ASSERT_OK(send_recv_data(server_fd, client_fd, total_bytes), + if (!ASSERT_OK(send_recv_data(server_fd, client_fd, total_bytes, NULL), "send_recv_data")) goto fail; @@ -623,6 +717,8 @@ void test_mptcp(void) test_mptcpify(); if (test__start_subtest("subflow")) test_subflow(); + if (test__start_subtest("ktls")) + test_ktls(); if (test__start_subtest("default")) test_default(); if (test__start_subtest("first"))