diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index 434bf270..38091bf8 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -31,6 +31,7 @@ pub mod http_api; pub mod http_api_client; mod peer_connection; mod peer_info_reader; +mod read_buf; mod session; mod spawn_utils; mod torrent_state; diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index 9a429bf4..e565ede4 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -10,15 +10,14 @@ use librqbit_core::{id20::Id20, lengths::ChunkInfo, peer_id::try_decode_peer_id} use parking_lot::RwLock; use peer_binary_protocol::{ extended::{handshake::ExtendedHandshake, ExtendedMessage}, - serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError, - MessageOwned, PIECE_MESSAGE_DEFAULT_LEN, + serialize_piece_preamble, Handshake, Message, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN, }; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use tokio::time::timeout; use tracing::trace; -use crate::spawn_utils::BlockingSpawner; +use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner}; pub trait PeerConnectionHandler { fn on_connected(&self, _connection_time: Duration) {} @@ -76,32 +75,6 @@ where } } -macro_rules! read_one { - ($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{ - let (extended, size) = loop { - match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) { - Ok((msg, size)) => break (msg, size), - Err(MessageDeserializeError::NotEnoughData(d, _)) => { - if $read_buf.len() < $read_so_far + d { - $read_buf.reserve(d); - $read_buf.resize($read_buf.capacity(), 0); - } - - let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..])) - .await - .context("error reading from peer")?; - if size == 0 { - anyhow::bail!("disconnected while reading, read so far: {}", $read_so_far) - } - $read_so_far += size; - } - Err(e) => return Err(e.into()), - } - }; - (extended, size) - }}; -} - impl PeerConnection { pub fn new( addr: SocketAddr, @@ -126,9 +99,7 @@ impl PeerConnection { pub async fn manage_peer_incoming( &self, outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, - // How many bytes into read buffer have we read already. - read_so_far: usize, - read_buf: Vec, + read_buf: ReadBuf, handshake: Handshake, mut conn: tokio::net::TcpStream, ) -> anyhow::Result<()> { @@ -166,7 +137,6 @@ impl PeerConnection { self.manage_peer( h_supports_extended, - read_so_far, read_buf, write_buf, conn, @@ -179,7 +149,6 @@ impl PeerConnection { &self, outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { - use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; let rwtimeout = self @@ -206,16 +175,11 @@ impl PeerConnection { .context("error writing handshake")?; write_buf.clear(); - let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; - let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf)) + let mut read_buf = ReadBuf::new(); + let h = read_buf + .read_handshake(&mut conn, rwtimeout) .await .context("error reading handshake")?; - if read_so_far == 0 { - anyhow::bail!("bad handshake"); - } - let (h, size) = Handshake::deserialize(&read_buf[..read_so_far]) - .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; - let h_supports_extended = h.supports_extended(); trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id))); if h.info_hash != self.info_hash.0 { @@ -228,14 +192,8 @@ impl PeerConnection { self.handler.on_handshake(h)?; - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; - self.manage_peer( h_supports_extended, - read_so_far, read_buf, write_buf, conn, @@ -247,14 +205,11 @@ impl PeerConnection { async fn manage_peer( &self, handshake_supports_extended: bool, - // How many bytes into read_buf is there of peer-sent-data. - mut read_so_far: usize, - mut read_buf: Vec, + mut read_buf: ReadBuf, mut write_buf: Vec, mut conn: tokio::net::TcpStream, mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { - use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; let rwtimeout = self @@ -354,23 +309,23 @@ impl PeerConnection { let reader = async move { loop { - let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout); - trace!("received: {:?}", &message); - - if let Message::Extended(ExtendedMessage::Handshake(h)) = &message { - *extended_handshake_ref.write() = Some(h.clone_to_owned()); - self.handler.on_extended_handshake(h)?; - trace!("remembered extended handshake for future serializing"); - } else { - self.handler - .on_received_message(message) - .context("error in handler.on_received_message()")?; - } - - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; + read_buf + .read_message(&mut read_half, rwtimeout, |message| { + trace!("received: {:?}", &message); + + if let Message::Extended(ExtendedMessage::Handshake(h)) = &message { + *extended_handshake_ref.write() = Some(h.clone_to_owned()); + self.handler.on_extended_handshake(h)?; + trace!("remembered extended handshake for future serializing"); + } else { + self.handler + .on_received_message(message) + .context("error in handler.on_received_message()")?; + } + Ok(()) + }) + .await + .context("error reading message")?; } // For type inference. @@ -378,7 +333,7 @@ impl PeerConnection { Ok::<_, anyhow::Error>(()) }; - let r = tokio::select! { + tokio::select! { r = reader => { trace!("reader is done, exiting"); r @@ -387,7 +342,6 @@ impl PeerConnection { trace!("writer is done, exiting"); r } - }; - r + } } } diff --git a/crates/librqbit/src/read_buf.rs b/crates/librqbit/src/read_buf.rs new file mode 100644 index 00000000..5867db78 --- /dev/null +++ b/crates/librqbit/src/read_buf.rs @@ -0,0 +1,96 @@ +use std::time::Duration; + +use crate::peer_connection::with_timeout; +use anyhow::Context; +use buffers::ByteBuf; +use peer_binary_protocol::{ + Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN, +}; +use tokio::io::AsyncReadExt; + +pub struct ReadBuf { + buf: Vec, + // How many bytes into the buffer we have read from the connection. + // New reads should go past this. + filled: usize, + // How many bytes have we successfully deserialized. + processed: usize, +} + +impl ReadBuf { + pub fn new() -> Self { + Self { + buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2], + filled: 0, + processed: 0, + } + } + + fn prepare_for_read(&mut self, need_additional_bytes: usize) { + // Ensure the buffer starts from the to-be-deserialized message. + if self.processed > 0 { + if self.filled > self.processed { + self.buf.copy_within(self.processed..self.filled, 0); + } + self.filled -= self.processed; + self.processed = 0; + } + + // Ensure we have enough capacity to deserialize the message. + if self.buf.len() < self.filled + need_additional_bytes { + self.buf.reserve(need_additional_bytes); + self.buf.resize(self.buf.capacity(), 0); + } + } + + // Read the BT handshake. + // This MUST be run as the first operation on the buffer. + pub async fn read_handshake( + &mut self, + mut conn: impl AsyncReadExt + Unpin, + timeout: Duration, + ) -> anyhow::Result>> { + self.filled = with_timeout(timeout, conn.read(&mut self.buf)) + .await + .context("error reading handshake")?; + if self.filled == 0 { + anyhow::bail!("peer disconnected while reading handshake"); + } + let (h, size) = Handshake::deserialize(&self.buf[..self.filled]) + .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; + self.processed = size; + Ok(h) + } + + // Read a message into the buffer, try to deserialize it and call the callback on it. + // We can't return the message because of a borrow checker issue. + pub async fn read_message( + &mut self, + mut conn: impl AsyncReadExt + Unpin, + timeout: Duration, + on_message: impl for<'a> FnOnce(MessageBorrowed<'a>) -> anyhow::Result<()>, + ) -> anyhow::Result<()> { + loop { + let need_additional_bytes = + match MessageBorrowed::deserialize(&self.buf[self.processed..self.filled]) { + Err(MessageDeserializeError::NotEnoughData(d, _)) => d, + Ok((msg, size)) => { + self.processed += size; + // Rust's borrow checker can't do this early return. So we are using a callback instead. + // return Ok(msg); + on_message(msg)?; + return Ok(()); + } + Err(e) => return Err(e.into()), + }; + self.prepare_for_read(need_additional_bytes); + let size = with_timeout(timeout, conn.read(&mut self.buf[self.filled..])) + .await + .context("error reading from peer")?; + if size == 0 { + anyhow::bail!("disconnected while reading, read so far: {}", self.filled) + } + self.filled += size; + } + } +} diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index c825988f..0f1e7c64 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -11,7 +11,7 @@ use std::{ use anyhow::{bail, Context}; use bencode::{bencode_serialize_to_writer, BencodeDeserializer}; -use buffers::{ByteBufT, ByteString}; +use buffers::{ByteBuf, ByteBufT, ByteString}; use clone_to_owned::CloneToOwned; use dht::{ Dht, DhtBuilder, DhtConfig, Id20, PersistentDht, PersistentDhtConfig, RequestPeersStream, @@ -22,23 +22,23 @@ use librqbit_core::{ magnet::Magnet, peer_id::generate_peer_id, spawn_utils::spawn_with_cancel, - torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned}, + torrent_metainfo::{ + torrent_from_bytes as bencode_torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned, + }, }; use parking_lot::RwLock; -use peer_binary_protocol::{Handshake, PIECE_MESSAGE_DEFAULT_LEN}; +use peer_binary_protocol::Handshake; use reqwest::Url; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_with::serde_as; -use tokio::{ - io::AsyncReadExt, - net::{TcpListener, TcpStream}, -}; +use tokio::net::{TcpListener, TcpStream}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, - peer_connection::{with_timeout, PeerConnectionOptions}, + peer_connection::PeerConnectionOptions, + read_buf::ReadBuf, spawn_utils::BlockingSpawner, torrent_state::{ ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, @@ -49,6 +49,14 @@ pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"]; pub type TorrentId = usize; +fn torrent_from_bytes(bytes: &[u8]) -> anyhow::Result { + debug!( + "all fields in torrent: {:#?}", + bencode::dyn_from_bytes::(bytes) + ); + bencode_torrent_from_bytes(bytes) +} + #[derive(Default)] pub struct SessionDatabase { next_id: TorrentId, @@ -361,9 +369,8 @@ async fn create_tcp_listener( pub(crate) struct CheckedIncomingConnection { pub addr: SocketAddr, pub stream: tokio::net::TcpStream, - pub read_buf: Vec, + pub read_buf: ReadBuf, pub handshake: Handshake, - pub read_so_far: usize, } impl Session { @@ -505,16 +512,11 @@ impl Session { .read_write_timeout .unwrap_or_else(|| Duration::from_secs(10)); - let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; - let mut read_so_far = with_timeout(rwtimeout, stream.read(&mut read_buf)) + let mut read_buf = ReadBuf::new(); + let h = read_buf + .read_handshake(&mut stream, rwtimeout) .await .context("error reading handshake")?; - if read_so_far == 0 { - anyhow::bail!("bad handshake"); - } - let (h, size) = Handshake::deserialize(&read_buf[..read_so_far]) - .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; - trace!("received handshake from {addr}: {:?}", h); if h.peer_id == self.peer_id.0 { @@ -535,11 +537,6 @@ impl Session { let handshake = h.clone_to_owned(); - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; - return Ok(( live, CheckedIncomingConnection { @@ -547,7 +544,6 @@ impl Session { stream, handshake, read_buf, - read_so_far, }, )); } diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 72a7db5c..83c459ad 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -457,7 +457,6 @@ impl TorrentStateLive { r = requester => {r} r = peer_connection.manage_peer_incoming( rx, - checked_peer.read_so_far, checked_peer.read_buf, checked_peer.handshake, checked_peer.stream