Skip to content

Commit

Permalink
Merge pull request #74 from ikatson/small-refactor
Browse files Browse the repository at this point in the history
Small refactor - introduce ReadBuf
  • Loading branch information
ikatson authored Jan 2, 2024
2 parents 396bacf + 8ee9854 commit 51ed57a
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 97 deletions.
1 change: 1 addition & 0 deletions crates/librqbit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub mod http_api;
pub mod http_api_client;
mod peer_connection;
mod peer_info_reader;
mod read_buf;
mod session;
mod spawn_utils;
mod torrent_state;
Expand Down
98 changes: 26 additions & 72 deletions crates/librqbit/src/peer_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@ use librqbit_core::{id20::Id20, lengths::ChunkInfo, peer_id::try_decode_peer_id}
use parking_lot::RwLock;
use peer_binary_protocol::{
extended::{handshake::ExtendedHandshake, ExtendedMessage},
serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError,
MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
serialize_piece_preamble, Handshake, Message, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use tokio::time::timeout;
use tracing::trace;

use crate::spawn_utils::BlockingSpawner;
use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner};

pub trait PeerConnectionHandler {
fn on_connected(&self, _connection_time: Duration) {}
Expand Down Expand Up @@ -76,32 +75,6 @@ where
}
}

macro_rules! read_one {
($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{
let (extended, size) = loop {
match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) {
Ok((msg, size)) => break (msg, size),
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if $read_buf.len() < $read_so_far + d {
$read_buf.reserve(d);
$read_buf.resize($read_buf.capacity(), 0);
}

let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..]))
.await
.context("error reading from peer")?;
if size == 0 {
anyhow::bail!("disconnected while reading, read so far: {}", $read_so_far)
}
$read_so_far += size;
}
Err(e) => return Err(e.into()),
}
};
(extended, size)
}};
}

impl<H: PeerConnectionHandler> PeerConnection<H> {
pub fn new(
addr: SocketAddr,
Expand All @@ -126,9 +99,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
pub async fn manage_peer_incoming(
&self,
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
// How many bytes into read buffer have we read already.
read_so_far: usize,
read_buf: Vec<u8>,
read_buf: ReadBuf,
handshake: Handshake<ByteString>,
mut conn: tokio::net::TcpStream,
) -> anyhow::Result<()> {
Expand Down Expand Up @@ -166,7 +137,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {

self.manage_peer(
h_supports_extended,
read_so_far,
read_buf,
write_buf,
conn,
Expand All @@ -179,7 +149,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
&self,
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;

let rwtimeout = self
Expand All @@ -206,16 +175,11 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.context("error writing handshake")?;
write_buf.clear();

let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf))
let mut read_buf = ReadBuf::new();
let h = read_buf
.read_handshake(&mut conn, rwtimeout)
.await
.context("error reading handshake")?;
if read_so_far == 0 {
anyhow::bail!("bad handshake");
}
let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;

let h_supports_extended = h.supports_extended();
trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id)));
if h.info_hash != self.info_hash.0 {
Expand All @@ -228,14 +192,8 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {

self.handler.on_handshake(h)?;

if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;

self.manage_peer(
h_supports_extended,
read_so_far,
read_buf,
write_buf,
conn,
Expand All @@ -247,14 +205,11 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
async fn manage_peer(
&self,
handshake_supports_extended: bool,
// How many bytes into read_buf is there of peer-sent-data.
mut read_so_far: usize,
mut read_buf: Vec<u8>,
mut read_buf: ReadBuf,
mut write_buf: Vec<u8>,
mut conn: tokio::net::TcpStream,
mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;

let rwtimeout = self
Expand Down Expand Up @@ -354,31 +309,31 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {

let reader = async move {
loop {
let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout);
trace!("received: {:?}", &message);

if let Message::Extended(ExtendedMessage::Handshake(h)) = &message {
*extended_handshake_ref.write() = Some(h.clone_to_owned());
self.handler.on_extended_handshake(h)?;
trace!("remembered extended handshake for future serializing");
} else {
self.handler
.on_received_message(message)
.context("error in handler.on_received_message()")?;
}

if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
read_buf
.read_message(&mut read_half, rwtimeout, |message| {
trace!("received: {:?}", &message);

if let Message::Extended(ExtendedMessage::Handshake(h)) = &message {
*extended_handshake_ref.write() = Some(h.clone_to_owned());
self.handler.on_extended_handshake(h)?;
trace!("remembered extended handshake for future serializing");
} else {
self.handler
.on_received_message(message)
.context("error in handler.on_received_message()")?;
}
Ok(())
})
.await
.context("error reading message")?;
}

// For type inference.
#[allow(unreachable_code)]
Ok::<_, anyhow::Error>(())
};

let r = tokio::select! {
tokio::select! {
r = reader => {
trace!("reader is done, exiting");
r
Expand All @@ -387,7 +342,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
trace!("writer is done, exiting");
r
}
};
r
}
}
}
96 changes: 96 additions & 0 deletions crates/librqbit/src/read_buf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::time::Duration;

use crate::peer_connection::with_timeout;
use anyhow::Context;
use buffers::ByteBuf;
use peer_binary_protocol::{
Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN,
};
use tokio::io::AsyncReadExt;

pub struct ReadBuf {
buf: Vec<u8>,
// How many bytes into the buffer we have read from the connection.
// New reads should go past this.
filled: usize,
// How many bytes have we successfully deserialized.
processed: usize,
}

impl ReadBuf {
pub fn new() -> Self {
Self {
buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2],
filled: 0,
processed: 0,
}
}

fn prepare_for_read(&mut self, need_additional_bytes: usize) {
// Ensure the buffer starts from the to-be-deserialized message.
if self.processed > 0 {
if self.filled > self.processed {
self.buf.copy_within(self.processed..self.filled, 0);
}
self.filled -= self.processed;
self.processed = 0;
}

// Ensure we have enough capacity to deserialize the message.
if self.buf.len() < self.filled + need_additional_bytes {
self.buf.reserve(need_additional_bytes);
self.buf.resize(self.buf.capacity(), 0);
}
}

// Read the BT handshake.
// This MUST be run as the first operation on the buffer.
pub async fn read_handshake(
&mut self,
mut conn: impl AsyncReadExt + Unpin,
timeout: Duration,
) -> anyhow::Result<Handshake<ByteBuf<'_>>> {
self.filled = with_timeout(timeout, conn.read(&mut self.buf))
.await
.context("error reading handshake")?;
if self.filled == 0 {
anyhow::bail!("peer disconnected while reading handshake");
}
let (h, size) = Handshake::deserialize(&self.buf[..self.filled])
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
self.processed = size;
Ok(h)
}

// Read a message into the buffer, try to deserialize it and call the callback on it.
// We can't return the message because of a borrow checker issue.
pub async fn read_message(
&mut self,
mut conn: impl AsyncReadExt + Unpin,
timeout: Duration,
on_message: impl for<'a> FnOnce(MessageBorrowed<'a>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
loop {
let need_additional_bytes =
match MessageBorrowed::deserialize(&self.buf[self.processed..self.filled]) {
Err(MessageDeserializeError::NotEnoughData(d, _)) => d,
Ok((msg, size)) => {
self.processed += size;
// Rust's borrow checker can't do this early return. So we are using a callback instead.
// return Ok(msg);
on_message(msg)?;
return Ok(());
}
Err(e) => return Err(e.into()),
};
self.prepare_for_read(need_additional_bytes);
let size = with_timeout(timeout, conn.read(&mut self.buf[self.filled..]))
.await
.context("error reading from peer")?;
if size == 0 {
anyhow::bail!("disconnected while reading, read so far: {}", self.filled)
}
self.filled += size;
}
}
}
44 changes: 20 additions & 24 deletions crates/librqbit/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{

use anyhow::{bail, Context};
use bencode::{bencode_serialize_to_writer, BencodeDeserializer};
use buffers::{ByteBufT, ByteString};
use buffers::{ByteBuf, ByteBufT, ByteString};
use clone_to_owned::CloneToOwned;
use dht::{
Dht, DhtBuilder, DhtConfig, Id20, PersistentDht, PersistentDhtConfig, RequestPeersStream,
Expand All @@ -22,23 +22,23 @@ use librqbit_core::{
magnet::Magnet,
peer_id::generate_peer_id,
spawn_utils::spawn_with_cancel,
torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned},
torrent_metainfo::{
torrent_from_bytes as bencode_torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned,
},
};
use parking_lot::RwLock;
use peer_binary_protocol::{Handshake, PIECE_MESSAGE_DEFAULT_LEN};
use peer_binary_protocol::Handshake;
use reqwest::Url;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_with::serde_as;
use tokio::{
io::AsyncReadExt,
net::{TcpListener, TcpStream},
};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, error_span, info, trace, warn, Instrument};

use crate::{
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
peer_connection::{with_timeout, PeerConnectionOptions},
peer_connection::PeerConnectionOptions,
read_buf::ReadBuf,
spawn_utils::BlockingSpawner,
torrent_state::{
ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive,
Expand All @@ -49,6 +49,14 @@ pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"];

pub type TorrentId = usize;

fn torrent_from_bytes(bytes: &[u8]) -> anyhow::Result<TorrentMetaV1Owned> {
debug!(
"all fields in torrent: {:#?}",
bencode::dyn_from_bytes::<ByteBuf>(bytes)
);
bencode_torrent_from_bytes(bytes)
}

#[derive(Default)]
pub struct SessionDatabase {
next_id: TorrentId,
Expand Down Expand Up @@ -361,9 +369,8 @@ async fn create_tcp_listener(
pub(crate) struct CheckedIncomingConnection {
pub addr: SocketAddr,
pub stream: tokio::net::TcpStream,
pub read_buf: Vec<u8>,
pub read_buf: ReadBuf,
pub handshake: Handshake<ByteString>,
pub read_so_far: usize,
}

impl Session {
Expand Down Expand Up @@ -505,16 +512,11 @@ impl Session {
.read_write_timeout
.unwrap_or_else(|| Duration::from_secs(10));

let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
let mut read_so_far = with_timeout(rwtimeout, stream.read(&mut read_buf))
let mut read_buf = ReadBuf::new();
let h = read_buf
.read_handshake(&mut stream, rwtimeout)
.await
.context("error reading handshake")?;
if read_so_far == 0 {
anyhow::bail!("bad handshake");
}
let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;

trace!("received handshake from {addr}: {:?}", h);

if h.peer_id == self.peer_id.0 {
Expand All @@ -535,19 +537,13 @@ impl Session {

let handshake = h.clone_to_owned();

if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;

return Ok((
live,
CheckedIncomingConnection {
addr,
stream,
handshake,
read_buf,
read_so_far,
},
));
}
Expand Down
Loading

0 comments on commit 51ed57a

Please sign in to comment.