Skip to content

Commit

Permalink
Fix lost packets if channel was closed, Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
RoDmitry committed Apr 8, 2024
1 parent 409692f commit b2fd8b4
Showing 1 changed file with 164 additions and 107 deletions.
271 changes: 164 additions & 107 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,117 +82,12 @@ pub struct IpStack {
}

impl IpStack {
pub fn new<D>(config: IpStackConfig, mut device: D) -> IpStack
pub fn new<D>(config: IpStackConfig, device: D) -> IpStack
where
D: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
{
let (accept_sender, accept_receiver) = mpsc::unbounded_channel::<IpStackStream>();

let handle = tokio::spawn(async move {
let mut streams: AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>> =
AHashMap::new();
let offset = if config.packet_information && cfg!(unix) {
4
} else {
0
};
let mut buffer = [0u8; u16::MAX as usize + 4];

let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
loop {
select! {
Ok(n) = device.read(&mut buffer) => {
let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else {
accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?;
continue;
};
if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
accept_sender.send(
IpStackStream::UnknownTransport(IpStackUnknownTransport::new(
packet.src_addr().ip(),
packet.dst_addr().ip(),
packet.payload,
&packet.ip,
config.mtu,
pkt_sender.clone()
))
)?;
continue;
}

match streams.entry(packet.network_tuple()){
Occupied(entry) =>{
if let Err(e) = entry.get().send(packet){
trace!("Send packet error \"{}\"", e);
}
}
Vacant(entry) => {
match packet.transport_protocol(){
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(
packet.src_addr(),
packet.dst_addr(),
h,
pkt_sender.clone(),
config.mtu,
config.tcp_timeout
){
Ok(stream) => {
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Tcp(stream))?;
}
Err(e) => {
if matches!(e,IpStackError::InvalidTcpPacket){
trace!("Invalid TCP packet");
continue;
}
error!("IpStackTcpStream::new failed \"{}\"", e);
}
}
}
IpStackPacketProtocol::Udp => {
let stream = IpStackUdpStream::new(
packet.src_addr(),
packet.dst_addr(),
packet.payload,
pkt_sender.clone(),
config.mtu,
config.udp_timeout
);
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Udp(stream))?;
}
IpStackPacketProtocol::Unknown => {
unreachable!()
}
}
}
}
}
Some(packet) = pkt_receiver.recv() => {
if packet.ttl() == 0{
streams.remove(&packet.reverse_network_tuple());
continue;
}
#[allow(unused_mut)]
let Ok(mut packet_byte) = packet.to_bytes() else{
trace!("to_bytes error");
continue;
};
#[cfg(unix)]
if config.packet_information {
if packet.src_addr().is_ipv4(){
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
} else{
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
}
}
device.write_all(&packet_byte).await?;
// device.flush().await.unwrap();
}
}
}
});
let handle = run(config, device, accept_sender);

IpStack {
accept_receiver,
Expand All @@ -207,3 +102,165 @@ impl IpStack {
.ok_or(IpStackError::AcceptError)
}
}

fn run<D>(
config: IpStackConfig,
mut device: D,
accept_sender: UnboundedSender<IpStackStream>,
) -> JoinHandle<Result<()>>
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut streams: AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>> = AHashMap::new();
let offset = if config.packet_information && cfg!(unix) {
4
} else {
0
};
let mut buffer = [0_u8; u16::MAX as usize + 4];
let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

tokio::spawn(async move {
loop {
select! {
Ok(n) = device.read(&mut buffer) => {
if let Some(stream) = process_read(
&buffer[offset..n],
&mut streams,
&pkt_sender,
&config,
)? {
accept_sender.send(stream)?;
}
}
Some(packet) = pkt_receiver.recv() => {
process_recv(
packet,
&mut streams,
&mut device,
#[cfg(unix)]
config.packet_information,
)
.await?;
}
}
}
})
}

fn process_read(
data: &[u8],
streams: &mut AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>>,
pkt_sender: &UnboundedSender<NetworkPacket>,
config: &IpStackConfig,
) -> Result<Option<IpStackStream>> {
let Ok(packet) = NetworkPacket::parse(data) else {
return Ok(Some(IpStackStream::UnknownNetwork(data.to_owned())));
};

if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
return Ok(Some(IpStackStream::UnknownTransport(
IpStackUnknownTransport::new(
packet.src_addr().ip(),
packet.dst_addr().ip(),
packet.payload,
&packet.ip,
config.mtu,
pkt_sender.clone(),
),
)));
}

Ok(match streams.entry(packet.network_tuple()) {
Occupied(mut entry) => {
if let Err(e) = entry.get().send(packet) {
trace!("New stream because: {}", e);
create_stream(e.0, config, pkt_sender)?.map(|s| {
entry.insert(s.0);
s.1
})
} else {
None
}
}
Vacant(entry) => create_stream(packet, config, pkt_sender)?.map(|s| {
entry.insert(s.0);
s.1
}),
})
}

fn create_stream(
packet: NetworkPacket,
config: &IpStackConfig,
pkt_sender: &UnboundedSender<NetworkPacket>,
) -> Result<Option<(UnboundedSender<NetworkPacket>, IpStackStream)>> {
match packet.transport_protocol() {
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(
packet.src_addr(),
packet.dst_addr(),
h,
pkt_sender.clone(),
config.mtu,
config.tcp_timeout,
) {
Ok(stream) => Ok(Some((stream.stream_sender(), IpStackStream::Tcp(stream)))),
Err(e) => {
if matches!(e, IpStackError::InvalidTcpPacket) {
trace!("Invalid TCP packet");
} else {
error!("IpStackTcpStream::new failed \"{}\"", e);
}
Ok(None)
}
}
}
IpStackPacketProtocol::Udp => {
let stream = IpStackUdpStream::new(
packet.src_addr(),
packet.dst_addr(),
packet.payload,
pkt_sender.clone(),
config.mtu,
config.udp_timeout,
);
Ok(Some((stream.stream_sender(), IpStackStream::Udp(stream))))
}
IpStackPacketProtocol::Unknown => {
unreachable!()
}
}
}

async fn process_recv<D>(
packet: NetworkPacket,
streams: &mut AHashMap<NetworkTuple, UnboundedSender<NetworkPacket>>,
device: &mut D,
#[cfg(unix)] packet_information: bool,
) -> Result<()>
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
if packet.ttl() == 0 {
streams.remove(&packet.reverse_network_tuple());
return Ok(());
}
#[allow(unused_mut)]
let Ok(mut packet_byte) = packet.to_bytes() else {
trace!("to_bytes error");
return Ok(());
};
#[cfg(unix)]
if packet_information {
if packet.src_addr().is_ipv4() {
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
} else {
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
}
}
device.write_all(&packet_byte).await?;
// device.flush().await.unwrap();

Ok(())
}

0 comments on commit b2fd8b4

Please sign in to comment.