Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Jan 13, 2025
1 parent 6c72fa0 commit 7c4b89a
Showing 1 changed file with 38 additions and 42 deletions.
80 changes: 38 additions & 42 deletions xmtp_mls/src/groups/device_sync/backup/backup_exporter.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
use super::{backup_stream::BackupStream, BackupOptions};
use crate::{groups::device_sync::DeviceSyncError, XmtpOpenMlsProvider};
use futures::StreamExt;
use prost::Message;
use std::{
io::{BufWriter, Read},
io::{Read, Write},
path::Path,
sync::Arc,
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
use xmtp_proto::xmtp::device_sync::BackupMetadata;
use zstd::stream::Encoder;

pub(super) struct BackupExporter<'a> {
stage: Stage,
buffer: Option<Vec<u8>>,
metadata: BackupMetadata,
stream: BackupStream,
position: usize,
deflate_encoder: Encoder<'a, BufWriter<Vec<u8>>>,
deflate_encoder: Encoder<'a, Vec<u8>>,
}

#[derive(Default)]
Expand All @@ -29,72 +26,71 @@ pub(super) enum Stage {

impl<'a> BackupExporter<'a> {
pub(super) fn new(opts: BackupOptions, provider: &Arc<XmtpOpenMlsProvider>) -> Self {
let buffer = BufWriter::new(Vec::new());
Self {
position: 0,
stage: Stage::default(),
stream: BackupStream::new(&opts, provider),
metadata: opts.into(),
buffer: Some(Vec::new()),
deflate_encoder: Encoder::new(buffer, 0).unwrap(),
deflate_encoder: Encoder::new(Vec::new(), 0).unwrap(),
}
}

pub async fn write_to_file(&mut self, path: impl AsRef<Path>) -> Result<(), DeviceSyncError> {
let mut file = tokio::fs::File::create(path.as_ref()).await?;
pub fn write_to_file(&mut self, path: impl AsRef<Path>) -> Result<(), DeviceSyncError> {
let mut file = std::fs::File::create(path.as_ref())?;
let mut buffer = [0u8; 1024];

let mut amount = self.read(&mut buffer)?;
while amount != 0 {
file.write_all(&buffer[..amount]).await?;
file.write_all(&buffer[..amount])?;
amount = self.read(&mut buffer)?;
}

file.flush().await?;
file.flush()?;

Ok(())
}
}

impl<'a> Read for BackupExporter<'a> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut buffer_inner = self.buffer.take().expect("This should always be here.");
if self.position < buffer_inner.len() {
let available = &buffer_inner[self.position..];
let amount = available.len().min(buf.len());
buf[..amount].clone_from_slice(&available[..amount]);
self.position += amount;
self.buffer = Some(buffer_inner);
return Ok(amount);
{
let buffer_inner = self.deflate_encoder.get_ref();
if self.position < buffer_inner.len() {
let available = &buffer_inner[self.position..];
let amount = available.len().min(buf.len());

buf[..amount].clone_from_slice(&available[..amount]);
self.position += amount;
return Ok(amount);
}
}

// The buffer is consumed. Reset.
self.position = 0;
buffer_inner.clear();
self.deflate_encoder.get_mut().clear();

// Time to fill the buffer with more data.
let mut buffer = ReadBuf::new(&mut buffer_inner);

match self.stage {
Stage::Metadata => {
buffer.put_slice(&serde_json::to_vec(&self.metadata)?);
self.stage = Stage::Elements;
}
Stage::Elements => match self.stream.next() {
Some(element) => {
element.encode(&mut buffer)?;
let mut byte_count = 0;
while byte_count < 8_000 {
let bytes = match self.stage {
Stage::Metadata => {
self.stage = Stage::Elements;
serde_json::to_vec(&self.metadata)?
}
None => {}
},
};

let filled = buffer.filled();
let amount = filled.len().min(buf.len());
buf[..amount].clone_from_slice(&filled[..amount]);
self.position = amount;

self.buffer = Some(buffer_inner);
Stage::Elements => match self.stream.next() {
Some(element) => element.encode_to_vec(),
None => break,
},
};
byte_count += bytes.len();
self.deflate_encoder.write(&bytes)?;
}
self.deflate_encoder.flush()?;

Ok(amount)
if byte_count > 0 {
self.read(buf)
} else {
Ok(0)
}
}
}

0 comments on commit 7c4b89a

Please sign in to comment.