From 708e12aec63a94d354c40648044f29bec4d3fdb3 Mon Sep 17 00:00:00 2001 From: ivanjermakov Date: Tue, 24 Oct 2023 15:30:11 +0200 Subject: [PATCH] Download pieces --- src/metainfo.rs | 4 +-- src/peer.rs | 93 +++++++++++++++++++++++++++++++++++++++---------- src/state.rs | 54 +++++++++++++++++++--------- src/tracker.rs | 2 +- 4 files changed, 116 insertions(+), 37 deletions(-) diff --git a/src/metainfo.rs b/src/metainfo.rs index 3cd0c70..87c9286 100644 --- a/src/metainfo.rs +++ b/src/metainfo.rs @@ -16,7 +16,7 @@ pub struct Metainfo { #[derive(Clone, PartialEq, PartialOrd, Hash)] pub struct Info { - pub piece_length: i64, + pub piece_length: u32, pub pieces: Vec, pub name: String, pub file_info: FileInfo, @@ -103,7 +103,7 @@ impl TryFrom for Metainfo { let metainfo = Metainfo { info: Info { piece_length: match info_dict.get("piece length") { - Some(BencodeValue::Int(v)) => *v, + Some(BencodeValue::Int(v)) => *v as u32, _ => return Err("'piece length' missing".into()), }, pieces, diff --git a/src/peer.rs b/src/peer.rs index d9733ec..49dd7fb 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -7,7 +7,8 @@ use tokio::sync::Mutex; use tokio::time::timeout; use crate::hex::hex; -use crate::state::{Block, PeerInfo, State}; +use crate::sha1; +use crate::state::{Block, PeerInfo, State, BLOCK_SIZE}; use crate::types::ByteString; #[derive(Debug)] @@ -263,29 +264,85 @@ pub async fn handle_peer(peer: PeerInfo, state: Arc>) -> Result<()> info!("successfull handshake with peer {:?}", peer); send_message(&mut stream, Message::Unchoke).await?; send_message(&mut stream, Message::Interested).await?; + loop { - match read_message(&mut stream).await { - Ok(Message::Choke) => { - continue; - } - Ok(msg) => { - if matches!(msg, Message::Unchoke) { - for i in 0..16 { - let block_size = 1 << 14; + let mut piece = match { state.lock().await.next_piece() } { + Some(p) => p, + None => return Ok(()), + }; + info!("next piece to request: {:?}", piece); + let total_blocks = (piece.length as f64 / BLOCK_SIZE as f64).ceil() as u32; + + loop { + match read_message(&mut stream).await { + Ok(Message::Choke) => continue, + Ok(Message::Unchoke) => { + for i in 0..total_blocks { let request_msg = Message::Request { - piece_index: 0, - begin: i * block_size, - length: block_size, + piece_index: piece.index, + begin: i * BLOCK_SIZE, + length: if i == total_blocks - 1 + && piece.length % BLOCK_SIZE != 0 + { + piece.length % BLOCK_SIZE + } else { + BLOCK_SIZE + }, }; send_message(&mut stream, request_msg).await?; } } - } - Err(e) => { - warn!("{}", e); - break; - } - }; + Ok(Message::Piece { + piece_index, + begin, + block, + }) => { + if piece_index != piece.index as u32 { + debug!("block for another piece, ignoring"); + continue; + } + if begin % BLOCK_SIZE != 0 { + warn!("block begin is not a multiple of block size"); + continue; + } + let block_index = begin / BLOCK_SIZE; + if block_index != total_blocks - 1 + && block.0.len() != BLOCK_SIZE as usize + { + warn!("block of unexpected size: {}", block.0.len()); + continue; + } + piece.blocks.insert(block_index, block); + debug!("got block {}/{}", piece.blocks.len(), total_blocks); + if piece.blocks.len() as u32 == total_blocks { + let piece_data: Vec = piece + .blocks + .values() + .flat_map(|b| b.0.as_slice()) + .copied() + .collect(); + let piece_hash = sha1::encode(piece_data); + if piece_hash != piece.hash.0 { + warn!("piece hash does not match: {:?}", piece); + trace!("{}", hex(&piece_hash)); + trace!("{}", hex(&piece.hash.0)); + continue; + } + info!("piece completed: {:?}", piece); + piece.completed = true; + state.lock().await.pieces.insert(piece.index, piece.clone()); + break; + } + } + Ok(msg) => { + debug!("no handler for message, skipping: {:?}", msg); + } + Err(e) => { + warn!("{}", e); + break; + } + }; + } } } Err(e) => warn!("handshake error: {}", e), diff --git a/src/state.rs b/src/state.rs index 393bd40..197329d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,27 +1,43 @@ use core::fmt; use std::collections::BTreeMap; +use rand::{seq::IteratorRandom, thread_rng}; + use crate::{ hex::hex, metainfo::{Info, Metainfo}, types::ByteString, }; +pub const BLOCK_SIZE: u32 = 1 << 14; + #[derive(Clone, Debug, PartialEq, PartialOrd, Hash)] pub struct State { pub metainfo: Metainfo, pub info_hash: Vec, pub peer_id: Vec, - pub pieces: Vec, + pub pieces: BTreeMap, pub peers: BTreeMap, } +impl State { + pub fn next_piece(&self) -> Option { + self.pieces + .values() + .filter(|p| !p.completed) + .choose(&mut thread_rng()) + .cloned() + } +} + #[derive(Clone, Debug, PartialEq, PartialOrd, Hash)] pub struct Piece { pub hash: PieceHash, - pub index: i64, - pub length: i64, - pub blocks: Vec, + pub index: u32, + pub length: u32, + /// Map of blocks -> + pub blocks: BTreeMap, + pub completed: bool, } #[derive(Clone, PartialEq, PartialOrd, Hash)] @@ -38,7 +54,7 @@ pub struct Block(pub Vec); impl fmt::Debug for Block { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("") + write!(f, "", self.0.len()) } } @@ -75,22 +91,28 @@ impl fmt::Debug for PeerInfo { } } -pub fn init_pieces(info: &Info) -> Vec { - let total_len = info.file_info.total_length(); +pub fn init_pieces(info: &Info) -> BTreeMap { + let total_len = info.file_info.total_length() as u32; assert!(info.pieces.len() == (total_len as f64 / info.piece_length as f64).ceil() as usize); info.pieces .iter() .cloned() .enumerate() - .map(|(i, p)| Piece { - hash: p, - index: i as i64, - length: if i == info.pieces.len() - 1 { - total_len % info.piece_length - } else { - info.piece_length - }, - blocks: vec![], + .map(|(i, p)| { + ( + i as u32, + Piece { + hash: p, + index: i as u32, + length: if i == info.pieces.len() - 1 { + total_len % info.piece_length + } else { + info.piece_length + }, + blocks: BTreeMap::new(), + completed: false, + }, + ) }) .collect() } diff --git a/src/tracker.rs b/src/tracker.rs index d57c468..bc2595e 100644 --- a/src/tracker.rs +++ b/src/tracker.rs @@ -202,7 +202,7 @@ pub async fn tracker_request(announce: String, request: TrackerRequest) -> Resul debug!("raw response: {}", String::from_utf8_lossy(&resp)); let resp_dict = parse_bencoded(resp.to_vec()) .0 - .context("Malformed response")?; + .context("malformed response")?; debug!("response: {resp_dict:?}"); TrackerResponse::try_from(resp_dict).map_err(Error::msg) }