diff --git a/Cargo.lock b/Cargo.lock index 9eb89ea..4b9def3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.4" @@ -500,6 +515,7 @@ dependencies = [ "clap", "russh", "russh-keys", + "russh-sftp", "ssh-key", "tokio", "tracing", @@ -595,6 +611,9 @@ name = "bitflags" version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +dependencies = [ + "serde", +] [[package]] name = "block-buffer" @@ -687,6 +706,20 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "chrono" +version = "0.4.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1311,6 +1344,29 @@ dependencies = [ "tokio-rustls", ] +[[package]] +name = "iana-time-zone" +version = "0.1.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -2037,6 +2093,22 @@ dependencies = [ "yasna", ] +[[package]] +name = "russh-sftp" +version = "2.0.0-beta.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d3b2a6990ae72682c590323b9bc2c9edffc63a4362f4b96f3f8de4117d6e8d" +dependencies = [ + "async-trait", + "bitflags 2.4.1", + "bytes", + "chrono", + "log", + "serde", + "thiserror", + "tokio", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -2784,6 +2856,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.51.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/aws-throwaway/Cargo.toml b/aws-throwaway/Cargo.toml index 28569c7..9f41914 100644 --- a/aws-throwaway/Cargo.toml +++ b/aws-throwaway/Cargo.toml @@ -21,6 +21,7 @@ anyhow = "1.0.42" uuid = { version = "1.0.0", features = ["serde", "v4"] } tracing = "0.1.15" async-trait = "0.1.30" +russh-sftp = "^2.0.0-beta.3" [dev-dependencies] tracing-subscriber = { version = "0.3.1", features = ["env-filter", "json"] } diff --git a/aws-throwaway/examples/aws-throwaway-test-large-file.rs b/aws-throwaway/examples/aws-throwaway-test-large-file.rs new file mode 100644 index 0000000..cd11987 --- /dev/null +++ b/aws-throwaway/examples/aws-throwaway-test-large-file.rs @@ -0,0 +1,38 @@ +use aws_throwaway::{Aws, CleanupResources, Ec2InstanceDefinition, InstanceType}; +use std::{path::Path, time::Instant}; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() { + let (non_blocking, _guard) = tracing_appender::non_blocking(std::io::stdout()); + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_writer(non_blocking) + .init(); + + let aws = Aws::builder(CleanupResources::AllResources).build().await; + let instance = aws + .create_ec2_instance(Ec2InstanceDefinition::new(InstanceType::T2Micro)) + .await; + + let start = Instant::now(); + std::fs::write("some_local_file", vec![0; 1024 * 1024 * 1]).unwrap(); // create 100MB file + println!("Time to create 100MB file locally {:?}", start.elapsed()); + + let start = Instant::now(); + instance + .ssh() + .push_file(Path::new("some_local_file"), Path::new("some_remote_file")) + .await; + println!("Time to push 100MB file {:?}", start.elapsed()); + + let start = Instant::now(); + instance + .ssh() + .pull_file(Path::new("some_remote_file"), Path::new("some_local_file")) + .await; + println!("Time to pull 100MB file {:?}", start.elapsed()); + + aws.cleanup_resources().await; + println!("\nAll AWS throwaway resources have been deleted") +} diff --git a/aws-throwaway/src/ssh.rs b/aws-throwaway/src/ssh.rs index a6e3a1c..274f54c 100644 --- a/aws-throwaway/src/ssh.rs +++ b/aws-throwaway/src/ssh.rs @@ -5,12 +5,9 @@ use russh::{ ChannelMsg, Sig, }; use russh_keys::{key::PublicKey, PublicKeyBase64}; +use russh_sftp::{client::SftpSession, protocol::OpenFlags}; use std::{fmt::Display, io::Write, net::IpAddr, path::Path, sync::Arc}; -use tokio::{ - fs::File, - io::{AsyncReadExt, BufReader}, - net::TcpStream, -}; +use tokio::{fs::File, io::AsyncReadExt, net::TcpStream}; pub struct SshConnection { address: IpAddr, @@ -192,11 +189,35 @@ impl SshConnection { let task = format!("pushing file from {source:?} to {}:{dest:?}", self.address); tracing::info!("{task}"); - let source = File::open(source) + let mut channel = self.session.channel_open_session().await.unwrap(); + channel.request_subsystem(true, "sftp").await.unwrap(); + let sftp = SftpSession::new(channel.into_stream()).await.unwrap(); + let mut file = sftp + .open_with_flags( + dest.to_str().unwrap(), + OpenFlags::WRITE | OpenFlags::TRUNCATE | OpenFlags::CREATE, + ) + .await + .unwrap(); + + let mut source = File::open(source) .await .map_err(|e| anyhow!(e).context(format!("Failed to read from {source:?}"))) .unwrap(); - self.push_file_impl(&task, source, dest).await; + + let mut bytes = vec![0u8; 1024 * 1024]; + loop { + let read_count = source + .read(&mut bytes[..]) + .await + .unwrap_or_else(|e| panic!("{task} failed to read from local disk with {e:?}")); + if read_count == 0 { + break; + } + tokio::io::AsyncWriteExt::write_all(&mut file, &bytes[0..read_count]) + .await + .unwrap_or_else(|e| panic!("{task} failed to write to remote server with {e:?}")); + } } /// Create a file on the remote machine at `dest` with the provided bytes. @@ -204,54 +225,19 @@ impl SshConnection { let task = format!("pushing raw bytes to {}:{dest:?}", self.address); tracing::info!("{task}"); - let source = BufReader::new(bytes); - self.push_file_impl(&task, source, dest).await; - } - - async fn push_file_impl(&self, task: &str, source: R, dest: &Path) { let mut channel = self.session.channel_open_session().await.unwrap(); - let command = format!("dd of='{0}'\nchmod 777 {0}", dest.to_str().unwrap()); - channel.exec(true, command).await.unwrap(); - - let mut stdout = vec![]; - let mut stderr = vec![]; - let mut status = None; - let mut failed = None; - channel.data(source).await.unwrap(); - channel.eof().await.unwrap(); - while let Some(msg) = channel.wait().await { - match msg { - ChannelMsg::Data { data } => stdout.write_all(&data).unwrap(), - ChannelMsg::ExtendedData { data, ext } => { - if ext == 1 { - stderr.write_all(&data).unwrap() - } else { - tracing::warn!("received unknown extended data with extension type {ext} containing: {:?}", data.to_vec()) - } - } - ChannelMsg::ExitStatus { exit_status } => { - status = Some(exit_status); - // cant exit immediately, there might be more data still - } - ChannelMsg::ExitSignal { - signal_name, - core_dumped, - error_message, - .. - } => { - failed = Some(format!( - "killed via signal {signal_name:?} core_dumped={core_dumped} {error_message:?}" - )) - } - _ => {} - } - } - let output = CommandOutput { - stdout: String::from_utf8(stdout).unwrap(), - stderr: String::from_utf8(stderr).unwrap(), - }; - - check_results(task, failed, status, &output); + channel.request_subsystem(true, "sftp").await.unwrap(); + let sftp = SftpSession::new(channel.into_stream()).await.unwrap(); + let mut file = sftp + .open_with_flags( + dest.to_str().unwrap(), + OpenFlags::WRITE | OpenFlags::TRUNCATE | OpenFlags::CREATE, + ) + .await + .unwrap(); + tokio::io::AsyncWriteExt::write_all(&mut file, bytes) + .await + .unwrap_or_else(|e| panic!("{task} failed to write to remote server with {e:?}")); } /// Pull a file from the remote machine to the local machine @@ -260,46 +246,28 @@ impl SshConnection { tracing::info!("{task}"); let mut channel = self.session.channel_open_session().await.unwrap(); - let command = format!("dd if='{0}'\nchmod 777 {0}", source.to_str().unwrap()); - channel.exec(true, command).await.unwrap(); + channel.request_subsystem(true, "sftp").await.unwrap(); + let sftp = SftpSession::new(channel.into_stream()).await.unwrap(); + let mut file = sftp.open(source.to_str().unwrap()).await.unwrap(); - let mut out = File::create(dest).await.unwrap(); - let mut stderr = vec![]; - let mut status = None; - let mut failed = None; - channel.eof().await.unwrap(); - while let Some(msg) = channel.wait().await { - match msg { - ChannelMsg::Data { data } => tokio::io::AsyncWriteExt::write_all(&mut out, &data) - .await - .unwrap(), - ChannelMsg::ExtendedData { data, ext } => { - if ext == 1 { - stderr.write_all(&data).unwrap() - } else { - tracing::warn!("received unknown extended data with extension type {ext} containing: {:?}", data.to_vec()) - } - } - ChannelMsg::ExitStatus { exit_status } => { - status = Some(exit_status); - // cant exit immediately, there might be more data still - } - ChannelMsg::ExitSignal { - signal_name, - core_dumped, - error_message, - .. - } => { - failed = Some(format!( - "killed via signal {signal_name:?} core_dumped={core_dumped} {error_message:?}" - )) - } - _ => {} + let mut dest = File::create(dest) + .await + .map_err(|e| anyhow!(e).context(format!("Failed to read from {source:?}"))) + .unwrap(); + + let mut bytes = vec![0u8; 1024 * 1024]; + loop { + let read_count = file + .read(&mut bytes[..]) + .await + .unwrap_or_else(|e| panic!("{task} failed to read from local disk with {e:?}")); + if read_count == 0 { + break; } + tokio::io::AsyncWriteExt::write_all(&mut dest, &bytes[0..read_count]) + .await + .unwrap_or_else(|e| panic!("{task} failed to write to remote server with {e:?}")); } - - let output = String::from_utf8(stderr).unwrap(); - check_results(&task, failed, status, &output); } }