From 19f02b237c0cfb896ee0bae22ef92c34f600ab29 Mon Sep 17 00:00:00 2001 From: knutaf Date: Wed, 4 Oct 2023 23:30:40 -0700 Subject: [PATCH] Remove mpsc channel between local input and router Before this change, to read from stdin or the terminal, the program would spin up a thread and send local input through an mpsc channel. Then something would do a `send_all` to drive the input into the router. But since the router's sink can be cloned, it could be possible to pass this directly to the stdin reader thread and bypass the CPU and memory cost of the channel. To make this change, all the places that were previously referencing a `LocalInputStream` need to change to use a future that exits when the local input is done. Internally the future will drive input to the router sink. This allows the stdin reader thread to call `router_sink.send()` and other input modes like random generation to call `router_sink.send_all()`. At program exit it's possible that the stdin reader task blocks on reading from input, so we have to shut down the Tokio runtime rudely to kill the task. Also fix a bug where the result of the last outbound connection wasn't being surfaced as the program's exit code. Also remove an incorrect assert in the codepaths the router send path and the socket close path notice the socket closure and both call `cleanup_route`. Also update to latest crate versions. --- Cargo.lock | 70 ++++----- Cargo.toml | 2 +- src/main.rs | 400 ++++++++++++++++++++++++++++------------------------ 3 files changed, 248 insertions(+), 224 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7ce1be1..cb47690 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,18 +30,18 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.0.5" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c378d78423fdad8089616f827526ee33c19f2fddbd5de1629152c9593ba4783" +checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" dependencies = [ "memchr", ] [[package]] name = "anstream" -version = "0.5.0" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f58811cfac344940f1a400b6e6231ce35171f614f26439e80f8c1465c5cc0c" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" dependencies = [ "anstyle", "anstyle-parse", @@ -53,15 +53,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84bf0a05bbb2a83e5eb6fa36bb6e87baa08193c35ff52bbf6b38d8af2890e46" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anstyle-parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" dependencies = [ "utf8parse", ] @@ -77,9 +77,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "2.1.0" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f54d10c6dfa51283a066ceab3ec1ab78d13fae00aa49243a45e4571fb79dfd" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" dependencies = [ "anstyle", "windows-sys 0.48.0", @@ -147,9 +147,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.4.3" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84ed82781cea27b43c9b106a979fe450a13a31aab0500595fb3fc06616de08e6" +checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956" dependencies = [ "clap_builder", "clap_derive", @@ -157,9 +157,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.2" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bb9faaa7c2ef94b2743a21f5a29e6f0010dff4caa69ac8e9d6cf8b6fa74da08" +checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45" dependencies = [ "anstream", "anstyle", @@ -424,9 +424,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "lazy_static" @@ -464,9 +464,9 @@ checksum = "df39d232f5c40b0891c10216992c2f250c054105cb1e56f0fc9032db6203ecc1" [[package]] name = "memchr" -version = "2.6.3" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "miniz_oxide" @@ -490,7 +490,7 @@ dependencies = [ [[package]] name = "netcrab" -version = "0.9.0" +version = "0.9.1" dependencies = [ "bytes", "clap", @@ -686,9 +686,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.5" +version = "1.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" +checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" dependencies = [ "aho-corasick", "memchr", @@ -698,9 +698,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" dependencies = [ "aho-corasick", "memchr", @@ -745,9 +745,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" [[package]] name = "socket2" @@ -767,9 +767,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "2.0.35" +version = "2.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59bf04c28bee9043ed9ea1e41afc0552288d3aba9c6efdd78903b802926f4879" +checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" dependencies = [ "proc-macro2", "quote", @@ -799,18 +799,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", @@ -860,9 +860,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" dependencies = [ "bytes", "futures-core", @@ -910,9 +910,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-width" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" [[package]] name = "utf8parse" diff --git a/Cargo.toml b/Cargo.toml index 06006cc..f8ac1b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "netcrab" description = "A multi-purpose TCP/UDP listener and connector" keywords = ["tcp", "udp", "networking", "sockets"] categories = ["network-programming", "command-line-utilities"] -version = "0.9.0" +version = "0.9.1" authors = ["knutaf"] edition = "2021" repository = "https://github.com/knutaf/netcrab" diff --git a/src/main.rs b/src/main.rs index 4b1c0b8..bca3909 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,10 @@ extern crate regex; use bytes::Bytes; use clap::{Args, CommandFactory, Parser, ValueEnum}; -use futures::{channel::mpsc, future, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt}; +use futures::{ + channel::mpsc, future, future::FusedFuture, stream::FuturesUnordered, FutureExt, SinkExt, + StreamExt, +}; use rand::{distributions::Distribution, Rng}; use regex::Regex; use std::{ @@ -98,42 +101,42 @@ struct SourcedBytes { impl SourcedBytes { // Wrap bytes that were produced by the local machine with the special local route address that marks them as - // originating from the local machine. As a convenience, also wrap in a Result, which is what the various streams - // and sinks need. - fn ok_from_local(data: Bytes) -> std::io::Result { - Ok(Self { + // originating from the local machine. + fn create_with_local_source(data: Bytes) -> Self { + Self { data, route: LOCAL_IO_ROUTE_ADDR, - }) + } } } type SockAddrSet = HashSet; type RouteAddrSet = HashSet; -// A stream of bytes produced from the local machine. In contrast with bytes that come from the network, it has no -// source address, though that is faked later in order to make it be treated just like other sockets by the router for -// purposes of forwarding. -// -// Lifetime specifier is needed because in some places the local stream incorporates an object that references function -// parameters (i.e. '_). -type LocalIoStream<'a> = Pin> + 'a>>; +// A future that represents work that drives the local input to completion. It is used with any `InputMode`, regardless +// of how input is obtained (stdin, random generation, etc.). +type LocalInputDriver = Pin>>>; // A sink that accepts byte buffers and sends them to the local IO function (stdout, echo, null, etc.). -type LocalIoSink = Pin>>; +type LocalOutputSink = Pin>>; // When setting up local IO, it's common to set up both the way input enters the program and where output from the // program should go. -type LocalIoSinkAndStream<'a> = (LocalIoSink, LocalIoStream<'a>); +type LocalOutputSinkAndInputDriver = (LocalOutputSink, LocalInputDriver); // A sink that accepts Bytes to be sent to the network. When the router determines that data should be sent to a socket, // it sends it into this sink, where there is one per remote peer. type RouterToNetSink<'a> = Pin + 'a>>; -// A sink of bytes originating from some remote peer. Each socket drives data received on it to the router using this -// sink, supplying where the data came from as well as the local address it arrives at. The router uses the data's -// origin and local destination to decide where the data should be forwarded to. -type NetToRouterSink = Pin>>; +// A sink of bytes originating from some remote peer and sent to the router, which uses an mpsc::channel to collect +// data from all sockets. Each socket drives data received on it to the router using this sink, supplying where the data +// came from as well as the local address it arrives at. The router uses the data's origin and local destination to +// decide where the data should be forwarded to. The type looks strange because the associated error has to be mapped to +// `std::io::Error` to fit with what the rest of the program uses. +type RouterSink = futures::sink::SinkMapErr< + mpsc::UnboundedSender, + fn(mpsc::SendError) -> std::io::Error, +>; // A grouping of connection information for a user-specified target, something passed as a command line arg. The // original argument value is stored as well as all the addresses that the name resolved to. @@ -352,21 +355,18 @@ struct TcpRouter<'a> { // not to other channels. channels: ChannelMap, - // A sink where all sockets send data to the router for forwarding. Normally this would just be an UnboundedSender, - // but since we map the send error, it gets stored as a complicated type. Thanks, Rust. - net_collector_sink: futures::sink::SinkMapErr< - mpsc::UnboundedSender, - fn(mpsc::SendError) -> std::io::Error, - >, + // A sink where all sockets send data to the router for forwarding. It is an mpsc::channel, so it gets cloned and + // handed out to each socket. + router_sink: RouterSink, // An internal stream that produces all of the data from all sockets that come into the router. inbound_net_traffic_stream: mpsc::UnboundedReceiver, - // A stream of data produced from input to the program. - local_io_stream: LocalIoStream<'a>, + // A future that drives the local input to completion. + local_input_driver: LocalInputDriver, // A sink where the router can send data to be printed out. - local_io_sink: LocalIoSink, + local_output_sink: LocalOutputSink, // Used to figure out when to hook up the local input. lifetime_client_count: u32, @@ -377,10 +377,10 @@ struct TcpRouter<'a> { impl<'a> TcpRouter<'a> { pub fn new(args: &'a NcArgs) -> TcpRouter<'a> { - let (net_collector_sink, inbound_net_traffic_stream) = mpsc::unbounded(); + let (router_sink, inbound_net_traffic_stream) = mpsc::unbounded(); // This funny syntax is required to coerce a function pointer to the fn type required by the field on the struct - let net_collector_sink = net_collector_sink.sink_map_err( + let router_sink = router_sink.sink_map_err( map_unbounded_sink_err_to_io_err as fn(mpsc::SendError) -> std::io::Error, ); @@ -388,10 +388,10 @@ impl<'a> TcpRouter<'a> { args, routes: HashMap::new(), channels: ChannelMap::new(), - net_collector_sink, + router_sink, inbound_net_traffic_stream, - local_io_stream: Box::pin(futures::stream::pending()), - local_io_sink: Box::pin( + local_input_driver: Box::pin(futures::future::pending()), + local_output_sink: Box::pin( futures::sink::drain().sink_map_err(map_drain_sink_err_to_io_err), ), lifetime_client_count: 0, @@ -402,7 +402,7 @@ impl<'a> TcpRouter<'a> { // Callers use this to add a new destination to the router for forwarding. The caller passes in the part of the TCP // socket that the router can use to write data to the socket, and it returns a sink where the caller can write data // to the router. - pub fn add_route(&mut self, tx_socket: tokio::net::tcp::OwnedWriteHalf) -> NetToRouterSink { + pub fn add_route(&mut self, tx_socket: tokio::net::tcp::OwnedWriteHalf) -> RouterSink { let new_route = RouteAddr { local: tx_socket.local_addr().unwrap(), peer: tx_socket.peer_addr().unwrap(), @@ -422,12 +422,13 @@ impl<'a> TcpRouter<'a> { // out of, say, a redirected input stream, and forward it to nobody, because there are no other clients. self.lifetime_client_count += 1; if self.lifetime_client_count == 1 { - (self.local_io_sink, self.local_io_stream) = setup_local_io(self.args); + (self.local_output_sink, self.local_input_driver) = + setup_local_io(self.args, self.router_sink.clone()); } - // The input end of the router (`net_collector_sink`) can be cloned to allow multiple callers to pass data into - // the same channel. - Box::pin(self.net_collector_sink.clone()) + // The input end of the router (`router_sink`) can be cloned to allow multiple callers to pass data into the + // same channel. + self.router_sink.clone() } pub fn remove_route(&mut self, route: &RouteAddr) { @@ -444,11 +445,9 @@ impl<'a> TcpRouter<'a> { ) { let removed = routes.remove(route); - // In channels mode we might have torn down an extra route here, which would cause cleanup_route to be called - // to notify us that the route had closed, but that's because we already did it. Permit it to happen in this - // mode. + // It's possible this route has already been cleaned up because the router noticed when forwarding to it that it + // was closed. if removed.is_none() { - assert!(args.forwarding_mode == ForwardingMode::Channels); return; } @@ -473,17 +472,8 @@ impl<'a> TcpRouter<'a> { // Start asynchronously processing data from all managed sockets. pub async fn service(&mut self) -> std::io::Result<()> { - // Since there is only one local input and output (i.e. stdin/stdout), don't create a new channel to add it to - // the router. Instead just feed data into the router directly by tagging it as originating from the local - // input. - let mut local_io_to_router_sink = Box::pin( - self.net_collector_sink - .clone() - .with(|b| async { SourcedBytes::ok_from_local(b) }), - ); - // If the local output (i.e. stdout) fails, we can set this to None to save perf on sending to it further. - let mut local_io_sink_opt = Some(&mut self.local_io_sink); + let mut local_output_sink_opt = Some(&mut self.local_output_sink); let mut stats_tracker = StatsTracker::new(); @@ -580,12 +570,12 @@ impl<'a> TcpRouter<'a> { // Came from a remote endpoint, so also send to local output if it hasn't failed yet. if sb.route.peer != LOCAL_IO_PEER_ADDR { stats_tracker.record_recv(&sb.data); - if let Some(ref mut local_io_sink) = local_io_sink_opt { + if let Some(ref mut local_output_sink) = local_output_sink_opt { // If we hit an error emitting output, clear out the local output sink so we don't bother // trying to output more. - if let Err(e) = local_io_sink.send(sb.data.clone()).await { + if let Err(e) = local_output_sink.send(sb.data.clone()).await { eprintln!("Local output closed. {}", e); - local_io_sink_opt = None; + local_output_sink_opt = None; } } } @@ -610,13 +600,14 @@ impl<'a> TcpRouter<'a> { } }, - // Send in all data from the local input to the router. - _result = local_io_to_router_sink.send_all(&mut self.local_io_stream).fuse() => { + // Drive local input to the router until it completes. `local_input_driver` was previously already + // passed the `router_sink` so it knows to send it to the router. + _result = &mut self.local_input_driver => { eprintln!("End of outbound data from local machine reached."); // Send an empty message as a special signal that the local stream has finished. - local_io_to_router_sink.send(Bytes::new()).await.expect("Local IO sink should not be closed early!"); - self.local_io_stream = Box::pin(futures::stream::pending()); + self.router_sink.send(SourcedBytes::create_with_local_source(Bytes::new())).await.expect("Router sink should not be closed early!"); + self.local_input_driver = Box::pin(futures::future::pending()); }, } } @@ -646,18 +637,26 @@ fn map_drain_sink_err_to_io_err(_err: std::convert::Infallible) -> std::io::Erro // For iterators, we shouldn't yield items as quickly as possible, or else the runtime sometimes doesn't give enough // processing time to other tasks. This creates a stream out of an iterator but also makes sure to yield after every // element to make sure it can flow through the rest of the system. -fn local_io_stream_from_iter<'a, TIter>(iter: TIter) -> LocalIoStream<'a> +fn local_input_driver_from_iter(iter: TIter, mut router_sink: RouterSink) -> LocalInputDriver where - TIter: Iterator> + 'a, + TIter: Iterator> + 'static, { - Box::pin(futures::stream::iter(iter).then(|e| async move { - tokio::task::yield_now().await; - e - })) + // Async block moves in the iterator and router sink. + Box::pin( + async move { + let mut stream = Box::pin(futures::stream::iter(iter).then(|e| async { + tokio::task::yield_now().await; + e.map(SourcedBytes::create_with_local_source) + })); + + router_sink.send_all(&mut stream).await + } + .fuse(), + ) } // Setup an output sink that either goes to stdout or to the void, depending on the user's selection. -fn setup_local_output(args: &NcArgs) -> LocalIoSink { +fn setup_local_output(args: &NcArgs) -> LocalOutputSink { let codec = BytesCodec::new(); match args.output_mode { OutputMode::Stdout => Box::pin(FramedWrite::new(tokio::io::stdout(), codec)), @@ -676,23 +675,20 @@ fn setup_local_output(args: &NcArgs) -> LocalIoSink { // process stdout. fn setup_async_reader_thread( mut input_reader: TInput, + mut router_sink: RouterSink, args: &NcArgs, -) -> mpsc::UnboundedReceiver> +) -> LocalInputDriver where TInput: std::io::Read + std::marker::Send + 'static, { - // Unbounded is OK because we'd rather prioritize faster throughput at the cost of more memory. - let (mut tx_input, rx_input) = mpsc::unbounded(); - // If the user is interactively using the program, by default use "character mode", which inputs a character on // every keystroke rather than a full line when hitting enter. However, the user can request to use normal stdin // mode, and if this is reading from a child program execution, that definitely can't use character mode. let is_interactive_mode = args.is_interactive_input_mode(); let chunk_size = args.send_size; - let _thread = std::thread::Builder::new() - .name("stdin_reader".to_string()) - .spawn(move || { + Box::pin( + tokio::task::spawn(async move { // Create storage for the input from stdin. It needs to be big enough to store the user's desired chunk size // It will accumulate data until it's filled an entire chunk, at which point it will be sent and repurposed // for the next chunk. @@ -701,7 +697,7 @@ where // The available buffer (pointer to next byte to write and remaining available space in the chunk). let mut available_buf = &mut read_buf[..]; - if is_interactive_mode { + let result = if is_interactive_mode { // The console crate has a function called stdout that gives you, uh, a Term object that also services // input. OK dude. let mut stdio = console::Term::stdout(); @@ -710,55 +706,61 @@ where let mut char_buf = [0u8; std::mem::size_of::()]; let char_buf_len = char_buf.len(); - while let Ok(ch) = stdio.read_char() { - // Encode the char from input as a series of bytes. - let encoded_str = ch.encode_utf8(&mut char_buf); - let num_bytes_read = encoded_str.len(); - assert!(num_bytes_read <= char_buf_len); - - // Echo the char back out, because `read_char` doesn't. - let _ = stdio.write(encoded_str.as_bytes()); - - // Track a slice of the remaining bytes of the just-read char that haven't been sent yet. - let mut char_bytes_remaining = &mut char_buf[..num_bytes_read]; - while char_bytes_remaining.len() > 0 { - // There might not be enough space left in the chunk to fit the whole char. - let num_bytes_copyable = - std::cmp::min(available_buf.len(), char_bytes_remaining.len()); - assert!(num_bytes_copyable <= available_buf.len()); - assert!(num_bytes_copyable <= char_bytes_remaining.len()); - assert!(num_bytes_copyable > 0); - - // Copy as much of the char as possible into the chunk. Note that for UTF-8 characters sent in - // UDP datagrams, if we fail to copy the whole character into a given datagram, the resultant - // traffic might be... strange. Imagine having a UTF-8 character split across two datagrams. - // Not great, but also not lossy, so who is to say what is correct? It's not possible to fully - // fill every UDP datagram of arbitrary size with varying size UTF-8 characters and not - // sometimes slice them. - available_buf[..num_bytes_copyable] - .copy_from_slice(&char_bytes_remaining[..num_bytes_copyable]); - - // Update the available buffer to the remaining space in the chunk after copying in this char. - available_buf = &mut available_buf[num_bytes_copyable..]; - - // Advance the remaining slice of the char to the uncopied portion. - char_bytes_remaining = &mut char_bytes_remaining[num_bytes_copyable..]; - - // There's no more available buffer in this chunk, meaning we've accumulated a full chunk, so - // send it now. - if available_buf.len() == 0 { - // Stop borrowing read_buf for a hot second so it can be sent. - available_buf = &mut []; - - if let Err(_) = - tx_input.unbounded_send(Ok(Bytes::copy_from_slice(&read_buf))) - { - break; + 'outer: loop { + let read_result = stdio.read_char(); + if let Ok(ch) = read_result { + // Encode the char from input as a series of bytes. + let encoded_str = ch.encode_utf8(&mut char_buf); + let num_bytes_read = encoded_str.len(); + assert!(num_bytes_read <= char_buf_len); + + // Echo the char back out, because `read_char` doesn't. + let _ = stdio.write(encoded_str.as_bytes()); + + // Track a slice of the remaining bytes of the just-read char that haven't been sent yet. + let mut char_bytes_remaining = &mut char_buf[..num_bytes_read]; + while char_bytes_remaining.len() > 0 { + // There might not be enough space left in the chunk to fit the whole char. + let num_bytes_copyable = + std::cmp::min(available_buf.len(), char_bytes_remaining.len()); + assert!(num_bytes_copyable <= available_buf.len()); + assert!(num_bytes_copyable <= char_bytes_remaining.len()); + assert!(num_bytes_copyable > 0); + + // Copy as much of the char as possible into the chunk. Note that for UTF-8 characters sent + // in UDP datagrams, if we fail to copy the whole character into a given datagram, the + // resultant traffic might be... strange. Imagine having a UTF-8 character split across two + // datagrams. Not great, but also not lossy, so who is to say what is correct? It's not + // possible to fully fill every UDP datagram of arbitrary size with varying size UTF-8 + // characters and not sometimes slice them. + available_buf[..num_bytes_copyable] + .copy_from_slice(&char_bytes_remaining[..num_bytes_copyable]); + + // Update the available buffer to the remaining space in the chunk after copying in this + // char. + available_buf = &mut available_buf[num_bytes_copyable..]; + + // Advance the remaining slice of the char to the uncopied portion. + char_bytes_remaining = &mut char_bytes_remaining[num_bytes_copyable..]; + + // There's no more available buffer in this chunk, meaning we've accumulated a full chunk, + // so send it now. + if available_buf.len() == 0 { + let send_result = router_sink + .send(SourcedBytes::create_with_local_source( + Bytes::copy_from_slice(&read_buf), + )) + .await; + if send_result.is_err() { + break 'outer send_result; + } + + // The chunk was sent. Reset the available buffer to allow storing the next chunk. + available_buf = &mut read_buf[..]; } - - // The chunk was sent. Reset the available buffer to allow storing the next chunk. - available_buf = &mut read_buf[..]; } + } else { + break read_result.map(|_| ()); } } } else { @@ -766,39 +768,46 @@ where if let Ok(num_bytes_read) = input_reader.read(available_buf) { if num_bytes_read == 0 { // EOF. Disconnect from the channel too. - tx_input.disconnect(); - break; + if let Err(e) = router_sink.close().await { + eprintln!("Failed to close router sink. Error {}", e); + } + + break Ok(()); } assert!(num_bytes_read <= available_buf.len()); // We've accumulated a full chunk, so send it now. if num_bytes_read == available_buf.len() { - if let Err(_) = - tx_input.unbounded_send(Ok(Bytes::copy_from_slice(&read_buf))) - { - eprintln!("breaking 1"); - break; + let send_result = router_sink + .send(SourcedBytes::create_with_local_source( + Bytes::copy_from_slice(&read_buf), + )) + .await; + if send_result.is_err() { + break send_result; } // The chunk was sent. Reset the buffer to allow storing the next chunk. available_buf = &mut read_buf[..]; } else { - // Read buffer isn't full yet. Set the available buffer to the rest of the buffer just past the - // portion that was written to. + // Read buffer isn't full yet. Set the available buffer to the rest of the buffer just past + // the portion that was written to. available_buf = &mut available_buf[num_bytes_read..]; } } } - } - }) - .expect("Failed to create input reader thread"); + }; - rx_input + result + }) + .map(|res| res.unwrap()) + .fuse(), + ) } // Return a sink and create a thread that writes everything from that sink to the specified writer. -fn setup_async_writer_thread(mut output_sink: TOutput) -> LocalIoSink +fn setup_async_writer_thread(mut output_writer: TOutput) -> LocalOutputSink where TOutput: std::io::Write + std::marker::Send + 'static, { @@ -813,7 +822,7 @@ where loop { match stream.next().await { Some(bytes) => { - output_sink.write_all(&bytes)?; + output_writer.write_all(&bytes)?; } None => { eprintln!("Output closed."); @@ -910,7 +919,7 @@ impl Iterator for FixedBytesIter { } } -fn setup_local_io(args: &NcArgs) -> LocalIoSinkAndStream { +fn setup_local_io(args: &NcArgs, router_sink: RouterSink) -> LocalOutputSinkAndInputDriver { if let Some(command) = args.exec_command_opt.as_ref() { // Pipe stdin and stdout from the child process here so it can be wired up to the network. let mut prog = execute::shell(command) @@ -922,30 +931,33 @@ fn setup_local_io(args: &NcArgs) -> LocalIoSinkAndStream { return ( Box::pin(setup_async_writer_thread(prog.stdin.take().unwrap())), - Box::pin(setup_async_reader_thread(prog.stdout.take().unwrap(), args)), + Box::pin(setup_async_reader_thread( + prog.stdout.take().unwrap(), + router_sink, + args, + )), ); } match args.input_mode { - InputMode::Null => ( + // Echo mode doesn't have a stream that produces data for sending. That will come directly from the sockets that + // send data to this machine. It's handled in the routing code. + InputMode::Null | InputMode::Echo => ( setup_local_output(args), - Box::pin(futures::stream::pending()), + Box::pin(futures::future::pending()), ), InputMode::Stdin | InputMode::StdinNoCharMode => ( setup_local_output(args), // Set up a thread to read from stdin. It will produce only chunks of the required size to send. - Box::pin(setup_async_reader_thread(std::io::stdin(), args)), - ), - - // Echo mode doesn't have a stream that produces data for sending. That will come directly from the sockets that - // send data to this machine. It's handled in the routing code. - InputMode::Echo => ( - setup_local_output(args), - Box::pin(futures::stream::pending()), + Box::pin(setup_async_reader_thread( + std::io::stdin(), + router_sink, + args, + )), ), InputMode::Random => ( setup_local_output(args), - local_io_stream_from_iter(RandBytesIter::new(&args.rand_config)), + local_input_driver_from_iter(RandBytesIter::new(&args.rand_config), router_sink), ), InputMode::Fixed => (setup_local_output(args), { // Create a random buffer of the size requested by send_size and containing data that matches the @@ -959,8 +971,8 @@ fn setup_local_io(args: &NcArgs) -> LocalIoSinkAndStream { let fixed_buf = RandBytesIter::new(&conf).next().unwrap().unwrap(); assert!(fixed_buf.len() == args.send_size as usize); - // Return a stream that will just keep producing that same fixed buffer forever. - local_io_stream_from_iter(FixedBytesIter::new(fixed_buf)) + // This will just keep sending that same fixed buffer to the router forever. + local_input_driver_from_iter(FixedBytesIter::new(fixed_buf), router_sink) }), } } @@ -1259,6 +1271,7 @@ async fn do_tcp( } } + let mut last_result = Ok(()); loop { // If we ever aren't at maximum clients accepted, start listening on all the specified addresses in order to // accept new clients. Only do this if we aren't currently listening. @@ -1361,6 +1374,8 @@ async fn do_tcp( } }; + last_result = result; + // Every time a socket closes, it's possible it was because the router finished, so check to // make sure the router is still active before doing things like reconnecting. if !router.is_done { @@ -1403,7 +1418,7 @@ async fn do_tcp( && listeners.is_empty() && accepts.is_empty() { - return Ok(()); + return last_result; } } } @@ -1411,7 +1426,7 @@ async fn do_tcp( async fn handle_tcp_stream( rx_socket: tokio::net::tcp::OwnedReadHalf, args: &NcArgs, - mut net_to_router_sink: NetToRouterSink, + mut net_to_router_sink: RouterSink, ) -> (std::io::Result<()>, RouteAddr) { let route_addr = RouteAddr::from_tcp_stream(&rx_socket); @@ -1595,22 +1610,14 @@ async fn handle_udp_sockets( } let (router_sink, mut inbound_net_traffic_stream) = mpsc::unbounded(); - let router_sink = router_sink.sink_map_err(map_unbounded_sink_err_to_io_err); + let mut router_sink = router_sink + .sink_map_err(map_unbounded_sink_err_to_io_err as fn(mpsc::SendError) -> std::io::Error); // If the local output (i.e. stdout) fails, we can set this to None to save perf on sending to it further. - let mut local_io_sink_opt = None; + let mut local_output_sink_opt = None; // Until the first peer is known, don't start pulling from local input, or else it will get consumed too early. - let mut local_io_stream: LocalIoStream = Box::pin(futures::stream::pending()); - - // Since there is only one local input and output (i.e. stdin/stdout), don't create a new channel to add it to - // the router. Instead just feed data into the router directly by tagging it as originating from the local - // input. - let mut local_io_to_router_sink = Box::pin( - router_sink - .clone() - .with(|b| async { SourcedBytes::ok_from_local(b) }), - ); + let mut local_input_driver: LocalInputDriver = Box::pin(futures::future::pending()); // A collection of all inbound traffic going to the router. let mut net_to_router_flows = FuturesUnordered::new(); @@ -1678,11 +1685,12 @@ async fn handle_udp_sockets( if lifetime_client_count == 1 { // Since we have a remote peer hooked up, start processing local IO. - let (local_io_sink, local_io_stream2) = setup_local_io(args); - local_io_stream = local_io_stream2; + let (local_output_sink, local_input_driver2) = + setup_local_io(args, router_sink.clone()); + local_input_driver = local_input_driver2; - assert!(local_io_sink_opt.is_none()); - local_io_sink_opt = Some(local_io_sink); + assert!(local_output_sink_opt.is_none()); + local_output_sink_opt = Some(local_output_sink); } } @@ -1731,10 +1739,10 @@ async fn handle_udp_sockets( // Don't add the local IO hookup until the first client is added, otherwise the router will pull all // the data out of, say, a redirected input stream, and forward it to nobody, because there are no // other clients. - if lifetime_client_count == 1 && local_io_sink_opt.is_none() { - let (local_io_sink, local_io_stream2) = setup_local_io(args); - local_io_stream = local_io_stream2; - local_io_sink_opt = Some(local_io_sink); + if lifetime_client_count == 1 && local_output_sink_opt.is_none() { + let (local_output_sink, local_input_driver2) = setup_local_io(args, router_sink.clone()); + local_input_driver = local_input_driver2; + local_output_sink_opt = Some(local_output_sink); } } @@ -1797,12 +1805,12 @@ async fn handle_udp_sockets( // Came from a remote endpoint, so also send to local IO. if sb.route.peer != LOCAL_IO_PEER_ADDR { stats_tracker.record_recv(&sb.data); - if let Some(ref mut local_io_sink) = local_io_sink_opt { + if let Some(ref mut local_output_sink) = local_output_sink_opt { // If we hit an error emitting output, clear out the local output sink so we don't bother // trying to output more. - if let Err(e) = local_io_sink.send(sb.data.clone()).await { + if let Err(e) = local_output_sink.send(sb.data.clone()).await { eprintln!("Local output closed. {}", e); - local_io_sink_opt = None; + local_output_sink_opt = None; } } } @@ -1830,11 +1838,12 @@ async fn handle_udp_sockets( } } }, - _result = local_io_to_router_sink.send_all(&mut local_io_stream).fuse() => { + _result = &mut local_input_driver => { eprintln!("End of outbound data from local machine reached."); - local_io_stream = Box::pin(futures::stream::pending()); + local_input_driver = Box::pin(futures::future::pending()); + // Send an empty message as a special signal that the local stream has finished. - local_io_to_router_sink.send(Bytes::new()).await.expect("Local IO sink should not be closed early!"); + router_sink.send(SourcedBytes::create_with_local_source(Bytes::new())).await.expect("Local IO sink should not be closed early!"); }, } } @@ -2098,8 +2107,7 @@ fn usage(msg: &str) -> ! { std::process::exit(1) } -#[tokio::main] -async fn main() -> Result<(), String> { +async fn async_main() -> Result<(), String> { let mut args = NcArgs::parse(); if args.help_more { @@ -2135,8 +2143,8 @@ async fn main() -> Result<(), String> { // Check and see if the user appended x123 or whatever as a multiplier at the end of the target string. if let Some(captures) = TARGET_MULTIPLIER_REGEX.captures_iter(&target).next() { - // Capture 0 is the entire matched text, so the part from the start up to the first captured - // character is the "before" portion. + // Capture 0 is the entire matched text, so the part from the start up to the first captured character + // is the "before" portion. let before_match = &target[..captures.get(0).unwrap().start()]; // Get the first capture, which should be the multiplier string. @@ -2175,9 +2183,9 @@ async fn main() -> Result<(), String> { } } - // Option::iter() makes an iterator that yields either 0 or 1 item, depending on if it's None or Some. - // The `true` param for get_local_addrs tells it to automatically include the wildcard local address if no source - // addresses were explicitly specified. + // Option::iter() makes an iterator that yields either 0 or 1 item, depending on if it's None or Some. The `true` + // param for get_local_addrs tells it to automatically include the wildcard local address if no source addresses + // were explicitly specified. let outbound_source_addrs = get_local_addrs( args.outbound_source_host_opt .iter() @@ -2204,8 +2212,8 @@ async fn main() -> Result<(), String> { .await .map_err(format_io_err)?; - // If max_inbound_connections wasn't specified explicitly, set its value automatically. If in hub or channel - // mode, you generally want more than one incoming client at a time, or else why are you in a forwarding mode?? + // If max_inbound_connections wasn't specified explicitly, set its value automatically. If in hub or channel mode, + // you generally want more than one incoming client at a time, or else why are you in a forwarding mode?? // Otherwise, safely limit to just one per user-specified listen address at a time. if args.max_inbound_connections.is_none() { args.max_inbound_connections = Some( @@ -2227,3 +2235,19 @@ async fn main() -> Result<(), String> { result.map_err(format_io_err) } + +fn main() -> Result<(), String> { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let result = runtime.block_on(async_main()); + + // At this point there may be a task blocked on reading from stdin or a child process's stdout. It won't return + // until the next read completes, but we don't want to wait for that. Given that we've already decided not to do + // anything further (the main task above is done), just shut down those lingering tasks so the program can exit. + runtime.shutdown_background(); + + result +}