From 0954a435daf83f2883174cd057c8a8e976b8cd8a Mon Sep 17 00:00:00 2001 From: Geliang Tang Date: Wed, 23 Oct 2024 10:01:36 +0800 Subject: [PATCH] fix Signed-off-by: Geliang Tang --- net/mptcp/bpf.c | 7 ++ net/mptcp/pm_netlink.c | 51 ++++---- tools/testing/selftests/bpf/progs/mptcp_bpf.h | 2 + .../selftests/bpf/progs/mptcp_bpf_bytes.c | 4 +- .../selftests/bpf/progs/mptcp_bpf_rr.c | 1 + .../bpf/progs/mptcp_bpf_userspace_pm.c | 115 ++++++++++++------ 6 files changed, 121 insertions(+), 59 deletions(-) diff --git a/net/mptcp/bpf.c b/net/mptcp/bpf.c index d2c2155f8ed7a..8259ccf2e4392 100644 --- a/net/mptcp/bpf.c +++ b/net/mptcp/bpf.c @@ -15,6 +15,7 @@ #include #include #include "protocol.h" +#include "mib.h" #ifdef CONFIG_BPF_JIT static struct bpf_struct_ops bpf_mptcp_pm_ops, @@ -686,6 +687,11 @@ __bpf_kfunc static void bpf_ipv6_addr_set_v4mapped(const __be32 addr, #endif } +__bpf_kfunc static void mptcp_inc_stats_rmsubflow(struct sock *sk) +{ + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); +} + __bpf_kfunc static bool bpf_mptcp_subflow_queues_empty(struct sock *sk) { return tcp_rtx_queue_empty(sk); @@ -737,6 +743,7 @@ BTF_ID_FLAGS(func, mptcp_pm_remove_addr_entry, KF_SLEEPABLE) BTF_ID_FLAGS(func, __mptcp_subflow_connect, KF_SLEEPABLE) BTF_ID_FLAGS(func, mptcp_subflow_shutdown, KF_SLEEPABLE) BTF_ID_FLAGS(func, mptcp_close_ssk, KF_SLEEPABLE) +BTF_ID_FLAGS(func, mptcp_inc_stats_rmsubflow) BTF_ID_FLAGS(func, mptcp_pm_nl_mp_prio_send_ack, KF_SLEEPABLE) BTF_ID_FLAGS(func, mptcp_subflow_active) BTF_ID_FLAGS(func, mptcp_set_timeout) diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 852706e43f131..e4c43a005ddf7 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -508,7 +508,20 @@ __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id) } static struct mptcp_pm_addr_entry * -__lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info) +__lookup_addr_by_id_rcu(struct pm_nl_pernet *pernet, unsigned int id) +{ + struct mptcp_pm_addr_entry *entry; + + list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { + if (entry->addr.id == id) + return entry; + } + return NULL; +} + +static struct mptcp_pm_addr_entry * +__lookup_addr_rcu(struct pm_nl_pernet *pernet, + const struct mptcp_addr_info *info) { struct mptcp_pm_addr_entry *entry; @@ -544,7 +557,7 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) mptcp_local_address((struct sock_common *)msk->first, &mpc_addr); rcu_read_lock(); - entry = __lookup_addr(pernet, &mpc_addr); + entry = __lookup_addr_rcu(pernet, &mpc_addr); if (entry) { __clear_bit(entry->addr.id, msk->pm.id_avail_bitmap.map); msk->mpc_endpoint_id = entry->addr.id; @@ -1127,14 +1140,13 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_pm_addr_entry { struct mptcp_pm_addr_entry *entry; struct pm_nl_pernet *pernet; - int ret = -1; + int ret; pernet = pm_nl_get_pernet_from_msk(msk); rcu_read_lock(); - entry = __lookup_addr(pernet, &local->addr); - if (entry) - ret = entry->addr.id; + entry = __lookup_addr_rcu(pernet, &local->addr); + ret = entry ? entry->addr.id : -1; rcu_read_unlock(); if (ret >= 0) return ret; @@ -1156,12 +1168,11 @@ u8 mptcp_pm_nl_get_flags(struct mptcp_sock *msk, struct mptcp_addr_info *skc) { struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); struct mptcp_pm_addr_entry *entry; - u8 flags = 0; + u8 flags; rcu_read_lock(); - entry = __lookup_addr(pernet, skc); - if (entry) - flags = entry->flags; + entry = __lookup_addr_rcu(pernet, skc); + flags = entry ? entry->flags : 0; rcu_read_unlock(); return flags; @@ -1779,13 +1790,13 @@ static int mptcp_pm_nl_get_addr(u8 id, struct mptcp_pm_addr_entry *addr, pernet = pm_nl_get_pernet(net); - spin_lock_bh(&pernet->lock); - entry = __lookup_addr_by_id(pernet, id); + rcu_read_lock(); + entry = __lookup_addr_by_id_rcu(pernet, id); if (entry) { *addr = *entry; ret = 0; } - spin_unlock_bh(&pernet->lock); + rcu_read_unlock(); return ret; } @@ -1849,9 +1860,7 @@ static int mptcp_pm_nl_dump_addr(struct mptcp_id_bitmap *bitmap, pernet = pm_nl_get_pernet(net); - spin_lock_bh(&pernet->lock); bitmap_copy(bitmap->map, pernet->id_bitmap.map, MPTCP_PM_MAX_ADDR_ID + 1); - spin_unlock_bh(&pernet->lock); return 0; } @@ -2047,17 +2056,17 @@ static int mptcp_pm_nl_set_flags(struct mptcp_pm_addr_entry *loc, if (loc->flags & MPTCP_PM_ADDR_FLAG_BACKUP) bkup = 1; - spin_lock_bh(&pernet->lock); - entry = lookup_by_id ? __lookup_addr_by_id(pernet, loc->addr.id) : - __lookup_addr(pernet, &loc->addr); + rcu_read_lock(); + entry = lookup_by_id ? __lookup_addr_by_id_rcu(pernet, loc->addr.id) : + __lookup_addr_rcu(pernet, &loc->addr); if (!entry) { - spin_unlock_bh(&pernet->lock); + rcu_read_unlock(); GENL_SET_ERR_MSG(info, "address not found"); return -EINVAL; } if ((loc->flags & MPTCP_PM_ADDR_FLAG_FULLMESH) && (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { - spin_unlock_bh(&pernet->lock); + rcu_read_unlock(); GENL_SET_ERR_MSG(info, "invalid addr flags"); return -EINVAL; } @@ -2065,7 +2074,7 @@ static int mptcp_pm_nl_set_flags(struct mptcp_pm_addr_entry *loc, changed = (loc->flags ^ entry->flags) & mask; entry->flags = (entry->flags & ~mask) | (loc->flags & mask); *loc = *entry; - spin_unlock_bh(&pernet->lock); + rcu_read_unlock(); mptcp_nl_set_flags(net, &loc->addr, bkup, changed); return 0; diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf.h b/tools/testing/selftests/bpf/progs/mptcp_bpf.h index a3c6294022815..be4ce246b89b8 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf.h +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf.h @@ -23,6 +23,7 @@ extern bool CONFIG_MPTCP_IPV6 __kconfig __weak; #define RCV_SHUTDOWN 1 #define SEND_SHUTDOWN 2 +#define ESRCH 3 /* No such process */ #define ENOMEM 12 /* Out of Memory */ #define EINVAL 22 /* Invalid argument */ @@ -132,6 +133,7 @@ extern void bpf_ipv6_addr_set_v4mapped(const __be32 addr, extern void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how) __ksym; extern void mptcp_close_ssk(struct sock *sk, struct sock *ssk, struct mptcp_subflow_context *subflow) __ksym; +extern void mptcp_inc_stats_rmsubflow(struct sock *sk) __ksym; extern int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, struct mptcp_addr_info *addr, diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_bytes.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_bytes.c index 28b4339331e03..7708e64ad68ef 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_bytes.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_bytes.c @@ -30,10 +30,10 @@ int BPF_PROG(trace_mptcp_sched_get_send, struct mptcp_sock *msk) tp = bpf_core_cast(ssk, struct tcp_sock); if (subflow->subflow_id == 1) { - bpf_printk("bytes 1: sent %lu received %lu subflows %u", tp->bytes_sent, tp->bytes_received, msk->pm.subflows); + //bpf_printk("bytes 1: sent %lu received %lu subflows %u", tp->bytes_sent, tp->bytes_received, msk->pm.subflows); bytes_sent_1 += tp->bytes_sent; } else if (subflow->subflow_id == 2) { - bpf_printk("bytes 2: sent %lu received %lu subflows %u", tp->bytes_sent, tp->bytes_received, msk->pm.subflows); + //bpf_printk("bytes 2: sent %lu received %lu subflows %u", tp->bytes_sent, tp->bytes_received, msk->pm.subflows); bytes_sent_2 += tp->bytes_sent; } } diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_rr.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_rr.c index c901ed045fdc6..7d4d29a50aa00 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_rr.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_rr.c @@ -60,6 +60,7 @@ int BPF_PROG(bpf_rr_get_subflow, struct mptcp_sock *msk, out: next = bpf_core_cast(next, struct mptcp_subflow_context); mptcp_subflow_set_scheduled(next, true); + //bpf_printk("rr subflow=%u/%u", next->subflow_id, msk->pm.subflows + 1); ptr->last_snd = mptcp_subflow_tcp_sock(next); return 0; } diff --git a/tools/testing/selftests/bpf/progs/mptcp_bpf_userspace_pm.c b/tools/testing/selftests/bpf/progs/mptcp_bpf_userspace_pm.c index 2c93610540730..1e2b0c8260dab 100644 --- a/tools/testing/selftests/bpf/progs/mptcp_bpf_userspace_pm.c +++ b/tools/testing/selftests/bpf/progs/mptcp_bpf_userspace_pm.c @@ -75,24 +75,32 @@ SEC("struct_ops") int BPF_PROG(mptcp_pm_address_announce, struct mptcp_sock *msk, struct mptcp_pm_addr_entry *local) { - int err; + int err = -EINVAL; - if (local->addr.id == 0 || !(local->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) - return -EINVAL; + if (local->addr.id == 0 || !(local->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { + bpf_printk("invalid addr id or flags"); + goto announce_err; + } err = mptcp_userspace_pm_append_new_local_addr(msk, local, false); - if (err < 0) - return err; + if (err < 0) { + bpf_printk("did not match address and id"); + goto announce_err; + } bpf_spin_lock_bh(&msk->pm.lock); + if (mptcp_pm_alloc_anno_list(msk, &local->addr)) { msk->pm.add_addr_signaled++; mptcp_pm_announce_addr(msk, &local->addr, false); mptcp_pm_nl_addr_send_ack(msk); } + bpf_spin_unlock_bh(&msk->pm.lock); - return 0; + err = 0; +announce_err: + return err; } static int mptcp_pm_remove_id_zero_address(struct mptcp_sock *msk) @@ -100,6 +108,7 @@ static int mptcp_pm_remove_id_zero_address(struct mptcp_sock *msk) struct mptcp_rm_list list = { .nr = 0 }; struct mptcp_subflow_context *subflow; bool has_id_0 = false; + int err = -EINVAL; mptcp_for_each_subflow(msk, subflow) { subflow = bpf_core_cast(subflow, struct mptcp_subflow_context); @@ -108,8 +117,10 @@ static int mptcp_pm_remove_id_zero_address(struct mptcp_sock *msk) break; } } - if (!has_id_0) - return -EINVAL; + if (!has_id_0) { + bpf_printk("address with id 0 not found"); + goto remove_err; + } list.ids[list.nr++] = 0; @@ -117,7 +128,10 @@ static int mptcp_pm_remove_id_zero_address(struct mptcp_sock *msk) mptcp_pm_remove_addr(msk, &list); bpf_spin_unlock_bh(&msk->pm.lock); - return 0; + err = 0; + +remove_err: + return err; } static struct mptcp_pm_addr_entry * @@ -137,15 +151,20 @@ int BPF_PROG(mptcp_pm_address_remove, struct mptcp_sock *msk, u8 id) { struct sock *sk = (struct sock *)msk; struct mptcp_pm_addr_entry *entry; + int err = -EINVAL; - if (id == 0) - return mptcp_pm_remove_id_zero_address(msk); + if (id == 0) { + err = mptcp_pm_remove_id_zero_address(msk); + goto out; + } bpf_spin_lock_bh(&msk->pm.lock); entry = mptcp_userspace_pm_lookup_addr_by_id(msk, id); bpf_spin_unlock_bh(&msk->pm.lock); - if (!entry) - return -EINVAL; + if (!entry) { + bpf_printk("address with specified id not found"); + goto out; + } mptcp_pm_remove_addr_entry(msk, entry); @@ -154,7 +173,9 @@ int BPF_PROG(mptcp_pm_address_remove, struct mptcp_sock *msk, u8 id) bpf_pm_free_entry(sk, entry); bpf_spin_unlock_bh(&msk->pm.lock); - return 0; + err = 0; +out: + return err; } static struct mptcp_pm_addr_entry * @@ -192,16 +213,22 @@ int BPF_PROG(mptcp_pm_subflow_create, struct mptcp_sock *msk, struct sock *sk = (struct sock *)msk; int err = -EINVAL; - if (local->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) - return err; + if (local->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { + bpf_printk("invalid addr flags"); + goto create_err; + } local->flags |= MPTCP_PM_ADDR_FLAG_SUBFLOW; - if (!bpf_mptcp_pm_addr_families_match(sk, &local->addr, remote)) - return err; + if (!bpf_mptcp_pm_addr_families_match(sk, &local->addr, remote)) { + bpf_printk("families mismatch"); + goto create_err; + } err = mptcp_userspace_pm_append_new_local_addr(msk, local, false); - if (err < 0) - return err; + if (err < 0) { + bpf_printk("did not match address and id"); + goto create_err; + } err = __mptcp_subflow_connect(sk, local, remote); bpf_spin_lock_bh(&msk->pm.lock); @@ -213,6 +240,7 @@ int BPF_PROG(mptcp_pm_subflow_create, struct mptcp_sock *msk, bpf_printk("mptcp_pm_subflow_create done"); +create_err: return err; } @@ -267,8 +295,8 @@ int BPF_PROG(mptcp_pm_subflow_destroy, struct mptcp_sock *msk, struct mptcp_pm_addr_entry *local, struct mptcp_addr_info *remote) { struct sock *sk = (struct sock *)msk; - int err = -EINVAL; struct sock *ssk; + int err = 0; if (local->addr.family == AF_INET && bpf_ipv6_addr_v4mapped(remote)) { bpf_ipv6_addr_set_v4mapped(local->addr.addr.s_addr, remote); @@ -279,25 +307,34 @@ int BPF_PROG(mptcp_pm_subflow_destroy, struct mptcp_sock *msk, remote->family = AF_INET6; } - if (local->addr.family != remote->family) - return err; + if (local->addr.family != remote->family) { + bpf_printk("address families do not match"); + err = -EINVAL; + goto destroy_err; + } - if (!local->addr.port || !remote->port) - return err; + if (!local->addr.port || !remote->port) { + bpf_printk("missing local or remote port"); + err = -EINVAL; + goto destroy_err; + } ssk = mptcp_pm_find_ssk(msk, &local->addr, remote); - if (ssk) { - struct mptcp_subflow_context *subflow = bpf_mptcp_subflow_ctx(ssk); - - bpf_spin_lock_bh(&msk->pm.lock); - err = mptcp_userspace_pm_delete_local_addr(msk, local); - bpf_spin_unlock_bh(&msk->pm.lock); - mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN); - mptcp_close_ssk(sk, ssk, subflow); + if (!ssk) { + err = -ESRCH; + goto destroy_err; } + bpf_spin_lock_bh(&msk->pm.lock); + err = mptcp_userspace_pm_delete_local_addr(msk, local); + bpf_spin_unlock_bh(&msk->pm.lock); + mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN); + mptcp_close_ssk(sk, ssk, bpf_mptcp_subflow_ctx(ssk)); + mptcp_inc_stats_rmsubflow(sk); + bpf_printk("mptcp_pm_subflow_destroy done"); +destroy_err: return err; } @@ -382,13 +419,16 @@ int BPF_PROG(mptcp_pm_set_flags, struct mptcp_sock *msk, struct mptcp_pm_addr_entry *local, struct mptcp_addr_info *remote) { struct mptcp_pm_addr_entry *entry; + int ret = -EINVAL; u8 bkup = 0; bpf_printk("mptcp_pm_set_flags"); if (local->addr.family == AF_UNSPEC || - remote->family == AF_UNSPEC) - return -EINVAL; + remote->family == AF_UNSPEC) { + bpf_printk("invalid address families"); + goto set_flags_err; + } if (local->flags & MPTCP_PM_ADDR_FLAG_BACKUP) bkup = 1; @@ -403,7 +443,10 @@ int BPF_PROG(mptcp_pm_set_flags, struct mptcp_sock *msk, } bpf_spin_unlock_bh(&msk->pm.lock); - return mptcp_pm_nl_mp_prio_send_ack(msk, &local->addr, remote, bkup); + ret = mptcp_pm_nl_mp_prio_send_ack(msk, &local->addr, remote, bkup); + +set_flags_err: + return ret; } SEC(".struct_ops.link")