Skip to content

Commit

Permalink
Support cmd execution timeout in service mode
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianM27 authored and mkozlowski committed Jun 6, 2024
1 parent 2e07ae1 commit 3c367c5
Showing 1 changed file with 65 additions and 9 deletions.
74 changes: 65 additions & 9 deletions memcr.c
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ static int rss_file;
static int compress;
static int checksum;
static int service;
static unsigned int timeout;


#define BIT(x) (1ULL << x)
Expand Down Expand Up @@ -923,6 +924,20 @@ static void clear_pid_on_worker_exit_non_blocking(pid_t worker)
}
}

static int get_pid_worker(pid_t pid)
{
int worker = PID_INVALID;
pthread_mutex_lock(&checkpoint_service_data_lock);
for (int i=0; i<CHECKPOINTED_PIDS_LIMIT; ++i) {
if (checkpoint_service_data[i].pid == pid) {
worker = checkpoint_service_data[i].worker;
break;
}
}
pthread_mutex_unlock(&checkpoint_service_data_lock);
return worker;
}

static int can_checkpoint_pid(pid_t pid)
{
pthread_mutex_lock(&checkpoint_service_data_lock);
Expand Down Expand Up @@ -2470,24 +2485,37 @@ static void try_to_abort_checkpoint(pid_t pid)
}
}

static void checkpoint_procedure_service(int checkpointSocket, int cd)
static int checkpoint_procedure_service(int checkpointSocket, int cd, int pid, int worker_pid)
{
int ret;
struct service_response svc_resp;

fprintf(stdout, "[+] Service waiting for worker checkpoint...\n");
if (timeout) {
fprintf(stdout, "[+] Service waiting for worker checkpoint with timeout %d[s]...\n", timeout);
struct timeval rcv_timeout = { .tv_sec = timeout, .tv_usec = 0 };
ret = setsockopt(checkpointSocket, SOL_SOCKET, SO_RCVTIMEO, &rcv_timeout, sizeof(rcv_timeout));
if (ret < 0)
fprintf(stderr, "[-] Error setting socket timeout: %m, waiting forever!\n");
} else
fprintf(stdout, "[+] Service waiting for worker checkpoint...\n");

ret = _read(checkpointSocket, &svc_resp, sizeof(svc_resp)); // receive resp from child

if (ret == sizeof(svc_resp)) {
fprintf(stdout, "[+] Service received checkpoint response, informing client...\n");
send_response_to_client(cd, svc_resp.resp_code);
return svc_resp.resp_code;
} else {
fprintf(stderr, "[!] Error reading checkpoint response from worker!\n");
// unnable to read response from worker, kill both
kill(pid, SIGKILL);
kill(worker_pid, SIGKILL);
send_response_to_client(cd, MEMCR_ERROR_GENERAL);
return MEMCR_ERROR_GENERAL;
}
}

static void restore_procedure_service(int cd, struct service_command svc_cmd)
static void restore_procedure_service(int cd, struct service_command svc_cmd, int worker_pid)
{
int rd, ret = 0;
struct service_response svc_resp;
Expand All @@ -2504,12 +2532,23 @@ static void restore_procedure_service(int cd, struct service_command svc_cmd)
ret = -1;
}

fprintf(stdout, "[+] Service waiting for worker to restore... \n");
if (timeout) {
fprintf(stdout, "[+] Service waiting for worker to restore with timeout %d[s]...\n", timeout);
struct timeval rcv_timeout = { .tv_sec = timeout, .tv_usec = 0 };
ret = setsockopt(rd, SOL_SOCKET, SO_RCVTIMEO, &rcv_timeout, sizeof(rcv_timeout));
if (ret < 0)
fprintf(stderr, "[-] Error setting socket timeout: %m, waiting forever!\n");
} else
fprintf(stdout, "[+] Service waiting for worker to restore... \n");

ret = _read(rd, &svc_resp, sizeof(struct service_response)); // read response from service
close(rd);

if (ret != sizeof(struct service_response)) {
fprintf(stderr, "[-] %s() read() svc_resp failed: ret %d\n", __func__, ret);
// unnable to read response from worker, kill both
kill(svc_cmd.pid, SIGKILL);
kill(worker_pid, SIGKILL);
ret = -1;
}

Expand Down Expand Up @@ -2573,18 +2612,30 @@ static void *service_command_thread(void *ptr)
} else if (forkpid > 0) {
close(checkpoint_resp_sockets[1]);
set_pid_checkpointing(svc_ctx.svc_cmd.pid, checkpoint_resp_sockets[0]);
checkpoint_procedure_service(checkpoint_resp_sockets[0], svc_ctx.cd);
set_pid_checkpointed(svc_ctx.svc_cmd.pid, forkpid);
if (checkpoint_procedure_service(checkpoint_resp_sockets[0], svc_ctx.cd,
svc_ctx.svc_cmd.pid, forkpid))
clear_pid_checkpoint_data(svc_ctx.svc_cmd.pid);
else
set_pid_checkpointed(svc_ctx.svc_cmd.pid, forkpid);

close(checkpoint_resp_sockets[0]);
} else {
fprintf(stderr, "%s(): Fork error!\n", __func__);
clear_pid_checkpoint_data(svc_ctx.svc_cmd.pid);
}

break;
}
case MEMCR_RESTORE: {
fprintf(stdout, "[+] handling MEMCR_RESTORE for %d.\n", svc_ctx.svc_cmd.pid);
restore_procedure_service(svc_ctx.cd, svc_ctx.svc_cmd);
int worker_pid = get_pid_worker(svc_ctx.svc_cmd.pid);
if (worker_pid == PID_INVALID) {
fprintf(stderr, "%s(): Error, worker pid not found for %d!\n", __func__, svc_ctx.svc_cmd.pid);
send_response_to_client(svc_ctx.cd, MEMCR_ERROR_GENERAL);
close(svc_ctx.cd);
break;
}
restore_procedure_service(svc_ctx.cd, svc_ctx.svc_cmd, worker_pid);
clear_pid_checkpoint_data(svc_ctx.svc_cmd.pid);
break;
}
Expand Down Expand Up @@ -2800,7 +2851,8 @@ static void usage(const char *name, int status)
" -f --rss-file include file mapped memory\n" \
" -z --compress compress memory dump\n" \
" -c --checksum enable md5 checksum for memory dump\n" \
" -e --encrypt enable encryption of memory dump\n",
" -e --encrypt enable encryption of memory dump\n" \
" -t --timeout timeout in seconds for checkpoint/restore execution in service mode\n",
name);

exit(status);
Expand Down Expand Up @@ -2840,14 +2892,15 @@ int main(int argc, char *argv[])
{ "compress", 0, NULL, 'z'},
{ "checksum", 0, NULL, 'c'},
{ "encrypt", 2, 0, 'e'},
{ "timeout", 1, 0, 't'},
{ NULL, 0, NULL, 0 }
};

dump_dir = "/tmp";
parasite_socket_dir = NULL;
parasite_socket_use_netns = 0;

while ((opt = getopt_long(argc, argv, "hp:d:S:Nl:nmfzce::", long_options, &option_index)) != -1) {
while ((opt = getopt_long(argc, argv, "hp:d:S:Nl:nmfzce::t:", long_options, &option_index)) != -1) {
switch (opt) {
case 'h':
usage(argv[0], 0);
Expand Down Expand Up @@ -2896,6 +2949,9 @@ int main(int argc, char *argv[])
else if (optind < argc && argv[optind][0] != '-')
encrypt_arg = argv[optind++];
break;
case 't':
timeout = atoi(optarg);
break;
default: /* '?' */
usage(argv[0], 1);
}
Expand Down

0 comments on commit 3c367c5

Please sign in to comment.