Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Populate client side trace's local address via tcp kprobes #1989

Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ BPF_HASH(active_accept_args_map, uint64_t, struct accept_args_t);
// Key is {tgid, pid}.
BPF_HASH(active_connect_args_map, uint64_t, struct connect_args_t);

// Map from thread to its sock* struct. This facilitates capturing
// the local address of a tcp socket during connect() syscalls.
// Key is {tgid, pid}.
BPF_HASH(tcp_connect_args_map, uint64_t, struct sock*);

// Map from thread to its ongoing write() syscall's input argument.
// Tracks write() call from entry -> exit.
// Key is {tgid, pid}.
Expand Down Expand Up @@ -345,19 +350,17 @@ static __inline void update_traffic_class(struct conn_info_t* conn_info,
* Perf submit functions
***********************************************************/

static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info,
const struct socket* socket) {
// Use BPF_PROBE_READ_KERNEL_VAR since BCC cannot insert them as expected.
struct sock* sk = NULL;
BPF_PROBE_READ_KERNEL_VAR(sk, &socket->sk);

struct sock_common* sk_common = &sk->__sk_common;
static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info, const struct sock* sk) {
const struct sock_common* sk_common = &sk->__sk_common;
uint16_t family = -1;
uint16_t lport = -1;
uint16_t rport = -1;

BPF_PROBE_READ_KERNEL_VAR(family, &sk_common->skc_family);
BPF_PROBE_READ_KERNEL_VAR(lport, &sk_common->skc_num);
// skc_num is stored in host byte order. The rest of our user space processing
// assumes network byte order so convert it here.
lport = htons(lport);
BPF_PROBE_READ_KERNEL_VAR(rport, &sk_common->skc_dport);

conn_info->laddr.sa.sa_family = family;
Expand All @@ -377,12 +380,12 @@ static __inline void read_sockaddr_kernel(struct conn_info_t* conn_info,
}

static __inline void submit_new_conn(struct pt_regs* ctx, uint32_t tgid, int32_t fd,
const struct sockaddr* addr, const struct socket* socket,
const struct sockaddr* addr, const struct sock* sock,
enum endpoint_role_t role, enum source_function_t source_fn) {
struct conn_info_t conn_info = {};
init_conn_info(tgid, fd, &conn_info);
if (socket != NULL) {
read_sockaddr_kernel(&conn_info, socket);
if (sock != NULL) {
read_sockaddr_kernel(&conn_info, sock);
} else if (addr != NULL) {
conn_info.raddr = *((union sockaddr_t*)addr);
}
Expand Down Expand Up @@ -585,6 +588,52 @@ int conn_cleanup_uprobe(struct pt_regs* ctx) {
return 0;
}

// These probes are used to capture the *sock struct during client side tracing
// of connect() syscalls. This is necessary to capture the socket's local address,
// which is not accessible via the connect() and later syscalls.
//
// This function requires that the function being probed receives a struct sock* as its
// first argument and that the active_connect_args_map is populated when this probe fires.
// This means the function being probed must be part of the connect() syscall path or similar
// syscall path.
//
// Using the struct sock* for capturing a socket's local address only works for TCP sockets.
// The equivalent UDP functions (udp_v4_connect, udp_v6_connect and upd_sendmsg) always receive a
// sock struct with a 0.0.0.0 or ::1 local address. This is deemed acceptable since our local
// address population for server side tracing relies on accept/accept4, which only applies for TCP.
//
// int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len);
// static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len);
// int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int probe_entry_populate_active_connect_sock(struct pt_regs* ctx) {
uint64_t id = bpf_get_current_pid_tgid();

const struct connect_args_t* connect_args = active_connect_args_map.lookup(&id);
if (connect_args == NULL) {
return 0;
}
struct sock* sk = (struct sock*)PT_REGS_PARM1(ctx);
tcp_connect_args_map.update(&id, &sk);

return 0;
}

int probe_ret_populate_active_connect_sock(struct pt_regs* ctx) {
uint64_t id = bpf_get_current_pid_tgid();

struct sock** sk = tcp_connect_args_map.lookup(&id);
if (sk == NULL) {
return 0;
}
struct connect_args_t* connect_args = active_connect_args_map.lookup(&id);
if (connect_args != NULL) {
connect_args->connect_sock = *sk;
}

tcp_connect_args_map.delete(&id);
return 0;
}

/***********************************************************
* BPF syscall processing functions
***********************************************************/
Expand Down Expand Up @@ -629,7 +678,8 @@ static __inline void process_syscall_connect(struct pt_regs* ctx, uint64_t id,
return;
}

submit_new_conn(ctx, tgid, args->fd, args->addr, /*socket*/ NULL, kRoleClient, kSyscallConnect);
submit_new_conn(ctx, tgid, args->fd, args->addr, args->connect_sock, kRoleClient,
kSyscallConnect);
}

static __inline void process_syscall_accept(struct pt_regs* ctx, uint64_t id,
Expand All @@ -645,8 +695,11 @@ static __inline void process_syscall_accept(struct pt_regs* ctx, uint64_t id,
return;
}

submit_new_conn(ctx, tgid, ret_fd, args->addr, args->sock_alloc_socket, kRoleServer,
kSyscallAccept);
const struct sock* sk = NULL;
if (args->sock_alloc_socket != NULL) {
BPF_PROBE_READ_KERNEL_VAR(sk, &args->sock_alloc_socket->sk);
}
submit_new_conn(ctx, tgid, ret_fd, args->addr, sk, kRoleServer, kSyscallAccept);
}

// TODO(oazizi): This is badly broken (but better than before).
Expand Down Expand Up @@ -690,7 +743,7 @@ static __inline void process_implicit_conn(struct pt_regs* ctx, uint64_t id,
return;
}

submit_new_conn(ctx, tgid, args->fd, args->addr, /*socket*/ NULL, kRoleUnknown, source_fn);
submit_new_conn(ctx, tgid, args->fd, args->addr, args->connect_sock, kRoleUnknown, source_fn);
}

static __inline bool should_send_data(uint32_t tgid, uint64_t conn_disabled_tsid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ struct socket_control_event_t {

struct connect_args_t {
const struct sockaddr* addr;
const struct sock* connect_sock;
int32_t fd;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@
#include "src/stirling/source_connectors/socket_tracer/bcc_bpf_intf/socket_trace.hpp"
#include "src/stirling/source_connectors/socket_tracer/socket_trace_connector.h"
#include "src/stirling/source_connectors/socket_tracer/testing/client_server_system.h"
#include "src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.h"
#include "src/stirling/source_connectors/socket_tracer/testing/socket_trace_bpf_test_fixture.h"
#include "src/stirling/testing/common.h"

namespace px {
namespace stirling {

using ::px::stirling::testing::FindRecordsMatchingPID;
using ::px::stirling::testing::GetLocalAddrs;
using ::px::stirling::testing::GetLocalPorts;
using ::px::stirling::testing::RecordBatchSizeIs;
using ::px::system::TCPSocket;
using ::px::system::UDPSocket;
Expand Down Expand Up @@ -747,6 +750,155 @@ TEST_F(NullRemoteAddrTest, IPv6Accept4WithNullRemoteAddr) {
EXPECT_EQ(records[kHTTPRemotePortIdx]->Get<types::Int64Value>(0), port);
}

using LocalAddrTest = testing::SocketTraceBPFTestFixture</* TClientSideTracing */ true>;

TEST_F(LocalAddrTest, IPv4ConnectPopulatesLocalAddr) {
StartTransferDataThread();

TCPSocket client;
TCPSocket server;

std::atomic<bool> server_ready = true;

std::thread server_thread([&server, &server_ready]() {
server.BindAndListen();
server_ready = true;
auto conn = server.Accept(/* populate_remote_addr */ true);

std::string data;

conn->Read(&data);
conn->Write(kHTTPRespMsg1);
});

// Wait for server thread to start listening.
while (!server_ready) {
}
// After server_ready, server.Accept() needs to enter the accepting state, before the client
// connection can succeed below. We don't have a simple and robust way to signal that from inside
// the server thread, so we just use sleep to avoid the race condition.
std::this_thread::sleep_for(std::chrono::seconds(1));

std::thread client_thread([&client, &server]() {
client.Connect(server);

std::string data;

client.Write(kHTTPReqMsg1);
client.Read(&data);
});

server_thread.join();
client_thread.join();

// Get the remote port seen by server from client's local port.
struct sockaddr_in client_sockaddr = {};
socklen_t client_sockaddr_len = sizeof(client_sockaddr);
struct sockaddr* client_sockaddr_ptr = reinterpret_cast<struct sockaddr*>(&client_sockaddr);
ASSERT_EQ(getsockname(client.sockfd(), client_sockaddr_ptr, &client_sockaddr_len), 0);

// Close after getting the sockaddr from fd, otherwise getsockname() wont work.
client.Close();
server.Close();

StopTransferDataThread();

std::vector<TaggedRecordBatch> tablets = ConsumeRecords(kHTTPTableNum);
ASSERT_NOT_EMPTY_AND_GET_RECORDS(const types::ColumnWrapperRecordBatch& record_batch, tablets);

std::vector<size_t> indices =
testing::FindRecordIdxMatchesPID(record_batch, kHTTPUPIDIdx, getpid());
ColumnWrapperRecordBatch records = testing::SelectRecordBatchRows(record_batch, indices);

ASSERT_THAT(records, RecordBatchSizeIs(2));

// Make sure that the socket info resolution works.
ASSERT_OK_AND_ASSIGN(std::string remote_addr, IPv4AddrToString(client_sockaddr.sin_addr));
EXPECT_THAT(GetLocalAddrs(records, kHTTPLocalAddrIdx, indices), Contains("127.0.0.1").Times(2));
EXPECT_EQ(remote_addr, "127.0.0.1");

bool found_port = false;
uint16_t port = ntohs(client_sockaddr.sin_port);
for (auto lport : GetLocalPorts(records, kHTTPLocalPortIdx, indices)) {
if (lport == port) {
found_port = true;
break;
}
}
EXPECT_TRUE(found_port);
}

TEST_F(LocalAddrTest, IPv6ConnectPopulatesLocalAddr) {
StartTransferDataThread();

TCPSocket client(AF_INET6);
TCPSocket server(AF_INET6);

std::atomic<bool> server_ready = false;

std::thread server_thread([&server, &server_ready]() {
server.BindAndListen();
server_ready = true;
auto conn = server.Accept(/* populate_remote_addr */ false);

std::string data;

conn->Read(&data);
conn->Write(kHTTPRespMsg1);
});

while (!server_ready) {
}

std::thread client_thread([&client, &server]() {
client.Connect(server);

std::string data;

client.Write(kHTTPReqMsg1);
client.Read(&data);
});

server_thread.join();
client_thread.join();

// Get the remote port seen by server from client's local port.
struct sockaddr_in6 client_sockaddr = {};
socklen_t client_sockaddr_len = sizeof(client_sockaddr);
struct sockaddr* client_sockaddr_ptr = reinterpret_cast<struct sockaddr*>(&client_sockaddr);
ASSERT_EQ(getsockname(client.sockfd(), client_sockaddr_ptr, &client_sockaddr_len), 0);

// Close after getting the sockaddr from fd, otherwise getsockname() wont work.
client.Close();
server.Close();

StopTransferDataThread();

std::vector<TaggedRecordBatch> tablets = ConsumeRecords(kHTTPTableNum);
ASSERT_NOT_EMPTY_AND_GET_RECORDS(const types::ColumnWrapperRecordBatch& record_batch, tablets);

std::vector<size_t> indices =
testing::FindRecordIdxMatchesPID(record_batch, kHTTPUPIDIdx, getpid());
ColumnWrapperRecordBatch records = testing::SelectRecordBatchRows(record_batch, indices);

ASSERT_THAT(records, RecordBatchSizeIs(2));

// Make sure that the socket info resolution works.
ASSERT_OK_AND_ASSIGN(std::string remote_addr, IPv6AddrToString(client_sockaddr.sin6_addr));
EXPECT_THAT(GetLocalAddrs(records, kHTTPLocalAddrIdx, indices), Contains("::1").Times(2));
EXPECT_EQ(remote_addr, "::1");

bool found_port = false;
uint16_t port = ntohs(client_sockaddr.sin6_port);
for (auto lport : GetLocalPorts(records, kHTTPLocalPortIdx, indices)) {
if (lport == port) {
found_port = true;
break;
}
}
EXPECT_TRUE(found_port);
}

// Run a UDP-based client-server system.
class UDPSocketTraceBPFTest : public SocketTraceBPFTest {
protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ DEFINE_bool(
stirling_debug_tls_sources, gflags::BoolFromEnv("PX_DEBUG_TLS_SOURCES", false),
"If true, stirling will add additional prometheus metrics regarding the traced tls sources");

DEFINE_uint32(stirling_bpf_loop_limit, 42,
DEFINE_uint32(stirling_bpf_loop_limit, 41,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iovec syscall return probes that rely on process_implicit_conn (recvmsg, recvmmsg, sendmsg, sendmmsg) had their instruction count increase beyond the 4.14 limit. This was my solution to getting the BPF instruction count on 4.14 kernels to work. The rationale is that cases that are right on the edge of this limit are likely to have problems already.

The other things I considered were the following:

  • Remove match_trace_tgid logic on 4.14 kernels -- did not lower the instruction count enough
  • Make the iovec ret syscalls provide a NULL sock for its submit_new_conn call (e.g. submit_new_conn(ctx, tgid, fd, addr, /*sock*/ NULL, role, source_fn)) -- not viable since it partially handles mid stream connections

I'm open to other suggestions if you have any, but this was the set of things I came up with.

"The maximum number of iovecs to capture for syscalls. "
"Set conservatively for older kernels by default to keep the instruction count below "
"BPF's limit for version 4 kernels (4096 per probe).");
Expand Down Expand Up @@ -342,6 +342,18 @@ const auto kProbeSpecs = MakeArray<bpf_tools::KProbeSpec>({
{"close", ProbeType::kReturn, "syscall__probe_ret_close"},
{"mmap", ProbeType::kEntry, "syscall__probe_entry_mmap"},
{"sock_alloc", ProbeType::kReturn, "probe_ret_sock_alloc", /*is_syscall*/ false},
{"tcp_v4_connect", ProbeType::kEntry, "probe_entry_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_v4_connect", ProbeType::kReturn, "probe_ret_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_v6_connect", ProbeType::kEntry, "probe_entry_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_v6_connect", ProbeType::kReturn, "probe_ret_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_sendmsg", ProbeType::kEntry, "probe_entry_populate_active_connect_sock",
/*is_syscall*/ false},
{"tcp_sendmsg", ProbeType::kReturn, "probe_ret_populate_active_connect_sock",
/*is_syscall*/ false},
Comment on lines +353 to +356
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't have an easy mechanism for testing this, I tried to test this mid stream case by creating a psql shell prior to the PEM starting and issuing queries after the socket tracer was initialized. I verified that a tcp connection was made and only sendto syscalls were issued when queries were executed (no connect syscalls) to check that this would simulate a mid stream case.

Surprisingly, I saw that the local_addr column was populated before I added the probe on tcp_sendmsg despite my thinking that it should be empty. I went ahead and added this probe since I verified with ftrace on the same psql process that tcp_v4_connect isn't called. Just raising this since I wasn't able to verify that this additional probe adds extra coverage from the testing I did.

{"security_socket_sendmsg", ProbeType::kEntry, "probe_entry_socket_sendmsg",
/*is_syscall*/ false, /* is_optional */ false,
std::make_shared<bpf_tools::KProbeSpec>(bpf_tools::KProbeSpec{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,26 @@ inline std::vector<bool> GetEncrypted(const types::ColumnWrapperRecordBatch& rb,
return encrypted;
}

inline std::vector<std::string> GetLocalAddrs(const types::ColumnWrapperRecordBatch& rb,
const int local_addr_idx,
const std::vector<size_t>& indices) {
std::vector<std::string> laddrs;
for (size_t idx : indices) {
laddrs.push_back(rb[local_addr_idx]->Get<types::StringValue>(idx));
}
return laddrs;
}

inline std::vector<int64_t> GetLocalPorts(const types::ColumnWrapperRecordBatch& rb,
const int local_port_idx,
const std::vector<size_t>& indices) {
std::vector<int64_t> ports;
for (size_t idx : indices) {
ports.push_back(rb[local_port_idx]->Get<types::Int64Value>(idx).val);
}
return ports;
}

inline std::vector<int64_t> GetRemotePorts(const types::ColumnWrapperRecordBatch& rb,
const std::vector<size_t>& indices) {
std::vector<int64_t> addrs;
Expand Down
Loading