diff --git a/src/agent/handle.rs b/src/agent/handle.rs index 32ecf31..0369f50 100644 --- a/src/agent/handle.rs +++ b/src/agent/handle.rs @@ -5,6 +5,8 @@ use anyhow::Result; use byteorder::{BigEndian, ReadBytesExt}; use futures::channel::mpsc::UnboundedSender; use futures::SinkExt; +use std::convert::TryInto; +use std::io::Write; use std::io::{Cursor, Read}; use std::time::{SystemTime, UNIX_EPOCH}; @@ -17,6 +19,9 @@ pub enum SSHAgentPacket { SignRequest(Vec, Vec, u32), // hostname, socket_path, signature, key HostName(String, String, Vec, Vec), + // extension type, extension data + ExtensionRequest(String, Vec), + Unkown(u8), } pub fn parse_packet(packet: &Vec, _socket: &mut UnixStream) -> SSHAgentPacket { @@ -42,6 +47,19 @@ pub fn parse_packet(packet: &Vec, _socket: &mut UnixStream) -> SSHAgentPacke return SSHAgentPacket::SignRequest(key_blob, data, flags); } + if typ == 27 { + let extension_type_length = cursor.read_u32::().unwrap(); + let mut extension_type_data = vec![0u8; extension_type_length as usize]; + cursor.read_exact(&mut extension_type_data).unwrap(); + let extension_type = String::from_utf8(extension_type_data).unwrap(); + + let extension_data_length = cursor.read_u32::().unwrap(); + let mut extension_data = vec![0u8; extension_data_length as usize]; + cursor.read_exact(&mut extension_data).unwrap(); + + return SSHAgentPacket::ExtensionRequest(extension_type, extension_data); + } + if typ == 254 { let data_length = cursor.read_u32::().unwrap(); let mut data = vec![0u8; data_length as usize]; @@ -68,7 +86,19 @@ pub fn parse_packet(packet: &Vec, _socket: &mut UnixStream) -> SSHAgentPacke ); } - panic!("unknown packet") + return SSHAgentPacket::Unkown(typ); +} + +async fn reply_general_failure(socket: &mut UnixStream) -> Result<()> { + let typ = 5u8; + let mut msg_payload = vec![]; + msg_payload.write(&[typ])?; + let length = msg_payload.len() as u32; + + tokio::io::AsyncWriteExt::write_u32(socket, length).await?; + tokio::io::AsyncWriteExt::write_all(socket, &msg_payload).await?; + + Ok(()) } pub async fn read_and_handle_packet( @@ -78,7 +108,14 @@ pub async fn read_and_handle_packet( remove_proxy_send: UnboundedSender, ) -> Result<()> { loop { - let length_bytes = tokio::io::AsyncReadExt::read_i32(socket).await?; + socket.readable().await?; + let mut length_bytes_vec = [0u8; 4]; + let bytes_read = socket.try_read(&mut length_bytes_vec).unwrap_or(0); + if bytes_read != 4 { + // yes this will loop, but the readable call will fail if the socket goes away + continue; + } + let length_bytes = u32::from_be_bytes(length_bytes_vec); let mut msg = vec![0u8; length_bytes as usize]; tokio::io::AsyncReadExt::read_exact(socket, &mut msg).await?; @@ -111,6 +148,17 @@ pub async fn read_and_handle_packet( }) .await?; } + SSHAgentPacket::ExtensionRequest(extension_type, _) => { + println!("Received extension request: {}", extension_type); + reply_general_failure(socket).await?; + } + SSHAgentPacket::Unkown(unknown_type) => { + println!( + "Received unknown/unsupported message type: {}", + unknown_type + ); + reply_general_failure(socket).await?; + } } } }