diff --git a/Driver.cpp b/Driver.cpp index f793fb6..a7bc4a1 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -240,7 +240,7 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &txBuf)); // copy data from request to tx buffer - PUCHAR data = OvpnTxBufferPut(txBuf, bufLen); + PUCHAR data = OvpnBufferPut(txBuf, bufLen); RtlCopyMemory(data, buf, bufLen); txBuf->IoQueue = device->PendingWritesQueue; diff --git a/bufferpool.cpp b/bufferpool.cpp index b28932b..d1653ff 100644 --- a/bufferpool.cpp +++ b/bufferpool.cpp @@ -205,6 +205,8 @@ OvpnRxBufferPoolGet(OVPN_RX_BUFFER_POOL handle, OVPN_RX_BUFFER** buffer) if (*buffer == NULL) return STATUS_INSUFFICIENT_RESOURCES; + (*buffer)->Data = (*buffer)->Head; + (*buffer)->Tail = (*buffer)->Data; (*buffer)->Pool = handle; (*buffer)->Len = 0; @@ -270,17 +272,6 @@ OvpnTxBufferPoolDelete(OVPN_BUFFER_POOL handle) OvpnBufferPoolDelete(handle); } -_Use_decl_annotations_ -UCHAR* -OvpnTxBufferPut(OVPN_TX_BUFFER* buffer, SIZE_T len) -{ - UCHAR* tmp = buffer->Tail; - buffer->Tail += len; - buffer->Len += len; - - return tmp; -} - _Use_decl_annotations_ UCHAR* OvpnTxBufferPush(OVPN_TX_BUFFER* buffer, SIZE_T len) @@ -295,7 +286,7 @@ _Use_decl_annotations_ NTSTATUS OvpnRxBufferPoolCreate(OVPN_RX_BUFFER_POOL* handle) { - return OvpnBufferPoolCreate((OVPN_BUFFER_POOL*)handle, sizeof(OVPN_RX_BUFFER), "rx", NULL); + return OvpnBufferPoolCreate((OVPN_BUFFER_POOL*)handle, sizeof(OVPN_RX_BUFFER) + OVPN_SOCKET_RX_PACKET_BUFFER_SIZE, "rx", NULL); } VOID diff --git a/bufferpool.h b/bufferpool.h index 1ed0721..fa742af 100644 --- a/bufferpool.h +++ b/bufferpool.h @@ -63,19 +63,51 @@ struct OVPN_TX_BUFFER struct OVPN_RX_BUFFER { - LIST_ENTRY PoolListEntry; + // points to the beginning of data + PUCHAR Data; - LIST_ENTRY QueueListEntry; + // points to the end of data + PUCHAR Tail; + // data length SIZE_T Len; + LIST_ENTRY PoolListEntry; + + LIST_ENTRY QueueListEntry; + OVPN_RX_BUFFER_POOL Pool; - UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; + #pragma warning(suppress:4200) //nonstandard extension used: zero-sized array in struct/union + UCHAR Head[]; }; +template UCHAR* -OvpnTxBufferPut(_In_ OVPN_TX_BUFFER* work, SIZE_T len); +OvpnBufferPut(_In_ BUF* buf, SIZE_T len) +{ + UCHAR* tmp = buf->Tail; + buf->Tail += len; + buf->Len += len; + + return tmp; +} + +template +VOID +OvpnBufferTrim(_In_ BUF* buf, SIZE_T len) +{ + buf->Len = len; + buf->Tail = buf->Data + len; +} + +template +VOID +OvpnBufferPull(_In_ BUF* buf, SIZE_T len) +{ + buf->Len -= len; + buf->Data += len; +} UCHAR* OvpnTxBufferPush(_In_ OVPN_TX_BUFFER* work, SIZE_T len); diff --git a/rxqueue.cpp b/rxqueue.cpp index d7687c3..2a85ffd 100644 --- a/rxqueue.cpp +++ b/rxqueue.cpp @@ -102,16 +102,6 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_RXQUEUE queue = OvpnGetRxQueueContext(netPacketQueue); OVPN_DEVICE* device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - if (peer == NULL) { - LOG_WARN("No peer"); - return; - } - - BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); - NET_RING_FRAGMENT_ITERATOR fi = NetRingGetAllFragments(queue->Rings); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); while (NetFragmentIteratorHasAny(&fi)) { @@ -126,7 +116,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) fragment->ValidLength = buffer->Len; fragment->Offset = 0; NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + payloadOffset, buffer->Len); + RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data, buffer->Len); InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len); diff --git a/socket.cpp b/socket.cpp index bebf26e..533c0c2 100644 --- a/socket.cpp +++ b/socket.cpp @@ -134,15 +134,14 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR return; } - if (sizeof(buffer->Data) >= totalLen) { + if (totalLen <= OVPN_SOCKET_RX_PACKET_BUFFER_SIZE) { if (hdrLen > 0) { // prepend with sockaddr - RtlCopyMemory(buffer->Data, remote, hdrLen); + RtlCopyMemory(OvpnBufferPut(buffer, hdrLen), remote, hdrLen); } // copy control packet payload - RtlCopyMemory(buffer->Data + hdrLen, buf, totalLen - hdrLen); - buffer->Len = totalLen; + RtlCopyMemory(OvpnBufferPut(buffer, len), buf, len); // enqueue buffer, it will be dequeued when read request arrives OvpnBufferQueueEnqueue(device->ControlRxBufferQueue, &buffer->QueueListEntry); @@ -222,12 +221,22 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ LOG_ERROR("keyId not found", TraceLoggingValue(keyId, "keyId")); } else { + // extend data area in the buffer for plaintext and crypto overhead + OvpnBufferPut(buffer, len); + // decrypt into plaintext buffer status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions); + // trim AEAD tag an the end + auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + if (aeadTagEnd) { + OvpnBufferTrim(buffer, len - AEAD_AUTH_TAG_LEN); + } + + // remove crypto overhead in front auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); - buffer->Len = len - cryptoOverhead; + auto cryptoOverheadFront = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + OvpnBufferPull(buffer, cryptoOverheadFront); } } else { @@ -243,24 +252,18 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ OvpnTimerResetRecv(peer->Timer); - // points to the beginning of plaintext - BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); - UCHAR* plaintext = buffer->Data + payloadOffset; - // ping packet? - if (OvpnTimerIsKeepaliveMessage(plaintext, buffer->Len)) { + if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { LOG_INFO("Ping received"); // no need to inject ping packet into OS, return buffer to the pool OvpnRxBufferPoolPut(buffer); } else { - if (OvpnMssIsIPv4(plaintext, buffer->Len)) { - OvpnMssDoIPv4(plaintext, buffer->Len, device->MSS); - } else if (OvpnMssIsIPv6(plaintext, buffer->Len)) { - OvpnMssDoIPv6(plaintext, buffer->Len, device->MSS); + if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { + OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); + } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { + OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); } // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath diff --git a/timer.cpp b/timer.cpp index 5c15ab9..dc409c9 100644 --- a/timer.cpp +++ b/timer.cpp @@ -69,7 +69,7 @@ static VOID OvpnTimerXmit(WDFTIMER timer) } // copy keepalive magic message to the buffer - RtlCopyMemory(OvpnTxBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); + RtlCopyMemory(OvpnBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); OvpnPeerContext* peer = timerCtx->Peer; KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); diff --git a/txqueue.cpp b/txqueue.cpp index 2ae629f..4f7ba30 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -71,7 +71,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress( &queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(OvpnTxBufferPut(buffer, fragment->ValidLength), + RtlCopyMemory(OvpnBufferPut(buffer, fragment->ValidLength), (UCHAR const*)virtualAddr->VirtualAddress + fragment->Offset, fragment->ValidLength); NetFragmentIteratorAdvance(&fi); @@ -103,7 +103,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); if (aeadTagEnd) { - OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN); + OvpnBufferPut(buffer, AEAD_AUTH_TAG_LEN); } // in-place encrypt, always with primary key