Skip to content

Commit

Permalink
mptcp: tls support
Browse files Browse the repository at this point in the history
ulp array
mptcp tls fix

Signed-off-by: Geliang Tang <[email protected]>
  • Loading branch information
geliangtang committed Jun 4, 2024
1 parent 3fa2d03 commit aac1055
Show file tree
Hide file tree
Showing 23 changed files with 300 additions and 53 deletions.
2 changes: 1 addition & 1 deletion include/net/espintcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ static inline struct espintcp_ctx *espintcp_getctx(const struct sock *sk)
const struct inet_connection_sock *icsk = inet_csk(sk);

/* RCU is only needed for diag */
return (__force void *)icsk->icsk_ulp_data;
return (__force void *)icsk->icsk_ulp_data[ULP_INDEX_DEFAULT];
}
#endif
13 changes: 10 additions & 3 deletions include/net/inet_connection_sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ struct inet_bind_bucket;
struct inet_bind2_bucket;
struct tcp_congestion_ops;

enum ulp_index {
ULP_INDEX_DEFAULT,
ULP_INDEX_MPTCP,
ULP_INDEX_MAX,
};

/*
* Pointers to address related TCP functions
* (i.e. things that depend on the address family)
Expand Down Expand Up @@ -94,8 +100,8 @@ struct inet_connection_sock {
__u32 icsk_pmtu_cookie;
const struct tcp_congestion_ops *icsk_ca_ops;
const struct inet_connection_sock_af_ops *icsk_af_ops;
const struct tcp_ulp_ops *icsk_ulp_ops;
void __rcu *icsk_ulp_data;
const struct tcp_ulp_ops *icsk_ulp_ops[ULP_INDEX_MAX];
void __rcu *icsk_ulp_data[ULP_INDEX_MAX];
void (*icsk_clean_acked)(struct sock *sk, u32 acked_seq);
unsigned int (*icsk_sync_mss)(struct sock *sk, u32 pmtu);
__u8 icsk_ca_state:5,
Expand Down Expand Up @@ -352,7 +358,8 @@ static inline void inet_csk_inc_pingpong_cnt(struct sock *sk)

static inline bool inet_csk_has_ulp(const struct sock *sk)
{
return inet_test_bit(IS_ICSK, sk) && !!inet_csk(sk)->icsk_ulp_ops;
return inet_test_bit(IS_ICSK, sk) && (!!inet_csk(sk)->icsk_ulp_ops[ULP_INDEX_DEFAULT] ||
!!inet_csk(sk)->icsk_ulp_ops[ULP_INDEX_MPTCP]);
}

static inline void inet_init_csk_locks(struct sock *sk)
Expand Down
1 change: 1 addition & 0 deletions include/net/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2543,6 +2543,7 @@ enum hrtimer_restart tcp_pace_kick(struct hrtimer *timer);

struct tcp_ulp_ops {
struct list_head list;
int id;

/* initialize ulp */
int (*init)(struct sock *sk);
Expand Down
2 changes: 1 addition & 1 deletion include/net/tls.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk)
/* Use RCU on icsk_ulp_data only for sock diag code,
* TLS data path doesn't need rcu_dereference().
*/
return (__force void *)icsk->icsk_ulp_data;
return (__force void *)icsk->icsk_ulp_data[ULP_INDEX_DEFAULT];
}

static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
Expand Down
20 changes: 14 additions & 6 deletions net/ipv4/inet_connection_sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -1148,11 +1148,14 @@ static void inet_clone_ulp(const struct request_sock *req, struct sock *newsk,
const gfp_t priority)
{
struct inet_connection_sock *icsk = inet_csk(newsk);
const struct tcp_ulp_ops *ulp_ops;
int i;

if (!icsk->icsk_ulp_ops)
return;

icsk->icsk_ulp_ops->clone(req, newsk, priority);
for (i = 0; i < ULP_INDEX_MAX; i++) {
ulp_ops = icsk->icsk_ulp_ops[i];
if (ulp_ops && ulp_ops->clone)
ulp_ops->clone(req, newsk, priority);
}
}

/**
Expand Down Expand Up @@ -1251,9 +1254,14 @@ EXPORT_SYMBOL(inet_csk_prepare_forced_close);
static int inet_ulp_can_listen(const struct sock *sk)
{
const struct inet_connection_sock *icsk = inet_csk(sk);
const struct tcp_ulp_ops *ulp_ops;
int i;

if (icsk->icsk_ulp_ops && !icsk->icsk_ulp_ops->clone)
return -EINVAL;
for (i = 0; i < ULP_INDEX_MAX; i++) {
ulp_ops = icsk->icsk_ulp_ops[i];
if (ulp_ops && !ulp_ops->clone)
return -EINVAL;
}

return 0;
}
Expand Down
4 changes: 2 additions & 2 deletions net/ipv4/tcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -4133,15 +4133,15 @@ int do_tcp_getsockopt(struct sock *sk, int level,
if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT;
len = min_t(unsigned int, len, TCP_ULP_NAME_MAX);
if (!icsk->icsk_ulp_ops) {
if (!icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT]) {
len = 0;
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
return 0;
}
if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT;
if (copy_to_sockptr(optval, icsk->icsk_ulp_ops->name, len))
if (copy_to_sockptr(optval, icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT]->name, len))
return -EFAULT;
return 0;

Expand Down
4 changes: 2 additions & 2 deletions net/ipv4/tcp_diag.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ static int tcp_diag_get_aux(struct sock *sk, bool net_admin,
if (net_admin) {
const struct tcp_ulp_ops *ulp_ops;

ulp_ops = icsk->icsk_ulp_ops;
ulp_ops = icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT];
if (ulp_ops)
err = tcp_diag_put_ulp(skb, sk, ulp_ops);
if (err)
Expand Down Expand Up @@ -167,7 +167,7 @@ static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin)
if (net_admin && sk_fullsock(sk)) {
const struct tcp_ulp_ops *ulp_ops;

ulp_ops = icsk->icsk_ulp_ops;
ulp_ops = icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT];
if (ulp_ops) {
size += nla_total_size(0) +
nla_total_size(TCP_ULP_NAME_MAX);
Expand Down
33 changes: 21 additions & 12 deletions net/ipv4/tcp_ulp.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,36 @@ void tcp_update_ulp(struct sock *sk, struct proto *proto,
void (*write_space)(struct sock *sk))
{
struct inet_connection_sock *icsk = inet_csk(sk);
const struct tcp_ulp_ops *ulp_ops;
int i;

if (icsk->icsk_ulp_ops->update)
icsk->icsk_ulp_ops->update(sk, proto, write_space);
for (i = 0; i < ULP_INDEX_MAX; i++) {
ulp_ops = icsk->icsk_ulp_ops[i];
if (ulp_ops->update)
ulp_ops->update(sk, proto, write_space);
}
}

void tcp_cleanup_ulp(struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
const struct tcp_ulp_ops *ulp_ops;
int i;

/* No sock_owned_by_me() check here as at the time the
* stack calls this function, the socket is dead and
* about to be destroyed.
*/
if (!icsk->icsk_ulp_ops)
return;

if (icsk->icsk_ulp_ops->release)
icsk->icsk_ulp_ops->release(sk);
module_put(icsk->icsk_ulp_ops->owner);

icsk->icsk_ulp_ops = NULL;
for (i = 0; i < ULP_INDEX_MAX; i++) {
ulp_ops = icsk->icsk_ulp_ops[i];
if (ulp_ops) {
//pr_info("%s ulp_ops->name=%s\n", __func__, ulp_ops->name);
if (ulp_ops->release)
ulp_ops->release(sk);
module_put(ulp_ops->owner);
ulp_ops = NULL;
}
}
}

static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops)
Expand All @@ -133,7 +142,7 @@ static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops)
int err;

err = -EEXIST;
if (icsk->icsk_ulp_ops)
if (icsk->icsk_ulp_ops[ulp_ops->id])
goto out_err;

if (sk->sk_socket)
Expand All @@ -147,7 +156,7 @@ static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops)
if (err)
goto out_err;

icsk->icsk_ulp_ops = ulp_ops;
icsk->icsk_ulp_ops[ulp_ops->id] = ulp_ops;
return 0;
out_err:
module_put(ulp_ops->owner);
Expand Down
2 changes: 1 addition & 1 deletion net/mptcp/diag.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ static int subflow_get_info(struct sock *sk, struct sk_buff *skb)

slow = lock_sock_fast(sk);
rcu_read_lock();
sf = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
sf = rcu_dereference(inet_csk(sk)->icsk_ulp_data[ULP_INDEX_MPTCP]);
if (!sf) {
err = 0;
goto nla_failure;
Expand Down
2 changes: 1 addition & 1 deletion net/mptcp/mptcp_diag.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ static void mptcp_diag_dump_listeners(struct sk_buff *skb, struct netlink_callba
if (num < diag_ctx->l_num)
goto next_listen;

if (!ctx || strcmp(inet_csk(sk)->icsk_ulp_ops->name, "mptcp"))
if (!ctx || strcmp(inet_csk(sk)->icsk_ulp_ops[ULP_INDEX_MPTCP]->name, "mptcp"))
goto next_listen;

sk = ctx->conn;
Expand Down
2 changes: 1 addition & 1 deletion net/mptcp/protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -2462,7 +2462,7 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
* the ssk has been already destroyed, we just need to release the
* reference owned by msk;
*/
if (!inet_csk(ssk)->icsk_ulp_ops) {
if (!inet_csk(ssk)->icsk_ulp_ops[ULP_INDEX_MPTCP]) {
WARN_ON_ONCE(!sock_flag(ssk, SOCK_DEAD));
kfree_rcu(subflow, rcu);
} else {
Expand Down
4 changes: 3 additions & 1 deletion net/mptcp/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <uapi/linux/mptcp.h>
#include <net/genetlink.h>
#include <net/rstreason.h>
#include <net/tls.h>

#define MPTCP_SUPPORTED_VERSION 1

Expand Down Expand Up @@ -499,6 +500,7 @@ DECLARE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions);
struct mptcp_subflow_context {
struct list_head node;/* conn_list of subflows */

struct tls_context tls;
struct_group(reset,

unsigned long avg_pacing_rate; /* protected by msk socket lock */
Expand Down Expand Up @@ -584,7 +586,7 @@ mptcp_subflow_ctx(const struct sock *sk)
const struct inet_connection_sock *icsk = inet_csk(sk);

/* Use RCU on icsk_ulp_data only for sock diag code */
return (__force struct mptcp_subflow_context *)icsk->icsk_ulp_data;
return (__force struct mptcp_subflow_context *)icsk->icsk_ulp_data[ULP_INDEX_MPTCP];
}

static inline struct sock *
Expand Down
53 changes: 52 additions & 1 deletion net/mptcp/sockopt.c
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ static bool mptcp_supported_sockopt(int level, int optname)
case TCP_FASTOPEN_CONNECT:
case TCP_FASTOPEN_KEY:
case TCP_FASTOPEN_NO_COOKIE:
case TCP_ULP:
return true;
}

Expand All @@ -576,9 +577,54 @@ static bool mptcp_supported_sockopt(int level, int optname)
* TCP_REPAIR_WINDOW are not supported, better avoid this mess
*/
}
if (level == SOL_TLS) {
switch (optname) {
case TLS_TX:
case TLS_RX:
return true;
}
}
return false;
}

static int mptcp_setsockopt_sol_tcp_ulp(struct mptcp_sock *msk, sockptr_t optval,
unsigned int optlen)
{
struct mptcp_subflow_context *subflow;
struct sock *sk = (struct sock *)msk;
char name[TCP_ULP_NAME_MAX];
int ret;

if (optlen < 1)
return -EINVAL;

ret = strncpy_from_sockptr(name, optval,
min_t(long, TCP_ULP_NAME_MAX - 1, optlen));
if (ret < 0)
return -EFAULT;
name[ret] = 0;

ret = 0;
lock_sock(sk);
sockopt_seq_inc(msk);
mptcp_for_each_subflow(msk, subflow) {
struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
int err;

lock_sock(ssk);
err = tcp_set_ulp(ssk, name);
if (err < 0 && ret == 0) {
pr_info("%s err=%d\n", __func__, err);
ret = err;
}
subflow->setsockopt_seq = msk->setsockopt_seq;
release_sock(ssk);
}

release_sock(sk);
return ret;
}

static int mptcp_setsockopt_sol_tcp_congestion(struct mptcp_sock *msk, sockptr_t optval,
unsigned int optlen)
{
Expand Down Expand Up @@ -806,7 +852,7 @@ static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname,

switch (optname) {
case TCP_ULP:
return -EOPNOTSUPP;
return mptcp_setsockopt_sol_tcp_ulp(msk, optval, optlen);
case TCP_CONGESTION:
return mptcp_setsockopt_sol_tcp_congestion(msk, optval, optlen);
case TCP_DEFER_ACCEPT:
Expand Down Expand Up @@ -867,6 +913,8 @@ static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname,
return ret;
}

int tls_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen);
int mptcp_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
Expand Down Expand Up @@ -902,6 +950,9 @@ int mptcp_setsockopt(struct sock *sk, int level, int optname,
if (level == SOL_TCP)
return mptcp_setsockopt_sol_tcp(msk, optname, optval, optlen);

if (level == SOL_TLS)
return tls_setsockopt(msk->first, level, optname, optval, optlen);

return -EOPNOTSUPP;
}

Expand Down
9 changes: 5 additions & 4 deletions net/mptcp/subflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,8 @@ static void subflow_ulp_fallback(struct sock *sk,
struct inet_connection_sock *icsk = inet_csk(sk);

mptcp_subflow_tcp_fallback(sk, old_ctx);
icsk->icsk_ulp_ops = NULL;
rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
icsk->icsk_ulp_ops[ULP_INDEX_MPTCP] = NULL;
rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_MPTCP], NULL);
tcp_sk(sk)->is_mptcp = 0;

mptcp_subflow_ops_undo_override(sk);
Expand All @@ -762,7 +762,7 @@ void mptcp_subflow_drop_ctx(struct sock *ssk)
return;

list_del(&mptcp_subflow_ctx(ssk)->node);
if (inet_csk(ssk)->icsk_ulp_ops) {
if (inet_csk(ssk)->icsk_ulp_ops[ULP_INDEX_MPTCP]) {
subflow_ulp_fallback(ssk, ctx);
if (ctx->conn)
sock_put(ctx->conn);
Expand Down Expand Up @@ -1755,7 +1755,7 @@ static struct mptcp_subflow_context *subflow_create_ctx(struct sock *sk,
if (!ctx)
return NULL;

rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_MPTCP], ctx);
INIT_LIST_HEAD(&ctx->node);
INIT_LIST_HEAD(&ctx->delegated_node);

Expand Down Expand Up @@ -2044,6 +2044,7 @@ static int tcp_abort_override(struct sock *ssk, int err)
}

static struct tcp_ulp_ops subflow_ulp_ops __read_mostly = {
.id = ULP_INDEX_MPTCP,
.name = "mptcp",
.owner = THIS_MODULE,
.init = subflow_ulp_init,
Expand Down
2 changes: 1 addition & 1 deletion net/mptcp/token_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ static void mptcp_token_test_msk_basic(struct kunit *test)
struct mptcp_sock *null_msk = NULL;
struct sock *sk;

rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
rcu_assign_pointer(icsk->icsk_ulp_data[ULP_INDEX_MPTCP], ctx);
ctx->conn = (struct sock *)msk;
sk = (struct sock *)msk;

Expand Down
2 changes: 1 addition & 1 deletion net/smc/af_smc.c
Original file line number Diff line number Diff line change
Expand Up @@ -3439,7 +3439,7 @@ static void smc_ulp_clone(const struct request_sock *req, struct sock *newsk,
struct inet_connection_sock *icsk = inet_csk(newsk);

/* don't inherit ulp ops to child when listen */
icsk->icsk_ulp_ops = NULL;
icsk->icsk_ulp_ops[ULP_INDEX_DEFAULT] = NULL;
}

static struct tcp_ulp_ops smc_ulp_ops __read_mostly = {
Expand Down
Loading

0 comments on commit aac1055

Please sign in to comment.