diff --git a/Cargo.toml b/Cargo.toml index b3cf73bc..3fb8338a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ getrandom = "0.2.15" lazy_static = "1.4.0" md-5 = "0.10.6" moka = { version = "0.12.7", default-features = false, features = ["sync"] } -nix = { version = "0.29.0", default-features = false, features = ["fs"] } +nix = { version = "0.29.0", default-features = false, features = ["fs", "net", "socket"] } prometheus = { version = "0.13.4", default-features = false } proxy-protocol = "0.5.0" rustls = "0.23.10" diff --git a/src/lib.rs b/src/lib.rs index f2b210ab..fa9d8ed0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,6 +49,9 @@ pub mod notification; pub(crate) mod server; pub mod storage; -pub use crate::server::ftpserver::{error::ServerError, options, Server, ServerBuilder}; +pub use crate::server::{ + ftpserver::{error::ServerError, options, Server, ServerBuilder}, + RETR_SOCKETS, +}; type BoxError = Box; diff --git a/src/server/datachan.rs b/src/server/datachan.rs index 65480916..cc58a043 100644 --- a/src/server/datachan.rs +++ b/src/server/datachan.rs @@ -11,7 +11,15 @@ use crate::{ }; use crate::server::chancomms::DataChanCmd; -use std::{path::PathBuf, sync::Arc}; +use std::{ + net::SocketAddr, + os::fd::{AsRawFd, BorrowedFd, RawFd}, + path::PathBuf, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; use tokio::sync::mpsc::{Receiver, Sender}; @@ -42,7 +50,53 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Instant; -struct MeasuringWriter { +/// Holds information about a socket processing a RETR command +#[derive(Debug)] +pub struct RetrSocket { + bytes: AtomicU64, + fd: RawFd, + peer: SocketAddr, +} + +impl RetrSocket { + /// How many bytes have been written to the socket so far? + /// + /// Note that this tracks bytes written to the socket, not sent on the wire. + pub fn bytes(&self) -> u64 { + self.bytes.load(Ordering::Relaxed) + } + + pub fn fd(&self) -> BorrowedFd<'_> { + // Safe because we always destroy the RetrSocket when the MeasuringWriter drops + #[allow(unsafe_code)] + unsafe { + BorrowedFd::borrow_raw(self.fd) + } + } + + fn new(w: &W) -> nix::Result { + let fd = w.as_raw_fd(); + let ss: nix::sys::socket::SockaddrStorage = nix::sys::socket::getpeername(fd)?; + let peer = if let Some(sin) = ss.as_sockaddr_in() { + SocketAddr::V4((*sin).into()) + } else if let Some(sin6) = ss.as_sockaddr_in6() { + SocketAddr::V6((*sin6).into()) + } else { + return Err(nix::errno::Errno::EINVAL); + }; + let bytes = Default::default(); + Ok(RetrSocket { bytes, fd, peer }) + } + + pub fn peer(&self) -> &SocketAddr { + &self.peer + } +} + +/// Collection of all sockets currently serving RETR commands +pub static RETR_SOCKETS: std::sync::RwLock> = std::sync::RwLock::new(std::collections::BTreeMap::new()); + +struct MeasuringWriter { writer: W, command: &'static str, } @@ -52,12 +106,20 @@ struct MeasuringReader { command: &'static str, } -impl AsyncWrite for MeasuringWriter { +impl AsyncWrite for MeasuringWriter { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> std::task::Poll> { let this = self.get_mut(); let result = Pin::new(&mut this.writer).poll_write(cx, buf); if let Poll::Ready(Ok(bytes_written)) = &result { + let bw = *bytes_written as u64; + RETR_SOCKETS + .read() + .unwrap() + .get(&this.writer.as_raw_fd()) + .expect("TODO: better error handling") + .bytes + .fetch_add(bw, Ordering::Relaxed); metrics::inc_sent_bytes(*bytes_written, this.command); } @@ -87,12 +149,22 @@ impl AsyncRead for MeasuringReader { } } -impl MeasuringWriter { +impl MeasuringWriter { fn new(writer: W, command: &'static str) -> MeasuringWriter { + let retr_socket = RetrSocket::new(&writer).expect("TODO: better error handling"); + RETR_SOCKETS.write().unwrap().insert(retr_socket.fd, retr_socket); Self { writer, command } } } +impl Drop for MeasuringWriter { + fn drop(&mut self) { + if let Ok(mut guard) = RETR_SOCKETS.write() { + guard.remove(&self.writer.as_raw_fd()); + } + } +} + impl MeasuringReader { fn new(reader: R, command: &'static str) -> MeasuringReader { Self { reader, command } diff --git a/src/server/mod.rs b/src/server/mod.rs index 8e87d9c0..8c473791 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -17,4 +17,5 @@ pub(crate) use controlchan::reply::{Reply, ReplyCode}; pub(crate) use controlchan::ControlChanMiddleware; pub(crate) use controlchan::Event; pub(crate) use controlchan::{ControlChanError, ControlChanErrorKind}; +pub use datachan::RETR_SOCKETS; use session::{Session, SessionState};