diff --git a/Cargo.lock b/Cargo.lock index b38026655d..4f8e0b727f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2227,6 +2227,7 @@ version = "0.1.11" dependencies = [ "anyhow", "arbitrary", + "async-stream", "async-trait", "base64 0.21.4", "bincode", @@ -2241,7 +2242,6 @@ dependencies = [ "hyper-rustls 0.24.1 (registry+https://github.com/rust-lang/crates.io-index)", "libsql-replication", "libsql-sys", - "memmap", "once_cell", "parking_lot", "pprof", @@ -2285,6 +2285,7 @@ name = "libsql-replication" version = "0.1.0" dependencies = [ "arbitrary", + "async-stream", "async-trait", "bincode", "bytemuck", diff --git a/libsql-replication/Cargo.toml b/libsql-replication/Cargo.toml index 5c1467834b..e348f9c0ef 100644 --- a/libsql-replication/Cargo.toml +++ b/libsql-replication/Cargo.toml @@ -20,6 +20,7 @@ tokio-stream = "0.1.14" async-trait = "0.1.74" uuid = { version = "1.5.0", features = ["v4"] } tokio-util = "0.7.9" +async-stream = "0.3.5" [dev-dependencies] arbitrary = { version = "1.3.0", features = ["derive_arbitrary"] } diff --git a/libsql-replication/src/lib.rs b/libsql-replication/src/lib.rs index 45883b0eee..9ed97713f6 100644 --- a/libsql-replication/src/lib.rs +++ b/libsql-replication/src/lib.rs @@ -3,6 +3,7 @@ pub mod injector; pub mod meta; pub mod replicator; pub mod rpc; +pub mod snapshot; mod error; diff --git a/libsql-replication/src/snapshot.rs b/libsql-replication/src/snapshot.rs new file mode 100644 index 0000000000..da54cf5a6f --- /dev/null +++ b/libsql-replication/src/snapshot.rs @@ -0,0 +1,86 @@ +use std::io::SeekFrom; +use std::mem::MaybeUninit; +use std::mem::size_of; +use std::path::Path; + +use tokio::fs::File; +use bytemuck::{pod_read_unaligned, Pod, Zeroable}; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncSeekExt; +use tokio_stream::Stream; +use tokio_stream::StreamExt; + +use crate::frame::Frame; +use crate::frame::FrameBorrowed; +use crate::frame::FrameNo; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Invalid snapshot file")] + InvalidSnapshot, +} + +#[derive(Debug, Copy, Clone, Zeroable, Pod, PartialEq, Eq)] +#[repr(C)] +pub struct SnapshotFileHeader { + /// id of the database + pub log_id: u128, + /// first frame in the snapshot + pub start_frame_no: u64, + /// end frame in the snapshot + pub end_frame_no: u64, + /// number of frames in the snapshot + pub frame_count: u64, + /// safe of the database after applying the snapshot + pub size_after: u32, + pub _pad: u32, +} + +pub struct SnapshotFile { + file: File, + header: SnapshotFileHeader, +} + +impl SnapshotFile { + pub async fn open(path: &Path) -> Result { + let mut file = File::open(path).await?; + let mut header_buf = [0; size_of::()]; + file.read_exact(&mut header_buf).await?; + let header: SnapshotFileHeader = pod_read_unaligned(&header_buf); + + Ok(Self { file, header }) + } + + pub fn into_stream(mut self) -> impl Stream> { + async_stream::try_stream! { + let mut previous_frame_no = None; + self.file.seek(SeekFrom::Start(size_of::() as _)).await?; + for _ in 0..self.header.frame_count { + let mut frame: MaybeUninit = MaybeUninit::uninit(); + let buf = unsafe { std::slice::from_raw_parts_mut(frame.as_mut_ptr() as *mut u8, size_of::()) }; + self.file.read_exact(buf).await?; + let frame = unsafe { frame.assume_init() }; + + if previous_frame_no.is_none() { + previous_frame_no = Some(frame.header().frame_no); + } else if previous_frame_no.unwrap() <= frame.header().frame_no { + // frames in snapshot must be in reverse ordering + Err(Error::InvalidSnapshot)?; + } else { + previous_frame_no = Some(frame.header().frame_no); + } + + yield Frame::from(frame) + } + } + } + + pub fn into_stream_from(self, from: FrameNo) -> impl Stream> { + self.into_stream().take_while(move |f| match f { + Ok(f) => f.header().frame_no >= from, + Err(_) => true, + }) + } +}