From 3486c5bd5e11976c9820ee0fdb1ddabf347b68f0 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Thu, 26 Sep 2024 10:37:21 +0300 Subject: [PATCH] buffers: RX_BUFFER refactor and buffer-manipulation functions Make RX_BUFFER compatible with buffer-manipulation templated functions, such as: - OvpnBufferTrim (new) - OvpnBufferPull (new) - OvpnBufferPut (existing, added template) Those functions are now compatible with both TX and RX buffers. This refactor allowed to simplify OvpnEvtRxQueueAdvance() and remove peer access from it, since the plaintext buffers dequeued there already have crypto overhead stripped. Signed-off-by: Lev Stipakov --- Driver.cpp | 2 +- bufferpool.cpp | 15 +++------------ bufferpool.h | 40 ++++++++++++++++++++++++++++++++++++---- rxqueue.cpp | 12 +----------- socket.cpp | 37 ++++++++++++++++++++----------------- timer.cpp | 2 +- txqueue.cpp | 4 ++-- 7 files changed, 64 insertions(+), 48 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index f793fb6..a7bc4a1 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -240,7 +240,7 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &txBuf)); // copy data from request to tx buffer - PUCHAR data = OvpnTxBufferPut(txBuf, bufLen); + PUCHAR data = OvpnBufferPut(txBuf, bufLen); RtlCopyMemory(data, buf, bufLen); txBuf->IoQueue = device->PendingWritesQueue; diff --git a/bufferpool.cpp b/bufferpool.cpp index b28932b..d1653ff 100644 --- a/bufferpool.cpp +++ b/bufferpool.cpp @@ -205,6 +205,8 @@ OvpnRxBufferPoolGet(OVPN_RX_BUFFER_POOL handle, OVPN_RX_BUFFER** buffer) if (*buffer == NULL) return STATUS_INSUFFICIENT_RESOURCES; + (*buffer)->Data = (*buffer)->Head; + (*buffer)->Tail = (*buffer)->Data; (*buffer)->Pool = handle; (*buffer)->Len = 0; @@ -270,17 +272,6 @@ OvpnTxBufferPoolDelete(OVPN_BUFFER_POOL handle) OvpnBufferPoolDelete(handle); } -_Use_decl_annotations_ -UCHAR* -OvpnTxBufferPut(OVPN_TX_BUFFER* buffer, SIZE_T len) -{ - UCHAR* tmp = buffer->Tail; - buffer->Tail += len; - buffer->Len += len; - - return tmp; -} - _Use_decl_annotations_ UCHAR* OvpnTxBufferPush(OVPN_TX_BUFFER* buffer, SIZE_T len) @@ -295,7 +286,7 @@ _Use_decl_annotations_ NTSTATUS OvpnRxBufferPoolCreate(OVPN_RX_BUFFER_POOL* handle) { - return OvpnBufferPoolCreate((OVPN_BUFFER_POOL*)handle, sizeof(OVPN_RX_BUFFER), "rx", NULL); + return OvpnBufferPoolCreate((OVPN_BUFFER_POOL*)handle, sizeof(OVPN_RX_BUFFER) + OVPN_SOCKET_RX_PACKET_BUFFER_SIZE, "rx", NULL); } VOID diff --git a/bufferpool.h b/bufferpool.h index 1ed0721..fa742af 100644 --- a/bufferpool.h +++ b/bufferpool.h @@ -63,19 +63,51 @@ struct OVPN_TX_BUFFER struct OVPN_RX_BUFFER { - LIST_ENTRY PoolListEntry; + // points to the beginning of data + PUCHAR Data; - LIST_ENTRY QueueListEntry; + // points to the end of data + PUCHAR Tail; + // data length SIZE_T Len; + LIST_ENTRY PoolListEntry; + + LIST_ENTRY QueueListEntry; + OVPN_RX_BUFFER_POOL Pool; - UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; + #pragma warning(suppress:4200) //nonstandard extension used: zero-sized array in struct/union + UCHAR Head[]; }; +template UCHAR* -OvpnTxBufferPut(_In_ OVPN_TX_BUFFER* work, SIZE_T len); +OvpnBufferPut(_In_ BUF* buf, SIZE_T len) +{ + UCHAR* tmp = buf->Tail; + buf->Tail += len; + buf->Len += len; + + return tmp; +} + +template +VOID +OvpnBufferTrim(_In_ BUF* buf, SIZE_T len) +{ + buf->Len = len; + buf->Tail = buf->Data + len; +} + +template +VOID +OvpnBufferPull(_In_ BUF* buf, SIZE_T len) +{ + buf->Len -= len; + buf->Data += len; +} UCHAR* OvpnTxBufferPush(_In_ OVPN_TX_BUFFER* work, SIZE_T len); diff --git a/rxqueue.cpp b/rxqueue.cpp index d7687c3..2a85ffd 100644 --- a/rxqueue.cpp +++ b/rxqueue.cpp @@ -102,16 +102,6 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_RXQUEUE queue = OvpnGetRxQueueContext(netPacketQueue); OVPN_DEVICE* device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - if (peer == NULL) { - LOG_WARN("No peer"); - return; - } - - BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); - NET_RING_FRAGMENT_ITERATOR fi = NetRingGetAllFragments(queue->Rings); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); while (NetFragmentIteratorHasAny(&fi)) { @@ -126,7 +116,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) fragment->ValidLength = buffer->Len; fragment->Offset = 0; NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + payloadOffset, buffer->Len); + RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data, buffer->Len); InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len); diff --git a/socket.cpp b/socket.cpp index bebf26e..533c0c2 100644 --- a/socket.cpp +++ b/socket.cpp @@ -134,15 +134,14 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR return; } - if (sizeof(buffer->Data) >= totalLen) { + if (totalLen <= OVPN_SOCKET_RX_PACKET_BUFFER_SIZE) { if (hdrLen > 0) { // prepend with sockaddr - RtlCopyMemory(buffer->Data, remote, hdrLen); + RtlCopyMemory(OvpnBufferPut(buffer, hdrLen), remote, hdrLen); } // copy control packet payload - RtlCopyMemory(buffer->Data + hdrLen, buf, totalLen - hdrLen); - buffer->Len = totalLen; + RtlCopyMemory(OvpnBufferPut(buffer, len), buf, len); // enqueue buffer, it will be dequeued when read request arrives OvpnBufferQueueEnqueue(device->ControlRxBufferQueue, &buffer->QueueListEntry); @@ -222,12 +221,22 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ LOG_ERROR("keyId not found", TraceLoggingValue(keyId, "keyId")); } else { + // extend data area in the buffer for plaintext and crypto overhead + OvpnBufferPut(buffer, len); + // decrypt into plaintext buffer status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions); + // trim AEAD tag an the end + auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + if (aeadTagEnd) { + OvpnBufferTrim(buffer, len - AEAD_AUTH_TAG_LEN); + } + + // remove crypto overhead in front auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); - buffer->Len = len - cryptoOverhead; + auto cryptoOverheadFront = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + OvpnBufferPull(buffer, cryptoOverheadFront); } } else { @@ -243,24 +252,18 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ OvpnTimerResetRecv(peer->Timer); - // points to the beginning of plaintext - BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); - UCHAR* plaintext = buffer->Data + payloadOffset; - // ping packet? - if (OvpnTimerIsKeepaliveMessage(plaintext, buffer->Len)) { + if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { LOG_INFO("Ping received"); // no need to inject ping packet into OS, return buffer to the pool OvpnRxBufferPoolPut(buffer); } else { - if (OvpnMssIsIPv4(plaintext, buffer->Len)) { - OvpnMssDoIPv4(plaintext, buffer->Len, device->MSS); - } else if (OvpnMssIsIPv6(plaintext, buffer->Len)) { - OvpnMssDoIPv6(plaintext, buffer->Len, device->MSS); + if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { + OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); + } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { + OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); } // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath diff --git a/timer.cpp b/timer.cpp index 5c15ab9..dc409c9 100644 --- a/timer.cpp +++ b/timer.cpp @@ -69,7 +69,7 @@ static VOID OvpnTimerXmit(WDFTIMER timer) } // copy keepalive magic message to the buffer - RtlCopyMemory(OvpnTxBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); + RtlCopyMemory(OvpnBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); OvpnPeerContext* peer = timerCtx->Peer; KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); diff --git a/txqueue.cpp b/txqueue.cpp index 2ae629f..4f7ba30 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -71,7 +71,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress( &queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(OvpnTxBufferPut(buffer, fragment->ValidLength), + RtlCopyMemory(OvpnBufferPut(buffer, fragment->ValidLength), (UCHAR const*)virtualAddr->VirtualAddress + fragment->Offset, fragment->ValidLength); NetFragmentIteratorAdvance(&fi); @@ -103,7 +103,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); if (aeadTagEnd) { - OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN); + OvpnBufferPut(buffer, AEAD_AUTH_TAG_LEN); } // in-place encrypt, always with primary key