Skip to content

Commit

Permalink
Remove unused oneshot channel from USR2 sdf work
Browse files Browse the repository at this point in the history
This commit removes an unused oneshot channel from the recent USR2
signal handling work in sdf (#4368).

Additional changes:
- Move the maintenance mode message into a constant in order to make the
  code more readable
- Rename "Result" to "ServerResult" for clarity in "server/server.rs"
- Re-order imports to match style

Signed-off-by: Nick Gerace <[email protected]>
  • Loading branch information
nickgerace committed Aug 20, 2024
1 parent d5c7376 commit f4b7278
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 66 deletions.
13 changes: 8 additions & 5 deletions lib/sdf-server/src/server/routes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
use super::{
server::ServerError,
state::{AppState, ApplicationRuntimeMode},
};
use axum::{
extract::State,
http::{HeaderValue, Request, StatusCode},
Expand All @@ -20,6 +16,13 @@ use thiserror::Error;
use tower_http::cors::CorsLayer;
use tower_http::{compression::CompressionLayer, cors::AllowOrigin};

use super::{
server::ServerError,
state::{AppState, ApplicationRuntimeMode},
};

const MAINTENANCE_MODE_MESSAGE: &str = " SI is currently in maintenance mode. Please try again later. Reach out to [email protected] or in the SI Discord for more information if this problem persists";

async fn app_state_middeware<B>(
State(state): State<AppState>,
request: Request<B>,
Expand All @@ -28,7 +31,7 @@ async fn app_state_middeware<B>(
match *state.application_runtime_mode.read().await {
ApplicationRuntimeMode::Maintenance => {
// Return a 503 when the server is in maintenance/offline
(StatusCode::SERVICE_UNAVAILABLE, " SI is currently in maintenance mode. Please try again later. Reach out to [email protected] or in the SI Discord for more information if this problem persists").into_response()
(StatusCode::SERVICE_UNAVAILABLE, MAINTENANCE_MODE_MESSAGE).into_response()
}
ApplicationRuntimeMode::Running => next.run(request).await,
}
Expand Down
111 changes: 50 additions & 61 deletions lib/sdf-server/src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use std::{io, net::SocketAddr, path::Path, path::PathBuf};
use telemetry::prelude::*;
use telemetry_http::{HttpMakeSpan, HttpOnResponse};
use thiserror::Error;

use tokio::{
io::{AsyncRead, AsyncWrite},
signal,
Expand Down Expand Up @@ -111,7 +110,7 @@ impl From<PgPoolError> for ServerError {
}
}

pub type Result<T, E = ServerError> = std::result::Result<T, E>;
pub type ServerResult<T, E = ServerError> = std::result::Result<T, E>;

pub struct Server<I, S> {
config: Config,
Expand All @@ -131,19 +130,18 @@ impl Server<(), ()> {
ws_multiplexer_client: MultiplexerClient,
crdt_multiplexer: Multiplexer,
crdt_multiplexer_client: MultiplexerClient,
) -> Result<(Server<AddrIncoming, SocketAddr>, broadcast::Receiver<()>)> {
) -> ServerResult<(Server<AddrIncoming, SocketAddr>, broadcast::Receiver<()>)> {
match config.incoming_stream() {
IncomingStream::HTTPSocket(socket_addr) => {
let (service, shutdown_rx, _application_mode_rx, shutdown_broadcast_rx) =
build_service(
services_context,
jwt_public_signing_key,
posthog_client,
ws_multiplexer_client,
crdt_multiplexer_client,
*config.create_workspace_permissions(),
config.create_workspace_allowlist().to_vec(),
)?;
let (service, shutdown_rx, shutdown_broadcast_rx) = build_service(
services_context,
jwt_public_signing_key,
posthog_client,
ws_multiplexer_client,
crdt_multiplexer_client,
*config.create_workspace_permissions(),
config.create_workspace_allowlist().to_vec(),
)?;

tokio::spawn(ws_multiplexer.run(shutdown_broadcast_rx.resubscribe()));
tokio::spawn(crdt_multiplexer.run(shutdown_broadcast_rx.resubscribe()));
Expand Down Expand Up @@ -178,19 +176,18 @@ impl Server<(), ()> {
ws_multiplexer_client: MultiplexerClient,
crdt_multiplexer: Multiplexer,
crdt_multiplexer_client: MultiplexerClient,
) -> Result<(Server<UdsIncomingStream, PathBuf>, broadcast::Receiver<()>)> {
) -> ServerResult<(Server<UdsIncomingStream, PathBuf>, broadcast::Receiver<()>)> {
match config.incoming_stream() {
IncomingStream::UnixDomainSocket(path) => {
let (service, shutdown_rx, _application_mode_rx, shutdown_broadcast_rx) =
build_service(
services_context,
jwt_public_signing_key,
posthog_client,
ws_multiplexer_client,
crdt_multiplexer_client,
*config.create_workspace_permissions(),
config.create_workspace_allowlist().to_vec(),
)?;
let (service, shutdown_rx, shutdown_broadcast_rx) = build_service(
services_context,
jwt_public_signing_key,
posthog_client,
ws_multiplexer_client,
crdt_multiplexer_client,
*config.create_workspace_permissions(),
config.create_workspace_allowlist().to_vec(),
)?;

tokio::spawn(ws_multiplexer.run(shutdown_broadcast_rx.resubscribe()));
tokio::spawn(crdt_multiplexer.run(shutdown_broadcast_rx.resubscribe()));
Expand All @@ -216,11 +213,11 @@ impl Server<(), ()> {
}
}

pub fn init() -> Result<()> {
pub fn init() -> ServerResult<()> {
Ok(dal::init()?)
}

pub async fn start_posthog(config: &PosthogConfig) -> Result<PosthogClient> {
pub async fn start_posthog(config: &PosthogConfig) -> ServerResult<PosthogClient> {
let (posthog_client, posthog_sender) = si_posthog::from_config(config)?;

drop(tokio::spawn(posthog_sender.run()));
Expand All @@ -232,14 +229,14 @@ impl Server<(), ()> {
pub async fn generate_veritech_key_pair(
secret_key_path: impl AsRef<Path>,
public_key_path: impl AsRef<Path>,
) -> Result<()> {
) -> ServerResult<()> {
VeritechKeyPair::create_and_write_files(secret_key_path, public_key_path)
.await
.map_err(Into::into)
}

#[instrument(name = "sdf.init.generate_symmetric_key", level = "info", skip_all)]
pub async fn generate_symmetric_key(symmetric_key_path: impl AsRef<Path>) -> Result<()> {
pub async fn generate_symmetric_key(symmetric_key_path: impl AsRef<Path>) -> ServerResult<()> {
SymmetricCryptoService::generate_key()
.save(symmetric_key_path.as_ref())
.await
Expand All @@ -251,7 +248,9 @@ impl Server<(), ()> {
level = "info",
skip_all
)]
pub async fn load_jwt_public_signing_key(config: JwtConfig) -> Result<JwtPublicSigningKey> {
pub async fn load_jwt_public_signing_key(
config: JwtConfig,
) -> ServerResult<JwtPublicSigningKey> {
Ok(JwtPublicSigningKey::from_config(config).await?)
}

Expand All @@ -260,20 +259,22 @@ impl Server<(), ()> {
level = "info",
skip_all
)]
pub async fn decode_jwt_public_signing_key(key_string: String) -> Result<JwtPublicSigningKey> {
pub async fn decode_jwt_public_signing_key(
key_string: String,
) -> ServerResult<JwtPublicSigningKey> {
Ok(JwtPublicSigningKey::decode(key_string).await?)
}

#[instrument(name = "sdf.init.load_encryption_key", level = "info", skip_all)]
pub async fn load_encryption_key(
crypto_config: VeritechCryptoConfig,
) -> Result<Arc<VeritechEncryptionKey>> {
) -> ServerResult<Arc<VeritechEncryptionKey>> {
Ok(Arc::new(
VeritechEncryptionKey::from_config(crypto_config).await?,
))
}

pub async fn migrate_snapshots(services_context: &ServicesContext) -> Result<()> {
pub async fn migrate_snapshots(services_context: &ServicesContext) -> ServerResult<()> {
let dal_context = services_context.clone().into_builder(true);
let ctx = dal_context.build_default().await?;

Expand All @@ -285,7 +286,7 @@ impl Server<(), ()> {
}

#[instrument(name = "sdf.init.migrate_database", level = "info", skip_all)]
pub async fn migrate_database(services_context: &ServicesContext) -> Result<()> {
pub async fn migrate_database(services_context: &ServicesContext) -> ServerResult<()> {
services_context.layer_db().pg_migrate().await?;
dal::migrate_all_with_progress(services_context).await?;

Expand All @@ -296,14 +297,14 @@ impl Server<(), ()> {
}

#[instrument(name = "sdf.init.create_pg_pool", level = "info", skip_all)]
pub async fn create_pg_pool(pg_pool_config: &PgPoolConfig) -> Result<PgPool> {
pub async fn create_pg_pool(pg_pool_config: &PgPoolConfig) -> ServerResult<PgPool> {
let pool = PgPool::new(pg_pool_config).await?;
debug!("successfully started pg pool (note that not all connections may be healthy)");
Ok(pool)
}

#[instrument(name = "sdf.init.connect_to_nats", level = "info", skip_all)]
pub async fn connect_to_nats(nats_config: &NatsConfig) -> Result<NatsClient> {
pub async fn connect_to_nats(nats_config: &NatsConfig) -> ServerResult<NatsClient> {
let client = NatsClient::new(nats_config).await?;
debug!("successfully connected nats client");
Ok(client)
Expand All @@ -320,7 +321,7 @@ impl Server<(), ()> {
)]
pub async fn create_symmetric_crypto_service(
config: &SymmetricCryptoServiceConfig,
) -> Result<SymmetricCryptoService> {
) -> ServerResult<SymmetricCryptoService> {
SymmetricCryptoService::from_config(config)
.await
.map_err(Into::into)
Expand All @@ -333,7 +334,7 @@ where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<Box<dyn std::error::Error + Send + Sync>>,
{
pub async fn run(self) -> Result<()> {
pub async fn run(self) -> ServerResult<()> {
let shutdown_rx = self.shutdown_rx;

self.inner
Expand All @@ -355,7 +356,9 @@ where
}
}

pub async fn migrate_builtins_from_module_index(services_context: &ServicesContext) -> Result<()> {
pub async fn migrate_builtins_from_module_index(
services_context: &ServicesContext,
) -> ServerResult<()> {
let mut interval = time::interval(Duration::from_secs(5));
let instant = Instant::now();

Expand Down Expand Up @@ -402,7 +405,7 @@ async fn install_builtins(
ctx: DalContext,
module_list: BuiltinsDetailsResponse,
module_index_client: ModuleIndexClient,
) -> Result<()> {
) -> ServerResult<()> {
let dal = &ctx;
let client = &module_index_client.clone();
let modules: Vec<ModuleDetailsResponse> = module_list.modules;
Expand Down Expand Up @@ -473,21 +476,14 @@ async fn install_builtins(
async fn fetch_builtin(
module: &ModuleDetailsResponse,
module_index_client: &ModuleIndexClient,
) -> Result<SiPkg> {
) -> ServerResult<SiPkg> {
let module = module_index_client
.get_builtin(Ulid::from_string(&module.id).unwrap_or_default())
.await?;

Ok(SiPkg::load_from_bytes(module)?)
}

type BuildServiceResult = Result<(
Router,
oneshot::Receiver<()>,
oneshot::Receiver<()>,
broadcast::Receiver<()>,
)>;

pub fn build_service_for_tests(
services_context: ServicesContext,
jwt_public_signing_key: JwtPublicSigningKey,
Expand All @@ -496,7 +492,7 @@ pub fn build_service_for_tests(
crdt_multiplexer_client: MultiplexerClient,
create_workspace_permissions: WorkspacePermissionsMode,
create_workspace_allowlist: Vec<WorkspacePermissions>,
) -> BuildServiceResult {
) -> ServerResult<(Router, oneshot::Receiver<()>, broadcast::Receiver<()>)> {
build_service_inner(
services_context,
jwt_public_signing_key,
Expand All @@ -517,7 +513,7 @@ pub fn build_service(
crdt_multiplexer_client: MultiplexerClient,
create_workspace_permissions: WorkspacePermissionsMode,
create_workspace_allowlist: Vec<WorkspacePermissions>,
) -> BuildServiceResult {
) -> ServerResult<(Router, oneshot::Receiver<()>, broadcast::Receiver<()>)> {
build_service_inner(
services_context,
jwt_public_signing_key,
Expand All @@ -540,7 +536,7 @@ fn build_service_inner(
crdt_multiplexer_client: MultiplexerClient,
create_workspace_permissions: WorkspacePermissionsMode,
create_workspace_allowlist: Vec<WorkspacePermissions>,
) -> BuildServiceResult {
) -> ServerResult<(Router, oneshot::Receiver<()>, broadcast::Receiver<()>)> {
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let (shutdown_broadcast_tx, shutdown_broadcast_rx) = broadcast::channel(1);

Expand Down Expand Up @@ -575,24 +571,17 @@ fn build_service_inner(
.on_response(HttpOnResponse::new().level(Level::DEBUG)),
);

let (graceful_shutdown_rx, application_mode_rx) =
prepare_signal_handlers(shutdown_rx, shutdown_broadcast_tx, mode)?;
let graceful_shutdown_rx = prepare_signal_handlers(shutdown_rx, shutdown_broadcast_tx, mode)?;

Ok((
routes,
graceful_shutdown_rx,
application_mode_rx,
shutdown_broadcast_rx,
))
Ok((routes, graceful_shutdown_rx, shutdown_broadcast_rx))
}

fn prepare_signal_handlers(
mut shutdown_rx: mpsc::Receiver<ShutdownSource>,
shutdown_broadcast_tx: broadcast::Sender<()>,
mode: Arc<RwLock<ApplicationRuntimeMode>>,
) -> Result<(oneshot::Receiver<()>, oneshot::Receiver<()>)> {
) -> ServerResult<oneshot::Receiver<()>> {
let (graceful_shutdown_tx, graceful_shutdown_rx) = oneshot::channel::<()>();
let (_application_mode_tx, application_mode_rx) = oneshot::channel::<()>();

let mut sigterm_watcher =
signal::unix::signal(signal::unix::SignalKind::terminate()).map_err(ServerError::Signal)?;
Expand Down Expand Up @@ -654,7 +643,7 @@ fn prepare_signal_handlers(
}
});

Ok((graceful_shutdown_rx, application_mode_rx))
Ok(graceful_shutdown_rx)
}

#[remain::sorted]
Expand Down

0 comments on commit f4b7278

Please sign in to comment.