Skip to content

Commit

Permalink
Reconnect to peers; peer & torrent status
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanjermakov committed Oct 25, 2023
1 parent 3268df5 commit b33dda7
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 43 deletions.
115 changes: 98 additions & 17 deletions src/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use tokio::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
},
select,
select, spawn,
sync::Mutex,
time::{sleep, timeout},
};

use crate::{
hex::hex,
sha1,
state::{Block, Peer, PeerInfo, State, BLOCK_SIZE},
state::{Block, Peer, PeerInfo, PeerStatus, State, TorrentStatus, BLOCK_SIZE},
types::ByteString,
};

Expand Down Expand Up @@ -159,7 +159,7 @@ pub async fn handshake(
) -> Result<TcpStream> {
debug!("connecting to peer {peer:?}");
let mut stream = timeout(
Duration::new(4, 0),
Duration::from_secs(4),
TcpStream::connect(format!("{}:{}", peer.ip, peer.port)),
)
.await??;
Expand All @@ -169,7 +169,7 @@ pub async fn handshake(
}
.into();

debug!("writing handshake {}", hex(&handshake.to_vec()));
trace!("writing handshake {}", hex(&handshake.to_vec()));
stream.write_all(&handshake).await.context("write error")?;
stream.flush().await?;

Expand Down Expand Up @@ -259,25 +259,101 @@ pub async fn read_message(stream: &mut OwnedReadHalf) -> Result<Message> {
}
}
}?;
debug!("<<< read message: {:?}", msg);
trace!("<<< read message: {:?}", msg);
Ok(msg)
}

pub async fn send_message(stream: &mut OwnedWriteHalf, message: Message) -> Result<()> {
debug!(">>> sending message: {:?}", message);
trace!(">>> sending message: {:?}", message);
let msg_p: Vec<u8> = message.into();
trace!("raw message: {}", hex(&msg_p));
stream.write_all(&msg_p).await?;
stream.flush().await?;
Ok(())
}

pub async fn peer_loop(state: Arc<Mutex<State>>) -> Result<()> {
let mut handles = vec![];
loop {
debug!("reconnecting peers");
let peers: Vec<PeerInfo> = state
.lock()
.await
.peers
.values()
.filter(|p| p.status == PeerStatus::Disconnected)
.map(|p| p.info.clone())
.collect();
trace!("disconnected peers: {}", peers.len());
peers.into_iter().for_each(|p| {
let state = state.clone();
handles.push(spawn(async {
if let Err(e) = handle_peer(p, state).await.context("peer error") {
debug!("{e:#}");
};
}));
});
select!(
_ = async {
loop {
if state.lock().await.status == TorrentStatus::Downloaded {
return;
}
sleep(Duration::from_millis(1000)).await
}
} => {
// this is important to ensure that no tasks hold Arc<State> reference
trace!("closing {} peer connections", handles.len());
for h in handles {
h.abort();
let _ = h.await;
}
trace!("peer connections closed");
return Ok(())
},
_ = sleep(Duration::from_secs(10)) => ()
);
}
}

pub async fn handle_peer(peer: PeerInfo, state: Arc<Mutex<State>>) -> Result<()> {
let (info_hash, peer_id) = {
{
debug!("connecting to peer: {:?}", peer);
let mut state = state.lock().await;
state
.peers
.insert(peer.peer_id.clone(), Peer::new(peer.clone()));
match state.peers.get_mut(&peer.peer_id) {
Some(p) if p.status == PeerStatus::Connected => {
return Err(Error::msg("peer is already connected"))
}
Some(p) => p.status = PeerStatus::Connected,
None => {
let mut p = Peer::new(peer.clone());
p.status = PeerStatus::Connected;
state.peers.insert(peer.peer_id.clone(), p);
}
};
};

let res = do_handle_peer(peer.clone(), state.clone()).await;

debug!("peer disconnected: {:?}", peer);
state
.lock()
.await
.peers
.get_mut(&peer.peer_id)
.context("no peer")?
.status = if res.is_err() {
PeerStatus::Disconnected
} else {
PeerStatus::Done
};

res
}

pub async fn do_handle_peer(peer: PeerInfo, state: Arc<Mutex<State>>) -> Result<()> {
let (info_hash, peer_id) = {
let state = state.lock().await;
(state.info_hash.clone(), state.peer_id.clone())
};
let stream = handshake(&peer, &info_hash, &peer_id)
Expand All @@ -286,7 +362,7 @@ pub async fn handle_peer(peer: PeerInfo, state: Arc<Mutex<State>>) -> Result<()>
info!("successfull handshake with peer {:?}", peer);

if let Some(p) = state.lock().await.peers.get_mut(&peer.peer_id) {
p.connected = true;
p.status = PeerStatus::Connected;
}

let (r_stream, mut w_stream) = stream.into_split();
Expand Down Expand Up @@ -323,13 +399,15 @@ pub async fn write_loop(
}
_ => debug!("no peer {:?}", peer),
}
let piece = match { state.lock().await.next_piece() } {

let piece = match state.lock().await.next_piece() {
Some(p) => p,
None => {
_ => {
debug!("no more pieces to request, disconnecting");
return Ok(());
}
};

debug!("next request piece: {:?}", piece);
let total_blocks = piece.total_blocks();

Expand All @@ -345,7 +423,7 @@ pub async fn write_loop(
};
send_message(&mut stream, request_msg).await?;
}
sleep(Duration::new(0, 100e6 as u32)).await;
sleep(Duration::from_millis(100)).await;
}
}

Expand Down Expand Up @@ -383,15 +461,18 @@ async fn read_loop(
}
};
if piece.completed {
debug!("downloaded block of already completed piece, loss");
continue;
}
let total_blocks = piece.total_blocks();
if block_index != total_blocks - 1 && block.0.len() != BLOCK_SIZE as usize {
warn!("block of unexpected size: {}", block.0.len());
debug!("block of unexpected size: {}", block.0.len());
continue;
}
piece.blocks.insert(block_index, block);
debug!("got block {}/{}", piece.blocks.len(), total_blocks);
if piece.blocks.insert(block_index, block).is_some() {
debug!("repeaded block download, loss");
};
trace!("got block {}/{}", piece.blocks.len(), total_blocks);
if piece.blocks.len() as u32 == total_blocks {
let piece_data: Vec<u8> = piece
.blocks
Expand Down
31 changes: 26 additions & 5 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,32 @@ pub struct State {
pub peer_id: Vec<u8>,
pub pieces: BTreeMap<u32, Piece>,
pub peers: BTreeMap<ByteString, Peer>,
pub status: TorrentStatus,
}

impl State {
pub fn next_piece(&self) -> Option<Piece> {
self.pieces
pub fn next_piece(&mut self) -> Option<Piece> {
let piece = self
.pieces
.values()
.filter(|p| !p.completed)
.choose(&mut thread_rng())
.cloned()
.cloned();
if piece.is_none() {
debug!("torrent is completed");
self.status = TorrentStatus::Downloaded;
}
piece
}
}

#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)]
pub enum TorrentStatus {
Started,
Downloaded,
Saved,
}

#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)]
pub struct Piece {
pub hash: PieceHash,
Expand Down Expand Up @@ -67,7 +81,7 @@ impl fmt::Debug for Block {
#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)]
pub struct Peer {
pub info: PeerInfo,
pub connected: bool,
pub status: PeerStatus,
pub am_choked: bool,
pub am_interested: bool,
pub choked: bool,
Expand All @@ -79,7 +93,7 @@ impl Peer {
pub fn new(info: PeerInfo) -> Peer {
Peer {
info,
connected: false,
status: PeerStatus::Disconnected,
am_choked: true,
am_interested: false,
choked: true,
Expand All @@ -89,6 +103,13 @@ impl Peer {
}
}

#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)]
pub enum PeerStatus {
Disconnected,
Connected,
Done,
}

#[derive(Clone, PartialEq, PartialOrd, Hash)]
pub struct PeerInfo {
pub peer_id: ByteString,
Expand Down
48 changes: 27 additions & 21 deletions src/torrent.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use anyhow::{anyhow, ensure, Context, Error, Result};
use futures::future;
use std::path::Path;
use std::{collections::BTreeMap, fs, path::PathBuf, sync::Arc};
use std::{fs, path::PathBuf, sync::Arc};
use tokio::{spawn, sync::Mutex};

use crate::{
bencode::{parse_bencoded, BencodeValue},
metainfo::{FileInfo, Metainfo, PathInfo},
peer::handle_peer,
peer::peer_loop,
sha1,
state::{init_pieces, State},
state::{Peer, TorrentStatus},
tracker::{tracker_request, TrackerEvent, TrackerRequest, TrackerResponse},
types::ByteString,
};
Expand Down Expand Up @@ -45,25 +46,27 @@ pub async fn download_torrent(path: &Path, peer_id: &ByteString) -> Result<()> {
.context("request failed")?;
info!("tracker response: {tracker_response:?}");

let resp = match tracker_response {
TrackerResponse::Success(r) => r,
TrackerResponse::Failure { failure_reason } => return Err(Error::msg(failure_reason)),
};

let state = Arc::new(Mutex::new(State {
metainfo: metainfo.clone(),
info_hash,
peer_id: peer_id.to_vec(),
pieces: init_pieces(&metainfo.info),
peers: BTreeMap::new(),
peers: resp
.peers
.into_iter()
.map(|p| (p.peer_id.clone(), Peer::new(p)))
.collect(),
status: TorrentStatus::Started,
}));
trace!("init state: {:?}", state);

let resp = match tracker_response {
TrackerResponse::Success(r) => r,
TrackerResponse::Failure { failure_reason } => return Err(Error::msg(failure_reason)),
};
future::join_all(
resp.peers
.into_iter()
.map(|p| spawn(handle_peer(p, state.clone())))
.collect::<Vec<_>>(),
)
.await;
debug!("connecting to peers");
peer_loop(state.clone()).await?;

trace!("unwrapping state");
let state = Arc::try_unwrap(state)
Expand All @@ -80,33 +83,35 @@ pub async fn download_torrent(path: &Path, peer_id: &ByteString) -> Result<()> {
"incomplete pieces"
);

write_to_disk(state, metainfo).await?;
info!("writing files to disk");
write_to_disk(state).await?;
Ok(())
}

async fn write_to_disk(state: State, metainfo: Metainfo) -> Result<()> {
info!("partitioning pieces into files");
async fn write_to_disk(mut state: State) -> Result<()> {
debug!("partitioning pieces into files");
let mut data: Vec<u8> = state
.pieces
.into_values()
.flat_map(|p| p.blocks.into_values().flat_map(|b| b.0))
.collect();

let files = match metainfo.info.file_info {
let files = match state.metainfo.info.file_info {
FileInfo::Single { length, md5_sum } => vec![PathInfo {
length,
path: PathBuf::from(&metainfo.info.name),
path: PathBuf::from(&state.metainfo.info.name),
md5_sum,
}],
FileInfo::Multi { files } => files,
};

// TODO: check files md5_sum
info!("writing files");
let mut write_handles = vec![];
for file in files {
let file_data = data.drain(0..file.length as usize).collect();
let path = PathBuf::from("download")
.join(&metainfo.info.name)
.join(&state.metainfo.info.name)
.join(file.path.clone());
write_handles.push(spawn(write_file(path, file_data)))
}
Expand All @@ -115,7 +120,8 @@ async fn write_to_disk(state: State, metainfo: Metainfo) -> Result<()> {
return Err(Error::msg("file write errors"));
}

info!("torrent downloaded: {}", metainfo.info.name);
state.status = TorrentStatus::Saved;
info!("torrent downloaded: {}", state.metainfo.info.name);
Ok(())
}

Expand Down

0 comments on commit b33dda7

Please sign in to comment.