diff --git a/net/mptcp/bpf.c b/net/mptcp/bpf.c index cfeb8537f423..8950d1f0626e 100644 --- a/net/mptcp/bpf.c +++ b/net/mptcp/bpf.c @@ -486,18 +486,27 @@ __bpf_kfunc_start_defs(); __bpf_kfunc static struct mptcp_sock *bpf_mptcp_sk(struct sock *sk) { + if (!sk || sk->sk_protocol != IPPROTO_MPTCP) + return NULL; + return mptcp_sk(sk); } __bpf_kfunc static struct mptcp_subflow_context * bpf_mptcp_subflow_ctx(const struct sock *sk) { + if (!sk) + return NULL; + return mptcp_subflow_ctx(sk); } __bpf_kfunc static struct sock * bpf_mptcp_subflow_tcp_sock(const struct mptcp_subflow_context *subflow) { + if (!subflow) + return NULL; + return mptcp_subflow_tcp_sock(subflow); } @@ -511,7 +520,8 @@ bpf_iter_mptcp_subflow_new(struct bpf_iter_mptcp_subflow *it, if (!msk) return -EINVAL; - msk_owned_by_me(msk); + if (!lockdep_sock_is_held((const struct sock *)msk)) + return -EINVAL; kit->pos = &msk->conn_list; return 0; @@ -544,7 +554,8 @@ bpf_iter_mptcp_userspace_pm_addr_new(struct bpf_iter_mptcp_userspace_pm_addr *it if (!msk) return -EINVAL; - lockdep_assert_held(&msk->pm.lock); + if (!lockdep_is_held(&msk->pm.lock)) + return -EINVAL; kit->pos = &msk->pm.userspace_pm_local_addr_list; return 0; @@ -692,10 +703,10 @@ __bpf_kfunc static bool bpf_mptcp_subflow_queues_empty(struct sock *sk) __bpf_kfunc_end_defs(); -BTF_KFUNCS_START(bpf_mptcp_common_kfunc_ids) -BTF_ID_FLAGS(func, bpf_mptcp_sk) -BTF_ID_FLAGS(func, bpf_mptcp_subflow_ctx) -BTF_ID_FLAGS(func, bpf_mptcp_subflow_tcp_sock) +BTF_KFUNCS_START(bpf_mptcp_kfunc_ids) +BTF_ID_FLAGS(func, bpf_mptcp_sk, KF_TRUSTED_ARGS | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_mptcp_subflow_ctx, KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_mptcp_subflow_tcp_sock, KF_RET_NULL) BTF_ID_FLAGS(func, bpf_iter_mptcp_subflow_new, KF_ITER_NEW | KF_TRUSTED_ARGS) BTF_ID_FLAGS(func, bpf_iter_mptcp_subflow_next, KF_ITER_NEXT | KF_RET_NULL) BTF_ID_FLAGS(func, bpf_iter_mptcp_subflow_destroy, KF_ITER_DESTROY) @@ -707,14 +718,6 @@ BTF_ID_FLAGS(func, bpf_mptcp_sock_release, KF_RELEASE) BTF_ID_FLAGS(func, bpf_spin_lock_bh) BTF_ID_FLAGS(func, bpf_spin_unlock_bh) BTF_ID_FLAGS(func, bpf_ipv6_addr_v4mapped) -BTF_KFUNCS_END(bpf_mptcp_common_kfunc_ids) - -static const struct btf_kfunc_id_set bpf_mptcp_common_kfunc_set = { - .owner = THIS_MODULE, - .set = &bpf_mptcp_common_kfunc_ids, -}; - -BTF_KFUNCS_START(bpf_mptcp_struct_ops_kfunc_ids) BTF_ID_FLAGS(func, bpf_ipv6_addr_set_v4mapped) BTF_ID_FLAGS(func, bpf_list_add_tail_rcu) BTF_ID_FLAGS(func, bpf_list_del_rcu) @@ -744,11 +747,11 @@ BTF_ID_FLAGS(func, mptcp_wnd_end) BTF_ID_FLAGS(func, tcp_stream_memory_free) BTF_ID_FLAGS(func, bpf_mptcp_subflow_queues_empty) BTF_ID_FLAGS(func, mptcp_pm_subflow_chk_stale) -BTF_KFUNCS_END(bpf_mptcp_struct_ops_kfunc_ids) +BTF_KFUNCS_END(bpf_mptcp_kfunc_ids) -static const struct btf_kfunc_id_set bpf_mptcp_struct_ops_kfunc_set = { +static const struct btf_kfunc_id_set bpf_mptcp_kfunc_set = { .owner = THIS_MODULE, - .set = &bpf_mptcp_struct_ops_kfunc_ids, + .set = &bpf_mptcp_kfunc_ids, }; static int __init bpf_mptcp_kfunc_init(void) @@ -756,10 +759,10 @@ static int __init bpf_mptcp_kfunc_init(void) int ret; ret = register_btf_fmodret_id_set(&bpf_mptcp_fmodret_set); - ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_UNSPEC, - &bpf_mptcp_common_kfunc_set); + ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_CGROUP_SOCKOPT, + &bpf_mptcp_kfunc_set); ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, - &bpf_mptcp_struct_ops_kfunc_set); + &bpf_mptcp_kfunc_set); #ifdef CONFIG_BPF_JIT ret = ret ?: register_bpf_struct_ops(&bpf_mptcp_pm_ops, mptcp_pm_ops); ret = ret ?: register_bpf_struct_ops(&bpf_mptcp_sched_ops, mptcp_sched_ops); diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_first.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_first.c index 5d0f89c636f0..7b762984f078 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_first.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_first.c @@ -20,7 +20,13 @@ SEC("struct_ops") int BPF_PROG(bpf_first_get_subflow, struct mptcp_sock *msk, struct mptcp_sched_data *data) { - mptcp_subflow_set_scheduled(bpf_mptcp_subflow_ctx(msk->first), true); + struct mptcp_subflow_context *subflow; + + subflow = bpf_mptcp_subflow_ctx(msk->first); + if (!subflow) + return -1; + + mptcp_subflow_set_scheduled(subflow, true); return 0; } diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_iters.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_iters.c index 48511faf7a2d..6c8ac82c3951 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_iters.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_iters.c @@ -26,7 +26,7 @@ int iters_subflow(struct bpf_sockopt *ctx) return 1; msk = bpf_mptcp_sk((struct sock *)sk); - if (msk->pm.server_side || !msk->pm.subflows) + if (!msk || msk->pm.server_side || !msk->pm.subflows) return 1; msk = bpf_mptcp_sock_acquire(msk); @@ -53,7 +53,7 @@ int iters_subflow(struct bpf_sockopt *ctx) /* only to check the following kfunc works */ subflow = bpf_mptcp_subflow_ctx(ssk); - if (subflow->token != msk->token) + if (!subflow || subflow->token != msk->token) goto out; ids = local_ids; @@ -71,12 +71,11 @@ int iters_address(struct bpf_sockopt *ctx) struct mptcp_sock *msk; int local_ids = 0; - if (!sk || sk->protocol != IPPROTO_MPTCP || - ctx->level != SOL_TCP || ctx->optname != TCP_IS_MPTCP) + if (ctx->level != SOL_TCP || ctx->optname != TCP_IS_MPTCP) return 1; msk = bpf_mptcp_sk((struct sock *)sk); - if (msk->pm.server_side) + if (!msk || msk->pm.server_side) return 1; msk = bpf_mptcp_sock_acquire(msk); diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_rr.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_rr.c index 62dd15223847..a3ebee31fb17 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_rr.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_rr.c @@ -43,6 +43,8 @@ int BPF_PROG(bpf_rr_get_subflow, struct mptcp_sock *msk, return -1; next = bpf_mptcp_subflow_ctx(msk->first); + if (!next) + return -1; if (!ptr->last_snd) goto out; diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_sockopt.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_sockopt.c index a705be8d65a1..e4dd0c7a3908 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_sockopt.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_sockopt.c @@ -20,7 +20,10 @@ static int mptcp_setsockopt_mark(struct bpf_sock *sk, struct bpf_sockopt *ctx) mark = *optval; - msk = bpf_mptcp_sock_acquire(bpf_mptcp_sk((struct sock *)sk)); + msk = bpf_mptcp_sk((struct sock *)sk); + if (!msk) + return 1; + msk = bpf_mptcp_sock_acquire(msk); if (!msk) return 1; @@ -54,7 +57,10 @@ static int mptcp_setsockopt_cc(struct bpf_sock *sk, struct bpf_sockopt *ctx) __builtin_memcpy(cc, optval, TCP_CA_NAME_MAX); - msk = bpf_mptcp_sock_acquire(bpf_mptcp_sk((struct sock *)sk)); + msk = bpf_mptcp_sk((struct sock *)sk); + if (!msk) + return 1; + msk = bpf_mptcp_sock_acquire(msk); if (!msk) return 1; @@ -96,7 +102,10 @@ static int mptcp_getsockopt_mark(struct bpf_sock *sk, struct bpf_sockopt *ctx) struct mptcp_sock *msk; int i = 0; - msk = bpf_mptcp_sock_acquire(bpf_mptcp_sk((struct sock *)sk)); + msk = bpf_mptcp_sk((struct sock *)sk); + if (!msk) + return 1; + msk = bpf_mptcp_sock_acquire(msk); if (!msk) return 1; @@ -121,7 +130,10 @@ static int mptcp_getsockopt_cc(struct bpf_sock *sk, struct bpf_sockopt *ctx) struct mptcp_sock *msk; int i = 0; - msk = bpf_mptcp_sock_acquire(bpf_mptcp_sk((struct sock *)sk)); + msk = bpf_mptcp_sk((struct sock *)sk); + if (!msk) + return 1; + msk = bpf_mptcp_sock_acquire(msk); if (!msk) return 1; diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_stale.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_stale.c index f6831ed2ab33..945ea267cdbb 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_stale.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_stale.c @@ -111,7 +111,10 @@ int BPF_PROG(bpf_stale_get_subflow, struct mptcp_sock *msk, int i; if (!msk->pm.subflows) { - mptcp_subflow_set_scheduled(bpf_mptcp_subflow_ctx(msk->first), true); + subflow = bpf_mptcp_subflow_ctx(msk->first); + if (!subflow) + return -1; + mptcp_subflow_set_scheduled(subflow, true); return 0; } diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_userspace_pm.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_userspace_pm.c index ade56408a958..8880ee085a59 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_userspace_pm.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_userspace_pm.c @@ -259,6 +259,8 @@ static struct sock *mptcp_pm_find_ssk(struct mptcp_sock *msk, struct sock *ssk; ssk = bpf_mptcp_subflow_tcp_sock(subflow); + if (!ssk) + continue; if (local->family != ssk->sk_family) continue; @@ -295,6 +297,7 @@ SEC("struct_ops") int BPF_PROG(mptcp_pm_subflow_destroy, struct mptcp_sock *msk, struct mptcp_pm_addr_entry *local, struct mptcp_addr_info *remote) { + struct mptcp_subflow_context *subflow; struct sock *sk = (struct sock *)msk; struct sock *ssk; int err = 0; @@ -330,7 +333,9 @@ int BPF_PROG(mptcp_pm_subflow_destroy, struct mptcp_sock *msk, err = mptcp_userspace_pm_delete_local_addr(msk, local); bpf_spin_unlock_bh(&msk->pm.lock); mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN); - mptcp_close_ssk(sk, ssk, bpf_mptcp_subflow_ctx(ssk)); + subflow = bpf_mptcp_subflow_ctx(ssk); + if (subflow) + mptcp_close_ssk(sk, ssk, subflow); BPF_MPTCP_INC_STATS(sk, MPTCP_MIB_RMSUBFLOW); bpf_printk("mptcp_pm_subflow_destroy done");