diff --git a/include/net/espintcp.h b/include/net/espintcp.h index c70efd704b6d5..4689ed3099331 100644 --- a/include/net/espintcp.h +++ b/include/net/espintcp.h @@ -35,6 +35,6 @@ static inline struct espintcp_ctx *espintcp_getctx(const struct sock *sk) const struct inet_connection_sock *icsk = inet_csk(sk); /* RCU is only needed for diag */ - return (__force void *)icsk->icsk_ulp_data; + return (__force void *)icsk->icsk_ulp_data[ULP_INDEX_DEFAULT]; } #endif diff --git a/include/net/inet_connection_sock.h b/include/net/inet_connection_sock.h index 7d6b1254c92d5..dcfb937b46ba7 100644 --- a/include/net/inet_connection_sock.h +++ b/include/net/inet_connection_sock.h @@ -28,6 +28,12 @@ struct inet_bind_bucket; struct inet_bind2_bucket; struct tcp_congestion_ops; +enum ulp_index { + ULP_INDEX_DEFAULT, + ULP_INDEX_MPTCP, + ULP_INDEX_MAX, +}; + /* * Pointers to address related TCP functions * (i.e. things that depend on the address family) @@ -94,8 +100,8 @@ struct inet_connection_sock { __u32 icsk_pmtu_cookie; const struct tcp_congestion_ops *icsk_ca_ops; const struct inet_connection_sock_af_ops *icsk_af_ops; - const struct tcp_ulp_ops *icsk_ulp_ops; - void __rcu *icsk_ulp_data; + const struct tcp_ulp_ops *icsk_ulp_ops[ULP_INDEX_MAX]; + void __rcu *icsk_ulp_data[ULP_INDEX_MAX]; void (*icsk_clean_acked)(struct sock *sk, u32 acked_seq); unsigned int (*icsk_sync_mss)(struct sock *sk, u32 pmtu); __u8 icsk_ca_state:5, @@ -352,7 +358,8 @@ static inline void inet_csk_inc_pingpong_cnt(struct sock *sk) static inline bool inet_csk_has_ulp(const struct sock *sk) { - return inet_test_bit(IS_ICSK, sk) && !!inet_csk(sk)->icsk_ulp_ops; + return inet_test_bit(IS_ICSK, sk) && (!!inet_csk(sk)->icsk_ulp_ops[ULP_INDEX_DEFAULT] || + !!inet_csk(sk)->icsk_ulp_ops[ULP_INDEX_MPTCP]); } static inline void inet_init_csk_locks(struct sock *sk) diff --git a/include/net/tcp.h b/include/net/tcp.h index 32815a40dea16..353ebf207914b 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -2543,6 +2543,7 @@ enum hrtimer_restart tcp_pace_kick(struct hrtimer *timer); struct tcp_ulp_ops { struct list_head list; + int id; /* initialize ulp */ int (*init)(struct sock *sk); diff --git a/include/net/tls.h b/include/net/tls.h index 3a33924db2bc7..d5bb7a9288f74 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -367,7 +367,7 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk) /* Use RCU on icsk_ulp_data only for sock diag code, * TLS data path doesn't need rcu_dereference(). */ - return (__force void *)icsk->icsk_ulp_data; + return (__force void *)icsk->icsk_ulp_data[ULP_INDEX_DEFAULT]; } static inline struct tls_sw_context_rx *tls_sw_ctx_rx( diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index d81f74ce0f02e..08a0ad03d7361 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -1148,11 +1148,14 @@ static void inet_clone_ulp(const struct request_sock *req, struct sock *newsk, const gfp_t priority) { struct inet_connection_sock *icsk = inet_csk(newsk); + const struct tcp_ulp_ops *ulp_ops; + int i; - if (!icsk->icsk_ulp_ops) - return; - - icsk->icsk_ulp_ops->clone(req, newsk, priority); + for (i = 0; i < ULP_INDEX_MAX; i++) { + ulp_ops = icsk->icsk_ulp_ops[i]; + if (ulp_ops && ulp_ops->clone) + ulp_ops->clone(req, newsk, priority); + } } /** @@ -1251,9 +1254,14 @@ EXPORT_SYMBOL(inet_csk_prepare_forced_close); static int inet_ulp_can_listen(const struct sock *sk) { const struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_ulp_ops *ulp_ops; + int i; - if (icsk->icsk_ulp_ops && !icsk->icsk_ulp_ops->clone) - return -EINVAL; + for (i = 0; i < ULP_INDEX_MAX; i++) { + ulp_ops = icsk->icsk_ulp_ops[i]; + if (ulp_ops && !ulp_ops->clone) + return -EINVAL; + } return 0; } diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index c1cac83c0d225..4a2174decd2cb 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -4133,7 +4133,7 @@ int do_tcp_getsockopt(struct sock *sk, int level, if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; len = min_t(unsigned int, len, TCP_ULP_NAME_MAX); - if (!icsk->icsk_ulp_ops) { + if (!icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT]) { len = 0; if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; @@ -4141,7 +4141,7 @@ int do_tcp_getsockopt(struct sock *sk, int level, } if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_sockptr(optval, icsk->icsk_ulp_ops->name, len)) + if (copy_to_sockptr(optval, icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT]->name, len)) return -EFAULT; return 0; diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c index f428ecf9120f2..a011d0a11eb39 100644 --- a/net/ipv4/tcp_diag.c +++ b/net/ipv4/tcp_diag.c @@ -132,7 +132,7 @@ static int tcp_diag_get_aux(struct sock *sk, bool net_admin, if (net_admin) { const struct tcp_ulp_ops *ulp_ops; - ulp_ops = icsk->icsk_ulp_ops; + ulp_ops = icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT]; if (ulp_ops) err = tcp_diag_put_ulp(skb, sk, ulp_ops); if (err) @@ -167,7 +167,7 @@ static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin) if (net_admin && sk_fullsock(sk)) { const struct tcp_ulp_ops *ulp_ops; - ulp_ops = icsk->icsk_ulp_ops; + ulp_ops = icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT]; if (ulp_ops) { size += nla_total_size(0) + nla_total_size(TCP_ULP_NAME_MAX); diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c index 2aa442128630e..e01fe5cab7691 100644 --- a/net/ipv4/tcp_ulp.c +++ b/net/ipv4/tcp_ulp.c @@ -104,27 +104,36 @@ void tcp_update_ulp(struct sock *sk, struct proto *proto, void (*write_space)(struct sock *sk)) { struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_ulp_ops *ulp_ops; + int i; - if (icsk->icsk_ulp_ops->update) - icsk->icsk_ulp_ops->update(sk, proto, write_space); + for (i = 0; i < ULP_INDEX_MAX; i++) { + ulp_ops = icsk->icsk_ulp_ops[i]; + if (ulp_ops->update) + ulp_ops->update(sk, proto, write_space); + } } void tcp_cleanup_ulp(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_ulp_ops *ulp_ops; + int i; /* No sock_owned_by_me() check here as at the time the * stack calls this function, the socket is dead and * about to be destroyed. */ - if (!icsk->icsk_ulp_ops) - return; - - if (icsk->icsk_ulp_ops->release) - icsk->icsk_ulp_ops->release(sk); - module_put(icsk->icsk_ulp_ops->owner); - - icsk->icsk_ulp_ops = NULL; + for (i = 0; i < ULP_INDEX_MAX; i++) { + ulp_ops = icsk->icsk_ulp_ops[i]; + if (ulp_ops) { + //pr_info("%s ulp_ops->name=%s\n", __func__, ulp_ops->name); + if (ulp_ops->release) + ulp_ops->release(sk); + module_put(ulp_ops->owner); + ulp_ops = NULL; + } + } } static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops) @@ -133,7 +142,7 @@ static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops) int err; err = -EEXIST; - if (icsk->icsk_ulp_ops) + if (icsk->icsk_ulp_ops[ulp_ops->id]) goto out_err; if (sk->sk_socket) @@ -147,7 +156,7 @@ static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops) if (err) goto out_err; - icsk->icsk_ulp_ops = ulp_ops; + icsk->icsk_ulp_ops[ulp_ops->id] = ulp_ops; return 0; out_err: module_put(ulp_ops->owner); diff --git a/net/mptcp/diag.c b/net/mptcp/diag.c index 3ae46b545d2c2..1a4a332358fd2 100644 --- a/net/mptcp/diag.c +++ b/net/mptcp/diag.c @@ -29,7 +29,7 @@ static int subflow_get_info(struct sock *sk, struct sk_buff *skb) slow = lock_sock_fast(sk); rcu_read_lock(); - sf = rcu_dereference(inet_csk(sk)->icsk_ulp_data); + sf = rcu_dereference(inet_csk(sk)->icsk_ulp_data[ULP_INDEX_MPTCP]); if (!sf) { err = 0; goto nla_failure; diff --git a/net/mptcp/mptcp_diag.c b/net/mptcp/mptcp_diag.c index 0566dd793810a..d050ff4a5f658 100644 --- a/net/mptcp/mptcp_diag.c +++ b/net/mptcp/mptcp_diag.c @@ -103,7 +103,7 @@ static void mptcp_diag_dump_listeners(struct sk_buff *skb, struct netlink_callba if (num < diag_ctx->l_num) goto next_listen; - if (!ctx || strcmp(inet_csk(sk)->icsk_ulp_ops->name, "mptcp")) + if (!ctx || strcmp(inet_csk(sk)->icsk_ulp_ops[ULP_INDEX_MPTCP]->name, "mptcp")) goto next_listen; sk = ctx->conn; diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index fcef499c6d57a..7f126dc2ebe29 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -2462,7 +2462,7 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, * the ssk has been already destroyed, we just need to release the * reference owned by msk; */ - if (!inet_csk(ssk)->icsk_ulp_ops) { + if (!inet_csk(ssk)->icsk_ulp_ops[ULP_INDEX_MPTCP]) { WARN_ON_ONCE(!sock_flag(ssk, SOCK_DEAD)); kfree_rcu(subflow, rcu); } else { diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index e24d4f3055df7..37120232b4713 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -14,6 +14,7 @@ #include #include #include +#include #define MPTCP_SUPPORTED_VERSION 1 @@ -499,6 +500,7 @@ DECLARE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions); struct mptcp_subflow_context { struct list_head node;/* conn_list of subflows */ + struct tls_context tls; struct_group(reset, unsigned long avg_pacing_rate; /* protected by msk socket lock */ @@ -584,7 +586,7 @@ mptcp_subflow_ctx(const struct sock *sk) const struct inet_connection_sock *icsk = inet_csk(sk); /* Use RCU on icsk_ulp_data only for sock diag code */ - return (__force struct mptcp_subflow_context *)icsk->icsk_ulp_data; + return (__force struct mptcp_subflow_context *)icsk->icsk_ulp_data[ULP_INDEX_MPTCP]; } static inline struct sock * diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c index 2026a9a36f804..7f5fe2fb0a5d4 100644 --- a/net/mptcp/sockopt.c +++ b/net/mptcp/sockopt.c @@ -567,6 +567,7 @@ static bool mptcp_supported_sockopt(int level, int optname) case TCP_FASTOPEN_CONNECT: case TCP_FASTOPEN_KEY: case TCP_FASTOPEN_NO_COOKIE: + case TCP_ULP: return true; } @@ -576,9 +577,54 @@ static bool mptcp_supported_sockopt(int level, int optname) * TCP_REPAIR_WINDOW are not supported, better avoid this mess */ } + if (level == SOL_TLS) { + switch (optname) { + case TLS_TX: + case TLS_RX: + return true; + } + } return false; } +static int mptcp_setsockopt_sol_tcp_ulp(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + char name[TCP_ULP_NAME_MAX]; + int ret; + + if (optlen < 1) + return -EINVAL; + + ret = strncpy_from_sockptr(name, optval, + min_t(long, TCP_ULP_NAME_MAX - 1, optlen)); + if (ret < 0) + return -EFAULT; + name[ret] = 0; + + ret = 0; + lock_sock(sk); + sockopt_seq_inc(msk); + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + int err; + + lock_sock(ssk); + err = tcp_set_ulp(ssk, name); + if (err < 0 && ret == 0) { + pr_info("%s err=%d\n", __func__, err); + ret = err; + } + subflow->setsockopt_seq = msk->setsockopt_seq; + release_sock(ssk); + } + + release_sock(sk); + return ret; +} + static int mptcp_setsockopt_sol_tcp_congestion(struct mptcp_sock *msk, sockptr_t optval, unsigned int optlen) { @@ -806,7 +852,7 @@ static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname, switch (optname) { case TCP_ULP: - return -EOPNOTSUPP; + return mptcp_setsockopt_sol_tcp_ulp(msk, optval, optlen); case TCP_CONGESTION: return mptcp_setsockopt_sol_tcp_congestion(msk, optval, optlen); case TCP_DEFER_ACCEPT: @@ -867,6 +913,8 @@ static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname, return ret; } +int tls_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen); int mptcp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, unsigned int optlen) { @@ -902,6 +950,9 @@ int mptcp_setsockopt(struct sock *sk, int level, int optname, if (level == SOL_TCP) return mptcp_setsockopt_sol_tcp(msk, optname, optval, optlen); + if (level == SOL_TLS) + return tls_setsockopt(msk->first, level, optname, optval, optlen); + return -EOPNOTSUPP; } diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c index 39e2cbdf38019..222647829a959 100644 --- a/net/mptcp/subflow.c +++ b/net/mptcp/subflow.c @@ -747,8 +747,8 @@ static void subflow_ulp_fallback(struct sock *sk, struct inet_connection_sock *icsk = inet_csk(sk); mptcp_subflow_tcp_fallback(sk, old_ctx); - icsk->icsk_ulp_ops = NULL; - rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + icsk->icsk_ulp_ops[ULP_INDEX_MPTCP] = NULL; + rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_MPTCP], NULL); tcp_sk(sk)->is_mptcp = 0; mptcp_subflow_ops_undo_override(sk); @@ -762,7 +762,7 @@ void mptcp_subflow_drop_ctx(struct sock *ssk) return; list_del(&mptcp_subflow_ctx(ssk)->node); - if (inet_csk(ssk)->icsk_ulp_ops) { + if (inet_csk(ssk)->icsk_ulp_ops[ULP_INDEX_MPTCP]) { subflow_ulp_fallback(ssk, ctx); if (ctx->conn) sock_put(ctx->conn); @@ -1755,7 +1755,7 @@ static struct mptcp_subflow_context *subflow_create_ctx(struct sock *sk, if (!ctx) return NULL; - rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_MPTCP], ctx); INIT_LIST_HEAD(&ctx->node); INIT_LIST_HEAD(&ctx->delegated_node); @@ -2044,6 +2044,7 @@ static int tcp_abort_override(struct sock *ssk, int err) } static struct tcp_ulp_ops subflow_ulp_ops __read_mostly = { + .id = ULP_INDEX_MPTCP, .name = "mptcp", .owner = THIS_MODULE, .init = subflow_ulp_init, diff --git a/net/mptcp/token_test.c b/net/mptcp/token_test.c index 4fc39fa2e262d..f439fd01d9a73 100644 --- a/net/mptcp/token_test.c +++ b/net/mptcp/token_test.c @@ -76,7 +76,7 @@ static void mptcp_token_test_msk_basic(struct kunit *test) struct mptcp_sock *null_msk = NULL; struct sock *sk; - rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_MPTCP], ctx); ctx->conn = (struct sock *)msk; sk = (struct sock *)msk; diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index e50a286fd0fb7..69c1cbf358937 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -3439,7 +3439,7 @@ static void smc_ulp_clone(const struct request_sock *req, struct sock *newsk, struct inet_connection_sock *icsk = inet_csk(newsk); /* don't inherit ulp ops to child when listen */ - icsk->icsk_ulp_ops = NULL; + icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT] = NULL; } static struct tcp_ulp_ops smc_ulp_ops __read_mostly = { diff --git a/net/tls/tls.h b/net/tls/tls.h index e5e47452308ab..30286d186e876 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -166,6 +166,8 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, int tls_sw_read_sock(struct sock *sk, read_descriptor_t *desc, sk_read_actor_t read_actor); +int tls_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen); int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); void tls_device_splice_eof(struct socket *sock); int tls_tx_records(struct sock *sk, int flags); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 90b7f253d3632..43ae0df65b2b0 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -383,7 +383,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) write_lock_bh(&sk->sk_callback_lock); if (free_ctx) - rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_DEFAULT], NULL); WRITE_ONCE(sk->sk_prot, ctx->sk_proto); if (sk->sk_write_space == tls_write_space) sk->sk_write_space = ctx->sk_write_space; @@ -794,8 +794,8 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, return rc; } -static int tls_setsockopt(struct sock *sk, int level, int optname, - sockptr_t optval, unsigned int optlen) +int tls_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) { struct tls_context *ctx = tls_get_ctx(sk); @@ -826,7 +826,7 @@ struct tls_context *tls_ctx_create(struct sock *sk) * address dependency between sk->sk_proto->{getsockopt,setsockopt} * and ctx->sk_proto. */ - rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_DEFAULT], ctx); return ctx; } @@ -1023,7 +1023,7 @@ static int tls_get_info(struct sock *sk, struct sk_buff *skb) return -EMSGSIZE; rcu_read_lock(); - ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); + ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data[ULP_INDEX_DEFAULT]); if (!ctx) { err = 0; goto nla_failure; @@ -1115,6 +1115,7 @@ static struct pernet_operations tls_proc_ops = { }; static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { + .id = ULP_INDEX_DEFAULT, .name = "tls", .owner = THIS_MODULE, .init = tls_init, diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 870314f40b5f0..a6d62a520c664 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1362,7 +1362,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, tls_strp_msg_load(&ctx->strp, released); - pr_info("%s return 1\n", __func__); + //pr_info("%s return 1\n", __func__); return 1; } @@ -1978,7 +1978,7 @@ int tls_sw_recvmsg(struct sock *sk, bool bpf_strp_enabled; bool zc_capable; - pr_info("%s tls_sw_sock_is_readable(sk)=%u\n", __func__, tls_sw_sock_is_readable(sk)); + //pr_info("%s tls_sw_sock_is_readable(sk)=%u\n", __func__, tls_sw_sock_is_readable(sk)); if (unlikely(flags & MSG_ERRQUEUE)) return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); @@ -2361,7 +2361,7 @@ bool tls_sw_sock_is_readable(struct sock *sk) ret = !ingress_empty || tls_strp_msg_ready(ctx) || !skb_queue_empty(&ctx->rx_list); - pr_info("%s ret=%u ingress_empty=%u\n", __func__, ret, ingress_empty); + //pr_info("%s ret=%u ingress_empty=%u\n", __func__, ret, ingress_empty); return ret; } diff --git a/net/tls/tls_toe.c b/net/tls/tls_toe.c index 825669e1ab479..f98c5f85932e2 100644 --- a/net/tls/tls_toe.c +++ b/net/tls/tls_toe.c @@ -50,7 +50,7 @@ static void tls_toe_sk_destruct(struct sock *sk) ctx->sk_destruct(sk); /* Free ctx */ - rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_DEFAULT], NULL); tls_ctx_free(sk, ctx); } diff --git a/net/xfrm/espintcp.c b/net/xfrm/espintcp.c index fe82e2d073006..34d73987a0085 100644 --- a/net/xfrm/espintcp.c +++ b/net/xfrm/espintcp.c @@ -495,7 +495,7 @@ static int espintcp_init_sk(struct sock *sk) sk->sk_data_ready = espintcp_data_ready; sk->sk_write_space = espintcp_write_space; sk->sk_destruct = espintcp_destruct; - rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_DEFAUL], ctx); INIT_WORK(&ctx->work, espintcp_tx_work); /* avoid using task_frag */ @@ -578,6 +578,7 @@ static void build_protos(struct proto *espintcp_prot, } static struct tcp_ulp_ops espintcp_ulp __read_mostly = { + .id = ULP_INDEX_DEFAULT, .name = "espintcp", .owner = THIS_MODULE, .init = espintcp_init_sk, diff --git a/tools/testing/selftests/bpf/config b/tools/testing/selftests/bpf/config index de358b51a05a1..ddc5ab6d56513 100644 --- a/tools/testing/selftests/bpf/config +++ b/tools/testing/selftests/bpf/config @@ -100,3 +100,4 @@ CONFIG_INET_IPCOMP=y CONFIG_INET_XFRM_TUNNEL=y CONFIG_INET6_IPCOMP=y CONFIG_INET6_XFRM_TUNNEL=y +CONFIG_TLS=y diff --git a/tools/testing/selftests/bpf/prog_tests/mptcp.c b/tools/testing/selftests/bpf/prog_tests/mptcp.c index 0303ca4d37888..f348c1b42a71a 100644 --- a/tools/testing/selftests/bpf/prog_tests/mptcp.c +++ b/tools/testing/selftests/bpf/prog_tests/mptcp.c @@ -4,6 +4,7 @@ #include #include +#include #include #include "cgroup_helpers.h" #include "network_helpers.h" @@ -41,6 +42,10 @@ #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED _BITUL(1) #endif +#ifndef TCP_ULP +#define TCP_ULP 31 +#endif + #ifndef TCP_CA_NAME_MAX #define TCP_CA_NAME_MAX 16 #endif @@ -430,6 +435,165 @@ static void test_subflow(void) close(cgroup_fd); } +static int sockmap_init_ktls(int fd) +{ + struct tls12_crypto_info_aes_gcm_128 tls_tx = { + .info = { + .version = TLS_1_2_VERSION, + .cipher_type = TLS_CIPHER_AES_GCM_128, + }, + }; + struct tls12_crypto_info_aes_gcm_128 tls_rx = { + .info = { + .version = TLS_1_2_VERSION, + .cipher_type = TLS_CIPHER_AES_GCM_128, + }, + }; + int so_buf = 6553500; + int err; + + err = setsockopt(fd, SOL_TCP, TCP_ULP, "tls", sizeof("tls")); + if (!ASSERT_OK(err, "setsockopt TCP_ULP")) + return err; + err = setsockopt(fd, SOL_TLS, TLS_TX, (void *)&tls_tx, sizeof(tls_tx)); + if (!ASSERT_OK(err, "setsockopt TLS_TX")) + return err; + err = setsockopt(fd, SOL_TLS, TLS_RX, (void *)&tls_rx, sizeof(tls_rx)); + if (!ASSERT_OK(err, "setsockopt TLS_RX")) + return err; + err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &so_buf, sizeof(so_buf)); + if (!ASSERT_OK(err, "setsockopt SO_SNDBUF")) + return err; + err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &so_buf, sizeof(so_buf)); + if (!ASSERT_OK(err, "setsockopt SO_RCVBUF")) + return err; + + return 0; +} + +static int has_bytes_sent(char *dst) +{ + return _ss_search(ADDR_1, dst, "sport", "bytes_sent:"); +} + +static int ktls_cb(int fd) +{ + int err; + + err = settimeo(fd, 0); + if (err) + return err; + err = sockmap_init_ktls(fd); + if (err) + return err; + + return 0; +} + +static void run_tcp_ktls(void) +{ + int server_fd, client_fd; + + server_fd = start_server(AF_INET, SOCK_STREAM, ADDR_1, PORT_1, 0); + if (!ASSERT_GE(server_fd, 0, "start_server")) + return; + + client_fd = connect_to_fd(server_fd, 0); + if (!ASSERT_GE(client_fd, 0, "connect to fd")) + goto fail; + + if (!ASSERT_OK(sockmap_init_ktls(client_fd), "init_ktls client_fd")) + goto fail; + + if (!ASSERT_OK(send_recv_data(server_fd, client_fd, + total_bytes, ktls_cb), + "send_recv_data")) + goto fail; + + //CHECK(has_bytes_sent(ADDR_1), "tcp_ktls", "should have bytes_sent on addr1\n"); + CHECK(!has_bytes_sent(ADDR_2), "tcp_ktls", "shouldn't have bytes_sent on addr2\n"); + + close(client_fd); +fail: + close(server_fd); +} + +static void test_tcp_ktls(void) +{ + struct nstoken *nstoken; + int cgroup_fd; + + cgroup_fd = test__join_cgroup("/tcp_ktls"); + if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup: tcp_ktls")) + return; + + nstoken = create_netns(NS_TEST); + if (!ASSERT_OK_PTR(nstoken, "create_netns: tcp_ktls")) + goto close_cgroup; + + if (!ASSERT_OK(endpoint_init("subflow"), "endpoint_init: tcp_ktls")) + goto close_netns; + + run_tcp_ktls(); + +close_netns: + cleanup_netns(nstoken); +close_cgroup: + close(cgroup_fd); +} + +static void run_mptcp_ktls(void) +{ + int server_fd, client_fd; + + server_fd = start_mptcp_server(AF_INET, ADDR_1, PORT_1, 0); + if (!ASSERT_GE(server_fd, 0, "start_mptcp_server")) + return; + + client_fd = connect_to_fd(server_fd, 0); + if (!ASSERT_GE(client_fd, 0, "connect to fd")) + goto fail; + + if (!ASSERT_OK(sockmap_init_ktls(client_fd), "init_ktls client_fd")) + goto fail; + + if (!ASSERT_OK(send_recv_data(server_fd, client_fd, + total_bytes, ktls_cb), + "send_recv_data")) + goto fail; + + CHECK(has_bytes_sent(ADDR_1), "mptcp ktls", "should have bytes_sent on addr1\n"); + CHECK(has_bytes_sent(ADDR_2), "mptcp ktls", "should have bytes_sent on addr2\n"); + + close(client_fd); +fail: + close(server_fd); +} + +static void test_mptcp_ktls(void) +{ + struct nstoken *nstoken; + int cgroup_fd; + + cgroup_fd = test__join_cgroup("/mptcp_ktls"); + if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup: mptcp_ktls")) + return; + + nstoken = create_netns(NS_TEST); + if (!ASSERT_OK_PTR(nstoken, "create_netns: mptcp_ktls")) + goto close_cgroup; + + if (!ASSERT_OK(endpoint_init("subflow"), "endpoint_init: mptcp_ktls")) + goto close_netns; + + run_mptcp_ktls(); + +close_netns: + cleanup_netns(nstoken); +close_cgroup: + close(cgroup_fd); +} + static struct nstoken *sched_init(char *flags, char *sched) { struct nstoken *nstoken; @@ -449,11 +613,6 @@ static struct nstoken *sched_init(char *flags, char *sched) return NULL; } -static int has_bytes_sent(char *dst) -{ - return _ss_search(ADDR_1, dst, "sport", "bytes_sent:"); -} - static void send_data_and_verify(char *sched, bool addr1, bool addr2) { struct timespec start, end; @@ -617,6 +776,10 @@ void test_mptcp(void) test_mptcpify(); if (test__start_subtest("subflow")) test_subflow(); + if (test__start_subtest("tcp_ktls")) + test_tcp_ktls(); + if (test__start_subtest("mptcp_ktls")) + test_mptcp_ktls(); if (test__start_subtest("default")) test_default(); if (test__start_subtest("first"))