From b33dda7bdc897a63984f8e3e8dfafc9e0ee0d801 Mon Sep 17 00:00:00 2001 From: ivanjermakov Date: Thu, 26 Oct 2023 01:21:44 +0200 Subject: [PATCH] Reconnect to peers; peer & torrent status --- src/peer.rs | 115 +++++++++++++++++++++++++++++++++++++++++-------- src/state.rs | 31 ++++++++++--- src/torrent.rs | 48 ++++++++++++--------- 3 files changed, 151 insertions(+), 43 deletions(-) diff --git a/src/peer.rs b/src/peer.rs index 4002ccb..8bfe613 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -8,7 +8,7 @@ use tokio::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpStream, }, - select, + select, spawn, sync::Mutex, time::{sleep, timeout}, }; @@ -16,7 +16,7 @@ use tokio::{ use crate::{ hex::hex, sha1, - state::{Block, Peer, PeerInfo, State, BLOCK_SIZE}, + state::{Block, Peer, PeerInfo, PeerStatus, State, TorrentStatus, BLOCK_SIZE}, types::ByteString, }; @@ -159,7 +159,7 @@ pub async fn handshake( ) -> Result { debug!("connecting to peer {peer:?}"); let mut stream = timeout( - Duration::new(4, 0), + Duration::from_secs(4), TcpStream::connect(format!("{}:{}", peer.ip, peer.port)), ) .await??; @@ -169,7 +169,7 @@ pub async fn handshake( } .into(); - debug!("writing handshake {}", hex(&handshake.to_vec())); + trace!("writing handshake {}", hex(&handshake.to_vec())); stream.write_all(&handshake).await.context("write error")?; stream.flush().await?; @@ -259,12 +259,12 @@ pub async fn read_message(stream: &mut OwnedReadHalf) -> Result { } } }?; - debug!("<<< read message: {:?}", msg); + trace!("<<< read message: {:?}", msg); Ok(msg) } pub async fn send_message(stream: &mut OwnedWriteHalf, message: Message) -> Result<()> { - debug!(">>> sending message: {:?}", message); + trace!(">>> sending message: {:?}", message); let msg_p: Vec = message.into(); trace!("raw message: {}", hex(&msg_p)); stream.write_all(&msg_p).await?; @@ -272,12 +272,88 @@ pub async fn send_message(stream: &mut OwnedWriteHalf, message: Message) -> Resu Ok(()) } +pub async fn peer_loop(state: Arc>) -> Result<()> { + let mut handles = vec![]; + loop { + debug!("reconnecting peers"); + let peers: Vec = state + .lock() + .await + .peers + .values() + .filter(|p| p.status == PeerStatus::Disconnected) + .map(|p| p.info.clone()) + .collect(); + trace!("disconnected peers: {}", peers.len()); + peers.into_iter().for_each(|p| { + let state = state.clone(); + handles.push(spawn(async { + if let Err(e) = handle_peer(p, state).await.context("peer error") { + debug!("{e:#}"); + }; + })); + }); + select!( + _ = async { + loop { + if state.lock().await.status == TorrentStatus::Downloaded { + return; + } + sleep(Duration::from_millis(1000)).await + } + } => { + // this is important to ensure that no tasks hold Arc reference + trace!("closing {} peer connections", handles.len()); + for h in handles { + h.abort(); + let _ = h.await; + } + trace!("peer connections closed"); + return Ok(()) + }, + _ = sleep(Duration::from_secs(10)) => () + ); + } +} + pub async fn handle_peer(peer: PeerInfo, state: Arc>) -> Result<()> { - let (info_hash, peer_id) = { + { + debug!("connecting to peer: {:?}", peer); let mut state = state.lock().await; - state - .peers - .insert(peer.peer_id.clone(), Peer::new(peer.clone())); + match state.peers.get_mut(&peer.peer_id) { + Some(p) if p.status == PeerStatus::Connected => { + return Err(Error::msg("peer is already connected")) + } + Some(p) => p.status = PeerStatus::Connected, + None => { + let mut p = Peer::new(peer.clone()); + p.status = PeerStatus::Connected; + state.peers.insert(peer.peer_id.clone(), p); + } + }; + }; + + let res = do_handle_peer(peer.clone(), state.clone()).await; + + debug!("peer disconnected: {:?}", peer); + state + .lock() + .await + .peers + .get_mut(&peer.peer_id) + .context("no peer")? + .status = if res.is_err() { + PeerStatus::Disconnected + } else { + PeerStatus::Done + }; + + res +} + +pub async fn do_handle_peer(peer: PeerInfo, state: Arc>) -> Result<()> { + let (info_hash, peer_id) = { + let state = state.lock().await; (state.info_hash.clone(), state.peer_id.clone()) }; let stream = handshake(&peer, &info_hash, &peer_id) @@ -286,7 +362,7 @@ pub async fn handle_peer(peer: PeerInfo, state: Arc>) -> Result<()> info!("successfull handshake with peer {:?}", peer); if let Some(p) = state.lock().await.peers.get_mut(&peer.peer_id) { - p.connected = true; + p.status = PeerStatus::Connected; } let (r_stream, mut w_stream) = stream.into_split(); @@ -323,13 +399,15 @@ pub async fn write_loop( } _ => debug!("no peer {:?}", peer), } - let piece = match { state.lock().await.next_piece() } { + + let piece = match state.lock().await.next_piece() { Some(p) => p, - None => { + _ => { debug!("no more pieces to request, disconnecting"); return Ok(()); } }; + debug!("next request piece: {:?}", piece); let total_blocks = piece.total_blocks(); @@ -345,7 +423,7 @@ pub async fn write_loop( }; send_message(&mut stream, request_msg).await?; } - sleep(Duration::new(0, 100e6 as u32)).await; + sleep(Duration::from_millis(100)).await; } } @@ -383,15 +461,18 @@ async fn read_loop( } }; if piece.completed { + debug!("downloaded block of already completed piece, loss"); continue; } let total_blocks = piece.total_blocks(); if block_index != total_blocks - 1 && block.0.len() != BLOCK_SIZE as usize { - warn!("block of unexpected size: {}", block.0.len()); + debug!("block of unexpected size: {}", block.0.len()); continue; } - piece.blocks.insert(block_index, block); - debug!("got block {}/{}", piece.blocks.len(), total_blocks); + if piece.blocks.insert(block_index, block).is_some() { + debug!("repeaded block download, loss"); + }; + trace!("got block {}/{}", piece.blocks.len(), total_blocks); if piece.blocks.len() as u32 == total_blocks { let piece_data: Vec = piece .blocks diff --git a/src/state.rs b/src/state.rs index e192a6b..13abc6b 100644 --- a/src/state.rs +++ b/src/state.rs @@ -18,18 +18,32 @@ pub struct State { pub peer_id: Vec, pub pieces: BTreeMap, pub peers: BTreeMap, + pub status: TorrentStatus, } impl State { - pub fn next_piece(&self) -> Option { - self.pieces + pub fn next_piece(&mut self) -> Option { + let piece = self + .pieces .values() .filter(|p| !p.completed) .choose(&mut thread_rng()) - .cloned() + .cloned(); + if piece.is_none() { + debug!("torrent is completed"); + self.status = TorrentStatus::Downloaded; + } + piece } } +#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)] +pub enum TorrentStatus { + Started, + Downloaded, + Saved, +} + #[derive(Clone, Debug, PartialEq, PartialOrd, Hash)] pub struct Piece { pub hash: PieceHash, @@ -67,7 +81,7 @@ impl fmt::Debug for Block { #[derive(Clone, Debug, PartialEq, PartialOrd, Hash)] pub struct Peer { pub info: PeerInfo, - pub connected: bool, + pub status: PeerStatus, pub am_choked: bool, pub am_interested: bool, pub choked: bool, @@ -79,7 +93,7 @@ impl Peer { pub fn new(info: PeerInfo) -> Peer { Peer { info, - connected: false, + status: PeerStatus::Disconnected, am_choked: true, am_interested: false, choked: true, @@ -89,6 +103,13 @@ impl Peer { } } +#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)] +pub enum PeerStatus { + Disconnected, + Connected, + Done, +} + #[derive(Clone, PartialEq, PartialOrd, Hash)] pub struct PeerInfo { pub peer_id: ByteString, diff --git a/src/torrent.rs b/src/torrent.rs index c23ca12..8c79098 100644 --- a/src/torrent.rs +++ b/src/torrent.rs @@ -1,15 +1,16 @@ use anyhow::{anyhow, ensure, Context, Error, Result}; use futures::future; use std::path::Path; -use std::{collections::BTreeMap, fs, path::PathBuf, sync::Arc}; +use std::{fs, path::PathBuf, sync::Arc}; use tokio::{spawn, sync::Mutex}; use crate::{ bencode::{parse_bencoded, BencodeValue}, metainfo::{FileInfo, Metainfo, PathInfo}, - peer::handle_peer, + peer::peer_loop, sha1, state::{init_pieces, State}, + state::{Peer, TorrentStatus}, tracker::{tracker_request, TrackerEvent, TrackerRequest, TrackerResponse}, types::ByteString, }; @@ -45,25 +46,27 @@ pub async fn download_torrent(path: &Path, peer_id: &ByteString) -> Result<()> { .context("request failed")?; info!("tracker response: {tracker_response:?}"); + let resp = match tracker_response { + TrackerResponse::Success(r) => r, + TrackerResponse::Failure { failure_reason } => return Err(Error::msg(failure_reason)), + }; + let state = Arc::new(Mutex::new(State { metainfo: metainfo.clone(), info_hash, peer_id: peer_id.to_vec(), pieces: init_pieces(&metainfo.info), - peers: BTreeMap::new(), + peers: resp + .peers + .into_iter() + .map(|p| (p.peer_id.clone(), Peer::new(p))) + .collect(), + status: TorrentStatus::Started, })); + trace!("init state: {:?}", state); - let resp = match tracker_response { - TrackerResponse::Success(r) => r, - TrackerResponse::Failure { failure_reason } => return Err(Error::msg(failure_reason)), - }; - future::join_all( - resp.peers - .into_iter() - .map(|p| spawn(handle_peer(p, state.clone()))) - .collect::>(), - ) - .await; + debug!("connecting to peers"); + peer_loop(state.clone()).await?; trace!("unwrapping state"); let state = Arc::try_unwrap(state) @@ -80,33 +83,35 @@ pub async fn download_torrent(path: &Path, peer_id: &ByteString) -> Result<()> { "incomplete pieces" ); - write_to_disk(state, metainfo).await?; + info!("writing files to disk"); + write_to_disk(state).await?; Ok(()) } -async fn write_to_disk(state: State, metainfo: Metainfo) -> Result<()> { - info!("partitioning pieces into files"); +async fn write_to_disk(mut state: State) -> Result<()> { + debug!("partitioning pieces into files"); let mut data: Vec = state .pieces .into_values() .flat_map(|p| p.blocks.into_values().flat_map(|b| b.0)) .collect(); - let files = match metainfo.info.file_info { + let files = match state.metainfo.info.file_info { FileInfo::Single { length, md5_sum } => vec![PathInfo { length, - path: PathBuf::from(&metainfo.info.name), + path: PathBuf::from(&state.metainfo.info.name), md5_sum, }], FileInfo::Multi { files } => files, }; + // TODO: check files md5_sum info!("writing files"); let mut write_handles = vec![]; for file in files { let file_data = data.drain(0..file.length as usize).collect(); let path = PathBuf::from("download") - .join(&metainfo.info.name) + .join(&state.metainfo.info.name) .join(file.path.clone()); write_handles.push(spawn(write_file(path, file_data))) } @@ -115,7 +120,8 @@ async fn write_to_disk(state: State, metainfo: Metainfo) -> Result<()> { return Err(Error::msg("file write errors")); } - info!("torrent downloaded: {}", metainfo.info.name); + state.status = TorrentStatus::Saved; + info!("torrent downloaded: {}", state.metainfo.info.name); Ok(()) }