Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and refactor #37

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 172 additions & 107 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
select,
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
task::JoinHandle,
};

mod error;
Expand Down Expand Up @@ -77,120 +78,21 @@ impl IpStackConfig {

pub struct IpStack {
accept_receiver: UnboundedReceiver<IpStackStream>,
pub handle: JoinHandle<Result<()>>,
}

impl IpStack {
pub fn new<D>(config: IpStackConfig, mut device: D) -> IpStack
pub fn new<D>(config: IpStackConfig, device: D) -> IpStack
where
D: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (accept_sender, accept_receiver) = mpsc::unbounded_channel::<IpStackStream>();
let handle = run(config, device, accept_sender);

tokio::spawn(async move {
let mut streams: AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>> =
AHashMap::new();
let mut buffer = [0u8; u16::MAX as usize];

let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
loop {
select! {
Ok(n) = device.read(&mut buffer) => {
let offset = if config.packet_information && cfg!(unix) {4} else {0};
let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else {
accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?;
continue;
};
if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
accept_sender.send(
IpStackStream::UnknownTransport(IpStackUnknownTransport::new(
packet.src_addr().ip(),
packet.dst_addr().ip(),
packet.payload,
&packet.ip,
config.mtu,
pkt_sender.clone()
))
)?;
continue;
}

match streams.entry(packet.network_tuple()){
Occupied(entry) =>{
if let Err(e) = entry.get().send(packet){
trace!("Send packet error \"{}\"", e);
}
}
Vacant(entry) => {
match packet.transport_protocol(){
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(
packet.src_addr(),
packet.dst_addr(),
h,
pkt_sender.clone(),
config.mtu,
config.tcp_timeout
){
Ok(stream) => {
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Tcp(stream))?;
}
Err(e) => {
if matches!(e,IpStackError::InvalidTcpPacket){
trace!("Invalid TCP packet");
continue;
}
error!("IpStackTcpStream::new failed \"{}\"", e);
}
}
}
IpStackPacketProtocol::Udp => {
let stream = IpStackUdpStream::new(
packet.src_addr(),
packet.dst_addr(),
packet.payload,
pkt_sender.clone(),
config.mtu,
config.udp_timeout
);
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Udp(stream))?;
}
IpStackPacketProtocol::Unknown => {
unreachable!()
}
}
}
}
}
Some(packet) = pkt_receiver.recv() => {
if packet.ttl() == 0{
streams.remove(&packet.reverse_network_tuple());
continue;
}
#[allow(unused_mut)]
let Ok(mut packet_byte) = packet.to_bytes() else{
trace!("to_bytes error");
continue;
};
#[cfg(unix)]
if config.packet_information {
if packet.src_addr().is_ipv4(){
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
} else{
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
}
}
device.write_all(&packet_byte).await?;
// device.flush().await.unwrap();
}
}
}
#[allow(unreachable_code)]
Ok::<(), IpStackError>(())
});

IpStack { accept_receiver }
IpStack {
accept_receiver,
handle,
}
}

pub async fn accept(&mut self) -> Result<IpStackStream, IpStackError> {
Expand All @@ -200,3 +102,166 @@ impl IpStack {
.ok_or(IpStackError::AcceptError)
}
}

fn run<D>(
config: IpStackConfig,
mut device: D,
accept_sender: UnboundedSender<IpStackStream>,
) -> JoinHandle<Result<()>>
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut streams: AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>> = AHashMap::new();
let offset = if config.packet_information && cfg!(unix) {
4
} else {
0
};
let mut buffer = [0_u8; u16::MAX as usize + 4];
let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

tokio::spawn(async move {
loop {
select! {
Ok(n) = device.read(&mut buffer) => {
if let Some(stream) = process_read(
&buffer[offset..n],
&mut streams,
&pkt_sender,
&config,
)? {
accept_sender.send(stream)?;
}
}
Some(packet) = pkt_receiver.recv() => {
process_recv(
packet,
&mut streams,
&mut device,
#[cfg(unix)]
config.packet_information,
)
.await?;
}
}
}
})
}

fn process_read(
data: &[u8],
streams: &mut AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>>,
pkt_sender: &UnboundedSender<NetworkPacket>,
config: &IpStackConfig,
) -> Result<Option<IpStackStream>> {
let Ok(packet) = NetworkPacket::parse(data) else {
return Ok(Some(IpStackStream::UnknownNetwork(data.to_owned())));
};

if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
return Ok(Some(IpStackStream::UnknownTransport(
IpStackUnknownTransport::new(
packet.src_addr().ip(),
packet.dst_addr().ip(),
packet.payload,
&packet.ip,
config.mtu,
pkt_sender.clone(),
),
)));
}

Ok(match streams.entry(packet.network_tuple()) {
Occupied(mut entry) => {
if let Err(e) = entry.get().send(packet) {
trace!("New stream because: {}", e);
create_stream(e.0, config, pkt_sender)?.map(|s| {
entry.insert(s.0);
s.1
})
} else {
None
}
}
Vacant(entry) => create_stream(packet, config, pkt_sender)?.map(|s| {
entry.insert(s.0);
s.1
}),
})
}

fn create_stream(
packet: NetworkPacket,
config: &IpStackConfig,
pkt_sender: &UnboundedSender<NetworkPacket>,
) -> Result<Option<(UnboundedSender<NetworkPacket>, IpStackStream)>> {
match packet.transport_protocol() {
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(
packet.src_addr(),
packet.dst_addr(),
h,
pkt_sender.clone(),
config.mtu,
config.tcp_timeout,
) {
Ok(stream) => Ok(Some((stream.stream_sender(), IpStackStream::Tcp(stream)))),
Err(e) => {
if matches!(e, IpStackError::InvalidTcpPacket) {
trace!("Invalid TCP packet");
Ok(None)
} else {
error!("IpStackTcpStream::new failed \"{}\"", e);
Err(e)
}
}
}
}
IpStackPacketProtocol::Udp => {
let stream = IpStackUdpStream::new(
packet.src_addr(),
packet.dst_addr(),
packet.payload,
pkt_sender.clone(),
config.mtu,
config.udp_timeout,
);
Ok(Some((stream.stream_sender(), IpStackStream::Udp(stream))))
}
IpStackPacketProtocol::Unknown => {
unreachable!()
}
}
}

async fn process_recv<D>(
packet: NetworkPacket,
streams: &mut AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>>,
device: &mut D,
#[cfg(unix)] packet_information: bool,
) -> Result<()>
where
D: AsyncWrite + Unpin + 'static,
{
if packet.ttl() == 0 {
streams.remove(&packet.reverse_network_tuple());
return Ok(());
}
#[allow(unused_mut)]
let Ok(mut packet_byte) = packet.to_bytes() else {
trace!("to_bytes error");
return Ok(());
};
#[cfg(unix)]
if packet_information {
if packet.src_addr().is_ipv4() {
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
} else {
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
}
}
device.write_all(&packet_byte).await?;
// device.flush().await.unwrap();

Ok(())
}
Loading