diff --git a/libsql-server/src/database.rs b/libsql-server/src/database.rs index 291f361908..7193868804 100644 --- a/libsql-server/src/database.rs +++ b/libsql-server/src/database.rs @@ -4,13 +4,20 @@ use crate::connection::libsql::LibSqlConnection; use crate::connection::write_proxy::{RpcStream, WriteProxyConnection}; use crate::connection::{Connection, MakeConnection, TrackedConnection}; use crate::replication::{ReplicationLogger, ReplicationLoggerHook}; +use async_trait::async_trait; +pub type Result = anyhow::Result; + +#[async_trait] pub trait Database: Sync + Send + 'static { /// The connection type of the database type Connection: Connection; fn connection_maker(&self) -> Arc>; - fn shutdown(&self); + + fn destroy(self); + + async fn shutdown(self) -> Result<()>; } pub struct ReplicaDatabase { @@ -18,6 +25,7 @@ pub struct ReplicaDatabase { Arc>>>, } +#[async_trait] impl Database for ReplicaDatabase { type Connection = TrackedConnection>; @@ -25,7 +33,11 @@ impl Database for ReplicaDatabase { self.connection_maker.clone() } - fn shutdown(&self) {} + fn destroy(self) {} + + async fn shutdown(self) -> Result<()> { + Ok(()) + } } pub type PrimaryConnection = TrackedConnection>; @@ -35,6 +47,7 @@ pub struct PrimaryDatabase { pub connection_maker: Arc>, } +#[async_trait] impl Database for PrimaryDatabase { type Connection = PrimaryConnection; @@ -42,7 +55,18 @@ impl Database for PrimaryDatabase { self.connection_maker.clone() } - fn shutdown(&self) { + fn destroy(self) { + self.logger.closed_signal.send_replace(true); + } + + async fn shutdown(self) -> Result<()> { self.logger.closed_signal.send_replace(true); + if let Some(replicator) = &self.logger.bottomless_replicator { + let replicator = replicator.lock().unwrap().take(); + if let Some(mut replicator) = replicator { + replicator.wait_until_snapshotted().await?; + } + } + Ok(()) } } diff --git a/libsql-server/src/error.rs b/libsql-server/src/error.rs index 344c3e01d3..110a24cf04 100644 --- a/libsql-server/src/error.rs +++ b/libsql-server/src/error.rs @@ -82,7 +82,6 @@ pub enum Error { Fork(#[from] ForkError), #[error("Fatal replication error")] FatalReplicationError, - #[error("Connection with primary broken")] PrimaryStreamDisconnect, #[error("Proxy protocal misuse")] @@ -91,6 +90,8 @@ pub enum Error { PrimaryStreamInterupted, #[error("Wrong URL: {0}")] UrlParseError(#[from] url::ParseError), + #[error("Namespace store has shutdown")] + NamespaceStoreShutdown, } trait ResponseError: std::error::Error { @@ -148,6 +149,7 @@ impl IntoResponse for Error { PrimaryStreamMisuse => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), PrimaryStreamInterupted => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), UrlParseError(_) => self.format_err(StatusCode::BAD_REQUEST), + NamespaceStoreShutdown => self.format_err(StatusCode::SERVICE_UNAVAILABLE), } } } diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index a66b95fecd..8311e1240d 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -1,6 +1,8 @@ #![allow(clippy::type_complexity, clippy::too_many_arguments)] +use std::future::Future; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::process::Command; use std::str::FromStr; use std::sync::{Arc, Weak}; @@ -370,6 +372,7 @@ where let snapshot_callback = self.make_snapshot_callback(); let auth = self.user_api_config.get_auth().map(Arc::new)?; let extensions = self.db_config.validate_extensions()?; + let namespace_store_shutdown_fut: Pin> + Send>>; match self.rpc_client_config { Some(rpc_config) => { @@ -385,6 +388,10 @@ where let (namespaces, proxy_service, replication_service) = replica.configure().await?; self.rpc_client_config = None; self.spawn_monitoring_tasks(&mut join_set, stats_receiver, namespaces.clone())?; + namespace_store_shutdown_fut = { + let namespaces = namespaces.clone(); + Box::pin(async move { namespaces.shutdown().await }) + }; let services = Services { namespaces, @@ -420,6 +427,10 @@ where let (namespaces, proxy_service, replication_service) = primary.configure().await?; self.rpc_server_config = None; self.spawn_monitoring_tasks(&mut join_set, stats_receiver, namespaces.clone())?; + namespace_store_shutdown_fut = { + let namespaces = namespaces.clone(); + Box::pin(async move { namespaces.shutdown().await }) + }; let services = Services { namespaces, @@ -442,6 +453,7 @@ where tokio::select! { _ = self.shutdown.notified() => { join_set.shutdown().await; + namespace_store_shutdown_fut.await?; // clean shutdown, remove sentinel file std::fs::remove_file(sentinel_file_path(&self.path))?; } diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 4689324673..205a2340ec 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -22,12 +22,14 @@ use tokio::task::JoinSet; use tokio::time::{Duration, Instant}; use tokio_util::io::StreamReader; use tonic::transport::Channel; +use tracing::trace; use uuid::Uuid; use crate::auth::Authenticated; use crate::connection::config::DatabaseConfigStore; use crate::connection::libsql::{open_conn, MakeLibSqlConn}; use crate::connection::write_proxy::MakeWriteProxyConn; +use crate::connection::Connection; use crate::connection::MakeConnection; use crate::database::{Database, PrimaryDatabase, ReplicaDatabase}; use crate::error::{Error, LoadDumpError}; @@ -281,6 +283,7 @@ struct NamespaceStoreInner { /// The namespace factory, to create new namespaces. make_namespace: M, allow_lazy_creation: bool, + has_shutdown: RwLock, } impl NamespaceStore { @@ -290,11 +293,15 @@ impl NamespaceStore { store: Default::default(), make_namespace, allow_lazy_creation, + has_shutdown: RwLock::new(false), }), } } pub async fn destroy(&self, namespace: NamespaceName) -> crate::Result<()> { + if *self.inner.has_shutdown.read().await { + return Err(Error::NamespaceStoreShutdown); + } let mut lock = self.inner.store.write().await; if let Some(ns) = lock.remove(&namespace) { // FIXME: when destroying, we are waiting for all the tasks associated with the @@ -320,7 +327,10 @@ impl NamespaceStore { &self, namespace: NamespaceName, restore_option: RestoreOption, - ) -> anyhow::Result<()> { + ) -> crate::Result<()> { + if *self.inner.has_shutdown.read().await { + return Err(Error::NamespaceStoreShutdown); + } let mut lock = self.inner.store.write().await; if let Some(ns) = lock.remove(&namespace) { // FIXME: when destroying, we are waiting for all the tasks associated with the @@ -379,6 +389,9 @@ impl NamespaceStore { to: NamespaceName, timestamp: Option, ) -> crate::Result<()> { + if *self.inner.has_shutdown.read().await { + return Err(Error::NamespaceStoreShutdown); + } let mut lock = self.inner.store.write().await; if lock.contains_key(&to) { return Err(crate::error::Error::NamespaceAlreadyExist( @@ -424,6 +437,9 @@ impl NamespaceStore { where Fun: FnOnce(&Namespace) -> R, { + if *self.inner.has_shutdown.read().await { + return Err(Error::NamespaceStoreShutdown); + } if !auth.is_namespace_authorized(&namespace) { return Err(Error::NamespaceDoesntExist(namespace.to_string())); } @@ -435,6 +451,9 @@ impl NamespaceStore { where Fun: FnOnce(&Namespace) -> R, { + if *self.inner.has_shutdown.read().await { + return Err(Error::NamespaceStoreShutdown); + } let before_load = Instant::now(); let lock = self.inner.store.upgradable_read().await; if let Some(ns) = lock.get(&namespace) { @@ -466,6 +485,9 @@ impl NamespaceStore { namespace: NamespaceName, restore_option: RestoreOption, ) -> crate::Result<()> { + if *self.inner.has_shutdown.read().await { + return Err(Error::NamespaceStoreShutdown); + } let lock = self.inner.store.upgradable_read().await; if lock.contains_key(&namespace) { return Err(crate::error::Error::NamespaceAlreadyExist( @@ -491,6 +513,17 @@ impl NamespaceStore { Ok(()) } + pub async fn shutdown(self) -> crate::Result<()> { + let mut has_shutdown = self.inner.has_shutdown.write().await; + *has_shutdown = true; + let mut lock = self.inner.store.write().await; + for (name, ns) in lock.drain() { + ns.shutdown().await?; + trace!("shutdown namespace: `{}`", name); + } + Ok(()) + } + pub(crate) async fn stats(&self, namespace: NamespaceName) -> crate::Result> { self.with(namespace, |ns| ns.stats.clone()).await } @@ -520,9 +553,22 @@ impl Namespace { } async fn destroy(mut self) -> anyhow::Result<()> { - self.db.shutdown(); self.tasks.shutdown().await; + self.db.destroy(); + Ok(()) + } + async fn checkpoint(&self) -> anyhow::Result<()> { + let conn = self.db.connection_maker().create().await?; + conn.vacuum_if_needed().await?; + conn.checkpoint().await?; + Ok(()) + } + + async fn shutdown(mut self) -> anyhow::Result<()> { + self.tasks.shutdown().await; + self.checkpoint().await?; + self.db.shutdown().await?; Ok(()) } } @@ -763,7 +809,7 @@ impl Namespace { } is_dirty |= did_recover; - Some(Arc::new(std::sync::Mutex::new(replicator))) + Some(Arc::new(std::sync::Mutex::new(Some(replicator)))) } else { None }; @@ -787,6 +833,7 @@ impl Namespace { let cb = config.snapshot_callback.clone(); move |path: &Path| cb(path, &name) }), + bottomless_replicator.clone(), )?); let ctx_builder = { diff --git a/libsql-server/src/replication/primary/logger.rs b/libsql-server/src/replication/primary/logger.rs index 772ef80ba5..5ee92f074b 100644 --- a/libsql-server/src/replication/primary/logger.rs +++ b/libsql-server/src/replication/primary/logger.rs @@ -4,9 +4,10 @@ use std::io::Write; use std::mem::size_of; use std::os::unix::prelude::FileExt; use std::path::{Path, PathBuf}; -use std::sync::Arc; +use std::sync::{Arc, MutexGuard}; use anyhow::{bail, ensure}; +use bottomless::replicator::Replicator; use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; use bytes::{Bytes, BytesMut}; use libsql_replication::frame::{Frame, FrameHeader, FrameMut}; @@ -51,7 +52,7 @@ pub enum ReplicationLoggerHook {} pub struct ReplicationLoggerHookCtx { buffer: Vec, logger: Arc, - bottomless_replicator: Option>>, + bottomless_replicator: Option>>>, } /// This implementation of WalHook intercepts calls to `on_frame`, and writes them to a @@ -84,6 +85,21 @@ unsafe impl WalHook for ReplicationLoggerHook { tracing::trace!("Last valid frame before applying: {last_valid_frame}"); let ctx = Self::wal_extract_ctx(wal); + let mut binding = ctx.bottomless_replicator.clone(); + let replicator_mutex: Option<&mut Replicator>; + let mut replicator_lock: MutexGuard>; + if let Some(replicator) = binding.as_mut() { + replicator_lock = replicator.lock().unwrap(); + if replicator_lock.is_none() { + tracing::error!("fatal error: expected a replicator, exiting"); + std::process::abort() + } else { + replicator_mutex = Option::from(replicator_lock.as_mut().unwrap()); + } + } else { + replicator_mutex = None; + } + let mut frame_count = 0; for (page_no, data) in PageHdrIter::new(page_headers, page_size as _) { ctx.write_frame(page_no, data); @@ -118,8 +134,7 @@ unsafe impl WalHook for ReplicationLoggerHook { // do backup after log replication as we don't want to replicate potentially // inconsistent frames - if let Some(replicator) = ctx.bottomless_replicator.as_mut() { - let mut replicator = replicator.lock().unwrap(); + if let Some(replicator) = replicator_mutex { replicator.register_last_valid_frame(last_valid_frame); if let Err(e) = replicator.set_page_size(page_size as usize) { tracing::error!("fatal error during backup: {e}, exiting"); @@ -162,12 +177,16 @@ unsafe impl WalHook for ReplicationLoggerHook { let ctx = Self::wal_extract_ctx(wal); if let Some(replicator) = ctx.bottomless_replicator.as_mut() { let last_valid_frame = unsafe { *wal_data }; - let mut replicator = replicator.lock().unwrap(); - let prev_valid_frame = replicator.peek_last_valid_frame(); - tracing::trace!( + if let Some(replicator) = replicator.lock().unwrap().as_mut() { + let prev_valid_frame = replicator.peek_last_valid_frame(); + tracing::trace!( "Savepoint: rolling back from frame {prev_valid_frame} to {last_valid_frame}", ); - replicator.rollback_to_frame(last_valid_frame); + replicator.rollback_to_frame(last_valid_frame); + } else { + tracing::error!("fatal error: expected a replicator, exiting"); + std::process::abort() + } } } @@ -215,29 +234,36 @@ unsafe impl WalHook for ReplicationLoggerHook { let ctx = Self::wal_extract_ctx(wal); let runtime = tokio::runtime::Handle::current(); if let Some(replicator) = ctx.bottomless_replicator.as_mut() { - let mut replicator = replicator.lock().unwrap(); - let last_known_frame = replicator.last_known_frame(); - replicator.request_flush(); - if last_known_frame == 0 { - tracing::debug!("No committed changes in this generation, not snapshotting"); - replicator.skip_snapshot_for_current_generation(); - return SQLITE_OK; - } - if let Err(e) = runtime.block_on(replicator.wait_until_committed(last_known_frame)) - { - tracing::error!( - "Failed to wait for S3 replicator to confirm {} frames backup: {}", - last_known_frame, - e - ); - return SQLITE_IOERR_WRITE; - } - if let Err(e) = runtime.block_on(replicator.wait_until_snapshotted()) { - tracing::error!( + if let Some(replicator) = replicator.lock().unwrap().as_mut() { + let last_known_frame = replicator.last_known_frame(); + replicator.request_flush(); + if last_known_frame == 0 { + tracing::debug!( + "No committed changes in this generation, not snapshotting" + ); + replicator.skip_snapshot_for_current_generation(); + return SQLITE_OK; + } + if let Err(e) = + runtime.block_on(replicator.wait_until_committed(last_known_frame)) + { + tracing::error!( + "Failed to wait for S3 replicator to confirm {} frames backup: {}", + last_known_frame, + e + ); + return SQLITE_IOERR_WRITE; + } + if let Err(e) = runtime.block_on(replicator.wait_until_snapshotted()) { + tracing::error!( "Failed to wait for S3 replicator to confirm database snapshot backup: {}", e ); - return SQLITE_IOERR_WRITE; + return SQLITE_IOERR_WRITE; + } + } else { + tracing::error!("fatal error: expected a replicator, exiting"); + std::process::abort() } } } @@ -266,13 +292,19 @@ unsafe impl WalHook for ReplicationLoggerHook { let ctx = Self::wal_extract_ctx(wal); let runtime = tokio::runtime::Handle::current(); if let Some(replicator) = ctx.bottomless_replicator.as_mut() { - let mut replicator = replicator.lock().unwrap(); - let _prev = replicator.new_generation(); - if let Err(e) = - runtime.block_on(async move { replicator.snapshot_main_db_file().await }) - { - tracing::error!("Failed to snapshot the main db file during checkpoint: {e}"); - return SQLITE_IOERR_WRITE; + if let Some(replicator) = replicator.lock().unwrap().as_mut() { + let _prev = replicator.new_generation(); + if let Err(e) = + runtime.block_on(async move { replicator.snapshot_main_db_file().await }) + { + tracing::error!( + "Failed to snapshot the main db file during checkpoint: {e}" + ); + return SQLITE_IOERR_WRITE; + } + } else { + tracing::error!("fatal error: expected a replicator, exiting"); + std::process::abort() } } } @@ -291,7 +323,7 @@ pub struct WalPage { impl ReplicationLoggerHookCtx { pub fn new( logger: Arc, - bottomless_replicator: Option>>, + bottomless_replicator: Option>>>, ) -> Self { if bottomless_replicator.is_some() { tracing::trace!("bottomless replication enabled"); @@ -783,6 +815,7 @@ pub struct ReplicationLogger { pub new_frame_notifier: watch::Sender>, pub closed_signal: watch::Sender, pub auto_checkpoint: u32, + pub bottomless_replicator: Option>>>, } impl ReplicationLogger { @@ -793,6 +826,7 @@ impl ReplicationLogger { dirty: bool, auto_checkpoint: u32, callback: SnapshotCallback, + bottomless_replicator: Option>>>, ) -> anyhow::Result { let log_path = db_path.join("wallog"); let data_path = db_path.join("data"); @@ -823,9 +857,21 @@ impl ReplicationLogger { }; if should_recover { - Self::recover(log_file, data_path, callback, auto_checkpoint) + Self::recover( + log_file, + data_path, + callback, + auto_checkpoint, + bottomless_replicator, + ) } else { - Self::from_log_file(db_path.to_path_buf(), log_file, callback, auto_checkpoint) + Self::from_log_file( + db_path.to_path_buf(), + log_file, + callback, + auto_checkpoint, + bottomless_replicator, + ) } } @@ -834,6 +880,7 @@ impl ReplicationLogger { log_file: LogFile, callback: SnapshotCallback, auto_checkpoint: u32, + bottomless_replicator: Option>>>, ) -> anyhow::Result { let header = log_file.header(); let generation_start_frame_no = header.last_frame_no(); @@ -867,6 +914,7 @@ impl ReplicationLogger { closed_signal, new_frame_notifier, auto_checkpoint, + bottomless_replicator, }) } @@ -875,6 +923,7 @@ impl ReplicationLogger { mut data_path: PathBuf, callback: SnapshotCallback, auto_checkpoint: u32, + bottomless_replicator: Option>>>, ) -> anyhow::Result { // It is necessary to checkpoint before we restore the replication log, since the WAL may // contain pages that are not in the database file. @@ -907,7 +956,13 @@ impl ReplicationLogger { assert!(data_path.pop()); - Self::from_log_file(data_path, log_file, callback, auto_checkpoint) + Self::from_log_file( + data_path, + log_file, + callback, + auto_checkpoint, + bottomless_replicator, + ) } pub fn log_id(&self) -> Uuid { @@ -1052,6 +1107,7 @@ mod test { false, DEFAULT_AUTO_CHECKPOINT, Box::new(|_| Ok(())), + None, ) .unwrap(); @@ -1088,6 +1144,7 @@ mod test { false, DEFAULT_AUTO_CHECKPOINT, Box::new(|_| Ok(())), + None, ) .unwrap(); let log_file = logger.log_file.write(); @@ -1105,6 +1162,7 @@ mod test { false, DEFAULT_AUTO_CHECKPOINT, Box::new(|_| Ok(())), + None, ) .unwrap(); let entry = WalPage {