diff --git a/libsql/src/replication/connection.rs b/libsql/src/replication/connection.rs index ed09b09539..588c29389d 100644 --- a/libsql/src/replication/connection.rs +++ b/libsql/src/replication/connection.rs @@ -57,6 +57,15 @@ impl State { (State::TxnReadOnly, StmtKind::TxnEnd) | (State::Txn, StmtKind::TxnEnd) => State::Init, + // Savepoint only makes sense within a transaction and doesn't change the transaction kind + (State::TxnReadOnly, StmtKind::Savepoint) => State::TxnReadOnly, + (State::Txn, StmtKind::Savepoint) => State::Txn, + (_, StmtKind::Savepoint) => State::Invalid, + // Releasing a savepoint only makes sense inside a transaction and it doesn't change its state + (State::TxnReadOnly, StmtKind::Release) => State::TxnReadOnly, + (State::Txn, StmtKind::Release) => State::Txn, + (_, StmtKind::Release) => State::Invalid, + (state, StmtKind::Other | StmtKind::Write | StmtKind::Read) => state, (State::Invalid, _) => State::Invalid, diff --git a/libsql/src/replication/parser.rs b/libsql/src/replication/parser.rs index 2c72b4763f..5cd4651f5c 100644 --- a/libsql/src/replication/parser.rs +++ b/libsql/src/replication/parser.rs @@ -26,6 +26,8 @@ pub enum StmtKind { TxnEnd, Read, Write, + Savepoint, + Release, Other, } @@ -52,7 +54,13 @@ impl StmtKind { Some(Self::TxnBeginReadOnly) } Cmd::Stmt(Stmt::Begin { .. }) => Some(Self::TxnBegin), - Cmd::Stmt(Stmt::Commit { .. } | Stmt::Rollback { .. }) => Some(Self::TxnEnd), + Cmd::Stmt( + Stmt::Commit { .. } + | Stmt::Rollback { + savepoint_name: None, + .. + }, + ) => Some(Self::TxnEnd), Cmd::Stmt( Stmt::CreateVirtualTable { tbl_name, .. } | Stmt::CreateTable { @@ -100,6 +108,12 @@ impl StmtKind { temporary: false, .. }) => Some(Self::Write), Cmd::Stmt(Stmt::DropView { .. }) => Some(Self::Write), + Cmd::Stmt(Stmt::Savepoint(_)) => Some(Self::Savepoint), + Cmd::Stmt(Stmt::Release(_)) + | Cmd::Stmt(Stmt::Rollback { + savepoint_name: Some(_), + .. + }) => Some(Self::Release), _ => None, } } @@ -168,6 +182,22 @@ impl StmtKind { }, } } + + /// Returns `true` if the stmt kind is [`Savepoint`]. + /// + /// [`Savepoint`]: StmtKind::Savepoint + #[must_use] + pub fn is_savepoint(&self) -> bool { + matches!(self, Self::Savepoint) + } + + /// Returns `true` if the stmt kind is [`Release`]. + /// + /// [`Release`]: StmtKind::Release + #[must_use] + pub fn is_release(&self) -> bool { + matches!(self, Self::Release) + } } impl Statement {