diff --git a/Cargo.toml b/Cargo.toml index 712c16fc..026c6ddf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "dash/controller", "dash/gateway", "dash/openapi", + "dash/pipe/connectors/storage", "dash/pipe/connectors/webcam", # exclude(alpine) "dash/pipe/functions/identity", "dash/pipe/functions/python", # exclude(alpine) @@ -57,6 +58,7 @@ actix-web = { version = "=4.4", default-features = false, features = [ anyhow = { version = "=1.0", features = ["backtrace"] } argon2 = { version = "=0.5" } async-recursion = { version = "=1.0" } +async-stream = { version = "=0.3" } async-trait = { version = "=0.1" } base64 = { version = "=0.21" } byteorder = { version = "=1.4" } @@ -67,7 +69,7 @@ chrono = { version = "=0.4", features = ["serde"] } clap = { version = "=4.4", features = ["env", "derive"] } csv = { version = "=1.2" } ctrlc = { version = "=3.4" } -deltalake = { version = "0.16", default-features = false } +deltalake = { version = "=0.16", default-features = false } email_address = { version = "=0.2" } futures = { version = "=0.3" } gethostname = { version = "=0.4" } @@ -87,7 +89,7 @@ k8s-openapi = { version = "=0.20", features = ["schemars", "v1_26"] } kube = { version = "=0.86", default-features = false } language-tags = { version = "=0.3", features = ["serde"] } tracing = { version = "=0.1" } -tracing-subscriber = { version = "0.3" } +tracing-subscriber = { version = "=0.3" } mime = { version = "=0.3" } # FIXME: push a PR: rustls-tls feature support minio = { git = "https://github.com/ulagbulag/minio-rs.git", default-features = false, rev = "5be4686e307b058aa4190134a555c925301c59b2", features = [ @@ -99,8 +101,8 @@ num-traits = { version = "=0.2" } octocrab = { git = "https://github.com/ulagbulag/octocrab.git", default-features = false, features = [ "rustls-tls", ] } -opencv = { version = "0.84", default-features = false } -ordered-float = { version = "4.0", default-features = false, features = [ +opencv = { version = "=0.84", default-features = false } +ordered-float = { version = "=4.0", default-features = false, features = [ "bytemuck", "schemars", "serde", @@ -147,6 +149,7 @@ sio = { git = "https://github.com/ulagbulag/sio-rs.git" } strum = { version = "=0.25", features = ["derive"] } tera = { version = "=1.19" } tokio = { version = "=1.32", features = ["macros", "rt"] } +tokio-stream = { version = "=0.1" } url = { version = "=2.4", features = ["serde"] } uuid = { version = "=1.4", features = ["js", "serde", "v4"] } which = { version = "=4.4" } diff --git a/dash/pipe/connectors/storage/Cargo.toml b/dash/pipe/connectors/storage/Cargo.toml new file mode 100644 index 00000000..064abaa5 --- /dev/null +++ b/dash/pipe/connectors/storage/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "dash-pipe-connector-storage" +version = "0.1.0" +edition = "2021" + +authors = ["Ho Kim "] +description = "Kubernetes Is Simple, Stupid which a part of OpenARK" +documentation = "https://docs.rs/kiss-api" +license = "GPL-3.0-or-later WITH Classpath-exception-2.0" +readme = "../../README.md" +homepage = "https://github.com/ulagbulag/OpenARK" +repository = "https://github.com/ulagbulag/OpenARK" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +dash-pipe-provider = { path = "../../provider" } + +anyhow = { workspace = true } +async-trait = { workspace = true } +clap = { workspace = true } +futures = { workspace = true } +serde = { workspace = true } +tokio = { workspace = true, features = ["time"] } diff --git a/dash/pipe/connectors/storage/src/main.rs b/dash/pipe/connectors/storage/src/main.rs new file mode 100644 index 00000000..cfabfef7 --- /dev/null +++ b/dash/pipe/connectors/storage/src/main.rs @@ -0,0 +1,65 @@ +use std::sync::Arc; + +use anyhow::{bail, Result}; +use async_trait::async_trait; +use clap::{ArgAction, Parser}; +use dash_pipe_provider::{ + FunctionContext, PipeArgs, PipeMessage, PipeMessages, PipePayload, StorageSet, StorageType, + Stream, +}; +use futures::StreamExt; +use serde::{Deserialize, Serialize}; + +fn main() { + PipeArgs::::from_env().loop_forever() +} + +#[derive(Clone, Debug, Serialize, Deserialize, Parser)] +pub struct FunctionArgs { + #[arg(long, env = "PIPE_PERSISTENCE", action = ArgAction::SetTrue)] + #[serde(default)] + persistence: Option, +} + +pub struct Function { + ctx: FunctionContext, + items: Stream, +} + +#[async_trait(?Send)] +impl ::dash_pipe_provider::Function for Function { + type Args = FunctionArgs; + type Input = (); + type Output = usize; + + async fn try_new( + args: &::Args, + ctx: &mut FunctionContext, + storage: &Arc, + ) -> Result { + let storage_type = match args.persistence { + Some(true) => StorageType::PERSISTENT, + Some(false) | None => StorageType::TEMPORARY, + }; + + Ok(Self { + ctx: ctx.clone(), + items: storage.get(storage_type).list().await?, + }) + } + + async fn tick( + &mut self, + _inputs: PipeMessages<::Input>, + ) -> Result::Output>> { + match self.items.next().await { + // TODO: stream이 JSON 메타데이터를 포함한 PipeMessage Object를 배출 + Some(Ok((path, value))) => Ok(PipeMessages::Single(PipeMessage { + payloads: vec![PipePayload::new(path.to_string(), value)], + value: Default::default(), + })), + Some(Err(error)) => bail!("failed to load data: {error}"), + None => self.ctx.terminate_ok(), + } + } +} diff --git a/dash/pipe/connectors/webcam/Cargo.toml b/dash/pipe/connectors/webcam/Cargo.toml index 519ee4d7..baeede82 100644 --- a/dash/pipe/connectors/webcam/Cargo.toml +++ b/dash/pipe/connectors/webcam/Cargo.toml @@ -23,4 +23,3 @@ image = { workspace = true, features = ["png"] } opencv = { workspace = true, features = ["imgcodecs", "videoio"] } serde = { workspace = true } tokio = { workspace = true, features = ["time"] } -tracing = { workspace = true } diff --git a/dash/pipe/connectors/webcam/src/main.rs b/dash/pipe/connectors/webcam/src/main.rs index 1953a6e8..a26c35bb 100644 --- a/dash/pipe/connectors/webcam/src/main.rs +++ b/dash/pipe/connectors/webcam/src/main.rs @@ -1,9 +1,11 @@ -use std::time::Duration; +use std::sync::Arc; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; use clap::{Parser, ValueEnum}; -use dash_pipe_provider::{PipeArgs, PipeMessage, PipeMessages, PipePayload}; +use dash_pipe_provider::{ + FunctionContext, PipeArgs, PipeMessage, PipeMessages, PipePayload, StorageSet, +}; use image::{codecs, RgbImage}; use opencv::{ core::{Mat, MatTraitConst, MatTraitConstManual, Vec3b, Vector}, @@ -11,8 +13,6 @@ use opencv::{ videoio::{self, VideoCapture, VideoCaptureTrait, VideoCaptureTraitConst}, }; use serde::{Deserialize, Serialize}; -use tokio::time::sleep; -use tracing::error; fn main() { PipeArgs::::from_env().loop_forever() @@ -69,6 +69,7 @@ impl CameraEncoder { pub struct Function { camera_encoder: CameraEncoder, capture: VideoCapture, + ctx: FunctionContext, frame: Mat, frame_counter: FrameCounter, frame_size: FrameSize, @@ -81,7 +82,11 @@ impl ::dash_pipe_provider::Function for Function { type Input = (); type Output = usize; - async fn try_new(args: &::Args) -> Result { + async fn try_new( + args: &::Args, + ctx: &mut FunctionContext, + _storage: &Arc, + ) -> Result { let FunctionArgs { camera_device, camera_encoder, @@ -96,6 +101,7 @@ impl ::dash_pipe_provider::Function for Function { Ok(Self { camera_encoder, capture, + ctx: ctx.clone(), frame: Default::default(), frame_counter: Default::default(), frame_size: Default::default(), @@ -158,9 +164,9 @@ impl ::dash_pipe_provider::Function for Function { } } Ok(false) => { - error!("video capture is disconnected!"); - sleep(Duration::from_millis(u64::MAX)).await; - return Ok(PipeMessages::None); + return self + .ctx + .terminate_err(anyhow!("video capture is disconnected!")) } Err(error) => bail!("failed to capture a frame: {error}"), }; diff --git a/dash/pipe/functions/identity/src/main.rs b/dash/pipe/functions/identity/src/main.rs index c5318eb5..7a13e1a7 100644 --- a/dash/pipe/functions/identity/src/main.rs +++ b/dash/pipe/functions/identity/src/main.rs @@ -1,7 +1,11 @@ +use std::sync::Arc; + use anyhow::Result; use async_trait::async_trait; use clap::{ArgAction, Parser}; -use dash_pipe_provider::{PipeArgs, PipeMessage, PipeMessages, PipePayload}; +use dash_pipe_provider::{ + FunctionContext, PipeArgs, PipeMessage, PipeMessages, PipePayload, StorageSet, +}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -26,7 +30,11 @@ impl ::dash_pipe_provider::Function for Function { type Input = Value; type Output = Value; - async fn try_new(args: &::Args) -> Result { + async fn try_new( + args: &::Args, + _ctx: &mut FunctionContext, + _storage: &Arc, + ) -> Result { Ok(Self { args: args.clone() }) } diff --git a/dash/pipe/functions/python/src/main.rs b/dash/pipe/functions/python/src/main.rs index 0f100c7f..1e531975 100644 --- a/dash/pipe/functions/python/src/main.rs +++ b/dash/pipe/functions/python/src/main.rs @@ -1,9 +1,9 @@ -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; use anyhow::{anyhow, Error, Result}; use async_trait::async_trait; use clap::Parser; -use dash_pipe_provider::{PipeArgs, PipeMessages, PyPipeMessage}; +use dash_pipe_provider::{FunctionContext, PipeArgs, PipeMessages, PyPipeMessage, StorageSet}; use pyo3::{types::PyModule, PyObject, Python}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -29,7 +29,11 @@ impl ::dash_pipe_provider::Function for Function { type Input = Value; type Output = Value; - async fn try_new(args: &::Args) -> Result { + async fn try_new( + args: &::Args, + _ctx: &mut FunctionContext, + _storage: &Arc, + ) -> Result { let FunctionArgs { python_script: file_path, } = args; diff --git a/dash/pipe/provider/Cargo.toml b/dash/pipe/provider/Cargo.toml index d44b0ec3..09496f63 100644 --- a/dash/pipe/provider/Cargo.toml +++ b/dash/pipe/provider/Cargo.toml @@ -25,9 +25,11 @@ s3 = ["minio"] ark-core = { path = "../../../ark/core" } anyhow = { workspace = true } +async-stream = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } clap = { workspace = true } +ctrlc = { workspace = true } deltalake = { workspace = true } futures = { workspace = true } minio = { workspace = true, optional = true } @@ -36,5 +38,6 @@ pyo3 = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true, features = ["full"] } +tokio-stream = { workspace = true } tracing = { workspace = true } url = { workspace = true } diff --git a/dash/pipe/provider/src/function.rs b/dash/pipe/provider/src/function.rs index 1f380d34..eaccfdc5 100644 --- a/dash/pipe/provider/src/function.rs +++ b/dash/pipe/provider/src/function.rs @@ -1,11 +1,17 @@ -use std::fmt; +use std::{ + fmt, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; -use anyhow::Result; +use anyhow::{anyhow, Error, Result}; use async_trait::async_trait; use clap::Args; use serde::{de::DeserializeOwned, Serialize}; -use crate::PipeMessages; +use crate::{PipeMessages, StorageSet}; #[async_trait(?Send)] pub trait Function { @@ -13,7 +19,11 @@ pub trait Function { type Input: 'static + Send + Sync + DeserializeOwned; type Output: 'static + Send + Serialize; - async fn try_new(args: &::Args) -> Result + async fn try_new( + args: &::Args, + ctx: &mut FunctionContext, + storage: &Arc, + ) -> Result where Self: Sized; @@ -22,3 +32,33 @@ pub trait Function { inputs: PipeMessages<::Input>, ) -> Result::Output>>; } + +#[derive(Clone, Debug, Default)] +pub struct FunctionContext { + is_terminating: Arc, +} + +impl FunctionContext { + pub(crate) fn trap_on_sigint(self) -> Result<()> { + ::ctrlc::set_handler(move || self.terminate()) + .map_err(|error| anyhow!("failed to set SIGINT handler: {error}")) + } + + pub(crate) fn terminate(&self) { + self.is_terminating.store(true, Ordering::SeqCst) + } + + pub fn terminate_ok(&self) -> Result> { + self.terminate(); + Ok(PipeMessages::None) + } + + pub fn terminate_err(&self, error: impl Into) -> Result> { + self.terminate(); + Err(error.into()) + } + + pub(crate) fn is_terminating(&self) -> bool { + self.is_terminating.load(Ordering::SeqCst) + } +} diff --git a/dash/pipe/provider/src/lib.rs b/dash/pipe/provider/src/lib.rs index 5885c077..b4e43fba 100644 --- a/dash/pipe/provider/src/lib.rs +++ b/dash/pipe/provider/src/lib.rs @@ -3,8 +3,9 @@ mod message; mod pipe; mod storage; -pub use self::function::Function; +pub use self::function::{Function, FunctionContext}; #[cfg(feature = "pyo3")] pub use self::message::PyPipeMessage; pub use self::message::{PipeMessage, PipeMessages, PipePayload}; pub use self::pipe::PipeArgs; +pub use self::storage::{Storage, StorageSet, StorageType, Stream}; diff --git a/dash/pipe/provider/src/message.rs b/dash/pipe/provider/src/message.rs index 232dd0d4..12b95c4c 100644 --- a/dash/pipe/provider/src/message.rs +++ b/dash/pipe/provider/src/message.rs @@ -273,7 +273,7 @@ where } impl PipeMessage { - pub(crate) async fn dump_payloads( + async fn dump_payloads( self, storage: &StorageSet, input_payloads: &HashMap>, @@ -355,7 +355,7 @@ impl PipePayload { } } - pub(crate) fn get_ref(&self) -> PipePayload + fn get_ref(&self) -> PipePayload where T: Default, { @@ -366,7 +366,7 @@ impl PipePayload { } } - pub(crate) async fn load(self, storage: &StorageSet) -> Result { + async fn load(self, storage: &StorageSet) -> Result { Ok(PipePayload { value: match self.storage { Some(type_) => storage.get(type_).get_with_str(&self.key).await?, @@ -377,7 +377,7 @@ impl PipePayload { }) } - pub(crate) fn load_as_empty(self) -> PipePayload + fn load_as_empty(self) -> PipePayload where T: Default, { @@ -394,7 +394,7 @@ impl PipePayload { } impl PipePayload { - pub(crate) async fn dump( + async fn dump( self, storage: &StorageSet, input_payloads: &HashMap>, diff --git a/dash/pipe/provider/src/pipe.rs b/dash/pipe/provider/src/pipe.rs index a79ad703..a13da6df 100644 --- a/dash/pipe/provider/src/pipe.rs +++ b/dash/pipe/provider/src/pipe.rs @@ -1,5 +1,6 @@ use std::{ collections::HashMap, + process::exit, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -13,14 +14,15 @@ use futures::{Future, StreamExt}; use nats::{Client, ServerAddr, Subscriber, ToServerAddrs}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{ - spawn, + select, spawn, sync::mpsc::{self, Receiver, Sender}, task::{yield_now, JoinHandle}, + time::sleep, }; use tracing::{error, warn}; use crate::{ - function::Function, + function::{Function, FunctionContext}, message::PipeMessages, storage::{StorageSet, StorageType}, PipeMessage, PipePayload, @@ -151,13 +153,21 @@ where StorageSet::try_new(&self.storage, &client, default_output).await? }); + let mut function_context = FunctionContext::default(); + function_context.clone().trap_on_sigint()?; + Ok(Context { batch_size: self.batch_size, batch_timeout: self.batch_timeout_ms.map(Duration::from_millis), - function: ::try_new(&self.function_args) - .await - .map(Into::into) - .map_err(|error| anyhow!("failed to init function: {error}"))?, + function: ::try_new( + &self.function_args, + &mut function_context, + &storage, + ) + .await + .map(Into::into) + .map_err(|error| anyhow!("failed to init function: {error}"))?, + function_context, reader: match &self.stream_in { Some(stream) => { let (tx, rx) = mpsc::channel(max_tasks); @@ -203,7 +213,8 @@ where .expect("failed to init tokio runtime") .block_on(self.loop_forever_async()) { - panic!("{error}") + error!("{error}"); + exit(1) } } @@ -214,8 +225,15 @@ where // yield per every loop yield_now().await; - if let Err(error) = tick_async(&mut ctx).await { - warn!("{error}") + if ctx.function_context.is_terminating() { + break Ok(()); + } + + let response = tick_async(&mut ctx).await; + if ctx.function_context.is_terminating() { + break response; + } else if let Err(error) = response { + warn!("{error}"); } } } @@ -225,35 +243,52 @@ async fn tick_async(ctx: &mut Context) -> Result<()> where F: Function, { - async fn recv_one(reader: &mut ReadContext) -> Result> { - reader - .rx - .recv() - .await - .ok_or_else(|| anyhow!("reader job connection closed")) + async fn recv_one( + function_context: &FunctionContext, + reader: &mut ReadContext, + ) -> Result>> { + loop { + select! { + input = reader + .rx + .recv() => break Ok(input), + () = sleep(Duration::from_millis(100)) => if function_context.is_terminating() { + break Ok(None) + }, + } + } } let inputs = match &mut ctx.reader { - Some(reader) => match ctx.batch_size { - Some(batch_size) => { - let timer = ctx.batch_timeout.map(Timer::new); - - let mut inputs = vec![recv_one(reader).await?]; - for _ in 1..batch_size { - if timer - .as_ref() - .map(|timer| timer.is_outdated()) - .unwrap_or_default() - { - break; - } else { - inputs.push(recv_one(reader).await?) + Some(reader) => { + let input = match recv_one(&ctx.function_context, reader).await? { + Some(input) => input, + None => return Ok(()), + }; + match ctx.batch_size { + Some(batch_size) => { + let timer = ctx.batch_timeout.map(Timer::new); + + let mut inputs = vec![input]; + for _ in 1..batch_size { + if timer + .as_ref() + .map(|timer| timer.is_outdated()) + .unwrap_or_default() + { + break; + } else { + inputs.push(match recv_one(&ctx.function_context, reader).await? { + Some(input) => input, + None => return Ok(()), + }) + } } + PipeMessages::Batch(inputs) } - PipeMessages::Batch(inputs) + None => PipeMessages::Single(input), } - None => PipeMessages::Single(recv_one(reader).await?), - }, + } None => PipeMessages::None, }; @@ -295,6 +330,7 @@ where batch_size: Option, batch_timeout: Option, function: F, + function_context: FunctionContext, reader: Option::Input>>, writer: WriteContext, } diff --git a/dash/pipe/provider/src/storage/lakehouse.rs b/dash/pipe/provider/src/storage/lakehouse.rs index ca1fcadb..720c8e7f 100644 --- a/dash/pipe/provider/src/storage/lakehouse.rs +++ b/dash/pipe/provider/src/storage/lakehouse.rs @@ -1,25 +1,25 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use anyhow::{anyhow, bail, Error, Result}; +use async_stream::try_stream; use async_trait::async_trait; use bytes::Bytes; use deltalake::{DeltaTable, DeltaTableBuilder, ObjectStore, Path}; -use futures::TryFutureExt; - -use super::s3::StorageS3Args; +use futures::{StreamExt, TryFutureExt}; +#[derive(Clone)] pub struct Storage { - table: DeltaTable, + table: Arc, } impl Storage { pub async fn try_new( - StorageS3Args { + super::StorageS3Args { access_key, s3_endpoint, region, secret_key, - }: &StorageS3Args, + }: &super::StorageS3Args, bucket_name: &str, ) -> Result { Ok(Self { @@ -39,6 +39,7 @@ impl Storage { .with_storage_options(backend_config) .build() .unwrap() + .into() }, }) } @@ -50,6 +51,18 @@ impl super::Storage for Storage { super::StorageType::LakeHouse } + async fn list(&self) -> Result { + let storage = self.clone(); + Ok(try_stream! { + let list = storage.table.get_files_iter(); + for path in list { + let value = storage.get(&path).await?; + yield (path, value); + } + } + .boxed()) + } + async fn get(&self, path: &Path) -> Result { self.table .object_store() diff --git a/dash/pipe/provider/src/storage/mod.rs b/dash/pipe/provider/src/storage/mod.rs index 95eedd79..64388bb5 100644 --- a/dash/pipe/provider/src/storage/mod.rs +++ b/dash/pipe/provider/src/storage/mod.rs @@ -5,6 +5,8 @@ mod nats; #[cfg(feature = "s3")] mod s3; +use std::pin::Pin; + use anyhow::{anyhow, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -85,6 +87,8 @@ impl StorageType { pub trait Storage { fn storage_type(&self) -> StorageType; + async fn list(&self) -> Result; + async fn get(&self, path: &Path) -> Result; async fn get_with_str(&self, path: &str) -> Result { @@ -111,9 +115,32 @@ pub struct StorageArgs { #[cfg(any(feature = "lakehouse", feature = "s3"))] #[command(flatten)] - s3: self::s3::StorageS3Args, + s3: StorageS3Args, +} + +#[cfg(any(feature = "lakehouse", feature = "s3"))] +#[derive(Clone, Debug, Serialize, Deserialize, Parser)] +pub struct StorageS3Args { + #[arg(long, env = "AWS_ACCESS_KEY_ID", value_name = "VALUE")] + pub(super) access_key: String, + + #[arg( + long, + env = "AWS_REGION", + value_name = "REGION", + default_value = "us-east-1" + )] + pub(super) region: String, + + #[arg(long, env = "AWS_ENDPOINT_URL", value_name = "URL")] + pub(super) s3_endpoint: ::url::Url, + + #[arg(long, env = "AWS_SECRET_ACCESS_KEY", value_name = "VALUE")] + pub(super) secret_key: String, } +pub type Stream = Pin>>>; + fn parse_path(path: impl AsRef) -> Result { Path::parse(path).map_err(|error| anyhow!("failed to parse storage path: {error}")) } diff --git a/dash/pipe/provider/src/storage/nats.rs b/dash/pipe/provider/src/storage/nats.rs index ea5941b2..14d34642 100644 --- a/dash/pipe/provider/src/storage/nats.rs +++ b/dash/pipe/provider/src/storage/nats.rs @@ -1,15 +1,17 @@ use std::io::Cursor; use anyhow::{anyhow, bail, Error, Result}; +use async_stream::try_stream; use async_trait::async_trait; use bytes::Bytes; use clap::Parser; use deltalake::Path; -use futures::TryFutureExt; +use futures::{StreamExt, TryFutureExt}; use nats::jetstream::object_store::ObjectStore; use serde::{Deserialize, Serialize}; use tokio::io::AsyncReadExt; +#[derive(Clone)] pub struct Storage { store: ObjectStore, } @@ -38,6 +40,26 @@ impl super::Storage for Storage { super::StorageType::Nats } + async fn list(&self) -> Result { + let storage = self.clone(); + Ok(try_stream! { + let mut list = storage.store.list() + .map_err(|error| anyhow!("failed to list objects from NATS object store: {error}")) + .await?; + while let Some(item) = list.next().await + { + if let Ok(path) = item + .map_err(Into::into) + .and_then(|item| super::parse_path(item.name)) + { + let value = storage.get(&path).await?; + yield (path, value); + } + } + } + .boxed()) + } + async fn get(&self, path: &Path) -> Result { self.store .get(path.as_ref()) diff --git a/dash/pipe/provider/src/storage/s3.rs b/dash/pipe/provider/src/storage/s3.rs index 858b0dee..8741b06b 100644 --- a/dash/pipe/provider/src/storage/s3.rs +++ b/dash/pipe/provider/src/storage/s3.rs @@ -1,18 +1,17 @@ use anyhow::{anyhow, bail, Error, Result}; +use async_stream::try_stream; use async_trait::async_trait; use bytes::Bytes; -use clap::Parser; use deltalake::Path; -use futures::TryFutureExt; +use futures::{StreamExt, TryFutureExt}; use minio::s3::{ - args::{GetObjectArgs, PutObjectApiArgs, RemoveObjectArgs}, + args::{GetObjectArgs, ListObjectsV2Args, PutObjectApiArgs, RemoveObjectArgs}, client::Client, creds::StaticProvider, http::BaseUrl, }; -use serde::{Deserialize, Serialize}; -use url::Url; +#[derive(Clone)] pub struct Storage { base_url: BaseUrl, bucket_name: String, @@ -21,12 +20,12 @@ pub struct Storage { impl Storage { pub async fn try_new( - StorageS3Args { + super::StorageS3Args { access_key, region: _, s3_endpoint, secret_key, - }: &StorageS3Args, + }: &super::StorageS3Args, bucket_name: &str, ) -> Result { Ok(Self { @@ -44,6 +43,26 @@ impl super::Storage for Storage { super::StorageType::S3 } + async fn list(&self) -> Result { + let storage = self.clone(); + Ok(try_stream! { + let args = ListObjectsV2Args::new(&storage.bucket_name)?; + let list = Client::new(storage.base_url.clone(), Some(&storage.provider)) + .list_objects_v2(&args) + .map_err(|error| anyhow!("failed to list objects from S3 object store: {error}")) + .await? + .contents; + for item in list + { + if let Ok(path) = super::parse_path(item.name) { + let value = storage.get(&path).await?; + yield (path, value); + } + } + } + .boxed()) + } + async fn get(&self, path: &Path) -> Result { let args = GetObjectArgs::new(&self.bucket_name, path.as_ref())?; @@ -54,12 +73,12 @@ impl super::Storage for Storage { match object.bytes().await { Ok(bytes) => Ok(bytes), Err(error) => { - bail!("failed to get object data from DeltaLake object store: {error}") + bail!("failed to get object data from S3 object store: {error}") } } }) .await - .map_err(|error| anyhow!("failed to get object from DeltaLake object store: {error}")) + .map_err(|error| anyhow!("failed to get object from S3 object store: {error}")) } async fn put(&self, path: &Path, bytes: Bytes) -> Result<()> { @@ -69,7 +88,7 @@ impl super::Storage for Storage { .put_object_api(&args) .await .map(|_| ()) - .map_err(|error| anyhow!("failed to put object into DeltaLake object store: {error}")) + .map_err(|error| anyhow!("failed to put object into S3 object store: {error}")) } async fn delete(&self, path: &Path) -> Result<()> { @@ -79,28 +98,6 @@ impl super::Storage for Storage { .remove_object(&args) .await .map(|_| ()) - .map_err(|error| { - anyhow!("failed to delete object from DeltaLake object store: {error}") - }) + .map_err(|error| anyhow!("failed to delete object from S3 object store: {error}")) } } - -#[derive(Clone, Debug, Serialize, Deserialize, Parser)] -pub struct StorageS3Args { - #[arg(long, env = "AWS_ACCESS_KEY_ID", value_name = "VALUE")] - pub(super) access_key: String, - - #[arg( - long, - env = "AWS_REGION", - value_name = "REGION", - default_value = "us-east-1" - )] - pub(super) region: String, - - #[arg(long, env = "AWS_ENDPOINT_URL", value_name = "URL")] - pub(super) s3_endpoint: Url, - - #[arg(long, env = "AWS_SECRET_ACCESS_KEY", value_name = "VALUE")] - pub(super) secret_key: String, -}