Skip to content

Commit

Permalink
fix(virtio-net): prepare checksum correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
cagatay-y committed Dec 2, 2024
1 parent 064998e commit 16bfa51
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 30 deletions.
178 changes: 148 additions & 30 deletions src/drivers/net/virtio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,38 +263,12 @@ impl NetworkDriver for VirtioNetDriver {
};

let mut header = Box::new_in(<Hdr as Default>::default(), DeviceAlloc);
// If a checksum isn't necessary, we have inform the host within the header
// see Virtio specification 5.1.6.2
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {

if let Some((ip_header_len, csum_offset)) = self.should_request_checksum(&mut packet) {
header.flags = HdrF::NEEDS_CSUM;
let ethernet_frame: smoltcp::wire::EthernetFrame<&[u8]> =
EthernetFrame::new_unchecked(&packet);
let packet_header_len: u16;
let protocol;
match ethernet_frame.ethertype() {
smoltcp::wire::EthernetProtocol::Ipv4 => {
let packet = Ipv4Packet::new_unchecked(ethernet_frame.payload());
packet_header_len = packet.header_len().into();
protocol = Some(packet.next_header());
}
smoltcp::wire::EthernetProtocol::Ipv6 => {
let packet = Ipv6Packet::new_unchecked(ethernet_frame.payload());
packet_header_len = packet.header_len().try_into().unwrap();
protocol = Some(packet.next_header());
}
_ => {
packet_header_len = 0;
protocol = None;
}
}
header.csum_start =
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len).into();
header.csum_offset = match protocol {
Some(smoltcp::wire::IpProtocol::Tcp) => 16,
Some(smoltcp::wire::IpProtocol::Udp) => 6,
_ => 0,
}
.into();
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + ip_header_len).into();
header.csum_offset = csum_offset.into();
}

let buff_tkn = AvailBufferToken::new(
Expand Down Expand Up @@ -784,6 +758,87 @@ impl VirtioNetDriver {

Ok(())
}

/// Sets the TCP or UDP checksum field to the checksum of the pseudo-header if necessary or returns None otherwise.
fn should_request_checksum<T: AsRef<[u8]> + AsMut<[u8]>>(
&self,
frame: T,
) -> Option<(u16, u16)> {
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
// If a checksum calculation by the host is necessary, we have to inform the host within the header
// see Virtio specification 5.1.6.2
let mut ethernet_frame = EthernetFrame::new_unchecked(frame);
// If the Ethernet protocol is not one of these two, we default to not asking for checksum,
// as otherwise the frame will be corrupted by the device trying to write the checksum.
if let ip @ (smoltcp::wire::EthernetProtocol::Ipv4
| smoltcp::wire::EthernetProtocol::Ipv6) = ethernet_frame.ethertype()
{
let ip_header_len: u16;
let ip_packet_len: usize;
let protocol;
let pseudo_header_checksum;
match ip {
smoltcp::wire::EthernetProtocol::Ipv4 => {
let ip_packet = Ipv4Packet::new_unchecked(&*ethernet_frame.payload_mut());
ip_header_len = ip_packet.header_len().into();
ip_packet_len = ip_packet.total_len().into();
protocol = ip_packet.next_header();
pseudo_header_checksum =
partial_checksum::ipv4_pseudo_header_partial_checksum(&ip_packet);
}
smoltcp::wire::EthernetProtocol::Ipv6 => {
let ip_packet = Ipv6Packet::new_unchecked(&*ethernet_frame.payload_mut());
ip_header_len = ip_packet.header_len().try_into().expect(
"VIRTIO does not support IP headers that are longer than u16::MAX bytes.",
);
ip_packet_len = ip_packet.total_len();
protocol = ip_packet.next_header();
pseudo_header_checksum =
partial_checksum::ipv6_pseudo_header_partial_checksum(&ip_packet);
}
_ => unreachable!(),
}
// Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
if let smoltcp::wire::IpProtocol::Tcp | smoltcp::wire::IpProtocol::Udp = protocol {
let ip_payload =
&mut ethernet_frame.payload_mut()[ip_header_len.into()..ip_packet_len];

// We do not care about the offset of the checksum for the protocol if we don't require checksum
// from the host, so we use None to signal that checksum from the host is not needed.
let csum_offset = match protocol {
smoltcp::wire::IpProtocol::Tcp => {
if !self.checksums.tcp.tx() {
let mut tcp_packet =
smoltcp::wire::TcpPacket::new_unchecked(ip_payload);
tcp_packet.set_checksum(pseudo_header_checksum);
Some(16)
} else {
None
}
}
smoltcp::wire::IpProtocol::Udp => {
if !self.checksums.tcp.tx() {
let mut udp_packet =
smoltcp::wire::UdpPacket::new_unchecked(ip_payload);
udp_packet.set_checksum(pseudo_header_checksum);
Some(6)
} else {
None
}
}
_ => None,
};
csum_offset.map(|csum_offset| (ip_header_len, csum_offset))
} else {
None
}
} else {
None
}
} else {
None
}
}
}

pub mod constants {
Expand All @@ -808,3 +863,66 @@ pub mod error {
IncompatibleFeatureSets(virtio::net::F, virtio::net::F),
}
}

/// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
/// and their results are meant to be combined with the TCP payload to calculate the real checksum.
/// They are only useful for the VIRTIO driver with the checksum offloading feature.
///
/// The calculations here can theoretically be made faster by exploiting the properties described in
/// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
mod partial_checksum {
use core::iter;

use smoltcp::wire::{Ipv4Packet, Ipv6Packet};

/// Calculates the checksum for the IPv4 pseudo-header as described in
/// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
pub(super) fn ipv4_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
packet: &Ipv4Packet<T>,
) -> u16 {
let src_addr = packet.src_addr();
let dst_addr = packet.dst_addr();
let address_words = src_addr
.as_bytes()
.iter()
.chain(dst_addr.as_bytes())
.copied()
.array_chunks::<{ size_of::<u16>() }>()
.map(u16::from_be_bytes);
let padded_protocol = u16::from(u8::from(packet.next_header()));
let payload_len = packet.total_len() - u16::from(packet.header_len());
address_words
.chain(iter::once(padded_protocol))
.chain(iter::once(payload_len))
.fold(0u16, ones_complement_add)
}

/// Calculates the checksum for the IPv6 pseudo-header as described in
/// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
pub(super) fn ipv6_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
packet: &Ipv6Packet<T>,
) -> u16 {
warn!("The IPv6 partial checksum implementation is untested!");
let src_addr = packet.src_addr();
let dst_addr = packet.dst_addr();
let payload_len = packet.payload_len();
let padded_protocol = u16::from(u8::from(packet.next_header()));

src_addr
.as_bytes()
.iter()
.chain(dst_addr.as_bytes())
.copied()
.array_chunks::<{ size_of::<u16>() }>()
.map(u16::from_be_bytes)
.chain(iter::once(payload_len))
.chain(iter::once(padded_protocol))
.fold(0u16, ones_complement_add)
}

/// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
fn ones_complement_add(lhs: u16, rhs: u16) -> u16 {
let (sum, overflow) = u16::overflowing_add(lhs, rhs);
sum + u16::from(overflow)
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)]
#![cfg_attr(target_arch = "x86_64", feature(abi_x86_interrupt))]
#![feature(allocator_api)]
#![feature(iter_array_chunks)]
#![feature(linked_list_cursors)]
#![feature(map_try_insert)]
#![feature(maybe_uninit_as_bytes)]
Expand Down

0 comments on commit 16bfa51

Please sign in to comment.