diff --git a/src/server/datachan.rs b/src/server/datachan.rs index 7a5e6b84..d440f600 100644 --- a/src/server/datachan.rs +++ b/src/server/datachan.rs @@ -12,15 +12,15 @@ use crate::{ use crate::server::chancomms::DataChanCmd; use std::{ - net::SocketAddr, path::PathBuf, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }, + sync::Arc, }; #[cfg(unix)] -use std::os::fd::{AsRawFd, BorrowedFd, RawFd}; +use std::{ + net::SocketAddr, + os::fd::{AsRawFd, BorrowedFd, RawFd}, + sync::atomic::{AtomicU64, Ordering}, +}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; use tokio::sync::mpsc::{Receiver, Sender}; @@ -100,16 +100,23 @@ impl RetrSocket { #[cfg(unix)] pub static RETR_SOCKETS: std::sync::RwLock> = std::sync::RwLock::new(std::collections::BTreeMap::new()); +#[cfg(unix)] struct MeasuringWriter { writer: W, command: &'static str, } +#[cfg(not(unix))] +struct MeasuringWriter { + writer: W, + command: &'static str, +} struct MeasuringReader { reader: R, command: &'static str, } +#[cfg(unix)] impl AsyncWrite for MeasuringWriter { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> std::task::Poll> { let this = self.get_mut(); @@ -117,16 +124,38 @@ impl AsyncWrite for MeasuringWriter { let result = Pin::new(&mut this.writer).poll_write(cx, buf); if let Poll::Ready(Ok(bytes_written)) = &result { let bw = *bytes_written as u64; - #[cfg(unix)] - { - RETR_SOCKETS - .read() - .unwrap() - .get(&this.writer.as_raw_fd()) - .expect("TODO: better error handling") - .bytes - .fetch_add(bw, Ordering::Relaxed); - } + RETR_SOCKETS + .read() + .unwrap() + .get(&this.writer.as_raw_fd()) + .expect("TODO: better error handling") + .bytes + .fetch_add(bw, Ordering::Relaxed); + metrics::inc_sent_bytes(*bytes_written, this.command); + } + + result + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + Pin::new(&mut this.writer).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + Pin::new(&mut this.writer).poll_shutdown(cx) + } +} + +#[cfg(not(unix))] +impl AsyncWrite for MeasuringWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> std::task::Poll> { + let this = self.get_mut(); + + let result = Pin::new(&mut this.writer).poll_write(cx, buf); + if let Poll::Ready(Ok(bytes_written)) = &result { + let bw = *bytes_written as u64; metrics::inc_sent_bytes(*bytes_written, this.command); } diff --git a/src/server/mod.rs b/src/server/mod.rs index 8c473791..2892f2ee 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -17,5 +17,6 @@ pub(crate) use controlchan::reply::{Reply, ReplyCode}; pub(crate) use controlchan::ControlChanMiddleware; pub(crate) use controlchan::Event; pub(crate) use controlchan::{ControlChanError, ControlChanErrorKind}; +#[cfg(unix)] pub use datachan::RETR_SOCKETS; use session::{Session, SessionState};