From 99a7dc4bbba75e1e85481226a4bfe6505002e200 Mon Sep 17 00:00:00 2001 From: Brian Caswell Date: Thu, 7 Sep 2023 16:52:01 -0400 Subject: [PATCH] provide `SeekableStream` impl for `tokio::fs::File` Addresses #1219 This builds upon #1359, #1358, and #1357. --- sdk/core/Cargo.toml | 4 +- sdk/core/src/lib.rs | 3 + sdk/core/src/tokio/fs.rs | 161 +++++++++++++++++++ sdk/core/src/tokio/mod.rs | 1 + sdk/storage_blobs/Cargo.toml | 4 +- sdk/storage_blobs/examples/stream_blob_02.rs | 92 +++++++++++ 6 files changed, 263 insertions(+), 2 deletions(-) create mode 100644 sdk/core/src/tokio/fs.rs create mode 100644 sdk/core/src/tokio/mod.rs create mode 100644 sdk/storage_blobs/examples/stream_blob_02.rs diff --git a/sdk/core/Cargo.toml b/sdk/core/Cargo.toml index 79eb0e7b0a..ee179e0aae 100644 --- a/sdk/core/Cargo.toml +++ b/sdk/core/Cargo.toml @@ -33,6 +33,7 @@ url = "2.2" uuid = { version = "1.0" } pin-project = "1.0" paste = "1.0" +tokio = {version="1.0", optional=true} # Add dependency to getrandom to enable WASM support [target.'cfg(target_arch = "wasm32")'.dependencies] @@ -43,7 +44,7 @@ rustc_version = "0.4" [dev-dependencies] env_logger = "0.10" -tokio = { version = "1", features = ["default"] } +tokio = { version = "1.0", features = ["default"] } thiserror = "1.0" [features] @@ -54,3 +55,4 @@ enable_reqwest_rustls = ["reqwest/rustls-tls"] test_e2e = [] azurite_workaround = [] xml = ["quick-xml"] +tokio-fs = ["tokio/fs", "tokio/io-util"] diff --git a/sdk/core/src/lib.rs b/sdk/core/src/lib.rs index 210d2571ed..8bdf752815 100644 --- a/sdk/core/src/lib.rs +++ b/sdk/core/src/lib.rs @@ -43,6 +43,9 @@ use uuid::Uuid; #[cfg(feature = "xml")] pub mod xml; +#[cfg(feature = "tokio")] +pub mod tokio; + pub mod base64; pub use bytes_stream::*; pub use constants::*; diff --git a/sdk/core/src/tokio/fs.rs b/sdk/core/src/tokio/fs.rs new file mode 100644 index 0000000000..0b522bde44 --- /dev/null +++ b/sdk/core/src/tokio/fs.rs @@ -0,0 +1,161 @@ +use crate::{request::Body, seekable_stream::SeekableStream, setters}; +use futures::{task::Poll, Future}; +use std::{cmp::min, io::SeekFrom, pin::Pin, sync::Arc, task::Context}; +use tokio::{ + fs::File, + io::{AsyncReadExt, AsyncSeekExt, Take}, + sync::Mutex, +}; + +#[derive(Debug)] +pub struct FileStreamBuilder { + handle: File, + /// Offset into the file to start reading from + offset: Option, + /// Amount of data to read from the file + buffer_size: Option, + /// How much to buffer in memory during streaming reads + block_size: Option, +} + +impl FileStreamBuilder { + pub fn new(handle: File) -> Self { + Self { + handle, + offset: None, + buffer_size: None, + block_size: None, + } + } + + setters! { + // #[doc = "Offset into the file to start reading from"] + offset: u64 => Some(offset), + // #[doc = "Amount of data to read from the file"] + block_size: u64 => Some(block_size), + // #[doc = "Amount of data to buffer in memory during streaming reads"] + buffer_size: usize => Some(buffer_size), + } + + pub async fn build(mut self) -> crate::Result { + let stream_size = self.handle.metadata().await?.len(); + + let buffer_size = self.buffer_size.unwrap_or(1024 * 64); + + let offset = if let Some(offset) = self.offset { + self.handle.seek(SeekFrom::Start(offset)).await?; + offset + } else { + 0 + }; + + let block_size = if let Some(block_size) = self.block_size { + block_size + } else { + stream_size - offset + }; + + let handle = Arc::new(Mutex::new(self.handle.take(block_size))); + + Ok(FileStream { + handle, + buffer_size, + block_size, + stream_size, + offset, + }) + } +} + +#[derive(Debug, Clone)] +#[pin_project::pin_project] +pub struct FileStream { + #[pin] + handle: Arc>>, + pub stream_size: u64, + pub block_size: u64, + buffer_size: usize, + pub offset: u64, +} + +impl FileStream { + async fn read(&mut self, slice: &mut [u8]) -> std::io::Result { + let mut handle = self.handle.clone().lock_owned().await; + handle.read(slice).await + } + + /// Resets the number of bytes that will be read from this instance to the + /// `stream_size` + /// + /// This is useful if you want to read the stream in mutliple blocks + pub async fn next_block(&mut self) -> crate::Result<()> { + log::info!("setting limit to {}", self.block_size); + let mut handle = self.handle.clone().lock_owned().await; + { + let inner = handle.get_mut(); + self.offset = inner.stream_position().await?; + } + handle.set_limit(self.block_size); + Ok(()) + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl SeekableStream for FileStream { + /// Seek to the specified offset into the file and reset the number of bytes to read + /// + /// This is useful upon encountering an error to reset the stream to the last + async fn reset(&mut self) -> crate::Result<()> { + log::info!( + "resetting stream to offset {} and limit to {}", + self.offset, + self.block_size + ); + let mut handle = self.handle.clone().lock_owned().await; + { + let inner = handle.get_mut(); + inner.seek(SeekFrom::Start(self.offset)).await?; + } + handle.set_limit(self.block_size); + Ok(()) + } + + fn len(&self) -> usize { + log::info!( + "stream len: {} - {} ... {}", + self.stream_size, + self.offset, + self.block_size + ); + min(self.stream_size - self.offset, self.block_size) as usize + } + + /* + fn buffer_size(&self) -> usize { + self.buffer_size + } + */ +} + +impl futures::io::AsyncRead for FileStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + slice: &mut [u8], + ) -> Poll> { + std::pin::pin!(self.read(slice)).poll(cx) + } +} + +impl From<&FileStream> for Body { + fn from(stream: &FileStream) -> Self { + Body::SeekableStream(Box::new(stream.clone())) + } +} + +impl From for Body { + fn from(stream: FileStream) -> Self { + Body::SeekableStream(Box::new(stream)) + } +} diff --git a/sdk/core/src/tokio/mod.rs b/sdk/core/src/tokio/mod.rs new file mode 100644 index 0000000000..d521fbd77e --- /dev/null +++ b/sdk/core/src/tokio/mod.rs @@ -0,0 +1 @@ +pub mod fs; diff --git a/sdk/storage_blobs/Cargo.toml b/sdk/storage_blobs/Cargo.toml index 29d9b9a42d..7f40e8c625 100644 --- a/sdk/storage_blobs/Cargo.toml +++ b/sdk/storage_blobs/Cargo.toml @@ -28,13 +28,15 @@ uuid = { version = "1.0", features = ["v4"] } url = "2.2" [dev-dependencies] -tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } +tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "io-util"]} env_logger = "0.10" azure_identity = { path = "../identity", default-features = false } reqwest = "0.11" mock_transport = { path = "../../eng/test/mock_transport" } md5 = "0.7" async-trait = "0.1" +clap = { version = "4.0", features = ["derive", "env"] } +azure_core = {path = "../core", version = "0.14", features = ["tokio-fs"]} [features] default = ["enable_reqwest"] diff --git a/sdk/storage_blobs/examples/stream_blob_02.rs b/sdk/storage_blobs/examples/stream_blob_02.rs new file mode 100644 index 0000000000..03f4b401f7 --- /dev/null +++ b/sdk/storage_blobs/examples/stream_blob_02.rs @@ -0,0 +1,92 @@ +use azure_core::{ + error::{ErrorKind, ResultExt}, + tokio::fs::FileStreamBuilder, +}; +use azure_storage::prelude::*; +use azure_storage_blobs::prelude::*; +use clap::Parser; +use std::path::PathBuf; +use tokio::fs::File; + +#[derive(Debug, Parser)] +struct Args { + /// Name of the container to upload + container_name: String, + /// Blob name + blob_name: String, + /// File path to upload + file_path: PathBuf, + + /// Offset to start uploading from + #[clap(long)] + offset: Option, + + /// how much to buffer in memory during streaming reads + #[clap(long)] + buffer_size: Option, + + #[clap(long)] + block_size: Option, + + /// storage account name + #[clap(env = "STORAGE_ACCOUNT")] + account: String, + + /// storage account access key + #[clap(env = "STORAGE_ACCESS_KEY")] + access_key: String, +} + +#[tokio::main] +async fn main() -> azure_core::Result<()> { + env_logger::init(); + let args = Args::parse(); + + let storage_credentials = + StorageCredentials::Key(args.account.clone(), args.access_key.clone()); + let blob_client = BlobServiceClient::new(&args.account, storage_credentials) + .container_client(&args.container_name) + .blob_client(&args.blob_name); + + let file = File::open(&args.file_path).await?; + + let mut builder = FileStreamBuilder::new(file); + + if let Some(buffer_size) = args.buffer_size { + builder = builder.buffer_size(buffer_size); + } + + if let Some(offset) = args.offset { + builder = builder.offset(offset); + } + + if let Some(block_size) = args.block_size { + builder = builder.block_size(block_size); + } + + let mut handle = builder.build().await?; + + if let Some(block_size) = args.block_size { + let mut block_list = BlockList::default(); + for offset in (handle.offset..handle.stream_size).step_by(block_size as usize) { + log::info!("trying to upload at offset {offset} - {block_size}"); + let block_id = format!("{:08X}", offset); + blob_client.put_block(block_id.clone(), &handle).await?; + log::info!("uploaded block {block_id}"); + block_list + .blocks + .push(BlobBlockType::new_uncommitted(block_id)); + handle.next_block().await?; + } + blob_client.put_block_list(block_list).await?; + } else { + // upload as one large block + blob_client.put_block_blob(handle).await?; + } + + let blob = blob_client.get_content().await?; + let s = String::from_utf8(blob).map_kind(ErrorKind::DataConversion)?; + println!("retrieved contents == {s:?}"); + + Ok(()) +}