diff --git a/src/api_server/src/lib.rs b/src/api_server/src/lib.rs index 7959203e8ee..a9d6ce3ca82 100644 --- a/src/api_server/src/lib.rs +++ b/src/api_server/src/lib.rs @@ -38,8 +38,6 @@ pub struct ApiServer { /// FD on which we notify the VMM that we have sent at least one /// `VmmRequest`. to_vmm_fd: EventFd, - /// If this flag is set, the API thread will go down. - shutdown_flag: bool, } impl ApiServer { @@ -55,7 +53,6 @@ impl ApiServer { api_request_sender, vmm_response_receiver, to_vmm_fd, - shutdown_flag: false, } } @@ -121,14 +118,6 @@ impl ApiServer { let delta_us = utils::time::get_time_us(utils::time::ClockType::Monotonic) - request_processing_start_us; debug!("Total previous API call duration: {} us.", delta_us); - - if self.shutdown_flag { - server.flush_outgoing_writes(); - debug!( - "/shutdown-internal request received, API server thread now ending itself" - ); - return; - } } } } @@ -145,10 +134,6 @@ impl ApiServer { RequestAction::Sync(vmm_action) => { self.serve_vmm_action_request(vmm_action, request_processing_start_us) } - RequestAction::ShutdownInternal => { - self.shutdown_flag = true; - Response::new(Version::Http11, StatusCode::NoContent) - } }; if let Some(message) = parsing_info.take_deprecation_message() { warn!("{}", message); @@ -483,4 +468,38 @@ mod tests { unanswered requests will be dropped.\" }"; assert_eq!(&buf[..], &error_message[..]); } + + #[test] + fn test_kill_switch() { + let mut tmp_socket = TempFile::new().unwrap(); + tmp_socket.remove().unwrap(); + let path_to_socket = tmp_socket.as_path().to_str().unwrap().to_owned(); + + let to_vmm_fd = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + let (api_request_sender, _from_api) = channel(); + let (_to_api, vmm_response_receiver) = channel(); + let seccomp_filters = get_empty_filters(); + + let api_kill_switch = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + let kill_switch = api_kill_switch.try_clone().unwrap(); + + let mut server = HttpServer::new(PathBuf::from(path_to_socket)).unwrap(); + server.add_kill_switch(kill_switch).unwrap(); + + let api_thread = thread::Builder::new() + .name("fc_api_test".to_owned()) + .spawn(move || { + ApiServer::new(api_request_sender, vmm_response_receiver, to_vmm_fd).run( + server, + ProcessTimeReporter::new(Some(1), Some(1), Some(1)), + seccomp_filters.get("api").unwrap(), + vmm::HTTP_MAX_PAYLOAD_SIZE, + ) + }) + .unwrap(); + // Signal the API thread it should shut down. + api_kill_switch.write(1).unwrap(); + // Verify API thread was brought down. + api_thread.join().unwrap(); + } } diff --git a/src/api_server/src/parsed_request.rs b/src/api_server/src/parsed_request.rs index 12eacb24bc0..4dddaf87b86 100644 --- a/src/api_server/src/parsed_request.rs +++ b/src/api_server/src/parsed_request.rs @@ -32,7 +32,6 @@ use crate::ApiServer; #[derive(Debug)] pub(crate) enum RequestAction { Sync(Box), - ShutdownInternal, // !!! not an API, used by shutdown to thread::join the API thread } #[derive(Debug, Default, PartialEq)] @@ -97,9 +96,6 @@ impl TryFrom<&Request> for ParsedRequest { (Method::Put, "network-interfaces", Some(body)) => { parse_put_net(body, path_tokens.next()) } - (Method::Put, "shutdown-internal", None) => { - Ok(ParsedRequest::new(RequestAction::ShutdownInternal)) - } (Method::Put, "snapshot", Some(body)) => parse_put_snapshot(body, path_tokens.next()), (Method::Put, "vsock", Some(body)) => parse_put_vsock(body), (Method::Put, "entropy", Some(body)) => parse_put_entropy(body), @@ -350,7 +346,6 @@ pub mod tests { (RequestAction::Sync(ref sync_req), RequestAction::Sync(ref other_sync_req)) => { sync_req == other_sync_req } - _ => false, } } } @@ -358,7 +353,6 @@ pub mod tests { pub(crate) fn vmm_action_from_request(req: ParsedRequest) -> VmmAction { match req.action { RequestAction::Sync(vmm_action) => *vmm_action, - _ => panic!("Invalid request"), } } @@ -371,7 +365,6 @@ pub mod tests { assert_eq!(req_msg, msg); *vmm_action } - _ => panic!("Invalid request"), } } @@ -883,21 +876,6 @@ pub mod tests { assert!(ParsedRequest::try_from(&req).is_ok()); } - #[test] - fn test_try_from_put_shutdown() { - let (mut sender, receiver) = UnixStream::pair().unwrap(); - let mut connection = HttpConnection::new(receiver); - sender - .write_all(http_request("PUT", "/shutdown-internal", None).as_bytes()) - .unwrap(); - assert!(connection.try_read().is_ok()); - let req = connection.pop_parsed_request().unwrap(); - match ParsedRequest::try_from(&req).unwrap().into_parts() { - (RequestAction::ShutdownInternal, _) => (), - _ => panic!("wrong parsed request"), - }; - } - #[test] fn test_try_from_patch_vm() { let (mut sender, receiver) = UnixStream::pair().unwrap(); diff --git a/src/firecracker/src/api_server_adapter.rs b/src/firecracker/src/api_server_adapter.rs index 3e1e1ef5e45..73e504503cd 100644 --- a/src/firecracker/src/api_server_adapter.rs +++ b/src/firecracker/src/api_server_adapter.rs @@ -1,9 +1,7 @@ // Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use std::io::Write; use std::os::unix::io::AsRawFd; -use std::os::unix::net::UnixStream; use std::path::PathBuf; use std::sync::mpsc::{channel, Receiver, Sender, TryRecvError}; use std::sync::{Arc, Mutex}; @@ -241,19 +239,7 @@ pub(crate) fn run_with_api( ) }); - // We want to tell the API thread to shut down for a clean exit. But this is after - // the Vmm.stop() has been called, so it's a moment of internal finalization (as - // opposed to be something the client might call to shut the Vm down). Since it's - // an internal signal implementing it with an HTTP request is probably not the ideal - // way to do it...but having another way would involve multiplexing micro-http server - // with some other communication mechanism, or enhancing micro-http with exit - // conditions. - - // We also need to make sure the socket path is ready. - let mut sock = UnixStream::connect(bind_path).unwrap(); - sock.write_all(b"PUT /shutdown-internal HTTP/1.1\r\n\r\n") - .unwrap(); - + api_kill_switch.write(1).unwrap(); // This call to thread::join() should block until the API thread has processed the // shutdown-internal and returns from its function. api_thread.join().expect("Api thread should join");