diff --git a/src/stirling/source_connectors/socket_tracer/bcc_bpf/socket_trace.c b/src/stirling/source_connectors/socket_tracer/bcc_bpf/socket_trace.c index f03c05479ac..68a23dda923 100644 --- a/src/stirling/source_connectors/socket_tracer/bcc_bpf/socket_trace.c +++ b/src/stirling/source_connectors/socket_tracer/bcc_bpf/socket_trace.c @@ -81,6 +81,11 @@ BPF_HASH(active_accept_args_map, uint64_t, struct accept_args_t); // Key is {tgid, pid}. BPF_HASH(active_connect_args_map, uint64_t, struct connect_args_t); +// Map from thread to its sock* struct. This facilitates capturing +// the local address of a tcp socket during connect() syscalls. +// Key is {tgid, pid}. +BPF_HASH(tcp_connect_args_map, uint64_t, struct sock*); + // Map from thread to its ongoing write() syscall's input argument. // Tracks write() call from entry -> exit. // Key is {tgid, pid}. @@ -345,19 +350,17 @@ static __inline void update_traffic_class(struct conn_info_t* conn_info, * Perf submit functions ***********************************************************/ -static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info, - const struct socket* socket) { - // Use BPF_PROBE_READ_KERNEL_VAR since BCC cannot insert them as expected. - struct sock* sk = NULL; - BPF_PROBE_READ_KERNEL_VAR(sk, &socket->sk); - - struct sock_common* sk_common = &sk->__sk_common; +static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info, const struct sock* sk) { + const struct sock_common* sk_common = &sk->__sk_common; uint16_t family = -1; uint16_t lport = -1; uint16_t rport = -1; BPF_PROBE_READ_KERNEL_VAR(family, &sk_common->skc_family); BPF_PROBE_READ_KERNEL_VAR(lport, &sk_common->skc_num); + // skc_num is stored in host byte order. The rest of our user space processing + // assumes network byte order so convert it here. + lport = htons(lport); BPF_PROBE_READ_KERNEL_VAR(rport, &sk_common->skc_dport); conn_info->laddr.sa.sa_family = family; @@ -377,12 +380,12 @@ static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info, } static __inline void submit_new_conn(struct pt_regs* ctx, uint32_t tgid, int32_t fd, - const struct sockaddr* addr, const struct socket* socket, + const struct sockaddr* addr, const struct sock* sock, enum endpoint_role_t role, enum source_function_t source_fn) { struct conn_info_t conn_info = {}; init_conn_info(tgid, fd, &conn_info); - if (socket != NULL) { - read_sockaddr_kernel(&conn_info, socket); + if (sock != NULL) { + read_sockaddr_kernel(&conn_info, sock); } else if (addr != NULL) { conn_info.raddr = *((union sockaddr_t*)addr); } @@ -585,6 +588,52 @@ int conn_cleanup_uprobe(struct pt_regs* ctx) { return 0; } +// These probes are used to capture the *sock struct during client side tracing +// of connect() syscalls. This is necessary to capture the socket's local address, +// which is not accessible via the connect() and later syscalls. +// +// This function requires that the function being probed receives a struct sock* as its +// first argument and that the active_connect_args_map is populated when this probe fires. +// This means the function being probed must be part of the connect() syscall path or similar +// syscall path. +// +// Using the struct sock* for capturing a socket's local address only works for TCP sockets. +// The equivalent UDP functions (udp_v4_connect, udp_v6_connect and upd_sendmsg) always receive a +// sock struct with a 0.0.0.0 or ::1 local address. This is deemed acceptable since our local +// address population for server side tracing relies on accept/accept4, which only applies for TCP. +// +// int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len); +// static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len); +// int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int probe_entry_populate_active_connect_sock(struct pt_regs* ctx) { + uint64_t id = bpf_get_current_pid_tgid(); + + const struct connect_args_t* connect_args = active_connect_args_map.lookup(&id); + if (connect_args == NULL) { + return 0; + } + struct sock* sk = (struct sock*)PT_REGS_PARM1(ctx); + tcp_connect_args_map.update(&id, &sk); + + return 0; +} + +int probe_ret_populate_active_connect_sock(struct pt_regs* ctx) { + uint64_t id = bpf_get_current_pid_tgid(); + + struct sock** sk = tcp_connect_args_map.lookup(&id); + if (sk == NULL) { + return 0; + } + struct connect_args_t* connect_args = active_connect_args_map.lookup(&id); + if (connect_args != NULL) { + connect_args->connect_sock = *sk; + } + + tcp_connect_args_map.delete(&id); + return 0; +} + /*********************************************************** * BPF syscall processing functions ***********************************************************/ @@ -629,7 +678,8 @@ static __inline void process_syscall_connect(struct pt_regs* ctx, uint64_t id, return; } - submit_new_conn(ctx, tgid, args->fd, args->addr, /*socket*/ NULL, kRoleClient, kSyscallConnect); + submit_new_conn(ctx, tgid, args->fd, args->addr, args->connect_sock, kRoleClient, + kSyscallConnect); } static __inline void process_syscall_accept(struct pt_regs* ctx, uint64_t id, @@ -645,8 +695,11 @@ static __inline void process_syscall_accept(struct pt_regs* ctx, uint64_t id, return; } - submit_new_conn(ctx, tgid, ret_fd, args->addr, args->sock_alloc_socket, kRoleServer, - kSyscallAccept); + const struct sock* sk = NULL; + if (args->sock_alloc_socket != NULL) { + BPF_PROBE_READ_KERNEL_VAR(sk, &args->sock_alloc_socket->sk); + } + submit_new_conn(ctx, tgid, ret_fd, args->addr, sk, kRoleServer, kSyscallAccept); } // TODO(oazizi): This is badly broken (but better than before). @@ -690,7 +743,7 @@ static __inline void process_implicit_conn(struct pt_regs* ctx, uint64_t id, return; } - submit_new_conn(ctx, tgid, args->fd, args->addr, /*socket*/ NULL, kRoleUnknown, source_fn); + submit_new_conn(ctx, tgid, args->fd, args->addr, args->connect_sock, kRoleUnknown, source_fn); } static __inline bool should_send_data(uint32_t tgid, uint64_t conn_disabled_tsid, diff --git a/src/stirling/source_connectors/socket_tracer/bcc_bpf_intf/socket_trace.h b/src/stirling/source_connectors/socket_tracer/bcc_bpf_intf/socket_trace.h index a66e30d7e62..c8339fd412d 100644 --- a/src/stirling/source_connectors/socket_tracer/bcc_bpf_intf/socket_trace.h +++ b/src/stirling/source_connectors/socket_tracer/bcc_bpf_intf/socket_trace.h @@ -263,6 +263,7 @@ struct socket_control_event_t { struct connect_args_t { const struct sockaddr* addr; + const struct sock* connect_sock; int32_t fd; }; diff --git a/src/stirling/source_connectors/socket_tracer/socket_trace_bpf_test.cc b/src/stirling/source_connectors/socket_tracer/socket_trace_bpf_test.cc index fde841643a4..8de89aa4b22 100644 --- a/src/stirling/source_connectors/socket_tracer/socket_trace_bpf_test.cc +++ b/src/stirling/source_connectors/socket_tracer/socket_trace_bpf_test.cc @@ -39,6 +39,7 @@ #include "src/stirling/source_connectors/socket_tracer/bcc_bpf_intf/socket_trace.hpp" #include "src/stirling/source_connectors/socket_tracer/socket_trace_connector.h" #include "src/stirling/source_connectors/socket_tracer/testing/client_server_system.h" +#include "src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.h" #include "src/stirling/source_connectors/socket_tracer/testing/socket_trace_bpf_test_fixture.h" #include "src/stirling/testing/common.h" @@ -46,6 +47,8 @@ namespace px { namespace stirling { using ::px::stirling::testing::FindRecordsMatchingPID; +using ::px::stirling::testing::GetLocalAddrs; +using ::px::stirling::testing::GetLocalPorts; using ::px::stirling::testing::RecordBatchSizeIs; using ::px::system::TCPSocket; using ::px::system::UDPSocket; @@ -747,6 +750,155 @@ TEST_F(NullRemoteAddrTest, IPv6Accept4WithNullRemoteAddr) { EXPECT_EQ(records[kHTTPRemotePortIdx]->Get(0), port); } +using LocalAddrTest = testing::SocketTraceBPFTestFixture; + +TEST_F(LocalAddrTest, IPv4ConnectPopulatesLocalAddr) { + StartTransferDataThread(); + + TCPSocket client; + TCPSocket server; + + std::atomic server_ready = true; + + std::thread server_thread([&server, &server_ready]() { + server.BindAndListen(); + server_ready = true; + auto conn = server.Accept(/* populate_remote_addr */ true); + + std::string data; + + conn->Read(&data); + conn->Write(kHTTPRespMsg1); + }); + + // Wait for server thread to start listening. + while (!server_ready) { + } + // After server_ready, server.Accept() needs to enter the accepting state, before the client + // connection can succeed below. We don't have a simple and robust way to signal that from inside + // the server thread, so we just use sleep to avoid the race condition. + std::this_thread::sleep_for(std::chrono::seconds(1)); + + std::thread client_thread([&client, &server]() { + client.Connect(server); + + std::string data; + + client.Write(kHTTPReqMsg1); + client.Read(&data); + }); + + server_thread.join(); + client_thread.join(); + + // Get the remote port seen by server from client's local port. + struct sockaddr_in client_sockaddr = {}; + socklen_t client_sockaddr_len = sizeof(client_sockaddr); + struct sockaddr* client_sockaddr_ptr = reinterpret_cast(&client_sockaddr); + ASSERT_EQ(getsockname(client.sockfd(), client_sockaddr_ptr, &client_sockaddr_len), 0); + + // Close after getting the sockaddr from fd, otherwise getsockname() wont work. + client.Close(); + server.Close(); + + StopTransferDataThread(); + + std::vector tablets = ConsumeRecords(kHTTPTableNum); + ASSERT_NOT_EMPTY_AND_GET_RECORDS(const types::ColumnWrapperRecordBatch& record_batch, tablets); + + std::vector indices = + testing::FindRecordIdxMatchesPID(record_batch, kHTTPUPIDIdx, getpid()); + ColumnWrapperRecordBatch records = testing::SelectRecordBatchRows(record_batch, indices); + + ASSERT_THAT(records, RecordBatchSizeIs(2)); + + // Make sure that the socket info resolution works. + ASSERT_OK_AND_ASSIGN(std::string remote_addr, IPv4AddrToString(client_sockaddr.sin_addr)); + EXPECT_THAT(GetLocalAddrs(records, kHTTPLocalAddrIdx, indices), Contains("127.0.0.1").Times(2)); + EXPECT_EQ(remote_addr, "127.0.0.1"); + + bool found_port = false; + uint16_t port = ntohs(client_sockaddr.sin_port); + for (auto lport : GetLocalPorts(records, kHTTPLocalPortIdx, indices)) { + if (lport == port) { + found_port = true; + break; + } + } + EXPECT_TRUE(found_port); +} + +TEST_F(LocalAddrTest, IPv6ConnectPopulatesLocalAddr) { + StartTransferDataThread(); + + TCPSocket client(AF_INET6); + TCPSocket server(AF_INET6); + + std::atomic server_ready = false; + + std::thread server_thread([&server, &server_ready]() { + server.BindAndListen(); + server_ready = true; + auto conn = server.Accept(/* populate_remote_addr */ false); + + std::string data; + + conn->Read(&data); + conn->Write(kHTTPRespMsg1); + }); + + while (!server_ready) { + } + + std::thread client_thread([&client, &server]() { + client.Connect(server); + + std::string data; + + client.Write(kHTTPReqMsg1); + client.Read(&data); + }); + + server_thread.join(); + client_thread.join(); + + // Get the remote port seen by server from client's local port. + struct sockaddr_in6 client_sockaddr = {}; + socklen_t client_sockaddr_len = sizeof(client_sockaddr); + struct sockaddr* client_sockaddr_ptr = reinterpret_cast(&client_sockaddr); + ASSERT_EQ(getsockname(client.sockfd(), client_sockaddr_ptr, &client_sockaddr_len), 0); + + // Close after getting the sockaddr from fd, otherwise getsockname() wont work. + client.Close(); + server.Close(); + + StopTransferDataThread(); + + std::vector tablets = ConsumeRecords(kHTTPTableNum); + ASSERT_NOT_EMPTY_AND_GET_RECORDS(const types::ColumnWrapperRecordBatch& record_batch, tablets); + + std::vector indices = + testing::FindRecordIdxMatchesPID(record_batch, kHTTPUPIDIdx, getpid()); + ColumnWrapperRecordBatch records = testing::SelectRecordBatchRows(record_batch, indices); + + ASSERT_THAT(records, RecordBatchSizeIs(2)); + + // Make sure that the socket info resolution works. + ASSERT_OK_AND_ASSIGN(std::string remote_addr, IPv6AddrToString(client_sockaddr.sin6_addr)); + EXPECT_THAT(GetLocalAddrs(records, kHTTPLocalAddrIdx, indices), Contains("::1").Times(2)); + EXPECT_EQ(remote_addr, "::1"); + + bool found_port = false; + uint16_t port = ntohs(client_sockaddr.sin6_port); + for (auto lport : GetLocalPorts(records, kHTTPLocalPortIdx, indices)) { + if (lport == port) { + found_port = true; + break; + } + } + EXPECT_TRUE(found_port); +} + // Run a UDP-based client-server system. class UDPSocketTraceBPFTest : public SocketTraceBPFTest { protected: diff --git a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc index 76c20b5a1a5..c4c273b96ba 100644 --- a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc +++ b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc @@ -176,7 +176,7 @@ DEFINE_bool( stirling_debug_tls_sources, gflags::BoolFromEnv("PX_DEBUG_TLS_SOURCES", false), "If true, stirling will add additional prometheus metrics regarding the traced tls sources"); -DEFINE_uint32(stirling_bpf_loop_limit, 42, +DEFINE_uint32(stirling_bpf_loop_limit, 41, "The maximum number of iovecs to capture for syscalls. " "Set conservatively for older kernels by default to keep the instruction count below " "BPF's limit for version 4 kernels (4096 per probe)."); @@ -342,6 +342,18 @@ const auto kProbeSpecs = MakeArray({ {"close", ProbeType::kReturn, "syscall__probe_ret_close"}, {"mmap", ProbeType::kEntry, "syscall__probe_entry_mmap"}, {"sock_alloc", ProbeType::kReturn, "probe_ret_sock_alloc", /*is_syscall*/ false}, + {"tcp_v4_connect", ProbeType::kEntry, "probe_entry_populate_active_connect_sock", + /*is_syscall*/ false}, + {"tcp_v4_connect", ProbeType::kReturn, "probe_ret_populate_active_connect_sock", + /*is_syscall*/ false}, + {"tcp_v6_connect", ProbeType::kEntry, "probe_entry_populate_active_connect_sock", + /*is_syscall*/ false}, + {"tcp_v6_connect", ProbeType::kReturn, "probe_ret_populate_active_connect_sock", + /*is_syscall*/ false}, + {"tcp_sendmsg", ProbeType::kEntry, "probe_entry_populate_active_connect_sock", + /*is_syscall*/ false}, + {"tcp_sendmsg", ProbeType::kReturn, "probe_ret_populate_active_connect_sock", + /*is_syscall*/ false}, {"security_socket_sendmsg", ProbeType::kEntry, "probe_entry_socket_sendmsg", /*is_syscall*/ false, /* is_optional */ false, std::make_shared(bpf_tools::KProbeSpec{ diff --git a/src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.h b/src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.h index 0cb66f59a14..207eb68e89b 100644 --- a/src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.h +++ b/src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.h @@ -135,6 +135,26 @@ inline std::vector GetEncrypted(const types::ColumnWrapperRecordBatch& rb, return encrypted; } +inline std::vector GetLocalAddrs(const types::ColumnWrapperRecordBatch& rb, + const int local_addr_idx, + const std::vector& indices) { + std::vector laddrs; + for (size_t idx : indices) { + laddrs.push_back(rb[local_addr_idx]->Get(idx)); + } + return laddrs; +} + +inline std::vector GetLocalPorts(const types::ColumnWrapperRecordBatch& rb, + const int local_port_idx, + const std::vector& indices) { + std::vector ports; + for (size_t idx : indices) { + ports.push_back(rb[local_port_idx]->Get(idx).val); + } + return ports; +} + inline std::vector GetRemotePorts(const types::ColumnWrapperRecordBatch& rb, const std::vector& indices) { std::vector addrs;