From af5864ee8efc3352281a444630a43e4b143feca1 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 27 Aug 2024 10:58:58 +0300 Subject: [PATCH 01/28] timer: refactor timers implementation The current implementation uses "relative" WDF timers which are "not ticked" at low power states and on resume they continue to where they were left of. This makes keepalive timeout detection sub-optimal, since in worst case a client has to wait for "ping-restart" seconds to reconnect, which could be several minutes. Refactor timers in a way that we only have single timer ticking every second. At that tick we compare "last" and "now" timestamps and do actions, similar to what openvpn2 is doing. Fixes #64 Backported from b750b2c ("timer: refactor timers implementation") and ed455c4 ("Fix bug check in timer management routines") Signed-off-by: Lev Stipakov --- PropertySheet.props | 2 +- peer.cpp | 27 +++--------- peer.h | 7 +-- socket.cpp | 2 +- timer.cpp | 104 ++++++++++++++++++++++++++++++-------------- timer.h | 15 ++++--- txqueue.cpp | 15 +++++-- 7 files changed, 103 insertions(+), 69 deletions(-) diff --git a/PropertySheet.props b/PropertySheet.props index f63952e..6bcce73 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -3,7 +3,7 @@ 2 - 0 + 2 1 diff --git a/peer.cpp b/peer.cpp index a7cdd07..3314fee 100644 --- a/peer.cpp +++ b/peer.cpp @@ -43,8 +43,7 @@ VOID OvpnPeerCtxFree(OvpnPeerContext* peer) { OvpnCryptoUninit(&peer->CryptoContext); - OvpnTimerDestroy(&peer->KeepaliveXmitTimer); - OvpnTimerDestroy(&peer->KeepaliveRecvTimer); + OvpnTimerDestroy(&peer->Timer); ExFreePoolWithTag(peer, 'ovpn'); } @@ -147,6 +146,8 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) OvpnPeerZeroStats(&device->Stats); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); + if (proto_tcp) { LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue)); // start async connect @@ -229,29 +230,15 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) if (set_peer->KeepaliveInterval != -1) { peer->KeepaliveInterval = set_peer->KeepaliveInterval; - if (peer->KeepaliveInterval > 0) { - // keepalive xmit timer, sends ping packets - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerXmitCreate(device->WdfDevice, peer, peer->KeepaliveInterval, &peer->KeepaliveXmitTimer)); - OvpnTimerReset(peer->KeepaliveXmitTimer, peer->KeepaliveInterval); - } - else { - LOG_INFO("Destroy xmit timer"); - OvpnTimerDestroy(&peer->KeepaliveXmitTimer); - } + // keepalive xmit timer, sends ping packets + OvpnTimerSetXmitInterval(peer->Timer, peer->KeepaliveInterval); } if (peer->KeepaliveTimeout != -1) { peer->KeepaliveTimeout = set_peer->KeepaliveTimeout; - if (peer->KeepaliveTimeout > 0) { - // keepalive recv timer, detects keepalive timeout - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerRecvCreate(device->WdfDevice, peer, &peer->KeepaliveRecvTimer)); - OvpnTimerReset(peer->KeepaliveRecvTimer, peer->KeepaliveTimeout); - } - else { - LOG_INFO("Destroy recv timer"); - OvpnTimerDestroy(&peer->KeepaliveRecvTimer); - } + // keepalive recv timer, detects keepalive timeout + OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout); } done: diff --git a/peer.h b/peer.h index 4495171..e1eb5d4 100644 --- a/peer.h +++ b/peer.h @@ -39,11 +39,8 @@ struct OvpnPeerContext // keepalive timeout in seconds LONG KeepaliveTimeout; - // timer used to send periodic ping messages to the server if no data has been sent within the past KeepaliveInterval seconds - WDFTIMER KeepaliveXmitTimer; - - // timer used to report keepalive timeout error to userspace when no data has been received for KeepaliveTimeout seconds - WDFTIMER KeepaliveRecvTimer; + // 1-sec timer which handles ping intervals and keepalive timeouts + WDFTIMER Timer; }; _Must_inspect_result_ diff --git a/socket.cpp b/socket.cpp index ef57e37..1b3ca42 100644 --- a/socket.cpp +++ b/socket.cpp @@ -203,7 +203,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ return; } - OvpnTimerReset(peer->KeepaliveRecvTimer, peer->KeepaliveTimeout); + OvpnTimerResetRecv(peer->Timer); // points to the beginning of plaintext UCHAR* buf = buffer->Data + device->CryptoOverhead; diff --git a/timer.cpp b/timer.cpp index bac0e9b..6ce4c95 100644 --- a/timer.cpp +++ b/timer.cpp @@ -37,6 +37,13 @@ static const UCHAR OvpnKeepaliveMessage[] = { // Context added to a timer's attributes typedef struct _OVPN_PEER_TIMER_CONTEXT { OvpnPeerContext* Peer; + + LARGE_INTEGER lastXmit; + LARGE_INTEGER lastRecv; + + // 0 means "not set" + LONG recvTimeout; + LONG xmitInterval; } OVPN_PEER_TIMER_CONTEXT, * POVPN_PEER_TIMER_CONTEXT; WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_PEER_TIMER_CONTEXT, OvpnGetPeerTimerContext); @@ -47,11 +54,8 @@ BOOLEAN OvpnTimerIsKeepaliveMessage(const PUCHAR buf, SIZE_T len) return RtlCompareMemory(buf, OvpnKeepaliveMessage, len) == sizeof(OvpnKeepaliveMessage); } -_Function_class_(EVT_WDF_TIMER) static VOID OvpnTimerXmit(WDFTIMER timer) { - LOG_ENTER(); - POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); OVPN_TX_BUFFER* buffer; @@ -92,30 +96,24 @@ static VOID OvpnTimerXmit(WDFTIMER timer) OvpnTxBufferPoolPut(buffer); } ExReleaseSpinLockShared(&device->SpinLock, kiqrl); - - LOG_EXIT(); } -_Function_class_(EVT_WDF_TIMER) -static VOID OvpnTimerRecv(WDFTIMER timer) +static BOOLEAN OvpnTimerRecv(WDFTIMER timer) { - LOG_ENTER(); - - LOG_WARN("Keepalive timeout"); - POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); WDFREQUEST request; NTSTATUS status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); if (!NT_SUCCESS(status)) { LOG_WARN("No pending request for keepalive timeout notification"); + return FALSE; } else { + LOG_INFO("Notify userspace about keepalive timeout"); ULONG_PTR bytesSent = 0; WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent); + return TRUE; } - - LOG_EXIT(); } _Use_decl_annotations_ @@ -129,7 +127,31 @@ VOID OvpnTimerDestroy(WDFTIMER* timer) } } -static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, PFN_WDF_TIMER func, _Inout_ WDFTIMER* timer) +_Function_class_(EVT_WDF_TIMER) +static VOID OvpnTimerTick(WDFTIMER timer) +{ + LARGE_INTEGER now; + KeQuerySystemTime(&now); + + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + if ((timerCtx->xmitInterval > 0) && (((now.QuadPart - timerCtx->lastXmit.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->xmitInterval)) + { + OvpnTimerXmit(timer); + timerCtx->lastXmit = now; + } + + if ((timerCtx->recvTimeout > 0) && (((now.QuadPart - timerCtx->lastRecv.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->recvTimeout)) + { + // have we have completed pending read request? + if (OvpnTimerRecv(timer)) + { + timerCtx->recvTimeout = 0; // one-off timer + } + } +} + +_Use_decl_annotations_ +NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer) { LOG_ENTER(); @@ -141,8 +163,9 @@ static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG p } WDF_TIMER_CONFIG timerConfig; - WDF_TIMER_CONFIG_INIT(&timerConfig, func); - timerConfig.Period = period * 1000; + WDF_TIMER_CONFIG_INIT(&timerConfig, OvpnTimerTick); + timerConfig.TolerableDelay = TolerableDelayUnlimited; + timerConfig.Period = 1000; WDF_OBJECT_ATTRIBUTES timerAttributes; WDF_OBJECT_ATTRIBUTES_INIT(&timerAttributes); @@ -155,36 +178,51 @@ static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG p if (NT_SUCCESS(status)) { POVPN_PEER_TIMER_CONTEXT pTimerContext = OvpnGetPeerTimerContext(*timer); pTimerContext->Peer = peer; + WdfTimerStart(*timer, WDF_REL_TIMEOUT_IN_SEC(1)); } LOG_EXIT(); return status; } -_Use_decl_annotations_ -NTSTATUS OvpnTimerXmitCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, WDFTIMER* timer) +#define CHECK_TIMER_HANDLE(timer) \ + do { \ + if ((timer) == WDF_NO_HANDLE) { \ + LOG_ERROR("Timer handle is not initialized"); \ + return; \ + } \ + } while (0) + +VOID OvpnTimerSetXmitInterval(WDFTIMER timer, LONG xmitInterval) { - NTSTATUS status; - LOG_INFO("Create xmit timer", TraceLoggingValue(period, "period")); - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, peer, period, OvpnTimerXmit, timer)); + CHECK_TIMER_HANDLE(timer); - return status; + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + timerCtx->xmitInterval = xmitInterval; + KeQuerySystemTime(&timerCtx->lastXmit); } -_Use_decl_annotations_ -NTSTATUS OvpnTimerRecvCreate(WDFOBJECT parent, OvpnPeerContext* peer, WDFTIMER* timer) +VOID OvpnTimerSetRecvTimeout(WDFTIMER timer, LONG recvTimeout) { - NTSTATUS status; - LOG_INFO("Create recv timer"); - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, peer, 0, OvpnTimerRecv, timer)); + CHECK_TIMER_HANDLE(timer); - return status; + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + timerCtx->recvTimeout = recvTimeout; + KeQuerySystemTime(&timerCtx->lastRecv); } -VOID OvpnTimerReset(WDFTIMER timer, ULONG dueTime) +VOID OvpnTimerResetXmit(WDFTIMER timer) { - if (timer != WDF_NO_HANDLE) { - // if timer has already been created this will reset "due time" value to the new one - WdfTimerStart(timer, WDF_REL_TIMEOUT_IN_SEC(dueTime)); - } + CHECK_TIMER_HANDLE(timer); + + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + KeQuerySystemTime(&timerCtx->lastXmit); +} + +VOID OvpnTimerResetRecv(WDFTIMER timer) +{ + CHECK_TIMER_HANDLE(timer); + + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + KeQuerySystemTime(&timerCtx->lastRecv); } diff --git a/timer.h b/timer.h index b9f8bfc..88ee019 100644 --- a/timer.h +++ b/timer.h @@ -26,15 +26,20 @@ #include VOID -OvpnTimerReset(WDFTIMER timer, ULONG dueTime); +OvpnTimerResetXmit(WDFTIMER timer); -_Must_inspect_result_ -NTSTATUS -OvpnTimerXmitCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, _Inout_ WDFTIMER* timer); +VOID +OvpnTimerResetRecv(WDFTIMER timer); _Must_inspect_result_ NTSTATUS -OvpnTimerRecvCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer); +OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer); + +VOID +OvpnTimerSetXmitInterval(WDFTIMER timer, LONG xmitInterval); + +VOID +OvpnTimerSetRecvTimeout(WDFTIMER timer, LONG recvTimeout); VOID OvpnTimerDestroy(_Inout_ WDFTIMER* timer); diff --git a/txqueue.cpp b/txqueue.cpp index a93a470..7f49038 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -128,7 +128,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET *tail = buffer; } - OvpnTimerReset(peer->KeepaliveXmitTimer, peer->KeepaliveInterval); + OvpnTimerResetXmit(peer->Timer); } else { OvpnTxBufferPoolPut(buffer); @@ -179,9 +179,16 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) } NetPacketIteratorSet(&pi); - if (packetSent && !device->Socket.Tcp) { - // this will use WskSendMessages to send buffers list which we constructed before - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead)); + if (packetSent) { + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer != NULL) { + OvpnTimerResetXmit(peer->Timer); + + if (!device->Socket.Tcp) { + // this will use WskSendMessages to send buffers list which we constructed before + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead)); + } + } } ExReleaseSpinLockShared(&device->SpinLock, kirql); From d31f0eef4eef72c28c25844def37bbd93f77c623 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Fri, 9 Aug 2024 13:23:40 +0300 Subject: [PATCH 02/28] More robust way to get a version info The current way to get a version information is to open DCO device and make IOCTL call. This has a few issues: - If DCO device is already in use, an another app won't be to get the version, since the device is exclusive - With the multiple DCO devices there is a high chance that \\.\ovpn-dco device, which we use to get version information, is already in use. To open another device, we use via device interface enumeration, which requires a lot of boilerplate code to work. To make it easier for userspace to get the device version, create a non-exclusive control device \\.\ovpn-dco-ver which supports single IOCTL to get the version number. This device is created when the first network device is created and removed with the last network device. Bump version to 2.3.0. https://github.com/OpenVPN/ovpn-dco-win/issues/75 Backported from f4adb2 ("More robust way to get a version info") Signed-off-by: Lev Stipakov --- Driver.cpp | 58 ++++++++++------ Driver.h | 2 + PropertySheet.props | 6 +- control.cpp | 128 +++++++++++++++++++++++++++++++++++ control.h | 31 +++++++++ ovpn-dco-win.vcxproj | 34 +++++----- ovpn-dco-win.vcxproj.filters | 6 ++ 7 files changed, 225 insertions(+), 40 deletions(-) create mode 100644 control.cpp create mode 100644 control.h diff --git a/Driver.cpp b/Driver.cpp index 6175809..4c3ceec 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -27,6 +27,7 @@ #include #include "bufferpool.h" +#include "control.h" #include "driver.h" #include "trace.h" #include "peer.h" @@ -50,6 +51,9 @@ OvpnEvtDriverUnload(_In_ WDFDRIVER driver) { UNREFERENCED_PARAMETER(driver); + LOG_ENTER(); + LOG_EXIT(); + TraceLoggingUnregister(g_hOvpnEtwProvider); // tail call optimization incorrectly eliminates TraceLoggingUnregister() call @@ -57,6 +61,17 @@ OvpnEvtDriverUnload(_In_ WDFDRIVER driver) __nop(); } +EVT_WDF_OBJECT_CONTEXT_CLEANUP OvpnEvtDriverCleanup; + +_Use_decl_annotations_ +VOID OvpnEvtDriverCleanup(_In_ WDFOBJECT driver) +{ + UNREFERENCED_PARAMETER(driver); + + LOG_ENTER(); + LOG_EXIT(); +} + EXTERN_C DRIVER_INITIALIZE DriverEntry; #ifdef ALLOC_PRAGMA @@ -81,6 +96,7 @@ DriverEntry(_In_ PDRIVER_OBJECT driverObject, _In_ PUNICODE_STRING registryPath) WDF_OBJECT_ATTRIBUTES driverAttrs; WDF_OBJECT_ATTRIBUTES_INIT(&driverAttrs); WDF_OBJECT_ATTRIBUTES_SET_CONTEXT_TYPE(&driverAttrs, OVPN_DRIVER); + driverAttrs.EvtCleanupCallback = OvpnEvtDriverCleanup; WDF_DRIVER_CONFIG driverConfig; WDF_DRIVER_CONFIG_INIT(&driverConfig, OvpnEvtDeviceAdd); @@ -213,26 +229,6 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) ExReleaseSpinLockShared(&device->SpinLock, kiqrl); } -NTSTATUS -OvpnGetVersion(WDFREQUEST request, _Out_ ULONG_PTR* bytesReturned) -{ - *bytesReturned = 0; - - NTSTATUS status; - POVPN_VERSION version = NULL; - GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_VERSION), (PVOID*)&version, NULL)); - - version->Major = OVPN_DCO_VERSION_MAJOR; - version->Minor = OVPN_DCO_VERSION_MINOR; - version->Patch = OVPN_DCO_VERSION_PATCH; - - *bytesReturned = sizeof(OVPN_VERSION); - -done: - return status; -} - - EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL OvpnEvtIoDeviceControl; _Use_decl_annotations_ @@ -339,6 +335,15 @@ VOID OvpnEvtDeviceCleanup(WDFOBJECT obj) { // it requires PASSIVE_LEVEL. OvpnCryptoUninitAlgHandles(device->AesAlgHandle, device->ChachaAlgHandle); + // delete control device if there are no devices left + POVPN_DRIVER driverCtx = OvpnGetDriverContext(WdfGetDriver()); + LONG deviceCount = InterlockedDecrement(&driverCtx->DeviceCount); + LOG_INFO("Device count", TraceLoggingValue(deviceCount, "deviceCount")); + if ((deviceCount == 0) && (driverCtx->ControlDevice != NULL)) { + LOG_INFO("Delete control device"); + WdfObjectDelete(driverCtx->ControlDevice); + driverCtx->ControlDevice = NULL; + } LOG_EXIT(); } @@ -436,6 +441,17 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { WDFDEVICE wdfDevice; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfDeviceCreate(&deviceInit, &objAttributes, &wdfDevice)); + POVPN_DRIVER driverCtx = OvpnGetDriverContext(WdfGetDriver()); + InterlockedIncrement(&driverCtx->DeviceCount); + + LOG_INFO("Device count", TraceLoggingValue(driverCtx->DeviceCount, "count")); + + if (driverCtx->DeviceCount == 1) + { + // create non-exclusive control device to get the version information + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCreateControlDevice(wdfDriver)); + } + // this will fail if one device has already been created but that's ok, since // openvpn2/3 accesses devices via Device Interface GUID, and symlink is used only by test client. LOG_IF_NOT_NT_SUCCESS(WdfDeviceCreateSymbolicLink(wdfDevice, &symLink)); @@ -530,4 +546,4 @@ OvpnGetFirstPeer(RTL_GENERIC_TABLE* peers) { OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(peers, 0); return ptr ? (OvpnPeerContext*)*ptr : NULL; -} \ No newline at end of file +} diff --git a/Driver.h b/Driver.h index f33fc0e..01df6c9 100644 --- a/Driver.h +++ b/Driver.h @@ -47,6 +47,8 @@ EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL OvpnEvtIoDeviceControl; typedef struct _OVPN_DRIVER { WSK_PROVIDER_NPI WskProviderNpi; WSK_REGISTRATION WskRegistration; + WDFDEVICE ControlDevice; + LONG DeviceCount; } OVPN_DRIVER, * POVPN_DRIVER; WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_DRIVER, OvpnGetDriverContext) diff --git a/PropertySheet.props b/PropertySheet.props index 6bcce73..61bdfe9 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -3,8 +3,8 @@ 2 - 2 - 1 + 3 + 0 @@ -29,4 +29,4 @@ true - \ No newline at end of file + diff --git a/control.cpp b/control.cpp new file mode 100644 index 0000000..32f30a5 --- /dev/null +++ b/control.cpp @@ -0,0 +1,128 @@ +/* + * ovpn-dco-win OpenVPN protocol accelerator for Windows + * + * Copyright (C) 2024- OpenVPN Inc + * + * Author: Lev Stipakov + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "control.h" +#include "Driver.h" +#include "uapi\ovpn-dco.h" +#include "trace.h" + +_Use_decl_annotations_ +NTSTATUS +OvpnGetVersion(WDFREQUEST request, ULONG_PTR* bytesReturned) +{ + LOG_ENTER(); + + *bytesReturned = 0; + + NTSTATUS status; + POVPN_VERSION version = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_VERSION), (PVOID*)&version, NULL)); + + version->Major = OVPN_DCO_VERSION_MAJOR; + version->Minor = OVPN_DCO_VERSION_MINOR; + version->Patch = OVPN_DCO_VERSION_PATCH; + + LOG_INFO("Version", TraceLoggingValue(version->Major, "Major"), TraceLoggingValue(version->Minor, "Minor"), TraceLoggingValue(version->Patch, "Patch")); + + *bytesReturned = sizeof(OVPN_VERSION); + +done: + LOG_EXIT(); + + return status; +} + +EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL OvpnEvtControlDeviceIOControl; + +VOID +OvpnEvtControlDeviceIOControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLength, size_t inputBufferLength, ULONG ioControlCode) +{ + UNREFERENCED_PARAMETER(queue); + UNREFERENCED_PARAMETER(inputBufferLength); + UNREFERENCED_PARAMETER(outputBufferLength); + + NTSTATUS status = STATUS_SUCCESS; + ULONG_PTR bytesReturned = 0; + + switch (ioControlCode) + { + case OVPN_IOCTL_GET_VERSION: + status = OvpnGetVersion(request, &bytesReturned); + break; + + default: + status = STATUS_INVALID_DEVICE_REQUEST; + break; + } + + WdfRequestCompleteWithInformation(request, status, bytesReturned); +} + +NTSTATUS +OvpnCreateControlDevice(WDFDRIVER wdfDriver) +{ + LOG_ENTER(); + + DECLARE_CONST_UNICODE_STRING(symLink, L"\\DosDevices\\ovpn-dco-ver"); // this will be used by CreateFile + DECLARE_CONST_UNICODE_STRING(deviceName, L"\\Device\\ovpn-dco-ver"); // this is required tp create symlink + + // allocate control device initialization structure + PWDFDEVICE_INIT deviceInit = WdfControlDeviceInitAllocate(wdfDriver, &SDDL_DEVOBJ_SYS_ALL_ADM_RWX_WORLD_R_RES_R); + if (deviceInit == NULL) + { + return STATUS_INSUFFICIENT_RESOURCES; + } + + // create the control device + WDF_OBJECT_ATTRIBUTES deviceAttributes; + WDF_OBJECT_ATTRIBUTES_INIT(&deviceAttributes); + WDFDEVICE controlDevice; + NTSTATUS status; + + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfDeviceInitAssignName(deviceInit, &deviceName)); + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfDeviceCreate(&deviceInit, &deviceAttributes, &controlDevice)); + + POVPN_DRIVER driverCtx = OvpnGetDriverContext(WdfGetDriver()); + driverCtx->ControlDevice = controlDevice; + + // symlink for control device + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfDeviceCreateSymbolicLink(controlDevice, &symLink)); + + // queue to handle IO + WDF_IO_QUEUE_CONFIG queueConfig; + WDF_IO_QUEUE_CONFIG_INIT_DEFAULT_QUEUE(&queueConfig, WdfIoQueueDispatchParallel); + queueConfig.EvtIoDeviceControl = OvpnEvtControlDeviceIOControl; + WDFQUEUE queue; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfIoQueueCreate(controlDevice, &queueConfig, WDF_NO_OBJECT_ATTRIBUTES, &queue)); + + // Complete the control device initialization + WdfControlFinishInitializing(controlDevice); + + done: + if (deviceInit) + { + WdfDeviceInitFree(deviceInit); + } + + LOG_EXIT(); + + return status; +} diff --git a/control.h b/control.h new file mode 100644 index 0000000..8827ed2 --- /dev/null +++ b/control.h @@ -0,0 +1,31 @@ +/* + * ovpn-dco-win OpenVPN protocol accelerator for Windows + * + * Copyright (C) 2024- OpenVPN Inc + * + * Author: Lev Stipakov + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include +#include + +NTSTATUS +OvpnGetVersion(WDFREQUEST request, _Out_ ULONG_PTR* bytesReturned); + +NTSTATUS +OvpnCreateControlDevice(WDFDRIVER wdfDriver); diff --git a/ovpn-dco-win.vcxproj b/ovpn-dco-win.vcxproj index 3fc5a88..d3000a6 100644 --- a/ovpn-dco-win.vcxproj +++ b/ovpn-dco-win.vcxproj @@ -69,6 +69,7 @@ + @@ -82,6 +83,7 @@ + @@ -440,7 +442,7 @@ OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -456,7 +458,7 @@ OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -478,7 +480,7 @@ $(ProjectDir)$(Platform)\$(ConfigurationName)\ovpn-dco.DVL.XML - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) /Brepro %(AdditionalOptions) DebugFull false @@ -505,7 +507,7 @@ $(ProjectDir)$(Platform)\$(ConfigurationName)\ovpn-dco.DVL.XML - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) /Brepro %(AdditionalOptions) DebugFull false @@ -526,7 +528,7 @@ OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -550,7 +552,7 @@ OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -580,7 +582,7 @@ $(ProjectDir)$(Platform)\$(ConfigurationName)\ovpn-dco.DVL.XML - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) /Brepro %(AdditionalOptions) DebugFull false @@ -607,7 +609,7 @@ $(ProjectDir)$(Platform)\$(ConfigurationName)\ovpn-dco.DVL.XML - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) /Brepro %(AdditionalOptions) DebugFull false @@ -628,7 +630,7 @@ OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -644,7 +646,7 @@ OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -661,7 +663,7 @@ false - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) /Brepro %(AdditionalOptions) DebugFull false @@ -683,7 +685,7 @@ false - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) /Brepro %(AdditionalOptions) DebugFull false @@ -710,7 +712,7 @@ $(ProjectDir)$(Platform)\$(ConfigurationName)\ovpn-dco.DVL.XML - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) /Brepro %(AdditionalOptions) DebugFull false @@ -737,7 +739,7 @@ $(ProjectDir)$(Platform)\$(ConfigurationName)\ovpn-dco.DVL.XML - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) /Brepro %(AdditionalOptions) DebugFull false @@ -758,7 +760,7 @@ OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -774,7 +776,7 @@ OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) - uuid.lib;Netio.lib;cng.lib;%(AdditionalDependencies) + uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) diff --git a/ovpn-dco-win.vcxproj.filters b/ovpn-dco-win.vcxproj.filters index a4603a6..7d0f4cf 100644 --- a/ovpn-dco-win.vcxproj.filters +++ b/ovpn-dco-win.vcxproj.filters @@ -67,6 +67,9 @@ Header Files + + Header Files + @@ -102,6 +105,9 @@ Source Files + + Source Files + From f7877d24b641e0b1289b9eebd86867de5f22cc9b Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Wed, 14 Aug 2024 15:53:09 +0300 Subject: [PATCH 03/28] Implement DATA_V3 protocol features Add support for features: - AEAD auth tag at the end - 64bit pktid For that, introduce a new OVPN_IOCTL_NEW_KEY_V2 ioctl, which accepts a new data structure, which includes a field to store bit flags for crypto options for abovementioned features. Refactor encrypt/decrypt routine to accomodate possible wire format changes. Refactor pktid code to support 64bit pktid. While on it, remove unused variable. Bump version to 2.4.0. https://github.com/OpenVPN/ovpn-dco-win/issues/74 Backported from 5acb8f ("Implement DATA_V3 protocol features") Signed-off-by: Lev Stipakov --- Driver.cpp | 6 +++ Driver.h | 1 - PropertySheet.props | 2 +- bufferpool.h | 1 - crypto.cpp | 123 ++++++++++++++++++++++++++++++++------------ crypto.h | 16 ++---- peer.cpp | 82 ++++++++++++++++++++++------- peer.h | 5 ++ pktid.cpp | 37 +++++++------ pktid.h | 10 ++-- rxqueue.cpp | 13 ++++- socket.cpp | 28 ++++++---- timer.cpp | 10 ++-- txqueue.cpp | 17 ++++-- uapi/ovpn-dco.h | 9 ++++ 15 files changed, 254 insertions(+), 106 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index 4c3ceec..4b4f0bc 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -269,6 +269,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; + case OVPN_IOCTL_NEW_KEY_V2: + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + status = OvpnPeerNewKeyV2(device, request); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + break; + case OVPN_IOCTL_SWAP_KEYS: kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerSwapKeys(device); diff --git a/Driver.h b/Driver.h index 01df6c9..bf221ce 100644 --- a/Driver.h +++ b/Driver.h @@ -94,7 +94,6 @@ struct OVPN_DEVICE { _Guarded_by_(SpinLock) RTL_GENERIC_TABLE Peers; - SIZE_T CryptoOverhead; }; typedef OVPN_DEVICE * POVPN_DEVICE; diff --git a/PropertySheet.props b/PropertySheet.props index 61bdfe9..6056166 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -3,7 +3,7 @@ 2 - 3 + 4 0 diff --git a/bufferpool.h b/bufferpool.h index fcde323..1ed0721 100644 --- a/bufferpool.h +++ b/bufferpool.h @@ -74,7 +74,6 @@ struct OVPN_RX_BUFFER UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; }; -_Must_inspect_result_ UCHAR* OvpnTxBufferPut(_In_ OVPN_TX_BUFFER* work, SIZE_T len); diff --git a/crypto.cpp b/crypto.cpp index 75005f4..0f915fa 100644 --- a/crypto.cpp +++ b/crypto.cpp @@ -48,11 +48,14 @@ OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId) OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptNone; _Use_decl_annotations_ -NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut) +NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { UNREFERENCED_PARAMETER(keySlot); - if (len < NONE_CRYPTO_OVERHEAD) { + BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN cryptoOverhead = OVPN_DATA_V2_LEN + pktId64bit ? 8 : 4; + + if (len < cryptoOverhead) { LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); return STATUS_DATA_ERROR; } @@ -66,10 +69,11 @@ OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptNone; _Use_decl_annotations_ NTSTATUS -OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len) +OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions) { UNREFERENCED_PARAMETER(keySlot); UNREFERENCED_PARAMETER(len); + UNREFERENCED_PARAMETER(cryptoOptions); // prepend with opcode, key-id and peer-id UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, 0, 0); @@ -121,74 +125,116 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL static NTSTATUS -OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut) +OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { /* AEAD Nonce : [Packet ID] [HMAC keying material] - [4 bytes ] [8 bytes ] + [4/8 bytes] [8/4 bytes ] [AEAD nonce total : 12 bytes ] TLS wire protocol : + Packet ID is 8 bytes long with CRYPTO_OPTIONS_64BIT_PKTID. + [DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext] - [4 bytes ] [4 bytes ] [16 bytes ] + [4 bytes ] [4/8 bytes] [16 bytes ] + [AEAD additional data(AD) ] + + With CRYPTO_OPTIONS_AEAD_TAG_END AEAD Auth tag is placed after ciphertext: + + [DATA_V2 opcode] [Packet ID] [ciphertext] [AEAD Auth tag] + [4 bytes ] [4/8 bytes] [16 bytes ] [AEAD additional data(AD) ] */ NTSTATUS status = STATUS_SUCCESS; - if (len < AEAD_CRYPTO_OVERHEAD) { + BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + + SIZE_T cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); + + if (len < cryptoOverhead) { LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); return STATUS_DATA_ERROR; } - UCHAR nonce[OVPN_PKTID_LEN + OVPN_NONCE_TAIL_LEN]; + UCHAR nonce[12]; if (encrypt) { // prepend with opcode, key-id and peer-id UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId); op = RtlUlongByteSwap(op); - *(UINT32*)(bufOut) = op; + *reinterpret_cast(bufOut) = op; - // calculate pktid - UINT32 pktid; - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid)); - ULONG pktidNetwork = RtlUlongByteSwap(pktid); + if (pktId64bit) + { + // calculate pktid + UINT64 pktid; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, true)); + ULONG64 pktidNetwork = RtlUlonglongByteSwap(pktid); + + // calculate nonce, which is pktid + nonce_tail + RtlCopyMemory(nonce, &pktidNetwork, 8); + RtlCopyMemory(nonce + 8, keySlot->EncNonceTail, 4); + + // prepend with pktid + *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + } + else + { + // calculate pktid + UINT32 pktid; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, false)); + ULONG pktidNetwork = RtlUlongByteSwap(pktid); - // calculate nonce, which is pktid + nonce_tail - RtlCopyMemory(nonce, &pktidNetwork, OVPN_PKTID_LEN); - RtlCopyMemory(nonce + OVPN_PKTID_LEN, keySlot->EncNonceTail, OVPN_NONCE_TAIL_LEN); + // calculate nonce, which is pktid + nonce_tail + RtlCopyMemory(nonce, &pktidNetwork, 4); + RtlCopyMemory(nonce + 4, keySlot->EncNonceTail, 8); - // prepend with pktid - *(UINT32*)(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + // prepend with pktid + *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + } } else { - RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, OVPN_PKTID_LEN); - RtlCopyMemory(nonce + OVPN_PKTID_LEN, &keySlot->DecNonceTail, sizeof(keySlot->DecNonceTail)); + ULONG64 pktId; + + RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, pktId64bit ? 8 : 4); + RtlCopyMemory(nonce + (pktId64bit ? 8 : 4), &keySlot->DecNonceTail, pktId64bit ? 4 : 8); + if (pktId64bit) + { + pktId = RtlUlonglongByteSwap(*reinterpret_cast(nonce)); + } + else + { + pktId = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce))); + } - UINT32 pktId = RtlUlongByteSwap(*(UINT32*)nonce); status = OvpnPktidRecvVerify(&keySlot->PktidRecv, pktId); if (!NT_SUCCESS(status)) { - LOG_ERROR("Invalid pktId", TraceLoggingUInt32(pktId, "pktId")); + LOG_ERROR("Invalid pktId", TraceLoggingUInt64(pktId, "pktId")); return STATUS_DATA_ERROR; } } + // we prepended buf with crypto overhead + len -= cryptoOverhead; + + BOOLEAN aeadTagEnd = cryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo; BCRYPT_INIT_AUTH_MODE_INFO(authInfo); authInfo.pbNonce = nonce; authInfo.cbNonce = sizeof(nonce); - authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + OVPN_PKTID_LEN; + authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? len : 0); authInfo.cbTag = AEAD_AUTH_TAG_LEN; authInfo.pbAuthData = (encrypt ? bufOut : bufIn); - authInfo.cbAuthData = OVPN_DATA_V2_LEN + OVPN_PKTID_LEN; - - bufOut += AEAD_CRYPTO_OVERHEAD; - bufIn += AEAD_CRYPTO_OVERHEAD; + authInfo.cbAuthData = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4); - len -= AEAD_CRYPTO_OVERHEAD; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + bufOut += payloadOffset; + bufIn += payloadOffset; // non-chaining mode ULONG bytesDone = 0; @@ -205,27 +251,29 @@ OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD; _Use_decl_annotations_ NTSTATUS -OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut) +OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { - return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut); + return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut, cryptoOptions); } OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD; _Use_decl_annotations_ NTSTATUS -OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len) +OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions) { - return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf); + return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf, cryptoOptions); } _Use_decl_annotations_ NTSTATUS -OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData, BCRYPT_ALG_HANDLE algHandle) +OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2, BCRYPT_ALG_HANDLE algHandle) { OvpnCryptoKeySlot* keySlot = NULL; NTSTATUS status = STATUS_SUCCESS; + POVPN_CRYPTO_DATA cryptoData = &cryptoDataV2->V1; + if (cryptoData->KeySlot == OVPN_KEY_SLOT::OVPN_KEY_SLOT_PRIMARY) { keySlot = &cryptoContext->Primary; } @@ -237,6 +285,15 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData, return STATUS_INVALID_DEVICE_REQUEST; } + if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID) + { + cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_64BIT_PKTID; + } + if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END) + { + cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_AEAD_TAG_END; + } + if ((cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM) || (cryptoData->CipherAlg == OVPN_CIPHER_ALG_CHACHA20_POLY1305)) { // destroy previous keys if (keySlot->EncKey) { diff --git a/crypto.h b/crypto.h index ee35cba..3782d42 100644 --- a/crypto.h +++ b/crypto.h @@ -29,14 +29,8 @@ #include "uapi\ovpn-dco.h" #include "socket.h" -#define AEAD_CRYPTO_OVERHEAD 24 // 4 + 4 + 16 data_v2 + pktid + auth_tag -#define NONE_CRYPTO_OVERHEAD 8 // 4 + 4 data_v2 + pktid -#define OVPN_PKTID_LEN 4 -#define OVPN_NONCE_TAIL_LEN 8 #define OVPN_DATA_V2_LEN 4 #define AEAD_AUTH_TAG_LEN 16 -#define AES_BLOCK_SIZE 16 -#define AES_GCM_NONCE_LEN 12 // packet opcode (high 5 bits) and key-id (low 3 bits) are combined in one byte #define OVPN_OP_DATA_V2 9 @@ -63,7 +57,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ typedef NTSTATUS -OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len); +OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ INT32 CryptoOptions); typedef OVPN_CRYPTO_ENCRYPT* POVPN_CRYPTO_ENCRYPT; _Function_class_(OVPN_CRYPTO_DECRYPT) @@ -71,7 +65,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ typedef NTSTATUS -OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut); +OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ INT32 CryptoOptions); typedef OVPN_CRYPTO_DECRYPT* POVPN_CRYPTO_DECRYPT; struct OvpnCryptoContext @@ -82,7 +76,7 @@ struct OvpnCryptoContext POVPN_CRYPTO_ENCRYPT Encrypt; POVPN_CRYPTO_DECRYPT Decrypt; - SIZE_T CryptoOverhead; + INT32 CryptoOptions; }; _Must_inspect_result_ @@ -99,7 +93,7 @@ OvpnCryptoUninit(_In_ OvpnCryptoContext* cryptoContext); _Must_inspect_result_ NTSTATUS -OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle); +OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle); _Must_inspect_result_ OvpnCryptoKeySlot* @@ -119,4 +113,4 @@ static inline UCHAR OvpnCryptoOpcodeExtract(UCHAR op) { return op >> OVPN_OPCODE_SHIFT; -} \ No newline at end of file +} diff --git a/peer.cpp b/peer.cpp index 3314fee..a2fa688 100644 --- a/peer.cpp +++ b/peer.cpp @@ -303,6 +303,31 @@ OvpnPeerStartVPN(POVPN_DEVICE device) return status; } +static NTSTATUS +OvpnPeerGetAlgHandle(POVPN_DEVICE device, OVPN_CIPHER_ALG cipherAlg, BCRYPT_ALG_HANDLE& algHandle) +{ + NTSTATUS status = STATUS_SUCCESS; + + switch (cipherAlg) { + case OVPN_CIPHER_ALG_AES_GCM: + algHandle = device->AesAlgHandle; + break; + + case OVPN_CIPHER_ALG_CHACHA20_POLY1305: + algHandle = device->ChachaAlgHandle; + if (algHandle == NULL) { + LOG_ERROR("CHACHA20-POLY1305 is not available"); + status = STATUS_INVALID_DEVICE_REQUEST; + } + break; + + default: + break; + } + + return status; +} + _Use_decl_annotations_ NTSTATUS OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) @@ -311,44 +336,63 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status = STATUS_SUCCESS; + POVPN_CRYPTO_DATA cryptoData = NULL; + OVPN_CRYPTO_DATA_V2 cryptoDataV2{}; + if (!OvpnHasPeers(device)) { LOG_ERROR("Peer not added"); status = STATUS_INVALID_DEVICE_REQUEST; goto done; } - POVPN_CRYPTO_DATA cryptoData = NULL; - GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA), (PVOID*)&cryptoData, nullptr)); BCRYPT_ALG_HANDLE algHandle = NULL; - switch (cryptoData->CipherAlg) { - case OVPN_CIPHER_ALG_AES_GCM: - algHandle = device->AesAlgHandle; - device->CryptoOverhead = AEAD_CRYPTO_OVERHEAD; - break; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoData->CipherAlg, algHandle)); - case OVPN_CIPHER_ALG_CHACHA20_POLY1305: - algHandle = device->ChachaAlgHandle; - if (algHandle == NULL) { - LOG_ERROR("CHACHA20-POLY1305 is not available"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } - device->CryptoOverhead = AEAD_CRYPTO_OVERHEAD; + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { + status = STATUS_OBJECTID_NOT_FOUND; + goto done; + } - default: - device->CryptoOverhead = NONE_CRYPTO_OVERHEAD; - break; + RtlCopyMemory(&cryptoDataV2.V1, cryptoData, sizeof(OVPN_CRYPTO_DATA)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle)); + +done: + LOG_EXIT(); + + return status; +} + +_Use_decl_annotations_ +NTSTATUS +OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request) +{ + LOG_ENTER(); + + NTSTATUS status = STATUS_SUCCESS; + + POVPN_CRYPTO_DATA_V2 cryptoDataV2 = NULL; + + if (!OvpnHasPeers(device)) { + LOG_ERROR("Peer not added"); + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA_V2), (PVOID*)&cryptoDataV2, nullptr)); + + BCRYPT_ALG_HANDLE algHandle = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoDataV2->V1.CipherAlg, algHandle)); + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); if (peer == NULL) { status = STATUS_OBJECTID_NOT_FOUND; goto done; } - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoData, algHandle)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle)); done: LOG_EXIT(); diff --git a/peer.h b/peer.h index e1eb5d4..fc780f8 100644 --- a/peer.h +++ b/peer.h @@ -85,6 +85,11 @@ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS OvpnPeerNewKey(_In_ POVPN_DEVICE device, WDFREQUEST request); +_Must_inspect_result_ +_Requires_exclusive_lock_held_(device->SpinLock) +NTSTATUS +OvpnPeerNewKeyV2(_In_ POVPN_DEVICE device, WDFREQUEST request); + _Must_inspect_result_ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS diff --git a/pktid.cpp b/pktid.cpp index b7f365d..4d94ee3 100644 --- a/pktid.cpp +++ b/pktid.cpp @@ -28,24 +28,29 @@ #define PKTID_WRAP_WARN 0xf0000000ULL _Use_decl_annotations_ -NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, UINT32* pktId) +NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId, BOOLEAN pktId64bit) { ULONG64 seqNum = InterlockedIncrementNoFence64(&px->SeqNum); - *pktId = (UINT32)seqNum; - if (seqNum < PKTID_WRAP_WARN) { - return STATUS_SUCCESS; - } - else { - LOG_ERROR("Pktid wrapped"); - return STATUS_INTEGER_OVERFLOW; - } + if (pktId64bit) { + *static_cast(pktId) = seqNum; + } + else + { + *static_cast(pktId) = static_cast(seqNum); + if (seqNum >= PKTID_WRAP_WARN) { + LOG_ERROR("Pktid wrapped"); + return STATUS_INTEGER_OVERFLOW; + } + } + + return STATUS_SUCCESS; } #define PKTID_RECV_EXPIRE ((30 * WDF_TIMEOUT_TO_SEC) / KeQueryTimeIncrement()) _Use_decl_annotations_ -NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) +NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT64 pktId) { LARGE_INTEGER now; KeQueryTickCount(&now); @@ -69,16 +74,16 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) } else if (pktId > pr->Id) { /* ID jumped forward by more than one */ - UINT32 delta = pktId - pr->Id; + const auto delta = pktId - pr->Id; if (delta < REPLAY_WINDOW_SIZE) { pr->Base = REPLAY_INDEX(pr->Base, -(INT32)delta); pr->History[pr->Base / 8] |= (1 << (pr->Base % 8)); - pr->Extent += delta; + pr->Extent += static_cast(delta); if (pr->Extent > REPLAY_WINDOW_SIZE) pr->Extent = REPLAY_WINDOW_SIZE; - for (UINT32 i = 1; i < delta; ++i) { - unsigned int newb = REPLAY_INDEX(pr->Base, i); + for (auto i = 1; i < delta; ++i) { + const auto newb = REPLAY_INDEX(pr->Base, i); pr->History[newb / 8] &= ~BIT(newb % 8); } @@ -93,10 +98,8 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) } else { /* ID backtrack */ - UINT32 delta = pr->Id - pktId; + const auto delta = pr->Id - pktId; - if (delta > pr->MaxBacktrack) - pr->MaxBacktrack = delta; if (delta < pr->Extent) { if (pktId > pr->IdFloor) { UINT32 ri = REPLAY_INDEX(pr->Base, delta); diff --git a/pktid.h b/pktid.h index b0d2325..dcc4be8 100644 --- a/pktid.h +++ b/pktid.h @@ -50,17 +50,17 @@ struct OvpnPktidRecv LARGE_INTEGER Expire; /* highest sequence number received */ - UINT32 Id; + UINT64 Id; /* we will only accept backtrack IDs > id_floor */ - UINT32 IdFloor; - UINT32 MaxBacktrack; + UINT64 IdFloor; }; /* Get the next packet ID for xmit */ -NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ UINT32* pktId); +NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId, BOOLEAN pktId64bit); + /* Packet replay detection. * Allows ID backtrack of up to REPLAY_WINDOW_SIZE - 1. */ -NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT32 pktId); +NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT64 pktId); diff --git a/rxqueue.cpp b/rxqueue.cpp index ce7d71b..d7687c3 100644 --- a/rxqueue.cpp +++ b/rxqueue.cpp @@ -27,6 +27,7 @@ #include "driver.h" #include "bufferpool.h" +#include "peer.h" #include "rxqueue.h" #include "netringiterator.h" #include "trace.h" @@ -101,6 +102,16 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_RXQUEUE queue = OvpnGetRxQueueContext(netPacketQueue); OVPN_DEVICE* device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { + LOG_WARN("No peer"); + return; + } + + BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + NET_RING_FRAGMENT_ITERATOR fi = NetRingGetAllFragments(queue->Rings); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); while (NetFragmentIteratorHasAny(&fi)) { @@ -115,7 +126,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) fragment->ValidLength = buffer->Len; fragment->Offset = 0; NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + device->CryptoOverhead, buffer->Len); + RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + payloadOffset, buffer->Len); InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len); diff --git a/socket.cpp b/socket.cpp index 1b3ca42..2fcf2d7 100644 --- a/socket.cpp +++ b/socket.cpp @@ -178,9 +178,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ return; } - if (peer->CryptoContext.Decrypt) { + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + + if (cryptoContext->Decrypt) { UCHAR keyId = OvpnCryptoKeyIdExtract(op); - OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(&peer->CryptoContext, keyId); + OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, keyId); if (!keySlot) { status = STATUS_INVALID_DEVICE_STATE; @@ -188,8 +190,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ } else { // decrypt into plaintext buffer - status = peer->CryptoContext.Decrypt(keySlot, cipherTextBuf, len, buffer->Data); - buffer->Len = len - device->CryptoOverhead; + status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions); + + auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); + buffer->Len = len - cryptoOverhead; } } else { @@ -206,20 +211,23 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ OvpnTimerResetRecv(peer->Timer); // points to the beginning of plaintext - UCHAR* buf = buffer->Data + device->CryptoOverhead; + BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + UCHAR* plaintext = buffer->Data + payloadOffset; // ping packet? - if (OvpnTimerIsKeepaliveMessage(buf, buffer->Len)) { + if (OvpnTimerIsKeepaliveMessage(plaintext, buffer->Len)) { LOG_INFO("Ping received"); // no need to inject ping packet into OS, return buffer to the pool OvpnRxBufferPoolPut(buffer); } else { - if (OvpnMssIsIPv4(buf, buffer->Len)) { - OvpnMssDoIPv4(buf, buffer->Len, device->MSS); - } else if (OvpnMssIsIPv6(buf, buffer->Len)) { - OvpnMssDoIPv6(buf, buffer->Len, device->MSS); + if (OvpnMssIsIPv4(plaintext, buffer->Len)) { + OvpnMssDoIPv4(plaintext, buffer->Len, device->MSS); + } else if (OvpnMssIsIPv6(plaintext, buffer->Len)) { + OvpnMssDoIPv6(plaintext, buffer->Len, device->MSS); } // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath diff --git a/timer.cpp b/timer.cpp index 6ce4c95..95eeaa1 100644 --- a/timer.cpp +++ b/timer.cpp @@ -73,12 +73,16 @@ static VOID OvpnTimerXmit(WDFTIMER timer) OvpnPeerContext* peer = timerCtx->Peer; KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); - if (peer->CryptoContext.Encrypt) { + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + if (cryptoContext->Encrypt) { // make space to crypto overhead - OvpnTxBufferPush(buffer, device->CryptoOverhead); + BOOLEAN pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + + OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); // in-place encrypt, always with primary key - status = peer->CryptoContext.Encrypt(&peer->CryptoContext.Primary, buffer->Data, buffer->Len); + status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions); } else { status = STATUS_INVALID_DEVICE_STATE; diff --git a/txqueue.cpp b/txqueue.cpp index 7f49038..feec4c5 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -93,12 +93,21 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET InterlockedExchangeAddNoFence64(&device->Stats.TunBytesSent, buffer->Len); - if (peer->CryptoContext.Encrypt) { + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + + if (cryptoContext->Encrypt) { + auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + // make space to crypto overhead - OvpnTxBufferPush(buffer, device->CryptoOverhead); + OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); + if (aeadTagEnd) + { + OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN); + } // in-place encrypt, always with primary key - status = peer->CryptoContext.Encrypt(&peer->CryptoContext.Primary, buffer->Data, buffer->Len); + status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions); } else { status = STATUS_INVALID_DEVICE_STATE; @@ -152,7 +161,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_TXQUEUE queue = OvpnGetTxQueueContext(netPacketQueue); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); POVPN_DEVICE device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); - bool packetSent = false; + BOOLEAN packetSent = false; KIRQL kirql = ExAcquireSpinLockShared(&device->SpinLock); diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index ea2a733..9e437f0 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -94,6 +94,14 @@ typedef struct _OVPN_CRYPTO_DATA { int PeerId; } OVPN_CRYPTO_DATA, * POVPN_CRYPTO_DATA; +#define CRYPTO_OPTIONS_AEAD_TAG_END (1<<1) +#define CRYPTO_OPTIONS_64BIT_PKTID (1<<2) + +typedef struct _OVPN_CRYPTO_DATA_V2 { + OVPN_CRYPTO_DATA V1; + UINT32 CryptoOptions; +} OVPN_CRYPTO_DATA_V2, * POVPN_CRYPTO_DATA_V2; + typedef struct _OVPN_SET_PEER { LONG KeepaliveInterval; LONG KeepaliveTimeout; @@ -114,3 +122,4 @@ typedef struct _OVPN_VERSION { #define OVPN_IOCTL_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 6, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_DEL_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 7, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_GET_VERSION CTL_CODE(FILE_DEVICE_UNKNOWN, 8, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define OVPN_IOCTL_NEW_KEY_V2 CTL_CODE(FILE_DEVICE_UNKNOWN, 9, METHOD_BUFFERED, FILE_ANY_ACCESS) From 46907991cd3da0d50df8b7d5079131929fdca9e8 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Thu, 5 Sep 2024 15:06:14 +0300 Subject: [PATCH 04/28] Add OVPN_IOCTL_SET_MODE THis is to select between client mode (OVPN_MODE_P2P, default) and server mode (OVPN_MODE_MP). At the moment only ioctl is added, no functional changes. Co-authored-by: Leon Dang Signed-off-by: Leon Dang Signed-off-by: Lev Stipakov --- Driver.cpp | 38 ++++++++++++++++++++++++++++++++++++++ Driver.h | 1 + uapi/ovpn-dco.h | 10 ++++++++++ 3 files changed, 49 insertions(+) diff --git a/Driver.cpp b/Driver.cpp index 4b4f0bc..97c0e65 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -229,6 +229,38 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) ExReleaseSpinLockShared(&device->SpinLock, kiqrl); } +NTSTATUS +OvpnSetMode(POVPN_DEVICE device, WDFREQUEST request) +{ + POVPN_SET_MODE mode; + NTSTATUS status = WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_SET_MODE), (PVOID*)&mode, NULL); + if (!NT_SUCCESS(status)) { + return status; + } + + if (device->Mode != OVPN_MODE_P2P) { + LOG_ERROR("mode already set"); + return STATUS_ALREADY_INITIALIZED; + } + + status = STATUS_SUCCESS; + + LOG_INFO("Set mode", TraceLoggingValue(static_cast(mode->Mode), "mode")); + + switch (mode->Mode) { + case OVPN_MODE_P2P: + case OVPN_MODE_MP: + device->Mode = mode->Mode; + break; + + default: + status = STATUS_INVALID_PARAMETER; + break; + } + + return status; +} + EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL OvpnEvtIoDeviceControl; _Use_decl_annotations_ @@ -291,6 +323,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe status = OvpnGetVersion(request, &bytesReturned); break; + case OVPN_IOCTL_SET_MODE: + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + status = OvpnSetMode(device, request); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + break; + default: LOG_WARN("Unknown ", TraceLoggingValue(ioControlCode, "ioControlCode")); status = STATUS_INVALID_DEVICE_REQUEST; diff --git a/Driver.h b/Driver.h index bf221ce..0045871 100644 --- a/Driver.h +++ b/Driver.h @@ -94,6 +94,7 @@ struct OVPN_DEVICE { _Guarded_by_(SpinLock) RTL_GENERIC_TABLE Peers; + OVPN_MODE Mode; }; typedef OVPN_DEVICE * POVPN_DEVICE; diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index 9e437f0..8163b8b 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -114,6 +114,15 @@ typedef struct _OVPN_VERSION { LONG Patch; } OVPN_VERSION, * POVPN_VERSION; +typedef enum { + OVPN_MODE_P2P, + OVPN_MODE_MP +} OVPN_MODE; + +typedef struct _OVPN_SET_MODE { + OVPN_MODE Mode; +} OVPN_SET_MODE, * POVPN_SET_MODE; + #define OVPN_IOCTL_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 1, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_GET_STATS CTL_CODE(FILE_DEVICE_UNKNOWN, 2, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_NEW_KEY CTL_CODE(FILE_DEVICE_UNKNOWN, 3, METHOD_BUFFERED, FILE_ANY_ACCESS) @@ -123,3 +132,4 @@ typedef struct _OVPN_VERSION { #define OVPN_IOCTL_DEL_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 7, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_GET_VERSION CTL_CODE(FILE_DEVICE_UNKNOWN, 8, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_NEW_KEY_V2 CTL_CODE(FILE_DEVICE_UNKNOWN, 9, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define OVPN_IOCTL_SET_MODE CTL_CODE(FILE_DEVICE_UNKNOWN, 10, METHOD_BUFFERED, FILE_ANY_ACCESS) From 1094495df5d2f0b6a873bf4c60228ce4cfefed21 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Wed, 11 Sep 2024 15:49:05 +0300 Subject: [PATCH 05/28] Replace OvpnDelPeer with OvpnStopVPN In preparation for MP support, rename existing OvpnDelPeer() implementation to OvpnStopVPN() - which deletes all peers and closes the device socket. This looks more like "stop VPN" behavior for both P2P and MP modes. Co-authored-by: Leon Dang Signed-off-by: Leon Dang Signed-off-by: Lev Stipakov --- Driver.cpp | 46 +++++++++++++++++++++++++++++++++++++++++++--- peer.cpp | 40 ---------------------------------------- peer.h | 5 ----- 3 files changed, 43 insertions(+), 48 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index 97c0e65..16ef036 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -261,6 +261,47 @@ OvpnSetMode(POVPN_DEVICE device, WDFREQUEST request) return status; } +static NTSTATUS +OvpnStopVPN(_In_ POVPN_DEVICE device) +{ + LOG_ENTER(); + + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + PWSK_SOCKET socket = device->Socket.Socket; + device->Socket.Socket = NULL; + + OvpnFlushPeers(device); + + device->Mode = OVPN_MODE_P2P; + + RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); + RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); + + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + if (socket != NULL) { + LOG_IF_NOT_NT_SUCCESS(OvpnSocketClose(socket)); + } + + // flush buffers in control queue so that client won't get control channel messages from previous session + while (LIST_ENTRY* entry = OvpnBufferQueueDequeue(device->ControlRxBufferQueue)) { + OVPN_RX_BUFFER* buffer = CONTAINING_RECORD(entry, OVPN_RX_BUFFER, QueueListEntry); + // return buffer back to pool + OvpnRxBufferPoolPut(buffer); + } + + WDFREQUEST request; + while (NT_SUCCESS(WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request))) { + ULONG_PTR bytesCopied = 0; + LOG_INFO("Cancel IO request from manual queue"); + WdfRequestCompleteWithInformation(request, STATUS_CANCELLED, bytesCopied); + } + + LOG_EXIT(); + + return STATUS_SUCCESS; +} + EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL OvpnEvtIoDeviceControl; _Use_decl_annotations_ @@ -288,7 +329,7 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe break; case OVPN_IOCTL_DEL_PEER: - status = OvpnPeerDel(device); + status = OvpnStopVPN(device); break; case OVPN_IOCTL_START_VPN: @@ -347,8 +388,7 @@ VOID OvpnEvtFileCleanup(WDFFILEOBJECT fileObject) { POVPN_DEVICE device = OvpnGetDeviceContext(WdfFileObjectGetDevice(fileObject)); - // peer might already be deleted - (VOID)OvpnPeerDel(device); + (VOID)OvpnStopVPN(device); if (device->Adapter != NULL) { OvpnAdapterSetLinkState(OvpnGetAdapterContext(device->Adapter), MediaConnectStateDisconnected); diff --git a/peer.cpp b/peer.cpp index a2fa688..4f36386 100644 --- a/peer.cpp +++ b/peer.cpp @@ -161,46 +161,6 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) return status; } -_Use_decl_annotations_ -NTSTATUS -OvpnPeerDel(POVPN_DEVICE device) -{ - LOG_ENTER(); - - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - - PWSK_SOCKET socket = device->Socket.Socket; - - device->Socket.Socket = NULL; - OvpnFlushPeers(device); - - RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); - RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); - - // OvpnSocketClose requires PASSIVE_LEVEL, so must release lock - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); - - LOG_IF_NOT_NT_SUCCESS(OvpnSocketClose(socket)); - - // flush buffers in control queue so that client won't get control channel messages from previous session - while (LIST_ENTRY* entry = OvpnBufferQueueDequeue(device->ControlRxBufferQueue)) { - OVPN_RX_BUFFER* buffer = CONTAINING_RECORD(entry, OVPN_RX_BUFFER, QueueListEntry); - // return buffer back to pool - OvpnRxBufferPoolPut(buffer); - } - - WDFREQUEST request; - while (NT_SUCCESS(WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request))) { - ULONG_PTR bytesCopied = 0; - LOG_INFO("Cancel IO request from manual queue"); - WdfRequestCompleteWithInformation(request, STATUS_CANCELLED, bytesCopied); - } - - LOG_EXIT(); - - return STATUS_SUCCESS; -} - _Use_decl_annotations_ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) { diff --git a/peer.h b/peer.h index fc780f8..11a80f8 100644 --- a/peer.h +++ b/peer.h @@ -59,11 +59,6 @@ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS OvpnPeerNew(_In_ POVPN_DEVICE device, WDFREQUEST request); -_Must_inspect_result_ -_IRQL_requires_(PASSIVE_LEVEL) -NTSTATUS -OvpnPeerDel(_In_ POVPN_DEVICE device); - _Must_inspect_result_ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS From 8bc221308899747a6affa3bf07c96c07938a7d6d Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Thu, 12 Sep 2024 09:56:24 +0300 Subject: [PATCH 06/28] GHA: update upload-artifacts to v4 Signed-off-by: Lev Stipakov --- .github/workflows/msbuild.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/msbuild.yml b/.github/workflows/msbuild.yml index f8dfea1..e63d5e1 100644 --- a/.github/workflows/msbuild.yml +++ b/.github/workflows/msbuild.yml @@ -55,7 +55,7 @@ jobs: Copy-Item -Path ..\${{ matrix.arch }}\${{ matrix.build_conf }}\ovpn-dco\* -Destination dist\${{ matrix.arch }}\win11\ -Recurse .\build.ps1 -Arch ${{ matrix.arch }} -Wix ..\wix - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: name: ovpn-dco_${{ matrix.arch }}_${{ matrix.build_conf }} path: | From 9010a3040c3c4ff76b223f22a1c918a474af04c4 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Thu, 12 Sep 2024 13:01:19 +0300 Subject: [PATCH 07/28] Driver.cpp: Add some mode checks to OvpnEvtIoDeviceControl Return an error to a userspace if P2P-only IOCTLs are used in MP mode. Signed-off-by: Lev Stipakov --- Driver.cpp | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/Driver.cpp b/Driver.cpp index 16ef036..ae5d930 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -261,6 +261,27 @@ OvpnSetMode(POVPN_DEVICE device, WDFREQUEST request) return status; } +static BOOLEAN +OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) +{ + if (mode == OVPN_MODE_MP) + { + switch (code) + { + // all those IOCTLs are only for P2P mode + case OVPN_IOCTL_NEW_PEER: + case OVPN_IOCTL_DEL_PEER: + case OVPN_IOCTL_NEW_KEY: + case OVPN_IOCTL_NEW_KEY_V2: + case OVPN_IOCTL_SWAP_KEYS: + case OVPN_IOCTL_SET_PEER: + return FALSE; + } + } + + return TRUE; +} + static NTSTATUS OvpnStopVPN(_In_ POVPN_DEVICE device) { @@ -316,6 +337,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe ULONG_PTR bytesReturned = 0; + if (!OvpnDeviceCheckMode(device->Mode, ioControlCode)) + { + WdfRequestCompleteWithInformation(request, STATUS_INVALID_DEVICE_STATE, bytesReturned); + return; + } + KIRQL kirql = 0; switch ((long)ioControlCode) { case OVPN_IOCTL_GET_STATS: From 83d2ad7c8564212c6362087ee6f9cba9066da75a Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Mon, 16 Sep 2024 14:28:18 +0300 Subject: [PATCH 08/28] Add OVPN_MP_START_VPN ioctl This starts listening on a provided local address. https://github.com/OpenVPN/ovpn-dco-win/issues/82 Co-authored-by: Leon Dang Signed-off-by: Leon Dang Signed-off-by: Lev Stipakov --- .gitignore | 7 +++++ Driver.cpp | 75 ++++++++++++++++++++++++++++++++++++++++++++++--- socket.cpp | 17 +++++++---- socket.h | 2 +- uapi/ovpn-dco.h | 9 ++++++ 5 files changed, 100 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 336c1bf..0c87796 100644 --- a/.gitignore +++ b/.gitignore @@ -53,6 +53,7 @@ Thumbs.db Debug Release x64 +x86 Win32 sdv ARM64 @@ -72,3 +73,9 @@ smvstats.txt *.cab *.vcxproj.user + +**/out +**/build +**/.vscode +**/signed +msm/tmp \ No newline at end of file diff --git a/Driver.cpp b/Driver.cpp index ae5d930..80d6368 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -264,10 +264,8 @@ OvpnSetMode(POVPN_DEVICE device, WDFREQUEST request) static BOOLEAN OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) { - if (mode == OVPN_MODE_MP) - { - switch (code) - { + if (mode == OVPN_MODE_MP) { + switch (code) { // all those IOCTLs are only for P2P mode case OVPN_IOCTL_NEW_PEER: case OVPN_IOCTL_DEL_PEER: @@ -275,6 +273,14 @@ OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) case OVPN_IOCTL_NEW_KEY_V2: case OVPN_IOCTL_SWAP_KEYS: case OVPN_IOCTL_SET_PEER: + case OVPN_IOCTL_START_VPN: + return FALSE; + } + } + else if (mode == OVPN_MODE_P2P) { + switch (code) { + // those IOCTLs are for MP mode + case OVPN_IOCTL_MP_START_VPN: return FALSE; } } @@ -323,6 +329,63 @@ OvpnStopVPN(_In_ POVPN_DEVICE device) return STATUS_SUCCESS; } +_Must_inspect_result_ +_IRQL_requires_(PASSIVE_LEVEL) +_IRQL_requires_same_ +NTSTATUS +OvpnMPStartVPN(POVPN_DEVICE device, WDFREQUEST request, ULONG_PTR* bytesReturned) +{ + NTSTATUS status = STATUS_SUCCESS; + + LOG_ENTER(); + + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + if (device->Socket.Socket != NULL) { + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + status = STATUS_ALREADY_INITIALIZED; + + goto done; + } + else { + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + POVPN_MP_START_VPN addrIn = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_START_VPN), (PVOID*)&addrIn, NULL)); + + PWSK_SOCKET socket = NULL; + POVPN_DRIVER driver = OvpnGetDriverContext(WdfGetDriver()); + + // Bind to the address provided + status = OvpnSocketInit(&driver->WskProviderNpi, &driver->WskRegistration, + addrIn->ListenAddress.Addr4.sin_family, false, + (PSOCKADDR)&addrIn->ListenAddress, NULL, + 0, device, &socket); + if (!NT_SUCCESS(status)) { + LOG_ERROR("Socket create failed", TraceLoggingValue((UINT32)status), + TraceLoggingHexUInt32(*(UINT32*)(&addrIn->ListenAddress.Addr4.sin_addr), "addr")); + goto done; + } + + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + device->Socket.Socket = socket; + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + // we might bind the socket to port 0 and we want to get actual port back to userspace + POVPN_MP_START_VPN addrOut = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_MP_START_VPN), (PVOID*)&addrOut, NULL)); + RtlCopyMemory(addrOut, addrIn, sizeof(OVPN_MP_START_VPN)); + *bytesReturned = sizeof(OVPN_MP_START_VPN); + } + + OvpnAdapterSetLinkState(OvpnGetAdapterContext(device->Adapter), MediaConnectStateConnected); + +done: + LOG_EXIT(); + + return status; +} + EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL OvpnEvtIoDeviceControl; _Use_decl_annotations_ @@ -397,6 +460,10 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; + case OVPN_IOCTL_MP_START_VPN: + status = OvpnMPStartVPN(device, request, &bytesReturned); + break; + default: LOG_WARN("Unknown ", TraceLoggingValue(ioControlCode, "ioControlCode")); status = STATUS_INVALID_DEVICE_REQUEST; diff --git a/socket.cpp b/socket.cpp index 2fcf2d7..70987b9 100644 --- a/socket.cpp +++ b/socket.cpp @@ -522,13 +522,20 @@ OvpnSocketInit(WSK_PROVIDER_NPI* wskProviderNpi, WSK_REGISTRATION* wskRegistrati return datagramDispatch->WskBind(*socket, localAddr, 0, irp); }, [](PIRP) {})); - // set remote - PWSK_PROVIDER_BASIC_DISPATCH basicDispatch = (PWSK_PROVIDER_BASIC_DISPATCH)(*socket)->Dispatch; - - GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnSocketSyncOp("SetRemote", [basicDispatch, socket, remoteAddrSize, remoteAddr](PIRP irp) { - return basicDispatch->WskControlSocket(*socket, WskIoctl, SIO_WSK_SET_REMOTE_ADDRESS, 0, remoteAddrSize, remoteAddr, 0, NULL, NULL, irp); + // get the locally bound address for the socket + GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnSocketSyncOp("GetLocalAddr", [datagramDispatch, socket, localAddr](PIRP irp) { + return datagramDispatch->WskGetLocalAddress(*socket, localAddr, irp); }, [](PIRP) {})); + if (remoteAddr != NULL) { + // set remote + PWSK_PROVIDER_BASIC_DISPATCH basicDispatch = (PWSK_PROVIDER_BASIC_DISPATCH)(*socket)->Dispatch; + + GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnSocketSyncOp("SetRemote", [basicDispatch, socket, remoteAddrSize, remoteAddr](PIRP irp) { + return basicDispatch->WskControlSocket(*socket, WskIoctl, SIO_WSK_SET_REMOTE_ADDRESS, 0, remoteAddrSize, remoteAddr, 0, NULL, NULL, irp); + }, [](PIRP) {})); + } + // enable ReceiveFrom event eventCallbackControl.NpiId = &NPI_WSK_INTERFACE_ID; eventCallbackControl.EventMask = WSK_EVENT_RECEIVE_FROM; diff --git a/socket.h b/socket.h index b6dee5a..f1df149 100644 --- a/socket.h +++ b/socket.h @@ -60,7 +60,7 @@ _Must_inspect_result_ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS OvpnSocketInit(_In_ WSK_PROVIDER_NPI* wskProviderNpi, _In_ WSK_REGISTRATION* wskRegistration, ADDRESS_FAMILY addrFamily, - BOOLEAN tcp, _In_ PSOCKADDR localAddr, _In_ PSOCKADDR remoteAddr, SIZE_T remoteAddrSize, + BOOLEAN tcp, _In_ PSOCKADDR localAddr, _In_opt_ PSOCKADDR remoteAddr, SIZE_T remoteAddrSize, _In_ PVOID deviceContext, _Out_ PWSK_SOCKET* socket); _Must_inspect_result_ diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index 8163b8b..095e008 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -123,6 +123,13 @@ typedef struct _OVPN_SET_MODE { OVPN_MODE Mode; } OVPN_SET_MODE, * POVPN_SET_MODE; +typedef struct _OVPN_MP_START_VPN { + union { + SOCKADDR_IN Addr4; + SOCKADDR_IN6 Addr6; + } ListenAddress; +} OVPN_MP_START_VPN, * POVPN_MP_START_VPN; + #define OVPN_IOCTL_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 1, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_GET_STATS CTL_CODE(FILE_DEVICE_UNKNOWN, 2, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_NEW_KEY CTL_CODE(FILE_DEVICE_UNKNOWN, 3, METHOD_BUFFERED, FILE_ANY_ACCESS) @@ -133,3 +140,5 @@ typedef struct _OVPN_SET_MODE { #define OVPN_IOCTL_GET_VERSION CTL_CODE(FILE_DEVICE_UNKNOWN, 8, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_NEW_KEY_V2 CTL_CODE(FILE_DEVICE_UNKNOWN, 9, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_SET_MODE CTL_CODE(FILE_DEVICE_UNKNOWN, 10, METHOD_BUFFERED, FILE_ANY_ACCESS) + +#define OVPN_IOCTL_MP_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 11, METHOD_BUFFERED, FILE_ANY_ACCESS) From 71e0a20137efa348af8d7c9ca12f69ad5f23bd57 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 17 Sep 2024 14:21:50 +0300 Subject: [PATCH 09/28] ovpn-dco-cli GUI: initial commit This includes following functionality: - get driver version - set mode - start VPN (p2p/mp) - new peer (p2p) - send/receive CC messages https://github.com/OpenVPN/ovpn-dco-win/issues/83 Signed-off-by: Lev Stipakov --- gui/CMakeLists.txt | 11 ++ gui/gui.cpp | 443 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 454 insertions(+) create mode 100644 gui/CMakeLists.txt create mode 100644 gui/gui.cpp diff --git a/gui/CMakeLists.txt b/gui/CMakeLists.txt new file mode 100644 index 0000000..ffb0b6f --- /dev/null +++ b/gui/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.10) + +project(gui) + +set(CMAKE_CXX_STANDARD 17) # Enable C++17 +set(CMAKE_CXX_STANDARD_REQUIRED YES) # Ensure the compiler strictly requires C++17 + +add_executable (gui WIN32 "gui.cpp") + +add_definitions(-DUNICODE -D_UNICODE) +target_link_libraries(gui PRIVATE user32 gdi32 Ws2_32) diff --git a/gui/gui.cpp b/gui/gui.cpp new file mode 100644 index 0000000..c8d04f0 --- /dev/null +++ b/gui/gui.cpp @@ -0,0 +1,443 @@ +#include "..\uapi\ovpn-dco.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#define DEV_NAME L"\\\\.\\ovpn-dco" +#define VER_DEV_NAME L"\\\\.\\ovpn-dco-ver" + +#define BTN_SEND_CC 100 + +LRESULT CALLBACK WindowProcedure(HWND, UINT, WPARAM, LPARAM); + +HWND hMPListenAddress, hMPListenPort, + hP2PLocalAddress, hP2PLocalPort, + hP2PRemoteAddress, hP2PRemotePort, + hCCMessage; + +HWND hLogArea; +std::unordered_map buttons = { + {OVPN_IOCTL_NEW_PEER, L"P2P New Peer"}, + {OVPN_IOCTL_GET_STATS, L"Get Stats"}, + {OVPN_IOCTL_NEW_KEY, L"New Key"}, + {OVPN_IOCTL_SWAP_KEYS, L"Swap Keys"}, + {OVPN_IOCTL_SET_PEER, L"Set Peer"}, + {OVPN_IOCTL_START_VPN, L"P2P Start VPN"}, + {OVPN_IOCTL_DEL_PEER, L"Del Peer"}, + {OVPN_IOCTL_GET_VERSION, L"Get Version"}, + {OVPN_IOCTL_NEW_KEY_V2, L"New Key V2"}, + {OVPN_IOCTL_SET_MODE, L"Set Mode"}, + {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"} +}; + +#define GET_IOCTL_FUNCTION_CODE(ioctl) (((ioctl) >> 2) & 0xFFF) + +unsigned long GetIoctlFromFunctionCode(unsigned long functionCode) { + return CTL_CODE(FILE_DEVICE_UNKNOWN, functionCode, METHOD_BUFFERED, FILE_ANY_ACCESS); +} + +std::vector> modeData = { + {OVPN_MODE_P2P, L"P2P"}, + {OVPN_MODE_MP, L"MP"} +}; + +std::vector hModes; + +template +void Log(Args... args) { + std::wstringstream stream; + + // Using a fold expression to insert all arguments into the stream + (stream << ... << args); + + // Move the caret to the end of the text + int textLength = GetWindowTextLength(hLogArea); + SendMessage(hLogArea, EM_SETSEL, (WPARAM)textLength, (LPARAM)textLength); + + // Add a newline character before the new text (if needed) + std::wstring textToAppend = (textLength > 0 ? L"\r\n" : L"") + stream.str(); + + // Insert the new text at the current caret position + SendMessage(hLogArea, EM_REPLACESEL, FALSE, (LPARAM)textToAppend.c_str()); +} + +HANDLE hDev; +char readBuffer[4096] = {0}; +OVERLAPPED ovRead = {0}, ovWrite = {0}; + +bool StartOverlappedRead() { + ZeroMemory(readBuffer, sizeof(readBuffer)); + BOOL result = ReadFile(hDev, readBuffer, sizeof(readBuffer), NULL, &ovRead); + if (!result && GetLastError() != ERROR_IO_PENDING) { + Log("ReadFile failed: ", GetLastError()); + return false; + } + return true; +} + +int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrevInst, LPSTR args, int ncmdshow) +{ + // Create a Window Class + WNDCLASSW wc = {0}; + + wc.hbrBackground = (HBRUSH)(COLOR_WINDOW + 1); + wc.hCursor = LoadCursor(NULL, IDC_ARROW); + wc.hInstance = hInst; + wc.lpszClassName = L"myWindowClass"; + wc.lpfnWndProc = WindowProcedure; + + // Register the Window Class + if (!RegisterClassW(&wc)) + return -1; + + // Create the Window + HWND hwnd = CreateWindowW(L"myWindowClass", L"ovpn-dco-win GUI", WS_OVERLAPPEDWINDOW | WS_VISIBLE, + 100, 100, 800, 600, NULL, NULL, NULL, NULL); + + HANDLE hEvRead = CreateEventW(NULL, FALSE, FALSE, NULL); + ovRead.hEvent = hEvRead; + + HANDLE hEvWrite = CreateEventW(NULL, FALSE, FALSE, NULL); + ovWrite.hEvent = hEvWrite; + + StartOverlappedRead(); + + while (true) { + HANDLE events[] = { hEvRead, hEvWrite }; + DWORD waitResult = MsgWaitForMultipleObjects(2, events, FALSE, INFINITE, QS_ALLINPUT); + + // read completed + if (waitResult == WAIT_OBJECT_0) { + DWORD bytesRead; + if (GetOverlappedResult(hDev, &ovRead, &bytesRead, FALSE)) { + if (bytesRead > 0) { + Log("CC[", bytesRead, "]> ", readBuffer); + } + } else { + Log("Overlapped read failed: ", GetLastError()); + } + + if (!StartOverlappedRead()) { + break; + } + } if (waitResult == WAIT_OBJECT_0 + 1) { + // write completed + DWORD bytesWrote; + if (GetOverlappedResult(hDev, &ovWrite, &bytesWrote, FALSE)) { + if (bytesWrote > 0) { + Log("Wrote ", bytesWrote, " bytes"); + } + } else { + Log("Overlapped write failed: ", GetLastError()); + } + } + else if (waitResult == WAIT_OBJECT_0 + 2) { + MSG msg; + while (PeekMessage(&msg, NULL, 0, 0, PM_REMOVE)) { + if (msg.message == WM_QUIT) { + CloseHandle(hDev); + CloseHandle(hEvRead); + CloseHandle(hEvWrite); + return 0; + } + TranslateMessage(&msg); + DispatchMessage(&msg); + } + } + } + + return 0; +} + +void OpenDevice(const std::wstring& devName) +{ + hDev = CreateFileW(devName.c_str(), GENERIC_READ | GENERIC_WRITE, + 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_SYSTEM | FILE_FLAG_OVERLAPPED, NULL); + if (hDev == INVALID_HANDLE_VALUE) { + Log(L"CreateFile(", devName, ") failed with code ", GetLastError()); + } + else{ + Log(L"Device ", devName, " opened: ", hDev); + } +} + +void DcoGetVersion() +{ + // try version device + HANDLE h = CreateFileW(VER_DEV_NAME, GENERIC_READ, 0, NULL, OPEN_EXISTING, 0, NULL); + if (h == INVALID_HANDLE_VALUE) { + h = hDev; + } + + OVPN_VERSION v{}; + DWORD bytesReturned; + + if (!DeviceIoControl(h, OVPN_IOCTL_GET_VERSION, NULL, 0, &v, sizeof(v), &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_GET_VERSION) failed with code ", GetLastError()); + } + else { + Log("Version: ", v.Major, ".", v.Minor, ".", v.Patch); + } +} + +void SetMode() +{ + OVPN_MODE m; + + if (SendMessage(hModes[0], BM_GETCHECK, 0, 0) == BST_CHECKED) { + m = OVPN_MODE_P2P; + } else { + m = OVPN_MODE_MP; + } + + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_SET_MODE, &m, sizeof(m), NULL, 0, &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_SET_MODE) failed with code ", GetLastError()); + } + else { + Log("Mode set: ", m); + } +} + +// Function to convert sockaddr_in to a wstring containing IP and port +std::wstring sockAddrToString(const sockaddr_in& addr) { + wchar_t ipAddress[INET_ADDRSTRLEN]; // Buffer to hold the IP address + + // Convert the binary IP address to a string (wide-char) + InetNtopW(AF_INET, &(addr.sin_addr), ipAddress, INET_ADDRSTRLEN); + + // Convert the port number from network byte order to host byte order + int port = ntohs(addr.sin_port); + + // Convert port to wstring and concatenate with IP address + std::wstring result = ipAddress; + result += L":"; + result += std::to_wstring(port); // Append the port number + + return result; +} + +void MPStartVPN() +{ + wchar_t ipAddress[16]; // Buffer to store IP address + wchar_t portNumber[6]; // Buffer to store port number + + // Get the content of the IP address edit box + GetWindowText(hMPListenAddress, ipAddress, 16); + + // Get the content of the port number edit box + GetWindowText(hMPListenPort, portNumber, 6); + + // Initialize sockaddr_in structure + sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; // IPv4 + + // Convert IP address string to binary form using InetPtonW + InetPtonW(AF_INET, ipAddress, &(addr.sin_addr)); + + // Convert port number string to integer and set it + int port = _wtoi(portNumber); // _wtoi for wide character conversion + addr.sin_port = htons(port); // Convert port to network byte order + + OVPN_MP_START_VPN in, out; + in.ListenAddress.Addr4 = addr; + + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_MP_START_VPN, &in, sizeof(in), &out, sizeof(out), &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_MP_START_VPN) failed with code ", GetLastError()); + } + else { + Log("MP Start VPN: Listen on ", sockAddrToString(out.ListenAddress.Addr4)); + } +} + +void P2PNewPeer() +{ + wchar_t localAddress[16], remoteAddress[16]; + wchar_t localPort[6], remotePort[6]; + + GetWindowText(hP2PLocalAddress, localAddress, 16); + GetWindowText(hP2PLocalPort, localPort, 6); + GetWindowText(hP2PRemoteAddress, remoteAddress, 16); + GetWindowText(hP2PRemotePort, remotePort, 6); + + sockaddr_in localAddr; + memset(&localAddr, 0, sizeof(localAddr)); + localAddr.sin_family = AF_INET; + InetPtonW(AF_INET, localAddress, &(localAddr.sin_addr)); + localAddr.sin_port = htons(_wtoi(localPort)); + + sockaddr_in remoteAddr; + memset(&remoteAddr, 0, sizeof(remoteAddr)); + remoteAddr.sin_family = AF_INET; + InetPtonW(AF_INET, remoteAddress, &(remoteAddr.sin_addr)); + remoteAddr.sin_port = htons(_wtoi(remotePort)); + + OVPN_NEW_PEER newPeer; + newPeer.Local.Addr4 = localAddr; + newPeer.Remote.Addr4 = remoteAddr; + newPeer.Proto = OVPN_PROTO_UDP; + + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_NEW_PEER, &newPeer, sizeof(newPeer), NULL, 0, &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_NEW_PEER) failed with code ", GetLastError()); + } + else { + Log("P2P peer added"); + } +} + +void P2PStartVPN() +{ + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_START_VPN, NULL, 0, NULL, 0, &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_START_VPN) failed with code ", GetLastError()); + } + else { + Log("P2P VPN Started"); + } +} + +void +CreatePushButton(HWND hWnd, DWORD ioctl, int x, int y) +{ + CreateWindowW(L"Button", buttons[ioctl].c_str(), WS_VISIBLE | WS_CHILD | BS_PUSHBUTTON, x, y, 100, 30, + hWnd, (HMENU)(INT_PTR)(GET_IOCTL_FUNCTION_CODE(ioctl)), NULL, NULL); +} + +void +CreatePushButton(HWND hWnd, wchar_t* title, HMENU hMenu, int x, int y) +{ + CreateWindowW(L"Button", title, WS_VISIBLE | WS_CHILD | BS_PUSHBUTTON, x, y, 100, 30, + hWnd, hMenu, NULL, NULL); +} + +HWND +CreateEditBox(HWND hWnd, WCHAR* text, int x, int y, int width) +{ + return CreateWindowW(L"Edit", text, WS_VISIBLE | WS_CHILD | WS_BORDER | ES_LEFT, x, y, width, 20, hWnd, NULL, NULL, NULL); +} + +void +SendCC() +{ + char text[1024]; + GetWindowTextA(hCCMessage, text, 1024); + + DWORD bytesWritten = 0; + BOOL res = WriteFile(hDev, text, (DWORD)strlen(text), &bytesWritten, &ovWrite); + if (!res && GetLastError() != ERROR_IO_PENDING) { + Log("WriteFile failed: ", GetLastError()); + } +} + +// Window Procedure Function +LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) +{ + switch (msg) + { + case WM_CREATE: + { + CreatePushButton(hwnd, OVPN_IOCTL_GET_VERSION, 10, 10); + CreatePushButton(hwnd, OVPN_IOCTL_SET_MODE, 150, 10); + + for (auto i = 0; i < modeData.size(); ++i) { + auto style = WS_VISIBLE | WS_CHILD | BS_AUTORADIOBUTTON; + if (i == 0) style |= WS_GROUP; + auto hMode = CreateWindowW(L"Button", modeData[i].second.c_str(), style, + 270 + 50 * i, 10, 50, 30, hwnd, (HMENU)(INT_PTR)(1000 + modeData[i].first), NULL, NULL); + hModes.push_back(hMode); + } + + CreatePushButton(hwnd, OVPN_IOCTL_MP_START_VPN, 10, 60); + hMPListenAddress = CreateEditBox(hwnd, L"0.0.0.0", 150, 60, 120); + hMPListenPort = CreateEditBox(hwnd, L"1194", 290, 60, 60); + + CreatePushButton(hwnd, OVPN_IOCTL_NEW_PEER, 10, 110); + hP2PLocalAddress = CreateEditBox(hwnd, L"192.168.100.1", 150, 110, 120); + hP2PLocalPort = CreateEditBox(hwnd, L"1194", 290, 110, 60); + hP2PRemoteAddress = CreateEditBox(hwnd, L"192.168.100.2", 400, 110, 120); + hP2PRemotePort = CreateEditBox(hwnd, L"1194", 540, 110, 60); + + CreatePushButton(hwnd, OVPN_IOCTL_START_VPN, 640, 110); + + CreatePushButton(hwnd, L"Send CC", (HMENU)BTN_SEND_CC, 10, 160); + hCCMessage = CreateEditBox(hwnd, L"hello, dco-win", 150, 160, 120); + } + + SendMessage(hModes[0], BM_SETCHECK, BST_CHECKED, 0); + + // Create Read-Only Multiline Edit Box (Log Area) with Scrollbars + hLogArea = CreateWindowW(L"Edit", L"", + WS_VISIBLE | WS_CHILD | WS_BORDER | ES_MULTILINE | ES_AUTOVSCROLL | ES_AUTOHSCROLL | + WS_VSCROLL | WS_HSCROLL | ES_READONLY, + 0, 0, 600, 100, hwnd, (HMENU)3, NULL, NULL); + + OpenDevice(DEV_NAME); + + break; + + case WM_COMMAND: + { + if ((wp >= 1) && (wp < 20)) + { + auto ioctl = GetIoctlFromFunctionCode((ULONG)wp); + + switch (ioctl) { + case OVPN_IOCTL_GET_VERSION: + DcoGetVersion(); + break; + + case OVPN_IOCTL_SET_MODE: + SetMode(); + break; + + case OVPN_IOCTL_MP_START_VPN: + MPStartVPN(); + break; + + case OVPN_IOCTL_NEW_PEER: + P2PNewPeer(); + break; + + case OVPN_IOCTL_START_VPN: + P2PStartVPN(); + break; + } + } + else if ((ULONG)wp == BTN_SEND_CC) { + SendCC(); + } + + } + + break; + + case WM_SIZE: + { + // Get the new width and height of the window + int width = LOWORD(lp); + int height = HIWORD(lp); + + // Resize the edit control (log area) to be at the bottom + SetWindowPos(hLogArea, NULL, 0, height - 110, width, 100, SWP_NOZORDER); + } + break; + + case WM_DESTROY: + PostQuitMessage(0); + break; + + default: + return DefWindowProcW(hwnd, msg, wp, lp); + } + return 0; +} \ No newline at end of file From 72fc2789aeb9030b442a1f30134ae8544494f703 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Wed, 18 Sep 2024 14:30:20 +0300 Subject: [PATCH 10/28] Account for SOCKADDR in control packets In multipeer UDP mode, we expect userspace to prepend CC packets with SOCKADDR to know where to send the control packet. Likewise, when we receive the control packet, we prepend it with remote SOCKADDR before pushing to userspace. https://github.com/OpenVPN/ovpn-dco-win/issues/84 Co-authored-by: Leon Dang Signed-off-by: Leon Dang Signed-off-by: Lev Stipakov --- Driver.cpp | 62 +++++++++++++++++++++++++++++++++++++---------- socket.cpp | 69 +++++++++++++++++++++++++++++++++++++++-------------- socket.h | 2 +- timer.cpp | 2 +- txqueue.cpp | 4 ++-- 5 files changed, 104 insertions(+), 35 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index 80d6368..f793fb6 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -187,7 +187,7 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) // acquire spinlock, since we access device->TransportSocket KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); - OVPN_TX_BUFFER* buffer = NULL; + OVPN_TX_BUFFER* txBuf = NULL; if (device->Socket.Socket == NULL) { status = STATUS_INVALID_DEVICE_STATE; @@ -195,31 +195,67 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) goto error; } - // fetch tx buffer - GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &buffer)); - // get request buffer - PVOID requestBuffer; - size_t requestBufferLength; - GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestRetrieveInputBuffer(request, 0, &requestBuffer, &requestBufferLength)); + PVOID buf; + size_t bufLen; + GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestRetrieveInputBuffer(request, 0, &buf, &bufLen)); + + PSOCKADDR sa = NULL; + + if (device->Mode == OVPN_MODE_MP) { + // buffer is prepended with SOCKADDR + + sa = (PSOCKADDR)buf; + switch (sa->sa_family) { + case AF_INET: + if (bufLen <= sizeof(SOCKADDR_IN)) { + status = STATUS_INVALID_MESSAGE; + LOG_ERROR("Message too short", TraceLoggingValue(bufLen, "msgLen"), TraceLoggingValue(sizeof(SOCKADDR_IN), "minLen")); + goto error; + } + + buf = (char*)buf + sizeof(SOCKADDR_IN); + bufLen -= sizeof(SOCKADDR_IN); + break; + + case AF_INET6: + if (bufLen <= sizeof(SOCKADDR_IN6)) { + status = STATUS_INVALID_MESSAGE; + LOG_ERROR("Message too short", TraceLoggingValue(bufLen, "msgLen"), TraceLoggingValue(sizeof(SOCKADDR_IN6), "minLen")); + goto error; + } + + buf = (char*)buf + sizeof(SOCKADDR_IN6); + bufLen -= sizeof(SOCKADDR_IN6); + break; + + default: + LOG_ERROR("Invalid address family", TraceLoggingValue(sa->sa_family, "AF")); + status = STATUS_INVALID_ADDRESS; + goto error; + } + } + + // fetch tx buffer + GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &txBuf)); // copy data from request to tx buffer - PUCHAR buf = OvpnTxBufferPut(buffer, requestBufferLength); - RtlCopyMemory(buf, requestBuffer, requestBufferLength); + PUCHAR data = OvpnTxBufferPut(txBuf, bufLen); + RtlCopyMemory(data, buf, bufLen); - buffer->IoQueue = device->PendingWritesQueue; + txBuf->IoQueue = device->PendingWritesQueue; // move request to manual queue GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestForwardToIoQueue(request, device->PendingWritesQueue)); // send - LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer)); + LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, txBuf, sa)); goto done_not_complete; error: - if (buffer != NULL) { - OvpnTxBufferPoolPut(buffer); + if (txBuf != NULL) { + OvpnTxBufferPoolPut(txBuf); } ULONG_PTR bytesCopied = 0; diff --git a/socket.cpp b/socket.cpp index 70987b9..bebf26e 100644 --- a/socket.cpp +++ b/socket.cpp @@ -97,8 +97,29 @@ OvpnSocketSyncOp(_In_z_ CHAR* opName, OP op, SUCCESS success) static _Requires_shared_lock_held_(device->SpinLock) VOID -OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR buf, SIZE_T len) +OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR buf, SIZE_T len, _In_opt_ PSOCKADDR remote) { + SIZE_T hdrLen = 0, totalLen = len; + + // in UDP and MP mode we prepend CC packet with remote sockaddr before pushing it to userspace + if (device->Mode == OVPN_MODE_MP && remote != NULL) { + switch (remote->sa_family) { + case AF_INET: + hdrLen = sizeof(SOCKADDR_IN); + break; + + case AF_INET6: + hdrLen = sizeof(SOCKADDR_IN6); + break; + + default: + LOG_ERROR("Invalid remote address family", TraceLoggingValue(remote->sa_family, "AF")); + InterlockedIncrementNoFence(&device->Stats.LostInControlPackets); + return; + } + totalLen += hdrLen; + } + WDFREQUEST request; NTSTATUS status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); if (!NT_SUCCESS(status)) { @@ -113,17 +134,22 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR return; } - if (sizeof(buffer->Data) >= len) { - // copy control packet to buffer - RtlCopyMemory(buffer->Data, buf, len); - buffer->Len = len; + if (sizeof(buffer->Data) >= totalLen) { + if (hdrLen > 0) { + // prepend with sockaddr + RtlCopyMemory(buffer->Data, remote, hdrLen); + } + + // copy control packet payload + RtlCopyMemory(buffer->Data + hdrLen, buf, totalLen - hdrLen); + buffer->Len = totalLen; // enqueue buffer, it will be dequeued when read request arrives OvpnBufferQueueEnqueue(device->ControlRxBufferQueue, &buffer->QueueListEntry); } else { LOG_ERROR("Buffer too small, packet len , buf len ", - TraceLoggingValue(len, "pktlen"), TraceLoggingValue(sizeof(buffer->Data), "buflen")); + TraceLoggingValue(totalLen, "pktlen"), TraceLoggingValue(sizeof(buffer->Data), "buflen")); OvpnRxBufferPoolPut(buffer); } @@ -133,19 +159,26 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR PVOID readBuffer; size_t readBufferLength; - ULONG_PTR bytesSent = len; + ULONG_PTR bytesSent = totalLen; - LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, len, &readBuffer, &readBufferLength)); + LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, totalLen, &readBuffer, &readBufferLength)); if (NT_SUCCESS(status)) { - // copy control packet to read request buffer - RtlCopyMemory(readBuffer, buf, len); + + if (hdrLen > 0) { + // prepend with sockaddr + RtlCopyMemory(readBuffer, remote, hdrLen); + } + + // copy control packet payload + RtlCopyMemory((PCHAR)readBuffer + hdrLen, buf, totalLen - hdrLen); + InterlockedIncrementNoFence(&device->Stats.ReceivedControlPackets); } else { InterlockedIncrementNoFence(&device->Stats.LostInControlPackets); if (status == STATUS_BUFFER_TOO_SMALL) { LOG_ERROR("Buffer too small, packet len , buf len ", - TraceLoggingValue(len, "pktlen"), TraceLoggingValue(readBufferLength, "buflen")); + TraceLoggingValue(totalLen, "pktlen"), TraceLoggingValue(readBufferLength, "buflen")); } bytesSent = 0; @@ -238,7 +271,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ } VOID -OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLength) PUCHAR buf, SIZE_T packetLength, BOOLEAN irqlDispatch) +OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLength) PUCHAR buf, SIZE_T packetLength, BOOLEAN irqlDispatch, _In_opt_ PSOCKADDR remoteAddr) { // If we're at dispatch level, we can use a small optimization and use function // which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock @@ -255,7 +288,7 @@ OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLengt OvpnSocketDataPacketReceived(device, op, buf, packetLength); } else { - OvpnSocketControlPacketReceived(device, buf, packetLength); + OvpnSocketControlPacketReceived(device, buf, packetLength, remoteAddr); } // don't forget to release spinlock @@ -330,7 +363,7 @@ OvpnSocketUdpReceiveFromEvent(_In_ PVOID socketContext, ULONG flags, _In_opt_ PW buf = packetBuf; } - OvpnSocketProcessIncomingPacket(device, buf, dataIndication->Buffer.Length, flags & WSK_FLAG_AT_DISPATCH_LEVEL); + OvpnSocketProcessIncomingPacket(device, buf, dataIndication->Buffer.Length, flags & WSK_FLAG_AT_DISPATCH_LEVEL, dataIndication->RemoteAddress); dataIndication = dataIndication->Next; } @@ -412,7 +445,7 @@ OvpnSocketTcpReceiveEvent(_In_opt_ PVOID socketContext, _In_ ULONG flags, _In_op buf = tcpState->PacketBuf; } - OvpnSocketProcessIncomingPacket(device, buf, tcpState->PacketLength, flags & WSK_FLAG_AT_DISPATCH_LEVEL); + OvpnSocketProcessIncomingPacket(device, buf, tcpState->PacketLength, flags & WSK_FLAG_AT_DISPATCH_LEVEL, NULL); mdlDataLen -= bytesRemained; dataIndicationLen -= bytesRemained; @@ -704,7 +737,7 @@ OvpnSocketSendComplete(_In_ PDEVICE_OBJECT deviceObj, _In_ PIRP irp, _In_ PVOID NTSTATUS _Use_decl_annotations_ -OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer) { +OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer, SOCKADDR* sa) { OVPN_DEVICE* device = (OVPN_DEVICE*)OvpnTxBufferPoolGetContext(buffer->Pool); PWSK_SOCKET socket = ovpnSocket->Socket; @@ -742,11 +775,11 @@ OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer) { } else if (buffer->WskBufList.Buffer.Length != 0) { PWSK_PROVIDER_DATAGRAM_DISPATCH datagramDispatch = (PWSK_PROVIDER_DATAGRAM_DISPATCH)socket->Dispatch; - LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendMessages(socket, &buffer->WskBufList, 0, NULL, 0, NULL, irp)); + LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendMessages(socket, &buffer->WskBufList, 0, sa, 0, NULL, irp)); } else { WSK_BUF wskBuf{ buffer->Mdl, FIELD_OFFSET(OVPN_TX_BUFFER, Head) + (ULONG)(buffer->Data - buffer->Head), buffer->Len }; PWSK_PROVIDER_DATAGRAM_DISPATCH datagramDispatch = (PWSK_PROVIDER_DATAGRAM_DISPATCH)socket->Dispatch; - LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendTo(socket, &wskBuf, 0, NULL, 0, NULL, irp)); + LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendTo(socket, &wskBuf, 0, sa, 0, NULL, irp)); } return status; diff --git a/socket.h b/socket.h index f1df149..0a6fc34 100644 --- a/socket.h +++ b/socket.h @@ -70,7 +70,7 @@ OvpnSocketClose(_In_opt_ PWSK_SOCKET socket); _Must_inspect_result_ NTSTATUS -OvpnSocketSend(_In_ OvpnSocket* ovpnSocket, _In_ OVPN_TX_BUFFER* buffer); +OvpnSocketSend(_In_ OvpnSocket* ovpnSocket, _In_ OVPN_TX_BUFFER* buffer, _In_opt_ SOCKADDR* sa); _Must_inspect_result_ NTSTATUS diff --git a/timer.cpp b/timer.cpp index 95eeaa1..5c15ab9 100644 --- a/timer.cpp +++ b/timer.cpp @@ -91,7 +91,7 @@ static VOID OvpnTimerXmit(WDFTIMER timer) if (NT_SUCCESS(status)) { // start async send, completion handler will return ciphertext buffer to the pool - LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer)); + LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, NULL)); if (NT_SUCCESS(status)) { LOG_INFO("Ping sent"); } diff --git a/txqueue.cpp b/txqueue.cpp index feec4c5..2ae629f 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -117,7 +117,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET if (NT_SUCCESS(status)) { // start async send, this will return ciphertext buffer to the pool if (device->Socket.Tcp) { - status = OvpnSocketSend(&device->Socket, buffer); + status = OvpnSocketSend(&device->Socket, buffer, NULL); } else { // for UDP we use SendMessages to send multiple datagrams at once @@ -195,7 +195,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) if (!device->Socket.Tcp) { // this will use WskSendMessages to send buffers list which we constructed before - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead)); + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, NULL)); } } } From c0882eed4e8bd76a20d514890802a386603c4c05 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Wed, 18 Sep 2024 14:32:04 +0300 Subject: [PATCH 11/28] gui: account for SOCKADDR in control packets In multipeer UDP mode, we expect userspace to prepend CC packets with SOCKADDR to know where to send the control packet. Likewise, when we receive the control packet, we prepend it with remote SOCKADDR before pushing to userspace. Signed-off-by: Lev Stipakov --- gui/gui.cpp | 105 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 72 insertions(+), 33 deletions(-) diff --git a/gui/gui.cpp b/gui/gui.cpp index c8d04f0..ce23a26 100644 --- a/gui/gui.cpp +++ b/gui/gui.cpp @@ -20,7 +20,7 @@ LRESULT CALLBACK WindowProcedure(HWND, UINT, WPARAM, LPARAM); HWND hMPListenAddress, hMPListenPort, hP2PLocalAddress, hP2PLocalPort, hP2PRemoteAddress, hP2PRemotePort, - hCCMessage; + hCCMessage, hCCRemoteAddress, hCCRemotePort; HWND hLogArea; std::unordered_map buttons = { @@ -37,6 +37,9 @@ std::unordered_map buttons = { {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"} }; +#define MIN_FUNCTION_CODE 1 +#define MAX_FUNCTION_CODE 20 + #define GET_IOCTL_FUNCTION_CODE(ioctl) (((ioctl) >> 2) & 0xFFF) unsigned long GetIoctlFromFunctionCode(unsigned long functionCode) { @@ -118,7 +121,21 @@ int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrevInst, LPSTR args, int ncmdsho DWORD bytesRead; if (GetOverlappedResult(hDev, &ovRead, &bytesRead, FALSE)) { if (bytesRead > 0) { - Log("CC[", bytesRead, "]> ", readBuffer); + bool mp = SendMessage(hModes[1], BM_GETCHECK, 0, 0) == BST_CHECKED; + + // if we're in server mode, we've received CC message prepended with sockaddr + if (mp) { + SOCKADDR_IN *sa = (SOCKADDR_IN *)readBuffer; + + char ip[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &(sa->sin_addr), ip, sizeof(ip)); + + int port = ntohs(sa->sin_port); + + Log("CC[", ip, ":", port, "]> ", readBuffer + sizeof(*sa)); + } else { + Log("CC[]> ", readBuffer); + } } } else { Log("Overlapped read failed: ", GetLastError()); @@ -139,6 +156,7 @@ int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrevInst, LPSTR args, int ncmdsho } } else if (waitResult == WAIT_OBJECT_0 + 2) { + // window messaging loop MSG msg; while (PeekMessage(&msg, NULL, 0, 0, PM_REMOVE)) { if (msg.message == WM_QUIT) { @@ -316,8 +334,7 @@ CreatePushButton(HWND hWnd, DWORD ioctl, int x, int y) void CreatePushButton(HWND hWnd, wchar_t* title, HMENU hMenu, int x, int y) { - CreateWindowW(L"Button", title, WS_VISIBLE | WS_CHILD | BS_PUSHBUTTON, x, y, 100, 30, - hWnd, hMenu, NULL, NULL); + CreateWindowW(L"Button", title, WS_VISIBLE | WS_CHILD | BS_PUSHBUTTON, x, y, 100, 30, hWnd, hMenu, NULL, NULL); } HWND @@ -329,11 +346,34 @@ CreateEditBox(HWND hWnd, WCHAR* text, int x, int y, int width) void SendCC() { - char text[1024]; + bool mp = SendMessage(hModes[1], BM_GETCHECK, 0, 0) == BST_CHECKED; + + sockaddr_in sa; + char text[1024], remoteAddress[16], remotePort[6]; GetWindowTextA(hCCMessage, text, 1024); + GetWindowTextA(hCCRemoteAddress, remoteAddress, 16); + GetWindowTextA(hCCRemotePort, remotePort, 6); + + char data[1024]; + DWORD dataLen = (DWORD)strlen(text); + if (mp) { + // in multipeer, we prepend CC message with sockaddr + memset(&sa, 0, sizeof(sa)); + sa.sin_family = AF_INET; + InetPtonA(AF_INET, remoteAddress, &(sa.sin_addr)); + sa.sin_port = htons(atoi(remotePort)); + + // prepend with sockaddr + memcpy(data, &sa, sizeof(sa)); + memcpy(data + sizeof(sa), text, strlen(text)); + + dataLen += sizeof(sa); + } else { + memcpy(data, text, strlen(text)); + } DWORD bytesWritten = 0; - BOOL res = WriteFile(hDev, text, (DWORD)strlen(text), &bytesWritten, &ovWrite); + BOOL res = WriteFile(hDev, data, dataLen, &bytesWritten, &ovWrite); if (!res && GetLastError() != ERROR_IO_PENDING) { Log("WriteFile failed: ", GetLastError()); } @@ -342,40 +382,39 @@ SendCC() // Window Procedure Function LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) { - switch (msg) - { + switch (msg) { case WM_CREATE: - { - CreatePushButton(hwnd, OVPN_IOCTL_GET_VERSION, 10, 10); - CreatePushButton(hwnd, OVPN_IOCTL_SET_MODE, 150, 10); - - for (auto i = 0; i < modeData.size(); ++i) { - auto style = WS_VISIBLE | WS_CHILD | BS_AUTORADIOBUTTON; - if (i == 0) style |= WS_GROUP; - auto hMode = CreateWindowW(L"Button", modeData[i].second.c_str(), style, - 270 + 50 * i, 10, 50, 30, hwnd, (HMENU)(INT_PTR)(1000 + modeData[i].first), NULL, NULL); - hModes.push_back(hMode); - } + CreatePushButton(hwnd, OVPN_IOCTL_GET_VERSION, 10, 10); + CreatePushButton(hwnd, OVPN_IOCTL_SET_MODE, 150, 10); + + for (auto i = 0; i < modeData.size(); ++i) { + auto style = WS_VISIBLE | WS_CHILD | BS_AUTORADIOBUTTON; + if (i == 0) style |= WS_GROUP; + auto hMode = CreateWindowW(L"Button", modeData[i].second.c_str(), style, + 270 + 50 * i, 10, 50, 30, hwnd, (HMENU)(INT_PTR)(1000 + modeData[i].first), NULL, NULL); + hModes.push_back(hMode); + } - CreatePushButton(hwnd, OVPN_IOCTL_MP_START_VPN, 10, 60); - hMPListenAddress = CreateEditBox(hwnd, L"0.0.0.0", 150, 60, 120); - hMPListenPort = CreateEditBox(hwnd, L"1194", 290, 60, 60); + CreatePushButton(hwnd, OVPN_IOCTL_MP_START_VPN, 10, 60); + hMPListenAddress = CreateEditBox(hwnd, L"0.0.0.0", 150, 60, 120); + hMPListenPort = CreateEditBox(hwnd, L"1194", 290, 60, 60); - CreatePushButton(hwnd, OVPN_IOCTL_NEW_PEER, 10, 110); - hP2PLocalAddress = CreateEditBox(hwnd, L"192.168.100.1", 150, 110, 120); - hP2PLocalPort = CreateEditBox(hwnd, L"1194", 290, 110, 60); - hP2PRemoteAddress = CreateEditBox(hwnd, L"192.168.100.2", 400, 110, 120); - hP2PRemotePort = CreateEditBox(hwnd, L"1194", 540, 110, 60); + CreatePushButton(hwnd, OVPN_IOCTL_NEW_PEER, 10, 110); + hP2PLocalAddress = CreateEditBox(hwnd, L"192.168.100.1", 150, 110, 120); + hP2PLocalPort = CreateEditBox(hwnd, L"1194", 290, 110, 60); + hP2PRemoteAddress = CreateEditBox(hwnd, L"192.168.100.2", 400, 110, 120); + hP2PRemotePort = CreateEditBox(hwnd, L"1194", 540, 110, 60); - CreatePushButton(hwnd, OVPN_IOCTL_START_VPN, 640, 110); + CreatePushButton(hwnd, OVPN_IOCTL_START_VPN, 640, 110); - CreatePushButton(hwnd, L"Send CC", (HMENU)BTN_SEND_CC, 10, 160); - hCCMessage = CreateEditBox(hwnd, L"hello, dco-win", 150, 160, 120); - } + CreatePushButton(hwnd, L"Send CC", (HMENU)BTN_SEND_CC, 10, 160); + hCCMessage = CreateEditBox(hwnd, L"hello, dco-win", 150, 160, 120); + hCCRemoteAddress = CreateEditBox(hwnd, L"192.168.100.1", 290, 160, 120); + hCCRemotePort = CreateEditBox(hwnd, L"1194", 430, 160, 60); SendMessage(hModes[0], BM_SETCHECK, BST_CHECKED, 0); - // Create Read-Only Multiline Edit Box (Log Area) with Scrollbars + // log area hLogArea = CreateWindowW(L"Edit", L"", WS_VISIBLE | WS_CHILD | WS_BORDER | ES_MULTILINE | ES_AUTOVSCROLL | ES_AUTOHSCROLL | WS_VSCROLL | WS_HSCROLL | ES_READONLY, @@ -387,7 +426,7 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) case WM_COMMAND: { - if ((wp >= 1) && (wp < 20)) + if ((wp >= MIN_FUNCTION_CODE) && (wp < MAX_FUNCTION_CODE)) { auto ioctl = GetIoctlFromFunctionCode((ULONG)wp); From 55cdbff169d4d99117e47c137dbbca02b351968b Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Wed, 25 Sep 2024 20:29:26 +0300 Subject: [PATCH 12/28] buffers: fix max headroom size Commit f7877d ("Implement DATA_V3 protocol features") has added support for 64bit pkitd. This means that headroom for TX buffer should be also increased by 4 bytes. Signed-off-by: Lev Stipakov --- bufferpool.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bufferpool.cpp b/bufferpool.cpp index 926c783..b28932b 100644 --- a/bufferpool.cpp +++ b/bufferpool.cpp @@ -27,7 +27,7 @@ #include "bufferpool.h" #include "trace.h" -#define OVPN_BUFFER_HEADROOM 26 // we prepend TCP packet size (2 bytes) and crypto overhead (24 bytes) +#define OVPN_BUFFER_HEADROOM 30 // we prepend TCP packet size (2 bytes) and max crypto overhead (28 bytes) // good enough limit for in-flight packets constexpr auto MAX_POOL_SIZE = 100'000; From 3486c5bd5e11976c9820ee0fdb1ddabf347b68f0 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Thu, 26 Sep 2024 10:37:21 +0300 Subject: [PATCH 13/28] buffers: RX_BUFFER refactor and buffer-manipulation functions Make RX_BUFFER compatible with buffer-manipulation templated functions, such as: - OvpnBufferTrim (new) - OvpnBufferPull (new) - OvpnBufferPut (existing, added template) Those functions are now compatible with both TX and RX buffers. This refactor allowed to simplify OvpnEvtRxQueueAdvance() and remove peer access from it, since the plaintext buffers dequeued there already have crypto overhead stripped. Signed-off-by: Lev Stipakov --- Driver.cpp | 2 +- bufferpool.cpp | 15 +++------------ bufferpool.h | 40 ++++++++++++++++++++++++++++++++++++---- rxqueue.cpp | 12 +----------- socket.cpp | 37 ++++++++++++++++++++----------------- timer.cpp | 2 +- txqueue.cpp | 4 ++-- 7 files changed, 64 insertions(+), 48 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index f793fb6..a7bc4a1 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -240,7 +240,7 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length) GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &txBuf)); // copy data from request to tx buffer - PUCHAR data = OvpnTxBufferPut(txBuf, bufLen); + PUCHAR data = OvpnBufferPut(txBuf, bufLen); RtlCopyMemory(data, buf, bufLen); txBuf->IoQueue = device->PendingWritesQueue; diff --git a/bufferpool.cpp b/bufferpool.cpp index b28932b..d1653ff 100644 --- a/bufferpool.cpp +++ b/bufferpool.cpp @@ -205,6 +205,8 @@ OvpnRxBufferPoolGet(OVPN_RX_BUFFER_POOL handle, OVPN_RX_BUFFER** buffer) if (*buffer == NULL) return STATUS_INSUFFICIENT_RESOURCES; + (*buffer)->Data = (*buffer)->Head; + (*buffer)->Tail = (*buffer)->Data; (*buffer)->Pool = handle; (*buffer)->Len = 0; @@ -270,17 +272,6 @@ OvpnTxBufferPoolDelete(OVPN_BUFFER_POOL handle) OvpnBufferPoolDelete(handle); } -_Use_decl_annotations_ -UCHAR* -OvpnTxBufferPut(OVPN_TX_BUFFER* buffer, SIZE_T len) -{ - UCHAR* tmp = buffer->Tail; - buffer->Tail += len; - buffer->Len += len; - - return tmp; -} - _Use_decl_annotations_ UCHAR* OvpnTxBufferPush(OVPN_TX_BUFFER* buffer, SIZE_T len) @@ -295,7 +286,7 @@ _Use_decl_annotations_ NTSTATUS OvpnRxBufferPoolCreate(OVPN_RX_BUFFER_POOL* handle) { - return OvpnBufferPoolCreate((OVPN_BUFFER_POOL*)handle, sizeof(OVPN_RX_BUFFER), "rx", NULL); + return OvpnBufferPoolCreate((OVPN_BUFFER_POOL*)handle, sizeof(OVPN_RX_BUFFER) + OVPN_SOCKET_RX_PACKET_BUFFER_SIZE, "rx", NULL); } VOID diff --git a/bufferpool.h b/bufferpool.h index 1ed0721..fa742af 100644 --- a/bufferpool.h +++ b/bufferpool.h @@ -63,19 +63,51 @@ struct OVPN_TX_BUFFER struct OVPN_RX_BUFFER { - LIST_ENTRY PoolListEntry; + // points to the beginning of data + PUCHAR Data; - LIST_ENTRY QueueListEntry; + // points to the end of data + PUCHAR Tail; + // data length SIZE_T Len; + LIST_ENTRY PoolListEntry; + + LIST_ENTRY QueueListEntry; + OVPN_RX_BUFFER_POOL Pool; - UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; + #pragma warning(suppress:4200) //nonstandard extension used: zero-sized array in struct/union + UCHAR Head[]; }; +template UCHAR* -OvpnTxBufferPut(_In_ OVPN_TX_BUFFER* work, SIZE_T len); +OvpnBufferPut(_In_ BUF* buf, SIZE_T len) +{ + UCHAR* tmp = buf->Tail; + buf->Tail += len; + buf->Len += len; + + return tmp; +} + +template +VOID +OvpnBufferTrim(_In_ BUF* buf, SIZE_T len) +{ + buf->Len = len; + buf->Tail = buf->Data + len; +} + +template +VOID +OvpnBufferPull(_In_ BUF* buf, SIZE_T len) +{ + buf->Len -= len; + buf->Data += len; +} UCHAR* OvpnTxBufferPush(_In_ OVPN_TX_BUFFER* work, SIZE_T len); diff --git a/rxqueue.cpp b/rxqueue.cpp index d7687c3..2a85ffd 100644 --- a/rxqueue.cpp +++ b/rxqueue.cpp @@ -102,16 +102,6 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_RXQUEUE queue = OvpnGetRxQueueContext(netPacketQueue); OVPN_DEVICE* device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - if (peer == NULL) { - LOG_WARN("No peer"); - return; - } - - BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); - NET_RING_FRAGMENT_ITERATOR fi = NetRingGetAllFragments(queue->Rings); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); while (NetFragmentIteratorHasAny(&fi)) { @@ -126,7 +116,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) fragment->ValidLength = buffer->Len; fragment->Offset = 0; NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + payloadOffset, buffer->Len); + RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data, buffer->Len); InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len); diff --git a/socket.cpp b/socket.cpp index bebf26e..533c0c2 100644 --- a/socket.cpp +++ b/socket.cpp @@ -134,15 +134,14 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR return; } - if (sizeof(buffer->Data) >= totalLen) { + if (totalLen <= OVPN_SOCKET_RX_PACKET_BUFFER_SIZE) { if (hdrLen > 0) { // prepend with sockaddr - RtlCopyMemory(buffer->Data, remote, hdrLen); + RtlCopyMemory(OvpnBufferPut(buffer, hdrLen), remote, hdrLen); } // copy control packet payload - RtlCopyMemory(buffer->Data + hdrLen, buf, totalLen - hdrLen); - buffer->Len = totalLen; + RtlCopyMemory(OvpnBufferPut(buffer, len), buf, len); // enqueue buffer, it will be dequeued when read request arrives OvpnBufferQueueEnqueue(device->ControlRxBufferQueue, &buffer->QueueListEntry); @@ -222,12 +221,22 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ LOG_ERROR("keyId not found", TraceLoggingValue(keyId, "keyId")); } else { + // extend data area in the buffer for plaintext and crypto overhead + OvpnBufferPut(buffer, len); + // decrypt into plaintext buffer status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions); + // trim AEAD tag an the end + auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + if (aeadTagEnd) { + OvpnBufferTrim(buffer, len - AEAD_AUTH_TAG_LEN); + } + + // remove crypto overhead in front auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); - buffer->Len = len - cryptoOverhead; + auto cryptoOverheadFront = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + OvpnBufferPull(buffer, cryptoOverheadFront); } } else { @@ -243,24 +252,18 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ OvpnTimerResetRecv(peer->Timer); - // points to the beginning of plaintext - BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); - UCHAR* plaintext = buffer->Data + payloadOffset; - // ping packet? - if (OvpnTimerIsKeepaliveMessage(plaintext, buffer->Len)) { + if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { LOG_INFO("Ping received"); // no need to inject ping packet into OS, return buffer to the pool OvpnRxBufferPoolPut(buffer); } else { - if (OvpnMssIsIPv4(plaintext, buffer->Len)) { - OvpnMssDoIPv4(plaintext, buffer->Len, device->MSS); - } else if (OvpnMssIsIPv6(plaintext, buffer->Len)) { - OvpnMssDoIPv6(plaintext, buffer->Len, device->MSS); + if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { + OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); + } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { + OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); } // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath diff --git a/timer.cpp b/timer.cpp index 5c15ab9..dc409c9 100644 --- a/timer.cpp +++ b/timer.cpp @@ -69,7 +69,7 @@ static VOID OvpnTimerXmit(WDFTIMER timer) } // copy keepalive magic message to the buffer - RtlCopyMemory(OvpnTxBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); + RtlCopyMemory(OvpnBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); OvpnPeerContext* peer = timerCtx->Peer; KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); diff --git a/txqueue.cpp b/txqueue.cpp index 2ae629f..4f7ba30 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -71,7 +71,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress( &queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(OvpnTxBufferPut(buffer, fragment->ValidLength), + RtlCopyMemory(OvpnBufferPut(buffer, fragment->ValidLength), (UCHAR const*)virtualAddr->VirtualAddress + fragment->Offset, fragment->ValidLength); NetFragmentIteratorAdvance(&fi); @@ -103,7 +103,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); if (aeadTagEnd) { - OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN); + OvpnBufferPut(buffer, AEAD_AUTH_TAG_LEN); } // in-place encrypt, always with primary key From 3c91912023b94b0ef359986c7734a278354e0734 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 8 Oct 2024 16:26:53 +0300 Subject: [PATCH 14/28] Implement OVPN_IOCTL_MP_NEW_PEER Implement adding peer in MP mode. Also add this to gui client. GitHub: https://github.com/OpenVPN/ovpn-dco-win/issues/85 Signed-off-by: Lev Stipakov Signed-off-by: Leon Dang Co-authored-by: Leon Dang --- Driver.cpp | 91 ++++++++++++++++++++++++++++++-- Driver.h | 26 +++++++++ crypto.h | 1 + gui/gui.cpp | 60 +++++++++++++++++++-- peer.cpp | 136 ++++++++++++++++++++++++++++++++++++++++++++++-- peer.h | 27 ++++++++++ socket.cpp | 9 ++-- txqueue.cpp | 4 +- uapi/ovpn-dco.h | 18 +++++++ 9 files changed, 355 insertions(+), 17 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index a7bc4a1..91e1c0c 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -317,6 +317,7 @@ OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) switch (code) { // those IOCTLs are for MP mode case OVPN_IOCTL_MP_START_VPN: + case OVPN_IOCTL_MP_NEW_PEER: return FALSE; } } @@ -500,6 +501,10 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe status = OvpnMPStartVPN(device, request, &bytesReturned); break; + case OVPN_IOCTL_MP_NEW_PEER: + status = OvpnMPPeerNew(device, request); + break; + default: LOG_WARN("Unknown ", TraceLoggingValue(ioControlCode, "ioControlCode")); status = STATUS_INVALID_DEVICE_REQUEST; @@ -705,8 +710,10 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoInitAlgHandles(&device->AesAlgHandle, &device->ChachaAlgHandle)); - // Initialize peers tree + // Initialize peers tables RtlInitializeGenericTable(&device->Peers, OvpnPeerCompareByPeerIdRoutine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); + RtlInitializeGenericTable(&device->PeersByVpn4, OvpnPeerCompareByVPN4Routine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); + RtlInitializeGenericTable(&device->PeersByVpn6, OvpnPeerCompareByVPN6Routine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); LOG_IF_NOT_NT_SUCCESS(status = OvpnAdapterCreate(device)); @@ -716,17 +723,17 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { return status; } -_Use_decl_annotations_ NTSTATUS -OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) +OvpnAddPeerToTable(_In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer) { NTSTATUS status; BOOLEAN newElem; - RtlInsertElementGenericTable(&device->Peers, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); + RtlInsertElementGenericTable(table, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); if (newElem) { status = STATUS_SUCCESS; + InterlockedIncrement(&peer->RefCounter); } else { LOG_ERROR("Unable to add new peer"); @@ -735,9 +742,33 @@ OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) return status; } + +_Use_decl_annotations_ +NTSTATUS +OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) +{ + return OvpnAddPeerToTable(&device->Peers, peer); +} + +_Use_decl_annotations_ +NTSTATUS +OvpnAddPeerVpn4(POVPN_DEVICE device, OvpnPeerContext* peer) +{ + return OvpnAddPeerToTable(&device->PeersByVpn4, peer); +} + +_Use_decl_annotations_ +NTSTATUS +OvpnAddPeerVpn6(POVPN_DEVICE device, OvpnPeerContext* peer) +{ + return OvpnAddPeerToTable(&device->PeersByVpn6, peer); +} + _Use_decl_annotations_ VOID OvpnFlushPeers(POVPN_DEVICE device) { + OvpnCleanupPeerTable(&device->PeersByVpn6); + OvpnCleanupPeerTable(&device->PeersByVpn4); OvpnCleanupPeerTable(&device->Peers); } @@ -750,7 +781,9 @@ OvpnCleanupPeerTable(RTL_GENERIC_TABLE* peers) OvpnPeerContext* peer = *(OvpnPeerContext**)ptr; RtlDeleteElementGenericTable(peers, ptr); - OvpnPeerCtxFree(peer); + if (InterlockedDecrement(&peer->RefCounter) == 0) { + OvpnPeerCtxFree(peer); + } } } @@ -761,3 +794,51 @@ OvpnGetFirstPeer(RTL_GENERIC_TABLE* peers) OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(peers, 0); return ptr ? (OvpnPeerContext*)*ptr : NULL; } + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId) +{ + if (device->Mode == OVPN_MODE_P2P) { + return OvpnGetFirstPeer(&device->Peers); + } + + OvpnPeerContext p {}; + p.PeerId = PeerId; + + auto* pp = &p; + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp); + return ptr ? (OvpnPeerContext*)*ptr : NULL; +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr) +{ + if (device->Mode == OVPN_MODE_P2P) { + return OvpnGetFirstPeer(&device->Peers); + } + + OvpnPeerContext p{}; + p.VpnAddrs.IPv4 = addr; + + auto* pp = &p; + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp); + return ptr ? (OvpnPeerContext*)*ptr : NULL; +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr) +{ + if (device->Mode == OVPN_MODE_P2P) { + return OvpnGetFirstPeer(&device->Peers); + } + + OvpnPeerContext p{}; + p.VpnAddrs.IPv6 = addr; + + auto* pp = &p; + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); + return ptr ? (OvpnPeerContext*)*ptr : NULL; +} diff --git a/Driver.h b/Driver.h index 0045871..745854c 100644 --- a/Driver.h +++ b/Driver.h @@ -94,6 +94,12 @@ struct OVPN_DEVICE { _Guarded_by_(SpinLock) RTL_GENERIC_TABLE Peers; + _Guarded_by_(SpinLock) + RTL_GENERIC_TABLE PeersByVpn4; + + _Guarded_by_(SpinLock) + RTL_GENERIC_TABLE PeersByVpn6; + OVPN_MODE Mode; }; @@ -114,6 +120,14 @@ _Must_inspect_result_ NTSTATUS OvpnAddPeer(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); +_Must_inspect_result_ +NTSTATUS +OvpnAddPeerVpn4(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); + +_Must_inspect_result_ +NTSTATUS +OvpnAddPeerVpn6(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); + VOID OvpnFlushPeers(_In_ POVPN_DEVICE device); @@ -123,3 +137,15 @@ OvpnCleanupPeerTable(_In_ RTL_GENERIC_TABLE*); _Must_inspect_result_ OvpnPeerContext* OvpnGetFirstPeer(_In_ RTL_GENERIC_TABLE*); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeer(_In_ POVPN_DEVICE device, INT32 PeerId); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeerVPN4(_In_ POVPN_DEVICE device, _In_ IN_ADDR addr); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeerVPN6(_In_ POVPN_DEVICE device, _In_ IN6_ADDR addr); diff --git a/crypto.h b/crypto.h index 3782d42..73d286b 100644 --- a/crypto.h +++ b/crypto.h @@ -36,6 +36,7 @@ #define OVPN_OP_DATA_V2 9 #define OVPN_KEY_ID_MASK 0x07 #define OVPN_OPCODE_SHIFT 3 +#define OVPN_PEER_ID_MASK 0x00FFFFFF struct OvpnCryptoKeySlot { diff --git a/gui/gui.cpp b/gui/gui.cpp index ce23a26..3eeeb75 100644 --- a/gui/gui.cpp +++ b/gui/gui.cpp @@ -20,7 +20,8 @@ LRESULT CALLBACK WindowProcedure(HWND, UINT, WPARAM, LPARAM); HWND hMPListenAddress, hMPListenPort, hP2PLocalAddress, hP2PLocalPort, hP2PRemoteAddress, hP2PRemotePort, - hCCMessage, hCCRemoteAddress, hCCRemotePort; + hCCMessage, hCCRemoteAddress, hCCRemotePort, + hMPNewPeerLocalIP, hMPNewPeerLocalPort, hMPNewPeerRemoteIP, hMPNewPeerRemotePort, hMPNewPeerVPNIP, hMPNewPeerPeerId; HWND hLogArea; std::unordered_map buttons = { @@ -34,7 +35,8 @@ std::unordered_map buttons = { {OVPN_IOCTL_GET_VERSION, L"Get Version"}, {OVPN_IOCTL_NEW_KEY_V2, L"New Key V2"}, {OVPN_IOCTL_SET_MODE, L"Set Mode"}, - {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"} + {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"}, + {OVPN_IOCTL_MP_NEW_PEER, L"MP New Peer"} }; #define MIN_FUNCTION_CODE 1 @@ -102,7 +104,7 @@ int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrevInst, LPSTR args, int ncmdsho // Create the Window HWND hwnd = CreateWindowW(L"myWindowClass", L"ovpn-dco-win GUI", WS_OVERLAPPEDWINDOW | WS_VISIBLE, - 100, 100, 800, 600, NULL, NULL, NULL, NULL); + 100, 100, 900, 600, NULL, NULL, NULL, NULL); HANDLE hEvRead = CreateEventW(NULL, FALSE, FALSE, NULL); ovRead.hEvent = hEvRead; @@ -324,6 +326,47 @@ void P2PStartVPN() } } +void MPNewPeer() +{ + wchar_t localIP[16], localPort[6]; + wchar_t remoteIP[16], remotePort[6], vpnIP[16]; + wchar_t peerId[6]; + + GetWindowText(hMPNewPeerLocalIP, localIP, 16); + GetWindowText(hMPNewPeerLocalPort, localPort, 6); + GetWindowText(hMPNewPeerRemoteIP, remoteIP, 16); + GetWindowText(hMPNewPeerRemotePort, remotePort, 6); + GetWindowText(hMPNewPeerVPNIP, vpnIP, 16); + GetWindowText(hMPNewPeerPeerId, peerId, 6); + + sockaddr_in localAddr = {}; + localAddr.sin_family = AF_INET; + InetPtonW(AF_INET, localIP, &(localAddr.sin_addr)); + localAddr.sin_port = htons(_wtoi(localPort)); + + sockaddr_in remoteAddr = {}; + remoteAddr.sin_family = AF_INET; + InetPtonW(AF_INET, remoteIP, &(remoteAddr.sin_addr)); + remoteAddr.sin_port = htons(_wtoi(remotePort)); + + in_addr vpnAddress; + InetPtonW(AF_INET, vpnIP, &vpnAddress); + + OVPN_MP_NEW_PEER newPeer = {}; + newPeer.Local.Addr4 = localAddr; + newPeer.Remote.Addr4 = remoteAddr; + newPeer.VpnAddr4 = vpnAddress; + newPeer.PeerId = _wtoi(peerId); + + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_MP_NEW_PEER, &newPeer, sizeof(newPeer), NULL, 0, &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_MP_NEW_PEER) failed with code ", GetLastError()); + } + else { + Log("MP peer added"); + } +} + void CreatePushButton(HWND hWnd, DWORD ioctl, int x, int y) { @@ -412,6 +455,14 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) hCCRemoteAddress = CreateEditBox(hwnd, L"192.168.100.1", 290, 160, 120); hCCRemotePort = CreateEditBox(hwnd, L"1194", 430, 160, 60); + CreatePushButton(hwnd, OVPN_IOCTL_MP_NEW_PEER, 10, 210); + hMPNewPeerLocalIP = CreateEditBox(hwnd, L"192.168.100.2", 150, 210, 120); + hMPNewPeerLocalPort = CreateEditBox(hwnd, L"1194", 290, 210, 60); + hMPNewPeerRemoteIP = CreateEditBox(hwnd, L"192.168.100.1", 400, 210, 120); + hMPNewPeerRemotePort = CreateEditBox(hwnd, L"1194", 540, 210, 60); + hMPNewPeerVPNIP = CreateEditBox(hwnd, L"10.8.0.6", 650, 210, 120); + hMPNewPeerPeerId = CreateEditBox(hwnd, L"1", 790, 210, 60); + SendMessage(hModes[0], BM_SETCHECK, BST_CHECKED, 0); // log area @@ -450,6 +501,9 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) case OVPN_IOCTL_START_VPN: P2PStartVPN(); break; + + case OVPN_IOCTL_MP_NEW_PEER: + MPNewPeer(); } } else if ((ULONG)wp == BTN_SEND_CC) { diff --git a/peer.cpp b/peer.cpp index 4f36386..57bccff 100644 --- a/peer.cpp +++ b/peer.cpp @@ -66,13 +66,54 @@ OvpnPeerFreeRoutine(_RTL_GENERIC_TABLE* table, PVOID buffer) ExFreePoolWithTag(buffer, 'ovpn'); } -RTL_GENERIC_COMPARE_RESULTS OvpnPeerCompareByPeerIdRoutine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +RTL_GENERIC_COMPARE_RESULTS +OvpnPeerCompareByPeerIdRoutine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) { UNREFERENCED_PARAMETER(table); - UNREFERENCED_PARAMETER(first); - UNREFERENCED_PARAMETER(second); - return GenericEqual; + OvpnPeerContext* peer1 = *(OvpnPeerContext**)first; + OvpnPeerContext* peer2 = *(OvpnPeerContext**)second; + + if (peer1->PeerId == peer2->PeerId) + return GenericEqual; + else if (peer1->PeerId < peer2->PeerId) + return GenericLessThan; + else + return GenericGreaterThan; +} + +RTL_GENERIC_COMPARE_RESULTS +OvpnPeerCompareByVPN4Routine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +{ + UNREFERENCED_PARAMETER(table); + + OvpnPeerContext* peer1 = *(OvpnPeerContext**)first; + OvpnPeerContext* peer2 = *(OvpnPeerContext**)second; + + int n = memcmp(&peer1->VpnAddrs.IPv4, &peer2->VpnAddrs.IPv4, sizeof(IN_ADDR)); + if (n == 0) + return GenericEqual; + else if (n < 0) + return GenericLessThan; + else + return GenericGreaterThan; +} + +RTL_GENERIC_COMPARE_RESULTS +OvpnPeerCompareByVPN6Routine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +{ + UNREFERENCED_PARAMETER(table); + + OvpnPeerContext* peer1 = *(OvpnPeerContext**)first; + OvpnPeerContext* peer2 = *(OvpnPeerContext**)second; + + int n = memcmp(&peer1->VpnAddrs.IPv6, &peer2->VpnAddrs.IPv6, sizeof(IN6_ADDR)); + if (n == 0) + return GenericEqual; + else if (n < 0) + return GenericLessThan; + else + return GenericGreaterThan; } static @@ -161,6 +202,93 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) return status; } +_Use_decl_annotations_ +NTSTATUS +OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) +{ + LOG_ENTER(); + + const struct in6_addr ovpn_in6addr_any = { { 0 } }; + + NTSTATUS status = STATUS_SUCCESS; + + POVPN_MP_NEW_PEER peer; + + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_NEW_PEER), (PVOID*)&peer, nullptr)); + + // check if we already have a peer with the same peer-id + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + OvpnPeerContext* peerCtx = OvpnFindPeer(device, peer->PeerId); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + if (peerCtx != NULL) { + status = STATUS_OBJECTID_EXISTS; + goto done; + } + + // ensure local/remote address is AF_INET or AF_INET6 + if ((peer->Local.Addr4.sin_family != AF_INET) && (peer->Local.Addr4.sin_family != AF_INET6)) + { + status = STATUS_INVALID_DEVICE_REQUEST; + LOG_ERROR("Unknown address family in peer->Local", TraceLoggingValue(peer->Local.Addr4.sin_family, "AF")); + goto done; + } + if ((peer->Remote.Addr4.sin_family != AF_INET) && (peer->Remote.Addr4.sin_family != AF_INET6)) + { + status = STATUS_INVALID_DEVICE_REQUEST; + LOG_ERROR("Unknown address family in peer->Remote", TraceLoggingValue(peer->Remote.Addr4.sin_family, "AF")); + goto done; + } + + // allocate peer + peerCtx = OvpnPeerCtxAlloc(); + if (peerCtx == NULL) { + status = STATUS_NO_MEMORY; + goto done; + } + + // assign local transport address + if (peer->Local.Addr4.sin_family == AF_INET) { + peerCtx->TransportAddrs.Local.IPv4 = peer->Local.Addr4.sin_addr; + } + else { + peerCtx->TransportAddrs.Local.IPv6 = peer->Local.Addr6.sin6_addr; + } + + // assign remote transport address + if (peer->Remote.Addr4.sin_family == AF_INET) { + peerCtx->TransportAddrs.Remote.IPv4 = peer->Remote.Addr4; + } + else { + peerCtx->TransportAddrs.Remote.IPv6 = peer->Remote.Addr6; + } + + peerCtx->VpnAddrs.IPv4 = peer->VpnAddr4; + peerCtx->VpnAddrs.IPv6 = peer->VpnAddr6; + + peerCtx->PeerId = peer->PeerId; + + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); + if (status == STATUS_SUCCESS) { + if (peer->VpnAddr4.S_un.S_addr != INADDR_ANY) { + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerVpn4(device, peerCtx)); + } + + if (RtlCompareMemory(&peer->VpnAddr6, &ovpn_in6addr_any, sizeof(IN6_ADDR)) != sizeof(IN6_ADDR)) { + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerVpn6(device, peerCtx)); + } + } + else { + OvpnPeerCtxFree(peerCtx); + } + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + +done: + LOG_EXIT(); + + return status; +} + _Use_decl_annotations_ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) { diff --git a/peer.h b/peer.h index 11a80f8..907d44b 100644 --- a/peer.h +++ b/peer.h @@ -41,6 +41,26 @@ struct OvpnPeerContext // 1-sec timer which handles ping intervals and keepalive timeouts WDFTIMER Timer; + + struct { + IN_ADDR IPv4; + IN6_ADDR IPv6; + } VpnAddrs; + + struct { + union { + IN_ADDR IPv4; + IN6_ADDR IPv6; + } Local; + + union { + SOCKADDR_IN IPv4; + SOCKADDR_IN6 IPv6; + } Remote; + + } TransportAddrs; + + LONG RefCounter; }; _Must_inspect_result_ @@ -53,12 +73,19 @@ OvpnPeerCtxFree(_In_ OvpnPeerContext*); RTL_GENERIC_ALLOCATE_ROUTINE OvpnPeerAllocateRoutine; RTL_GENERIC_FREE_ROUTINE OvpnPeerFreeRoutine; RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByPeerIdRoutine; +RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByVPN4Routine; +RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByVPN6Routine; _Must_inspect_result_ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS OvpnPeerNew(_In_ POVPN_DEVICE device, WDFREQUEST request); +_Must_inspect_result_ +_IRQL_requires_(PASSIVE_LEVEL) +NTSTATUS +OvpnMPPeerNew(_In_ POVPN_DEVICE device, WDFREQUEST request); + _Must_inspect_result_ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS diff --git a/socket.cpp b/socket.cpp index 533c0c2..2b8653c 100644 --- a/socket.cpp +++ b/socket.cpp @@ -189,13 +189,13 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR static _Requires_shared_lock_held_(device->SpinLock) -VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_(len) PUCHAR cipherTextBuf, SIZE_T len) +VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 peerId, _In_reads_(len) PUCHAR cipherTextBuf, SIZE_T len) { InterlockedExchangeAddNoFence64(&device->Stats.TransportBytesReceived, len); - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + OvpnPeerContext* peer = OvpnFindPeer(device, peerId); if (peer == NULL) { - LOG_WARN("No peer"); + LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peerId")); InterlockedIncrementNoFence(&device->Stats.LostInDataPackets); return; } @@ -288,7 +288,8 @@ OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLengt UCHAR op = RtlUlongByteSwap(*(ULONG*)(buf)) >> 24; if (OvpnCryptoOpcodeExtract(op) == OVPN_OP_DATA_V2) { - OvpnSocketDataPacketReceived(device, op, buf, packetLength); + UINT32 peerId = RtlUlongByteSwap(*(ULONG*)(buf)) & OVPN_PEER_ID_MASK; + OvpnSocketDataPacketReceived(device, op, peerId, buf, packetLength); } else { OvpnSocketControlPacketReceived(device, buf, packetLength, remoteAddr); diff --git a/txqueue.cpp b/txqueue.cpp index 4f7ba30..94a86f3 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -77,13 +77,15 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET NetFragmentIteratorAdvance(&fi); } + OvpnPeerContext* peer = NULL; if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); + peer = OvpnFindPeerVPN4(device, ((IPV4_HEADER*)buffer->Data)->DestinationAddress); } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); + peer = OvpnFindPeerVPN6(device, ((IPV6_HEADER*)buffer->Data)->DestinationAddress); } - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); if (peer == NULL) { status = STATUS_ADDRESS_NOT_ASSOCIATED; OvpnTxBufferPoolPut(buffer); diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index 095e008..60e4a53 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -47,6 +47,23 @@ typedef struct _OVPN_NEW_PEER { OVPN_PROTO Proto; } OVPN_NEW_PEER, * POVPN_NEW_PEER; +typedef struct _OVPN_MP_NEW_PEER { + union { + SOCKADDR_IN Addr4; + SOCKADDR_IN6 Addr6; + } Local; + + union { + SOCKADDR_IN Addr4; + SOCKADDR_IN6 Addr6; + } Remote; + + IN_ADDR VpnAddr4; + IN6_ADDR VpnAddr6; + + int PeerId; +} OVPN_MP_NEW_PEER, * POVPN_MP_NEW_PEER; + typedef struct _OVPN_STATS { LONG LostInControlPackets; LONG LostOutControlPackets; @@ -142,3 +159,4 @@ typedef struct _OVPN_MP_START_VPN { #define OVPN_IOCTL_SET_MODE CTL_CODE(FILE_DEVICE_UNKNOWN, 10, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_MP_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 11, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define OVPN_IOCTL_MP_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 12, METHOD_BUFFERED, FILE_ANY_ACCESS) From cd610da95cb0c78840245d6ac7294bc86abbded2 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Wed, 9 Oct 2024 14:38:54 +0300 Subject: [PATCH 15/28] gui: implement OVPN_IOCTL_NEW_KEY Key material is read from data64.key. Signed-off-by: Lev Stipakov --- gui/CMakeLists.txt | 2 +- gui/gui.cpp | 67 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/gui/CMakeLists.txt b/gui/CMakeLists.txt index ffb0b6f..074aa7e 100644 --- a/gui/CMakeLists.txt +++ b/gui/CMakeLists.txt @@ -8,4 +8,4 @@ set(CMAKE_CXX_STANDARD_REQUIRED YES) # Ensure the compiler strictly requires C add_executable (gui WIN32 "gui.cpp") add_definitions(-DUNICODE -D_UNICODE) -target_link_libraries(gui PRIVATE user32 gdi32 Ws2_32) +target_link_libraries(gui PRIVATE user32 gdi32 Ws2_32 Crypt32) diff --git a/gui/gui.cpp b/gui/gui.cpp index 3eeeb75..d1d17cd 100644 --- a/gui/gui.cpp +++ b/gui/gui.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -21,7 +22,8 @@ HWND hMPListenAddress, hMPListenPort, hP2PLocalAddress, hP2PLocalPort, hP2PRemoteAddress, hP2PRemotePort, hCCMessage, hCCRemoteAddress, hCCRemotePort, - hMPNewPeerLocalIP, hMPNewPeerLocalPort, hMPNewPeerRemoteIP, hMPNewPeerRemotePort, hMPNewPeerVPNIP, hMPNewPeerPeerId; + hMPNewPeerLocalIP, hMPNewPeerLocalPort, hMPNewPeerRemoteIP, hMPNewPeerRemotePort, hMPNewPeerVPNIP, hMPNewPeerPeerId, + hP2PNewKeyPeerId; HWND hLogArea; std::unordered_map buttons = { @@ -36,7 +38,8 @@ std::unordered_map buttons = { {OVPN_IOCTL_NEW_KEY_V2, L"New Key V2"}, {OVPN_IOCTL_SET_MODE, L"Set Mode"}, {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"}, - {OVPN_IOCTL_MP_NEW_PEER, L"MP New Peer"} + {OVPN_IOCTL_MP_NEW_PEER, L"MP New Peer"}, + {OVPN_IOCTL_NEW_KEY, L"P2P New Key"}, }; #define MIN_FUNCTION_CODE 1 @@ -367,6 +370,58 @@ void MPNewPeer() } } +void +P2PNewKey() +{ + wchar_t peerId[6]; + GetWindowText(hP2PNewKeyPeerId, peerId, 6); + + std::ifstream file("data64.key"); + if (!file) return; + + std::string b64str{(std::istreambuf_iterator(file)), std::istreambuf_iterator()}; + + DWORD binarySize = 0; + if (!CryptStringToBinaryA(b64str.c_str(), 0, CRYPT_STRING_BASE64, nullptr, &binarySize, nullptr, nullptr)) + return; + + std::vector buf(binarySize); + if (!CryptStringToBinaryA(b64str.c_str(), 0, CRYPT_STRING_BASE64, buf.data(), &binarySize, nullptr, nullptr)) + return; + + OVPN_CRYPTO_DATA crypto_data = {}; + constexpr int keyLen = sizeof(crypto_data.Encrypt.Key); + + bool keyDirectory = 0; + if (keyDirectory) { + CopyMemory(crypto_data.Encrypt.Key, buf.data() + keyLen, keyLen); + CopyMemory(crypto_data.Decrypt.Key, buf.data(), keyLen); + } + else { + CopyMemory(crypto_data.Encrypt.Key, buf.data(), keyLen); + CopyMemory(crypto_data.Decrypt.Key, buf.data() + keyLen, keyLen); + } + + crypto_data.Encrypt.KeyLen = keyLen; // hardcode 256bit key size + crypto_data.Decrypt.KeyLen = keyLen; // hardcode 256bit key size + + constexpr int nonceTailLen = sizeof(crypto_data.Encrypt.NonceTail); + // for test purposes decrypt and encrypt nonces are same + CopyMemory(crypto_data.Encrypt.NonceTail, buf.data() + keyLen * 2, nonceTailLen); + CopyMemory(crypto_data.Decrypt.NonceTail, buf.data() + keyLen * 2, nonceTailLen); + + crypto_data.CipherAlg = OVPN_CIPHER_ALG::OVPN_CIPHER_ALG_AES_GCM; + crypto_data.PeerId = _wtoi(peerId); + + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_NEW_KEY, &crypto_data, sizeof(crypto_data), NULL, 0, &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_NEW_KEY) failed with code ", GetLastError()); + } + else { + Log("New key added"); + } +} + void CreatePushButton(HWND hWnd, DWORD ioctl, int x, int y) { @@ -463,6 +518,9 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) hMPNewPeerVPNIP = CreateEditBox(hwnd, L"10.8.0.6", 650, 210, 120); hMPNewPeerPeerId = CreateEditBox(hwnd, L"1", 790, 210, 60); + CreatePushButton(hwnd, OVPN_IOCTL_NEW_KEY, 10, 260); + hP2PNewKeyPeerId = CreateEditBox(hwnd, L"1", 150, 260, 60); + SendMessage(hModes[0], BM_SETCHECK, BST_CHECKED, 0); // log area @@ -504,6 +562,11 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) case OVPN_IOCTL_MP_NEW_PEER: MPNewPeer(); + break; + + case OVPN_IOCTL_NEW_KEY: + P2PNewKey(); + break; } } else if ((ULONG)wp == BTN_SEND_CC) { From 515d133156e27f0d2302612d45d043c3f2f7594b Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Wed, 9 Oct 2024 14:50:07 +0300 Subject: [PATCH 16/28] peer.cpp: make PeerNewKey use the actual peer Instead of getting the first peer, find peer by peer-id passed in crypto data. Signed-off-by: Lev Stipakov --- peer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/peer.cpp b/peer.cpp index 57bccff..e327617 100644 --- a/peer.cpp +++ b/peer.cpp @@ -438,7 +438,7 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) BCRYPT_ALG_HANDLE algHandle = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoData->CipherAlg, algHandle)); - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + OvpnPeerContext* peer = OvpnFindPeer(device, cryptoData->PeerId); if (peer == NULL) { status = STATUS_OBJECTID_NOT_FOUND; goto done; @@ -474,7 +474,7 @@ OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request) BCRYPT_ALG_HANDLE algHandle = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoDataV2->V1.CipherAlg, algHandle)); - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + OvpnPeerContext* peer = OvpnFindPeer(device, cryptoDataV2->V1.PeerId); if (peer == NULL) { status = STATUS_OBJECTID_NOT_FOUND; goto done; From b146d945e8bb81c813da9e83f061fe673973052d Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Thu, 10 Oct 2024 14:24:17 +0300 Subject: [PATCH 17/28] driver.cpp: enable NEW_KEY(_V2) ioctls in MP mode Those are receiving peer-id so they could work in MP mode too. Signed-off-by: Lev Stipakov --- Driver.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index 91e1c0c..51f2df6 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -305,8 +305,6 @@ OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) // all those IOCTLs are only for P2P mode case OVPN_IOCTL_NEW_PEER: case OVPN_IOCTL_DEL_PEER: - case OVPN_IOCTL_NEW_KEY: - case OVPN_IOCTL_NEW_KEY_V2: case OVPN_IOCTL_SWAP_KEYS: case OVPN_IOCTL_SET_PEER: case OVPN_IOCTL_START_VPN: From 669fd0dade771177632c47ae3f1cb852525352fb Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Thu, 10 Oct 2024 15:48:12 +0300 Subject: [PATCH 18/28] txqueue: pass SOCKADDR* to OvpnSocketSend In MP mode OvpnSocketSend requires destination address to where data should be sent, so we pass peer's remote address. To make code work for both MP and P2P cases, store peer remote address for P2P mode. Signed-off-by: Lev Stipakov --- peer.cpp | 15 +++++++++++++++ txqueue.cpp | 3 ++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/peer.cpp b/peer.cpp index e327617..e02e7ee 100644 --- a/peer.cpp +++ b/peer.cpp @@ -153,6 +153,13 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) goto done; } + if ((peer->Remote.Addr4.sin_family != AF_INET) && (peer->Remote.Addr4.sin_family != AF_INET6)) + { + status = STATUS_INVALID_DEVICE_REQUEST; + LOG_ERROR("Unknown address family in peer->Remote", TraceLoggingValue(peer->Remote.Addr4.sin_family, "AF")); + goto done; + } + POVPN_DRIVER driver = OvpnGetDriverContext(WdfGetDriver()); PWSK_SOCKET socket = NULL; BOOLEAN proto_tcp = peer->Proto == OVPN_PROTO_TCP; @@ -164,6 +171,14 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) goto done; } + // assign remote transport address + if (peer->Remote.Addr4.sin_family == AF_INET) { + peerCtx->TransportAddrs.Remote.IPv4 = peer->Remote.Addr4; + } + else { + peerCtx->TransportAddrs.Remote.IPv6 = peer->Remote.Addr6; + } + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnSocketInit(&driver->WskProviderNpi, &driver->WskRegistration, peer->Local.Addr4.sin_family, proto_tcp, (PSOCKADDR)&peer->Local, diff --git a/txqueue.cpp b/txqueue.cpp index 94a86f3..5a4e5f9 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -191,13 +191,14 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) NetPacketIteratorSet(&pi); if (packetSent) { + // TODO: get actual peer OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); if (peer != NULL) { OvpnTimerResetXmit(peer->Timer); if (!device->Socket.Tcp) { // this will use WskSendMessages to send buffers list which we constructed before - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, NULL)); + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, (SOCKADDR*)&peer->TransportAddrs.Remote)); } } } From c55ce385dac48758d88d68150e1cd9f81a358c21 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Thu, 10 Oct 2024 15:49:10 +0300 Subject: [PATCH 19/28] gui: make "New Key" work for both P2P and MP Signed-off-by: Lev Stipakov --- gui/gui.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/gui/gui.cpp b/gui/gui.cpp index d1d17cd..2fd76c8 100644 --- a/gui/gui.cpp +++ b/gui/gui.cpp @@ -23,7 +23,7 @@ HWND hMPListenAddress, hMPListenPort, hP2PRemoteAddress, hP2PRemotePort, hCCMessage, hCCRemoteAddress, hCCRemotePort, hMPNewPeerLocalIP, hMPNewPeerLocalPort, hMPNewPeerRemoteIP, hMPNewPeerRemotePort, hMPNewPeerVPNIP, hMPNewPeerPeerId, - hP2PNewKeyPeerId; + hNewKeyPeerId; HWND hLogArea; std::unordered_map buttons = { @@ -39,7 +39,7 @@ std::unordered_map buttons = { {OVPN_IOCTL_SET_MODE, L"Set Mode"}, {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"}, {OVPN_IOCTL_MP_NEW_PEER, L"MP New Peer"}, - {OVPN_IOCTL_NEW_KEY, L"P2P New Key"}, + {OVPN_IOCTL_NEW_KEY, L"New Key"}, }; #define MIN_FUNCTION_CODE 1 @@ -371,10 +371,10 @@ void MPNewPeer() } void -P2PNewKey() +NewKey() { wchar_t peerId[6]; - GetWindowText(hP2PNewKeyPeerId, peerId, 6); + GetWindowText(hNewKeyPeerId, peerId, 6); std::ifstream file("data64.key"); if (!file) return; @@ -392,8 +392,9 @@ P2PNewKey() OVPN_CRYPTO_DATA crypto_data = {}; constexpr int keyLen = sizeof(crypto_data.Encrypt.Key); - bool keyDirectory = 0; - if (keyDirectory) { + bool mp = SendMessage(hModes[1], BM_GETCHECK, 0, 0) == BST_CHECKED; + bool keyDir = mp ? 1 : 0; + if (keyDir) { CopyMemory(crypto_data.Encrypt.Key, buf.data() + keyLen, keyLen); CopyMemory(crypto_data.Decrypt.Key, buf.data(), keyLen); } @@ -519,7 +520,7 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) hMPNewPeerPeerId = CreateEditBox(hwnd, L"1", 790, 210, 60); CreatePushButton(hwnd, OVPN_IOCTL_NEW_KEY, 10, 260); - hP2PNewKeyPeerId = CreateEditBox(hwnd, L"1", 150, 260, 60); + hNewKeyPeerId = CreateEditBox(hwnd, L"1", 150, 260, 60); SendMessage(hModes[0], BM_SETCHECK, BST_CHECKED, 0); @@ -565,7 +566,7 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) break; case OVPN_IOCTL_NEW_KEY: - P2PNewKey(); + NewKey(); break; } } From 6dfc04a5e104c8d0594c2efb64ef5bd32cf3015d Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Fri, 11 Oct 2024 10:22:43 +0300 Subject: [PATCH 20/28] txqueue.cpp: send packets to actual peer Signed-off-by: Lev Stipakov --- txqueue.cpp | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/txqueue.cpp b/txqueue.cpp index 5a4e5f9..f0cf628 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -41,7 +41,7 @@ _Requires_shared_lock_held_(device->SpinLock) static NTSTATUS OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET_RING_PACKET_ITERATOR *pi, - _Inout_ OVPN_TX_BUFFER **head, _Inout_ OVPN_TX_BUFFER** tail) + _Inout_ OVPN_TX_BUFFER **head, _Inout_ OVPN_TX_BUFFER** tail, _Inout_ SOCKADDR **headSockaddr) { NET_RING_FRAGMENT_ITERATOR fi = NetPacketIteratorGetFragments(pi); @@ -129,14 +129,27 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET buffer->WskBufList.Buffer.Mdl = buffer->Mdl; buffer->WskBufList.Buffer.Offset = FIELD_OFFSET(OVPN_TX_BUFFER, Head) + (ULONG)(buffer->Data - buffer->Head); - if (*head == NULL) { + // If this peer is different (head sockaddr != peer sockaddr) to the previous buffer chain peers, + // then flush those and restart with a new buffer list. + + if ((*head != NULL) && *headSockaddr != (SOCKADDR*)&(peer->TransportAddrs.Remote)) + { + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, *head, *headSockaddr)); *head = buffer; + *tail = buffer; + *headSockaddr = (SOCKADDR*)&(peer->TransportAddrs.Remote); } else { - (*tail)->WskBufList.Next = &buffer->WskBufList; + if (*head == NULL) { + *head = buffer; + *headSockaddr = (SOCKADDR*)&(peer->TransportAddrs.Remote); + } + else { + (*tail)->WskBufList.Next = &buffer->WskBufList; + } + + *tail = buffer; } - - *tail = buffer; } OvpnTimerResetXmit(peer->Timer); @@ -169,12 +182,13 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) OVPN_TX_BUFFER* txBufferHead = NULL; OVPN_TX_BUFFER* txBufferTail = NULL; + SOCKADDR* headSockaddr = NULL; while (NetPacketIteratorHasAny(&pi)) { NET_PACKET* packet = NetPacketIteratorGetPacket(&pi); NTSTATUS status = STATUS_SUCCESS; if (!packet->Ignore && !packet->Scratch) { - status = OvpnTxProcessPacket(device, queue, &pi, &txBufferHead, &txBufferTail); + status = OvpnTxProcessPacket(device, queue, &pi, &txBufferHead, &txBufferTail, &headSockaddr); if (!NT_SUCCESS(status)) { InterlockedIncrementNoFence(&device->Stats.LostOutDataPackets); } @@ -191,15 +205,9 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) NetPacketIteratorSet(&pi); if (packetSent) { - // TODO: get actual peer - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - if (peer != NULL) { - OvpnTimerResetXmit(peer->Timer); - - if (!device->Socket.Tcp) { - // this will use WskSendMessages to send buffers list which we constructed before - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, (SOCKADDR*)&peer->TransportAddrs.Remote)); - } + if (!device->Socket.Tcp) { + // this will use WskSendMessages to send buffers list which we constructed before + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, headSockaddr)); } } From 70f7155db065285d67ed8b6b7a16d561ec3a28ef Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 15 Oct 2024 15:25:12 +0300 Subject: [PATCH 21/28] multipeer: switch to AVL trees By default Windows uses splay trees, which require exclusive lock on lookup. Switch to AVL trees, since they support concurrent read using shared locks. Signed-off-by: Lev Stipakov --- ovpn-dco-win.vcxproj | 32 ++++++++++++++++---------------- peer.cpp | 10 +++++----- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/ovpn-dco-win.vcxproj b/ovpn-dco-win.vcxproj index d3000a6..ab2e3e3 100644 --- a/ovpn-dco-win.vcxproj +++ b/ovpn-dco-win.vcxproj @@ -439,7 +439,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -455,7 +455,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -471,7 +471,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) false @@ -498,7 +498,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) false @@ -525,7 +525,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -549,7 +549,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -573,7 +573,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) false @@ -600,7 +600,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) false @@ -627,7 +627,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -643,7 +643,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -659,7 +659,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) false @@ -681,7 +681,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) false @@ -703,7 +703,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) false @@ -730,7 +730,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) false @@ -757,7 +757,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) @@ -773,7 +773,7 @@ true %(AdditionalIncludeDirectories) - OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) + RTL_USE_AVL_TABLES=0;OVPN_DCO_VERSION_MAJOR=$(OVPN_DCO_VERSION_MAJOR);OVPN_DCO_VERSION_MINOR=$(OVPN_DCO_VERSION_MINOR);OVPN_DCO_VERSION_PATCH=$(OVPN_DCO_VERSION_PATCH);OVPN_DCO_VERSION_STR=$(OVPN_DCO_VERSION_MAJOR).$(OVPN_DCO_VERSION_MINOR).$(OVPN_DCO_VERSION_PATCH);NETADAPTER_VERSION_MAJOR=$(NETADAPTER_VERSION_MAJOR);NETADAPTER_VERSION_MINOR=$(NETADAPTER_VERSION_MINOR);%(PreprocessorDefinitions) uuid.lib;Netio.lib;cng.lib;wdmsec.lib%(AdditionalDependencies) diff --git a/peer.cpp b/peer.cpp index e02e7ee..6a140b4 100644 --- a/peer.cpp +++ b/peer.cpp @@ -50,7 +50,7 @@ OvpnPeerCtxFree(OvpnPeerContext* peer) _Use_decl_annotations_ PVOID -OvpnPeerAllocateRoutine(_RTL_GENERIC_TABLE* table, CLONG size) +OvpnPeerAllocateRoutine(RTL_GENERIC_TABLE* table, CLONG size) { UNREFERENCED_PARAMETER(table); @@ -59,7 +59,7 @@ OvpnPeerAllocateRoutine(_RTL_GENERIC_TABLE* table, CLONG size) _Use_decl_annotations_ VOID -OvpnPeerFreeRoutine(_RTL_GENERIC_TABLE* table, PVOID buffer) +OvpnPeerFreeRoutine(RTL_GENERIC_TABLE* table, PVOID buffer) { UNREFERENCED_PARAMETER(table); @@ -67,7 +67,7 @@ OvpnPeerFreeRoutine(_RTL_GENERIC_TABLE* table, PVOID buffer) } RTL_GENERIC_COMPARE_RESULTS -OvpnPeerCompareByPeerIdRoutine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +OvpnPeerCompareByPeerIdRoutine(RTL_GENERIC_TABLE* table, PVOID first, PVOID second) { UNREFERENCED_PARAMETER(table); @@ -83,7 +83,7 @@ OvpnPeerCompareByPeerIdRoutine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID sec } RTL_GENERIC_COMPARE_RESULTS -OvpnPeerCompareByVPN4Routine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +OvpnPeerCompareByVPN4Routine(RTL_GENERIC_TABLE* table, PVOID first, PVOID second) { UNREFERENCED_PARAMETER(table); @@ -100,7 +100,7 @@ OvpnPeerCompareByVPN4Routine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID secon } RTL_GENERIC_COMPARE_RESULTS -OvpnPeerCompareByVPN6Routine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +OvpnPeerCompareByVPN6Routine(RTL_GENERIC_TABLE* table, PVOID first, PVOID second) { UNREFERENCED_PARAMETER(table); From b737eb9b6ee8fea308956818f95962bc825453d0 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Fri, 18 Oct 2024 11:10:18 +0300 Subject: [PATCH 22/28] multipeer: support for SET_PEER Add support for IOCTL_MP_SET_PEER. This is used to set keepalive internal, keepalive timeout and MSS for a specific peer. Note that proper userspace notification for expired peer is stil missing. GitHub: https://github.com/OpenVPN/ovpn-dco-win/issues/85 Signed-off-by: Lev Stipakov --- Driver.cpp | 7 ++++++ Driver.h | 4 --- peer.cpp | 65 ++++++++++++++++++++++++++++++++++++++++--------- peer.h | 7 ++++++ socket.cpp | 6 ++--- timer.cpp | 5 ++-- txqueue.cpp | 16 ++++++++---- uapi/ovpn-dco.h | 14 ++++++++--- 8 files changed, 95 insertions(+), 29 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index 51f2df6..a9e9b5f 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -316,6 +316,7 @@ OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) // those IOCTLs are for MP mode case OVPN_IOCTL_MP_START_VPN: case OVPN_IOCTL_MP_NEW_PEER: + case OVPN_IOCTL_MP_SET_PEER: return FALSE; } } @@ -503,6 +504,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe status = OvpnMPPeerNew(device, request); break; + case OVPN_IOCTL_MP_SET_PEER: + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + status = OvpnMPPeerSet(device, request); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + break; + default: LOG_WARN("Unknown ", TraceLoggingValue(ioControlCode, "ioControlCode")); status = STATUS_INVALID_DEVICE_REQUEST; diff --git a/Driver.h b/Driver.h index 745854c..0f60608 100644 --- a/Driver.h +++ b/Driver.h @@ -81,10 +81,6 @@ struct OVPN_DEVICE { BCRYPT_ALG_HANDLE AesAlgHandle; BCRYPT_ALG_HANDLE ChachaAlgHandle; - // set from the userspace, defines TCP Maximum Segment Size - _Guarded_by_(SpinLock) - UINT16 MSS; - _Guarded_by_(SpinLock) OvpnSocket Socket; diff --git a/peer.cpp b/peer.cpp index 6a140b4..6ba9a8c 100644 --- a/peer.cpp +++ b/peer.cpp @@ -282,6 +282,13 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) peerCtx->PeerId = peer->PeerId; + // create peer-specific timer + LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); + if (status != STATUS_SUCCESS) { + OvpnPeerCtxFree(peerCtx); + goto done; + } + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); if (status == STATUS_SUCCESS) { @@ -304,6 +311,27 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) return status; } +VOID OvpnPeerSetDoWork(OvpnPeerContext *peer, LONG keepaliveInterval, LONG keepaliveTimeout, LONG mss) +{ + if (mss != -1) { + peer->MSS = (UINT16)mss; + } + + if (keepaliveInterval != -1) { + peer->KeepaliveInterval = keepaliveInterval; + + // keepalive xmit timer, sends ping packets + OvpnTimerSetXmitInterval(peer->Timer, peer->KeepaliveInterval); + } + + if (keepaliveTimeout != -1) { + peer->KeepaliveTimeout = keepaliveTimeout; + + // keepalive recv timer, detects keepalive timeout + OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout); + } +} + _Use_decl_annotations_ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) { @@ -326,24 +354,37 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) TraceLoggingValue(set_peer->KeepaliveTimeout, "timeout"), TraceLoggingValue(set_peer->MSS, "MSS")); - if (set_peer->MSS != -1) { - device->MSS = (UINT16)set_peer->MSS; - } + OvpnPeerSetDoWork(peer, set_peer->KeepaliveInterval, set_peer->KeepaliveTimeout, set_peer->MSS); - if (set_peer->KeepaliveInterval != -1) { - peer->KeepaliveInterval = set_peer->KeepaliveInterval; +done: + LOG_EXIT(); + return status; +} - // keepalive xmit timer, sends ping packets - OvpnTimerSetXmitInterval(peer->Timer, peer->KeepaliveInterval); - } +_Use_decl_annotations_ +NTSTATUS OvpnMPPeerSet(POVPN_DEVICE device, WDFREQUEST request) +{ + LOG_ENTER(); + + NTSTATUS status = STATUS_SUCCESS; - if (peer->KeepaliveTimeout != -1) { - peer->KeepaliveTimeout = set_peer->KeepaliveTimeout; + POVPN_MP_SET_PEER set_peer = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_SET_PEER), (PVOID*)&set_peer, nullptr)); - // keepalive recv timer, detects keepalive timeout - OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout); + LOG_INFO("MP Set peer", TraceLoggingValue(set_peer->PeerId, "peer-id"), + TraceLoggingValue(set_peer->KeepaliveInterval, "interval"), + TraceLoggingValue(set_peer->KeepaliveTimeout, "timeout"), + TraceLoggingValue(set_peer->MSS, "MSS")); + + OvpnPeerContext* peer = OvpnFindPeer(device, set_peer->PeerId); + if (peer == NULL) { + LOG_ERROR("Peer not found", TraceLoggingValue(set_peer->PeerId, "peer-id")); + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } + OvpnPeerSetDoWork(peer, set_peer->KeepaliveInterval, set_peer->KeepaliveTimeout, set_peer->MSS); + done: LOG_EXIT(); return status; diff --git a/peer.h b/peer.h index 907d44b..00a1585 100644 --- a/peer.h +++ b/peer.h @@ -42,6 +42,8 @@ struct OvpnPeerContext // 1-sec timer which handles ping intervals and keepalive timeouts WDFTIMER Timer; + UINT16 MSS; + struct { IN_ADDR IPv4; IN6_ADDR IPv6; @@ -91,6 +93,11 @@ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS OvpnPeerSet(_In_ POVPN_DEVICE device, WDFREQUEST request); +_Must_inspect_result_ +_Requires_exclusive_lock_held_(device->SpinLock) +NTSTATUS +OvpnMPPeerSet(_In_ POVPN_DEVICE device, WDFREQUEST request); + _Must_inspect_result_ NTSTATUS _Requires_shared_lock_held_(device->SpinLock) diff --git a/socket.cpp b/socket.cpp index 2b8653c..5281207 100644 --- a/socket.cpp +++ b/socket.cpp @@ -254,16 +254,16 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee // ping packet? if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { - LOG_INFO("Ping received"); + LOG_INFO("Ping received", TraceLoggingValue(peer->PeerId, "peer-id")); // no need to inject ping packet into OS, return buffer to the pool OvpnRxBufferPoolPut(buffer); } else { if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { - OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); + OvpnMssDoIPv4(buffer->Data, buffer->Len, peer->MSS); } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { - OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); + OvpnMssDoIPv6(buffer->Data, buffer->Len, peer->MSS); } // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath diff --git a/timer.cpp b/timer.cpp index dc409c9..d67e158 100644 --- a/timer.cpp +++ b/timer.cpp @@ -91,9 +91,10 @@ static VOID OvpnTimerXmit(WDFTIMER timer) if (NT_SUCCESS(status)) { // start async send, completion handler will return ciphertext buffer to the pool - LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, NULL)); + SOCKADDR* sa = (SOCKADDR*)&(peer->TransportAddrs.Remote); + LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, sa)); if (NT_SUCCESS(status)) { - LOG_INFO("Ping sent"); + LOG_INFO("Ping sent", TraceLoggingValue(peer->PeerId, "peer-id")); } } else { diff --git a/txqueue.cpp b/txqueue.cpp index f0cf628..dd1b063 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -78,18 +78,24 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET } OvpnPeerContext* peer = NULL; + if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { - OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); - peer = OvpnFindPeerVPN4(device, ((IPV4_HEADER*)buffer->Data)->DestinationAddress); + auto addr = ((IPV4_HEADER*)buffer->Data)->DestinationAddress; + peer = OvpnFindPeerVPN4(device, addr); + if (peer != NULL) { + OvpnMssDoIPv4(buffer->Data, buffer->Len, peer->MSS); + } } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { - OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); - peer = OvpnFindPeerVPN6(device, ((IPV6_HEADER*)buffer->Data)->DestinationAddress); + auto addr = ((IPV6_HEADER*)buffer->Data)->DestinationAddress; + peer = OvpnFindPeerVPN6(device, addr); + if (peer != NULL) { + OvpnMssDoIPv6(buffer->Data, buffer->Len, peer->MSS); + } } if (peer == NULL) { status = STATUS_ADDRESS_NOT_ASSOCIATED; OvpnTxBufferPoolPut(buffer); - LOG_WARN("No peer"); goto out; } diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index 60e4a53..be4770b 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -119,10 +119,17 @@ typedef struct _OVPN_CRYPTO_DATA_V2 { UINT32 CryptoOptions; } OVPN_CRYPTO_DATA_V2, * POVPN_CRYPTO_DATA_V2; +typedef struct _OVPN_MP_SET_PEER { + int PeerId; + LONG KeepaliveInterval; + LONG KeepaliveTimeout; + LONG MSS; +} OVPN_MP_SET_PEER, * POVPN_MP_SET_PEER; + typedef struct _OVPN_SET_PEER { - LONG KeepaliveInterval; - LONG KeepaliveTimeout; - LONG MSS; + LONG KeepaliveInterval; + LONG KeepaliveTimeout; + LONG MSS; } OVPN_SET_PEER, * POVPN_SET_PEER; typedef struct _OVPN_VERSION { @@ -160,3 +167,4 @@ typedef struct _OVPN_MP_START_VPN { #define OVPN_IOCTL_MP_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 11, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_MP_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 12, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define OVPN_IOCTL_MP_SET_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 13, METHOD_BUFFERED, FILE_ANY_ACCESS) From 5a38804f754aab3bcbc5422a1721a5416864fced Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Fri, 18 Oct 2024 11:10:47 +0300 Subject: [PATCH 23/28] gui: support fot (MP_)SET_PEER Signed-off-by: Lev Stipakov --- gui/gui.cpp | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/gui/gui.cpp b/gui/gui.cpp index 2fd76c8..178ae26 100644 --- a/gui/gui.cpp +++ b/gui/gui.cpp @@ -23,7 +23,8 @@ HWND hMPListenAddress, hMPListenPort, hP2PRemoteAddress, hP2PRemotePort, hCCMessage, hCCRemoteAddress, hCCRemotePort, hMPNewPeerLocalIP, hMPNewPeerLocalPort, hMPNewPeerRemoteIP, hMPNewPeerRemotePort, hMPNewPeerVPNIP, hMPNewPeerPeerId, - hNewKeyPeerId; + hNewKeyPeerId, + hSetPeerPeerId, hSetPeerInterval, hSetPeerTimeout, hSetPeerMSS; HWND hLogArea; std::unordered_map buttons = { @@ -40,6 +41,7 @@ std::unordered_map buttons = { {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"}, {OVPN_IOCTL_MP_NEW_PEER, L"MP New Peer"}, {OVPN_IOCTL_NEW_KEY, L"New Key"}, + {OVPN_IOCTL_SET_PEER, L"Set Peer"} }; #define MIN_FUNCTION_CODE 1 @@ -423,6 +425,50 @@ NewKey() } } +void +SetPeer() +{ + wchar_t peerId[6], interval[6], timeout[6], mss[6]; + + GetWindowText(hSetPeerPeerId, peerId, 16); + GetWindowText(hSetPeerInterval, interval, 16); + GetWindowText(hSetPeerTimeout, timeout, 16); + GetWindowText(hSetPeerMSS, mss, 16); + + bool mp = SendMessage(hModes[1], BM_GETCHECK, 0, 0) == BST_CHECKED; + + if (mp) { + OVPN_MP_SET_PEER set_peer = {}; + set_peer.PeerId = _wtoi(peerId); + set_peer.KeepaliveInterval = _wtoi(interval); + set_peer.KeepaliveTimeout = _wtoi(timeout); + set_peer.MSS = _wtoi(mss); + + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_MP_SET_PEER, &set_peer, sizeof(set_peer), NULL, 0, &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_MP_SET_PEER) failed with code ", GetLastError()); + } + else { + Log("MP Peer set", peerId); + } + } else { + OVPN_SET_PEER set_peer = {}; + set_peer.KeepaliveInterval = _wtoi(interval); + set_peer.KeepaliveTimeout = _wtoi(timeout); + set_peer.MSS = _wtoi(mss); + + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_SET_PEER, &set_peer, sizeof(set_peer), NULL, 0, &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_SET_PEER) failed with code ", GetLastError()); + } + else { + Log("Peer set"); + } + } + + +} + void CreatePushButton(HWND hWnd, DWORD ioctl, int x, int y) { @@ -522,6 +568,12 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) CreatePushButton(hwnd, OVPN_IOCTL_NEW_KEY, 10, 260); hNewKeyPeerId = CreateEditBox(hwnd, L"1", 150, 260, 60); + CreatePushButton(hwnd, OVPN_IOCTL_SET_PEER, 10, 310); + hSetPeerPeerId = CreateEditBox(hwnd, L"1", 150, 310, 60); + hSetPeerInterval = CreateEditBox(hwnd, L"5", 240, 310, 60); + hSetPeerTimeout = CreateEditBox(hwnd, L"30", 330, 310, 60); + hSetPeerMSS = CreateEditBox(hwnd, L"-1", 420, 310, 60); + SendMessage(hModes[0], BM_SETCHECK, BST_CHECKED, 0); // log area @@ -568,6 +620,9 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) case OVPN_IOCTL_NEW_KEY: NewKey(); break; + + case OVPN_IOCTL_SET_PEER: + SetPeer(); } } else if ((ULONG)wp == BTN_SEND_CC) { From d1c9a6beed7676c4857c68fbeabb5184572955ed Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 22 Oct 2024 12:11:16 +0300 Subject: [PATCH 24/28] Implement userspace notifications Existing mechanism of userspace notifications (by erroring out pending read requests) is not flexible enough to be used by multipeer. This adds a new OVPN_IOCTL_NOTIFY_EVENT ioctl. When request arrives, we check if there are pending notifications. If yes, then we complete request, writing notification command, peer-id and peer delete reason (if applicable). If there are no pending notifications, request is queued. When notification occurs (such as peer keepalive timeout), we check if there is a pending requests in the queue (see above). If yes, we complete request with notification details. If there are no pending requests, we add notification event to a queue. Events queue is implemented with a C++ class and a kernel linked lists. The queue is a member of device context. Since there is no C++ runtime and constructors are not called for context members, we have to use a separate method for initialization. GitHub: https://github.com/OpenVPN/ovpn-dco-win/issues/87 Signed-off-by: Lev Stipakov --- Driver.cpp | 52 ++++++++++++++++++++++++++- Driver.h | 5 +++ notifyqueue.cpp | 86 ++++++++++++++++++++++++++++++++++++++++++++ notifyqueue.h | 53 +++++++++++++++++++++++++++ ovpn-dco-win.vcxproj | 2 ++ timer.cpp | 44 ++++++++++++++++++----- uapi/ovpn-dco.h | 21 +++++++++++ 7 files changed, 254 insertions(+), 9 deletions(-) create mode 100644 notifyqueue.cpp create mode 100644 notifyqueue.h diff --git a/Driver.cpp b/Driver.cpp index a9e9b5f..0cd4a29 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -356,10 +356,18 @@ OvpnStopVPN(_In_ POVPN_DEVICE device) WDFREQUEST request; while (NT_SUCCESS(WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request))) { ULONG_PTR bytesCopied = 0; - LOG_INFO("Cancel IO request from manual queue"); + LOG_INFO("Cancel pending read requests"); WdfRequestCompleteWithInformation(request, STATUS_CANCELLED, bytesCopied); } + while (NT_SUCCESS(WdfIoQueueRetrieveNextRequest(device->PendingNotificationRequestsQueue, &request))) { + ULONG_PTR bytesCopied = 0; + LOG_INFO("Cancel pending notifications"); + WdfRequestCompleteWithInformation(request, STATUS_CANCELLED, bytesCopied); + } + + device->PendingNotificationsQueue.FlushEvents(); + LOG_EXIT(); return STATUS_SUCCESS; @@ -422,6 +430,37 @@ OvpnMPStartVPN(POVPN_DEVICE device, WDFREQUEST request, ULONG_PTR* bytesReturned return status; } +NTSTATUS +OvpnNotifyEvent(POVPN_DEVICE device, WDFREQUEST request, _Out_ ULONG_PTR* bytesReturned) { + LOG_ENTER(); + + NTSTATUS status = STATUS_SUCCESS; + + *bytesReturned = 0; + + // do we have pending notifications? + auto event = device->PendingNotificationsQueue.GetEvent(); + if (event != nullptr) { + OVPN_NOTIFY_EVENT* evt; + LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_NOTIFY_EVENT), (PVOID*)&evt, nullptr)); + if (NT_SUCCESS(status)) { + evt->Cmd = evt->Cmd; + evt->PeerId = evt->PeerId; + evt->DelPeerReason = evt->DelPeerReason; + *bytesReturned = sizeof(OVPN_NOTIFY_EVENT); + } + device->PendingNotificationsQueue.FreeEvent(event); + } + else { + LOG_IF_NOT_NT_SUCCESS(WdfRequestForwardToIoQueue(request, device->PendingNotificationRequestsQueue)); + status = STATUS_PENDING; + } + + LOG_EXIT(); + + return status; +} + EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL OvpnEvtIoDeviceControl; _Use_decl_annotations_ @@ -510,6 +549,10 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; + case OVPN_IOCTL_NOTIFY_EVENT: + status = OvpnNotifyEvent(device, request, &bytesReturned); + break; + default: LOG_WARN("Unknown ", TraceLoggingValue(ioControlCode, "ioControlCode")); status = STATUS_INVALID_DEVICE_REQUEST; @@ -707,12 +750,19 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { WDF_IO_QUEUE_CONFIG_INIT(&queueConfig, WdfIoQueueDispatchManual); GOTO_IF_NOT_NT_SUCCESS(done, status, WdfIoQueueCreate(wdfDevice, &queueConfig, WDF_NO_OBJECT_ATTRIBUTES, &device->PendingNewPeerQueue)); + // create manual queue which handles userspace notification requests + WDF_IO_QUEUE_CONFIG_INIT(&queueConfig, WdfIoQueueDispatchManual); + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfIoQueueCreate(wdfDevice, &queueConfig, WDF_NO_OBJECT_ATTRIBUTES, &device->PendingNotificationRequestsQueue)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTxBufferPoolCreate(&device->TxBufferPool, device)); GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnRxBufferPoolCreate(&device->RxBufferPool)); GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnBufferQueueCreate(&device->ControlRxBufferQueue)); GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnBufferQueueCreate(&device->DataRxBufferQueue)); + // constructors are not called for the members of WDF object context, so we use Init() method + device->PendingNotificationsQueue.Init(); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoInitAlgHandles(&device->AesAlgHandle, &device->ChachaAlgHandle)); // Initialize peers tables diff --git a/Driver.h b/Driver.h index 0f60608..f9f2dc1 100644 --- a/Driver.h +++ b/Driver.h @@ -31,6 +31,7 @@ #include "adapter.h" #include "bufferpool.h" #include "crypto.h" +#include "notifyqueue.h" #include "socket.h" #include "uapi\ovpn-dco.h" @@ -60,6 +61,7 @@ struct OVPN_DEVICE { WDFQUEUE PendingReadsQueue; WDFQUEUE PendingWritesQueue; + WDFQUEUE PendingNotificationRequestsQueue; // NEW_PEER request may be enqueued here if TCP connect doesn't finish immediatelly WDFQUEUE PendingNewPeerQueue; @@ -76,6 +78,9 @@ struct OVPN_DEVICE { // buffer pool for encrypted data channel and control channel packets to be sent OVPN_TX_BUFFER_POOL TxBufferPool; + // queue to store pending userspace notifications + NotifyQueue PendingNotificationsQueue; + OVPN_STATS Stats; BCRYPT_ALG_HANDLE AesAlgHandle; diff --git a/notifyqueue.cpp b/notifyqueue.cpp new file mode 100644 index 0000000..95d9d93 --- /dev/null +++ b/notifyqueue.cpp @@ -0,0 +1,86 @@ +/* + * ovpn-dco-win OpenVPN protocol accelerator for Windows + * + * Copyright (C) 2024- OpenVPN Inc + * + * Author: Lev Stipakov + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "notifyqueue.h" + +#include "trace.h" + +VOID +NotifyQueue::Init() +{ + LOG_ENTER(); + + InitializeListHead(&Head); + KeInitializeSpinLock(&Lock); + + LOG_EXIT(); +} + +NTSTATUS +NotifyQueue::AddEvent(OVPN_NOTIFY_CMD cmd, int peerId, OVPN_DEL_PEER_REASON delPeerReason) +{ + NotifyEvent* event = (NotifyEvent*)ExAllocatePool2(POOL_FLAG_NON_PAGED, sizeof(NotifyEvent), 'ovpn'); + if (!event) { + return STATUS_MEMORY_NOT_ALLOCATED; + } + + RtlZeroMemory(event, sizeof(NotifyEvent)); + + event->Cmd = cmd; + event->PeerId = peerId; + event->DelPeerReason = delPeerReason; + + ExInterlockedInsertTailList(&Head, &event->ListEntry, &Lock); + + return STATUS_SUCCESS; +} + +NotifyEvent* +NotifyQueue::GetEvent() +{ + PLIST_ENTRY entry = ExInterlockedRemoveHeadList(&Head, &Lock); + if (entry == nullptr) { + return nullptr; + } + + return CONTAINING_RECORD(entry, NotifyEvent, ListEntry); +} + +VOID +NotifyQueue::FreeEvent(NotifyEvent* event) +{ + if (event != nullptr) { + ExFreePoolWithTag(event, 'ovpn'); + } +} + +VOID +NotifyQueue::FlushEvents() +{ + LOG_ENTER(); + + NotifyEvent* event = nullptr; + while ((event = GetEvent()) != nullptr) { + FreeEvent(event); + } + + LOG_EXIT(); +} diff --git a/notifyqueue.h b/notifyqueue.h new file mode 100644 index 0000000..f1b0733 --- /dev/null +++ b/notifyqueue.h @@ -0,0 +1,53 @@ +/* + * ovpn-dco-win OpenVPN protocol accelerator for Windows + * + * Copyright (C) 2024- OpenVPN Inc + * + * Author: Lev Stipakov + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once + +#include + +#include "uapi/ovpn-dco.h" + +struct NotifyEvent { + LIST_ENTRY ListEntry; + + OVPN_NOTIFY_CMD Cmd; + int PeerId; + OVPN_DEL_PEER_REASON DelPeerReason; +}; + +class NotifyQueue { +private: + LIST_ENTRY Head; + KSPIN_LOCK Lock; + +public: + NotifyQueue() = delete; + + VOID Init(); + + NTSTATUS AddEvent(OVPN_NOTIFY_CMD cmd, int peerId, OVPN_DEL_PEER_REASON delPeerReason=OVPN_DEL_PEER_REASON_EXPIRED); + + NotifyEvent* GetEvent(); + + VOID FreeEvent(NotifyEvent* event); + + VOID FlushEvents(); +}; diff --git a/ovpn-dco-win.vcxproj b/ovpn-dco-win.vcxproj index ab2e3e3..42c635a 100644 --- a/ovpn-dco-win.vcxproj +++ b/ovpn-dco-win.vcxproj @@ -73,6 +73,7 @@ + @@ -88,6 +89,7 @@ + diff --git a/timer.cpp b/timer.cpp index d67e158..2d93495 100644 --- a/timer.cpp +++ b/timer.cpp @@ -107,18 +107,46 @@ static BOOLEAN OvpnTimerRecv(WDFTIMER timer) { POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + auto peerId = timerCtx->Peer->PeerId; + LOG_INFO("Keepalive timeout", TraceLoggingValue(peerId, "peer-id")); + WDFREQUEST request; - NTSTATUS status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); - if (!NT_SUCCESS(status)) { - LOG_WARN("No pending request for keepalive timeout notification"); - return FALSE; - } - else { - LOG_INFO("Notify userspace about keepalive timeout"); + NTSTATUS status = STATUS_SUCCESS; + + if (device->Mode == OVPN_MODE_P2P) { + status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); + if (!NT_SUCCESS(status)) { + LOG_WARN("No pending request for keepalive timeout notification"); + return FALSE; + } + ULONG_PTR bytesSent = 0; WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent); - return TRUE; } + else { + status = WdfIoQueueRetrieveNextRequest(device->PendingNotificationRequestsQueue, &request); + if (!NT_SUCCESS(status)) { + LOG_WARN("Adding keepalive timeout notification to the queue"); + return NT_SUCCESS(device->PendingNotificationsQueue.AddEvent(OVPN_NOTIFY_DEL_PEER, peerId, OVPN_DEL_PEER_REASON_EXPIRED)); + } + else { + OVPN_NOTIFY_EVENT *evt; + ULONG_PTR bytesSent = 0; + LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_NOTIFY_EVENT), (PVOID*)&evt, nullptr)); + if (NT_SUCCESS(status)) { + evt->Cmd = OVPN_NOTIFY_DEL_PEER; + evt->PeerId = peerId; + evt->DelPeerReason = OVPN_DEL_PEER_REASON_EXPIRED; + bytesSent = sizeof(OVPN_NOTIFY_EVENT); + } + WdfRequestCompleteWithInformation(request, status, bytesSent); + + // TODO: remove peer + } + } + + return NT_SUCCESS(status); } _Use_decl_annotations_ diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index be4770b..5a8201c 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -154,6 +154,25 @@ typedef struct _OVPN_MP_START_VPN { } ListenAddress; } OVPN_MP_START_VPN, * POVPN_MP_START_VPN; +typedef enum { + OVPN_NOTIFY_DEL_PEER, + OVPN_NOTIFY_ROTATE_KEY +} OVPN_NOTIFY_CMD; + +typedef enum { + OVPN_DEL_PEER_REASON_TEARDOWN, + OVPN_DEL_PEER_REASON_USERSPACE, + OVPN_DEL_PEER_REASON_EXPIRED, + OVPN_DEL_PEER_REASON_TRANSPORT_ERROR, + OVPN_DEL_PEER_REASON_TRANSPORT_DISCONNECT +} OVPN_DEL_PEER_REASON; + +typedef struct _OVPN_NOTIFY_EVENT { + OVPN_NOTIFY_CMD Cmd; + int PeerId; + OVPN_DEL_PEER_REASON DelPeerReason; +} OVPN_NOTIFY_EVENT, * POVPN_NOTIFY_EVENT; + #define OVPN_IOCTL_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 1, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_GET_STATS CTL_CODE(FILE_DEVICE_UNKNOWN, 2, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_NEW_KEY CTL_CODE(FILE_DEVICE_UNKNOWN, 3, METHOD_BUFFERED, FILE_ANY_ACCESS) @@ -168,3 +187,5 @@ typedef struct _OVPN_MP_START_VPN { #define OVPN_IOCTL_MP_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 11, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_MP_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 12, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_MP_SET_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 13, METHOD_BUFFERED, FILE_ANY_ACCESS) + +#define OVPN_IOCTL_NOTIFY_EVENT CTL_CODE(FILE_DEVICE_UNKNOWN, 14, METHOD_BUFFERED, FILE_ANY_ACCESS) From ad8c81d26f5f72ad577d9d8807c9923c33f0c77e Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 22 Oct 2024 12:20:22 +0300 Subject: [PATCH 25/28] gui: support for notifications ioctl Use overlapped I/O with OVPN_IOCTL_NOTIFY_EVENT ioctl to get notifications from the kernel. Factor out read/write/notification callbacks into own functions for better readability. GitHub: https://github.com/OpenVPN/ovpn-dco-win/issues/87 Signed-off-by: Lev Stipakov --- gui/gui.cpp | 133 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 96 insertions(+), 37 deletions(-) diff --git a/gui/gui.cpp b/gui/gui.cpp index 178ae26..1120574 100644 --- a/gui/gui.cpp +++ b/gui/gui.cpp @@ -60,6 +60,19 @@ std::vector> modeData = { std::vector hModes; +std::vector notifCmds = { + "peer deleted", + "key rotation" +}; + +std::vector delPeerReasons = { + "teardown", + "userspace", + "expired", + "trtansport error", + "transport disconnect" +}; + template void Log(Args... args) { std::wstringstream stream; @@ -80,7 +93,7 @@ void Log(Args... args) { HANDLE hDev; char readBuffer[4096] = {0}; -OVERLAPPED ovRead = {0}, ovWrite = {0}; +OVERLAPPED ovRead = {0}, ovWrite = {0}, ovNotif = {0}; bool StartOverlappedRead() { ZeroMemory(readBuffer, sizeof(readBuffer)); @@ -92,6 +105,75 @@ bool StartOverlappedRead() { return true; } +OVPN_NOTIFY_EVENT notifEvent = {0}; + +bool StartOverlappedNotif() { + if (!DeviceIoControl(hDev, OVPN_IOCTL_NOTIFY_EVENT, NULL, 0, ¬ifEvent, sizeof(notifEvent), NULL, &ovNotif)) { + if (GetLastError() != ERROR_IO_PENDING) { + Log("Failed to start notification read: ", GetLastError()); + return false; + } + } + return true; +} + +bool OnReadCompleted() +{ + DWORD bytesRead; + if (GetOverlappedResult(hDev, &ovRead, &bytesRead, FALSE)) { + if (bytesRead > 0) { + bool mp = SendMessage(hModes[1], BM_GETCHECK, 0, 0) == BST_CHECKED; + + // if we're in server mode, we've received CC message prepended with sockaddr + if (mp) { + SOCKADDR_IN *sa = (SOCKADDR_IN *)readBuffer; + + char ip[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &(sa->sin_addr), ip, sizeof(ip)); + + int port = ntohs(sa->sin_port); + + Log("CC[", ip, ":", port, "]> ", readBuffer + sizeof(*sa)); + } else { + Log("CC[]> ", readBuffer); + } + } + } else { + Log("Overlapped read failed: ", GetLastError()); + } + + return StartOverlappedRead(); +} + +bool OnNotifyCompleted() +{ + DWORD bytesRead; + if (GetOverlappedResult(hDev, &ovNotif, &bytesRead, FALSE)) { + if (bytesRead > 0) { + Log("Notification: ", + "Cmd: ", notifCmds[notifEvent.Cmd], + ", peer-id: ", notifEvent.PeerId, + ", del reason: ", delPeerReasons[notifEvent.DelPeerReason]); + } + } else { + Log("Notif read failed: ", GetLastError()); + } + + return StartOverlappedNotif(); +} + +void OnWriteCompleted() +{ + DWORD bytesWrote; + if (GetOverlappedResult(hDev, &ovWrite, &bytesWrote, FALSE)) { + if (bytesWrote > 0) { + Log("Wrote ", bytesWrote, " bytes"); + } + } else { + Log("Overlapped write failed: ", GetLastError()); + } +} + int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrevInst, LPSTR args, int ncmdshow) { // Create a Window Class @@ -117,52 +199,29 @@ int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrevInst, LPSTR args, int ncmdsho HANDLE hEvWrite = CreateEventW(NULL, FALSE, FALSE, NULL); ovWrite.hEvent = hEvWrite; + HANDLE hEvNotif = CreateEventW(NULL, FALSE, FALSE, NULL); + ovNotif.hEvent = hEvNotif; + StartOverlappedRead(); + StartOverlappedNotif(); + while (true) { - HANDLE events[] = { hEvRead, hEvWrite }; - DWORD waitResult = MsgWaitForMultipleObjects(2, events, FALSE, INFINITE, QS_ALLINPUT); + HANDLE events[] = { hEvRead, hEvWrite, hEvNotif }; + DWORD waitResult = MsgWaitForMultipleObjects(3, events, FALSE, INFINITE, QS_ALLINPUT); - // read completed if (waitResult == WAIT_OBJECT_0) { - DWORD bytesRead; - if (GetOverlappedResult(hDev, &ovRead, &bytesRead, FALSE)) { - if (bytesRead > 0) { - bool mp = SendMessage(hModes[1], BM_GETCHECK, 0, 0) == BST_CHECKED; - - // if we're in server mode, we've received CC message prepended with sockaddr - if (mp) { - SOCKADDR_IN *sa = (SOCKADDR_IN *)readBuffer; - - char ip[INET_ADDRSTRLEN]; - inet_ntop(AF_INET, &(sa->sin_addr), ip, sizeof(ip)); - - int port = ntohs(sa->sin_port); - - Log("CC[", ip, ":", port, "]> ", readBuffer + sizeof(*sa)); - } else { - Log("CC[]> ", readBuffer); - } - } - } else { - Log("Overlapped read failed: ", GetLastError()); - } - - if (!StartOverlappedRead()) { + if (!OnReadCompleted()) { break; } } if (waitResult == WAIT_OBJECT_0 + 1) { - // write completed - DWORD bytesWrote; - if (GetOverlappedResult(hDev, &ovWrite, &bytesWrote, FALSE)) { - if (bytesWrote > 0) { - Log("Wrote ", bytesWrote, " bytes"); - } - } else { - Log("Overlapped write failed: ", GetLastError()); + OnWriteCompleted(); + } if (waitResult == WAIT_OBJECT_0 + 2) { + if (!OnNotifyCompleted()) { + break; } } - else if (waitResult == WAIT_OBJECT_0 + 2) { + else if (waitResult == WAIT_OBJECT_0 + 3) { // window messaging loop MSG msg; while (PeekMessage(&msg, NULL, 0, 0, PM_REMOVE)) { From 3b74322652354becffb0befc95a8cdbcf5c2145a Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 22 Oct 2024 15:18:28 +0300 Subject: [PATCH 26/28] gui: add a button to subscribe for notifications Instead of subscribe for notifications at startup, do it on button click. This is because current driver version doesn't yet support notifications. Signed-off-by: Lev Stipakov --- gui/gui.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gui/gui.cpp b/gui/gui.cpp index 1120574..453737e 100644 --- a/gui/gui.cpp +++ b/gui/gui.cpp @@ -15,6 +15,7 @@ #define VER_DEV_NAME L"\\\\.\\ovpn-dco-ver" #define BTN_SEND_CC 100 +#define BTN_SUBSCRIBE_NOTIF 101 LRESULT CALLBACK WindowProcedure(HWND, UINT, WPARAM, LPARAM); @@ -204,8 +205,6 @@ int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrevInst, LPSTR args, int ncmdsho StartOverlappedRead(); - StartOverlappedNotif(); - while (true) { HANDLE events[] = { hEvRead, hEvWrite, hEvNotif }; DWORD waitResult = MsgWaitForMultipleObjects(3, events, FALSE, INFINITE, QS_ALLINPUT); @@ -616,6 +615,8 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) hCCRemoteAddress = CreateEditBox(hwnd, L"192.168.100.1", 290, 160, 120); hCCRemotePort = CreateEditBox(hwnd, L"1194", 430, 160, 60); + CreatePushButton(hwnd, L"Subscribe notif", (HMENU)BTN_SUBSCRIBE_NOTIF, 640, 160); + CreatePushButton(hwnd, OVPN_IOCTL_MP_NEW_PEER, 10, 210); hMPNewPeerLocalIP = CreateEditBox(hwnd, L"192.168.100.2", 150, 210, 120); hMPNewPeerLocalPort = CreateEditBox(hwnd, L"1194", 290, 210, 60); @@ -686,6 +687,8 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) } else if ((ULONG)wp == BTN_SEND_CC) { SendCC(); + } else if ((ULONG)wp == BTN_SUBSCRIBE_NOTIF) { + StartOverlappedNotif(); } } From c0a1dcc88154ff7c2bc5e123c9f13c55bcbad429 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 22 Oct 2024 16:37:32 +0300 Subject: [PATCH 27/28] multipeer: delete expired peer On keepalive timeout, delete peer context from all tables, unititialize crypto/timer and and free the peer context memory. Github: https://github.com/OpenVPN/ovpn-dco-win/issues/88 Signed-off-by: Lev Stipakov --- Driver.cpp | 46 ++++++++++++++++++++++++++++++++++++++++++++++ Driver.h | 4 ++++ timer.cpp | 9 +++++---- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index 0cd4a29..a10e28c 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -897,3 +897,49 @@ OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr) OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); return ptr ? (OvpnPeerContext*)*ptr : NULL; } + +VOID +OvpnDeletePeerFromTable(RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName) +{ + auto peerId = peer->PeerId; + auto pp = &peer; + + if (RtlDeleteElementGenericTable(table, pp)) { + LOG_INFO("Peer deleted", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); + + if (InterlockedDecrement(&peer->RefCounter) == 0) { + OvpnPeerCtxFree(peer); + LOG_INFO("Peer freed", TraceLoggingValue(peerId, "peer-id")); + } + } + else { + LOG_INFO("Peer not found", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); + } +} + +NTSTATUS +OvpnDeletePeer(POVPN_DEVICE device, INT32 peerId) +{ + NTSTATUS status = STATUS_SUCCESS; + + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + + LOG_INFO("Deleting peer", TraceLoggingValue(peerId, "peer-id")); + + // get peer from main table + OvpnPeerContext* peerCtx = OvpnFindPeer(device, peerId); + if (peerCtx == NULL) { + status = STATUS_NOT_FOUND; + LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peer-id")); + } + else { + OvpnDeletePeerFromTable(&device->PeersByVpn4, peerCtx, "vpn4"); + OvpnDeletePeerFromTable(&device->PeersByVpn6, peerCtx, "vpn6"); + + OvpnDeletePeerFromTable(&device->Peers, peerCtx, "peers"); + } + + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + return status; +} diff --git a/Driver.h b/Driver.h index f9f2dc1..3ac667e 100644 --- a/Driver.h +++ b/Driver.h @@ -150,3 +150,7 @@ OvpnFindPeerVPN4(_In_ POVPN_DEVICE device, _In_ IN_ADDR addr); _Must_inspect_result_ OvpnPeerContext* OvpnFindPeerVPN6(_In_ POVPN_DEVICE device, _In_ IN6_ADDR addr); + +_Must_inspect_result_ +NTSTATUS +OvpnDeletePeer(_In_ POVPN_DEVICE device, INT32 peerId); diff --git a/timer.cpp b/timer.cpp index 2d93495..29034a0 100644 --- a/timer.cpp +++ b/timer.cpp @@ -117,7 +117,7 @@ static BOOLEAN OvpnTimerRecv(WDFTIMER timer) if (device->Mode == OVPN_MODE_P2P) { status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); if (!NT_SUCCESS(status)) { - LOG_WARN("No pending request for keepalive timeout notification"); + LOG_INFO("No pending request for keepalive timeout notification"); return FALSE; } @@ -125,12 +125,15 @@ static BOOLEAN OvpnTimerRecv(WDFTIMER timer) WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent); } else { + (VOID)OvpnDeletePeer(device, peerId); + status = WdfIoQueueRetrieveNextRequest(device->PendingNotificationRequestsQueue, &request); if (!NT_SUCCESS(status)) { - LOG_WARN("Adding keepalive timeout notification to the queue"); + LOG_INFO("Adding keepalive timeout notification to the queue"); return NT_SUCCESS(device->PendingNotificationsQueue.AddEvent(OVPN_NOTIFY_DEL_PEER, peerId, OVPN_DEL_PEER_REASON_EXPIRED)); } else { + LOG_INFO("Notify userspace about expired peer"); OVPN_NOTIFY_EVENT *evt; ULONG_PTR bytesSent = 0; LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_NOTIFY_EVENT), (PVOID*)&evt, nullptr)); @@ -141,8 +144,6 @@ static BOOLEAN OvpnTimerRecv(WDFTIMER timer) bytesSent = sizeof(OVPN_NOTIFY_EVENT); } WdfRequestCompleteWithInformation(request, status, bytesSent); - - // TODO: remove peer } } From 66bda5095820064328372b625800b4d72106b8d5 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Mon, 28 Oct 2024 15:43:47 +0200 Subject: [PATCH 28/28] Improve locking and peer lifetime management The current locking mechanism (global read-write spinlock) works fine for client (since almost all operations are read-only, such as encrypt/decrypt) but doesn't scale well for multipeer, since amount of state-modifying operations (add/remove peer, modify keys) increases a lot. Instead of a single lock: - Use a global spinlock when accessing peers tables (by peer-id, by vpn4 and vpn6 addresses). Acquire for shared access for lookup and for exclusive access when modifying the tables. After lookup, increment peer reference counter and return a pointer to a peer context. - When doing read-only operation on a peer, acquire peer-specific spinlock for shared access. For state-modifying operation, acquire peer-specific spinlock for exclusive access. - After finishing operations on a peer, decrement peer reference counter. If after decrement peer reference counter is zero, free peer context. This approach reduces contention and ensures that peer context won't be deleted during peer-specific operation. GitHub: https://github.com/OpenVPN/ovpn-dco-win/issues/89 Signed-off-by: Lev Stipakov --- Driver.cpp | 184 +++++++++++++++++++++++++++++--------------------- Driver.h | 21 +----- peer.cpp | 188 ++++++++++++++++++++++++++++------------------------ peer.h | 5 ++ socket.cpp | 89 ++++++++++++++----------- timer.cpp | 14 ++-- txqueue.cpp | 86 ++++++++++++++++++------ 7 files changed, 342 insertions(+), 245 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index a10e28c..8540e5d 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -329,12 +329,12 @@ OvpnStopVPN(_In_ POVPN_DEVICE device) { LOG_ENTER(); + OvpnFlushPeers(device); + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); PWSK_SOCKET socket = device->Socket.Socket; device->Socket.Socket = NULL; - OvpnFlushPeers(device); - device->Mode = OVPN_MODE_P2P; RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); @@ -481,12 +481,9 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe return; } - KIRQL kirql = 0; switch ((long)ioControlCode) { case OVPN_IOCTL_GET_STATS: - kirql = ExAcquireSpinLockShared(&device->SpinLock); status = OvpnPeerGetStats(device, request, &bytesReturned); - ExReleaseSpinLockShared(&device->SpinLock, kirql); break; case OVPN_IOCTL_NEW_PEER: @@ -502,27 +499,19 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe break; case OVPN_IOCTL_NEW_KEY: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerNewKey(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_NEW_KEY_V2: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerNewKeyV2(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_SWAP_KEYS: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerSwapKeys(device); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_SET_PEER: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerSet(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_GET_VERSION: @@ -530,9 +519,7 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe break; case OVPN_IOCTL_SET_MODE: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnSetMode(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_MP_START_VPN: @@ -544,9 +531,7 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe break; case OVPN_IOCTL_MP_SET_PEER: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnMPPeerSet(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_NOTIFY_EVENT: @@ -778,12 +763,15 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { return status; } +_Use_decl_annotations_ NTSTATUS -OvpnAddPeerToTable(_In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer) +OvpnAddPeerToTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* table, OvpnPeerContext* peer) { NTSTATUS status; BOOLEAN newElem; + auto irql = ExAcquireSpinLockExclusive(&device->SpinLock); + RtlInsertElementGenericTable(table, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); if (newElem) { @@ -794,116 +782,154 @@ OvpnAddPeerToTable(_In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer) LOG_ERROR("Unable to add new peer"); status = STATUS_NO_MEMORY; } - return status; -} - -_Use_decl_annotations_ -NTSTATUS -OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) -{ - return OvpnAddPeerToTable(&device->Peers, peer); -} - -_Use_decl_annotations_ -NTSTATUS -OvpnAddPeerVpn4(POVPN_DEVICE device, OvpnPeerContext* peer) -{ - return OvpnAddPeerToTable(&device->PeersByVpn4, peer); -} + ExReleaseSpinLockExclusive(&device->SpinLock, irql); -_Use_decl_annotations_ -NTSTATUS -OvpnAddPeerVpn6(POVPN_DEVICE device, OvpnPeerContext* peer) -{ - return OvpnAddPeerToTable(&device->PeersByVpn6, peer); + return status; } _Use_decl_annotations_ VOID OvpnFlushPeers(POVPN_DEVICE device) { - OvpnCleanupPeerTable(&device->PeersByVpn6); - OvpnCleanupPeerTable(&device->PeersByVpn4); - OvpnCleanupPeerTable(&device->Peers); + OvpnCleanupPeerTable(device, &device->PeersByVpn6); + OvpnCleanupPeerTable(device, &device->PeersByVpn4); + OvpnCleanupPeerTable(device, &device->Peers); } _Use_decl_annotations_ VOID -OvpnCleanupPeerTable(RTL_GENERIC_TABLE* peers) +OvpnCleanupPeerTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* peers) { + auto irql = ExAcquireSpinLockExclusive(&device->SpinLock); + while (!RtlIsGenericTableEmpty(peers)) { PVOID ptr = RtlGetElementGenericTable(peers, 0); OvpnPeerContext* peer = *(OvpnPeerContext**)ptr; RtlDeleteElementGenericTable(peers, ptr); - if (InterlockedDecrement(&peer->RefCounter) == 0) { - OvpnPeerCtxFree(peer); - } + OvpnPeerCtxRelease(peer); } + + ExReleaseSpinLockExclusive(&device->SpinLock, irql); } _Use_decl_annotations_ OvpnPeerContext* -OvpnGetFirstPeer(RTL_GENERIC_TABLE* peers) +OvpnGetFirstPeer(POVPN_DEVICE device) { - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(peers, 0); - return ptr ? (OvpnPeerContext*)*ptr : NULL; + auto irql = ExAcquireSpinLockShared(&device->SpinLock); + + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + OvpnPeerContext* peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + + if (peer != nullptr) { + InterlockedIncrement(&peer->RefCounter); + } + + ExReleaseSpinLockShared(&device->SpinLock, irql); + + return peer; } _Use_decl_annotations_ OvpnPeerContext* OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId) { + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + if (device->Mode == OVPN_MODE_P2P) { - return OvpnGetFirstPeer(&device->Peers); + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.PeerId = PeerId; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp); + } + + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + + if (peer) { + InterlockedIncrement(&peer->RefCounter); } - OvpnPeerContext p {}; - p.PeerId = PeerId; + ExReleaseSpinLockShared(&device->SpinLock, kirql); - auto* pp = &p; - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp); - return ptr ? (OvpnPeerContext*)*ptr : NULL; + return peer; } _Use_decl_annotations_ OvpnPeerContext* OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr) { + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + if (device->Mode == OVPN_MODE_P2P) { - return OvpnGetFirstPeer(&device->Peers); + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.VpnAddrs.IPv4 = addr; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp); } - OvpnPeerContext p{}; - p.VpnAddrs.IPv4 = addr; + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + if (peer) { + InterlockedIncrement(&peer->RefCounter); + } + + ExReleaseSpinLockShared(&device->SpinLock, kirql); - auto* pp = &p; - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp); - return ptr ? (OvpnPeerContext*)*ptr : NULL; + return peer; } _Use_decl_annotations_ OvpnPeerContext* OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr) { + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + if (device->Mode == OVPN_MODE_P2P) { - return OvpnGetFirstPeer(&device->Peers); + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.VpnAddrs.IPv6 = addr; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); + } + + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + if (peer) { + InterlockedIncrement(&peer->RefCounter); } - OvpnPeerContext p{}; - p.VpnAddrs.IPv6 = addr; + ExReleaseSpinLockShared(&device->SpinLock, kirql); - auto* pp = &p; - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); - return ptr ? (OvpnPeerContext*)*ptr : NULL; + return peer; } VOID -OvpnDeletePeerFromTable(RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName) +OvpnDeletePeerFromTable(POVPN_DEVICE device, RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName) { auto peerId = peer->PeerId; auto pp = &peer; + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + if (RtlDeleteElementGenericTable(table, pp)) { LOG_INFO("Peer deleted", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); @@ -915,31 +941,35 @@ OvpnDeletePeerFromTable(RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* t else { LOG_INFO("Peer not found", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); } + + ExReleaseSpinLockShared(&device->SpinLock, kirql); } +_Use_decl_annotations_ NTSTATUS OvpnDeletePeer(POVPN_DEVICE device, INT32 peerId) { NTSTATUS status = STATUS_SUCCESS; - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - LOG_INFO("Deleting peer", TraceLoggingValue(peerId, "peer-id")); // get peer from main table - OvpnPeerContext* peerCtx = OvpnFindPeer(device, peerId); - if (peerCtx == NULL) { + OvpnPeerContext* peer = OvpnFindPeer(device, peerId); + if (peer == NULL) { status = STATUS_NOT_FOUND; LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peer-id")); } else { - OvpnDeletePeerFromTable(&device->PeersByVpn4, peerCtx, "vpn4"); - OvpnDeletePeerFromTable(&device->PeersByVpn6, peerCtx, "vpn6"); + auto irql = ExAcquireSpinLockShared(&peer->SpinLock); - OvpnDeletePeerFromTable(&device->Peers, peerCtx, "peers"); - } + OvpnDeletePeerFromTable(device, &device->PeersByVpn4, peer, "vpn4"); + OvpnDeletePeerFromTable(device, &device->PeersByVpn6, peer, "vpn6"); + OvpnDeletePeerFromTable(device, &device->Peers, peer, "peers"); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + ExReleaseSpinLockShared(&peer->SpinLock, irql); + + OvpnPeerCtxRelease(peer); + } return status; } diff --git a/Driver.h b/Driver.h index 3ac667e..881cb7e 100644 --- a/Driver.h +++ b/Driver.h @@ -108,36 +108,21 @@ typedef OVPN_DEVICE * POVPN_DEVICE; WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_DEVICE, OvpnGetDeviceContext) -static inline -BOOLEAN -OvpnHasPeers(_In_ POVPN_DEVICE device) -{ - return !RtlIsGenericTableEmpty(&device->Peers); -} - struct OvpnPeerContext; _Must_inspect_result_ NTSTATUS -OvpnAddPeer(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); - -_Must_inspect_result_ -NTSTATUS -OvpnAddPeerVpn4(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); - -_Must_inspect_result_ -NTSTATUS -OvpnAddPeerVpn6(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); +OvpnAddPeerToTable(POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer); VOID OvpnFlushPeers(_In_ POVPN_DEVICE device); VOID -OvpnCleanupPeerTable(_In_ RTL_GENERIC_TABLE*); +OvpnCleanupPeerTable(_In_ POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE*); _Must_inspect_result_ OvpnPeerContext* -OvpnGetFirstPeer(_In_ RTL_GENERIC_TABLE*); +OvpnGetFirstPeer(_In_ POVPN_DEVICE device); _Must_inspect_result_ OvpnPeerContext* diff --git a/peer.cpp b/peer.cpp index 6ba9a8c..fddcb7c 100644 --- a/peer.cpp +++ b/peer.cpp @@ -35,16 +35,32 @@ OvpnPeerCtxAlloc() if (peer != NULL) { RtlZeroMemory(peer, sizeof(OvpnPeerContext)); } + InterlockedIncrement(&peer->RefCounter); return peer; } +_Use_decl_annotations_ +VOID +OvpnPeerCtxRelease(OvpnPeerContext* peer) +{ + if (InterlockedDecrement(&peer->RefCounter) == 0) { + auto peerId = peer->PeerId; + OvpnPeerCtxFree(peer); + LOG_INFO("Peer freed", TraceLoggingValue(peerId, "peer-id")); + } +} + _Use_decl_annotations_ VOID OvpnPeerCtxFree(OvpnPeerContext* peer) { + auto irql = ExAcquireSpinLockExclusive(&peer->SpinLock); + OvpnCryptoUninit(&peer->CryptoContext); OvpnTimerDestroy(&peer->Timer); + ExReleaseSpinLockExclusive(&peer->SpinLock, irql); + ExFreePoolWithTag(peer, 'ovpn'); } @@ -142,17 +158,16 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) POVPN_NEW_PEER peer = NULL; NTSTATUS status; - GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_NEW_PEER), (PVOID*)&peer, nullptr)); - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - const BOOLEAN peerExists = OvpnHasPeers(device); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); - if (peerExists) { + auto peerCtx = OvpnGetFirstPeer(device); + if (peerCtx != nullptr) { LOG_WARN("Peer already exists"); status = STATUS_OBJECTID_EXISTS; goto done; } + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_NEW_PEER), (PVOID*)&peer, nullptr)); + if ((peer->Remote.Addr4.sin_family != AF_INET) && (peer->Remote.Addr4.sin_family != AF_INET6)) { status = STATUS_INVALID_DEVICE_REQUEST; @@ -165,7 +180,7 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) BOOLEAN proto_tcp = peer->Proto == OVPN_PROTO_TCP; SIZE_T remoteAddrSize = peer->Remote.Addr4.sin_family == AF_INET ? sizeof(peer->Remote.Addr4) : sizeof(peer->Remote.Addr6); - OvpnPeerContext* peerCtx = OvpnPeerCtxAlloc(); + peerCtx = OvpnPeerCtxAlloc(); if (peerCtx == NULL) { status = STATUS_NO_MEMORY; goto done; @@ -185,33 +200,28 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) (PSOCKADDR)&peer->Remote, remoteAddrSize, device, &socket)); - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnAddPeerToTable(device, &device->Peers, peerCtx)); - LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); - if (status != STATUS_SUCCESS) { - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); - OvpnPeerCtxFree(peerCtx); - LOG_IF_NOT_NT_SUCCESS(OvpnSocketClose(socket)); - } - else { - device->Socket.Socket = socket; - device->Socket.Tcp = proto_tcp; - RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); - RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + device->Socket.Socket = socket; + device->Socket.Tcp = proto_tcp; + RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); + RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); - OvpnPeerZeroStats(&device->Stats); + OvpnPeerZeroStats(&device->Stats); - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); - if (proto_tcp) { - LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue)); - // start async connect - status = OvpnSocketTcpConnect(socket, device, (PSOCKADDR)&peer->Remote); - } + if (proto_tcp) { + LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue)); + // start async connect + status = OvpnSocketTcpConnect(socket, device, (PSOCKADDR)&peer->Remote); } done: + if (peerCtx != nullptr) { + OvpnPeerCtxRelease(peerCtx); + } + LOG_EXIT(); return status; @@ -228,14 +238,13 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status = STATUS_SUCCESS; POVPN_MP_NEW_PEER peer; + OvpnPeerContext* peerCtx = nullptr; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_NEW_PEER), (PVOID*)&peer, nullptr)); // check if we already have a peer with the same peer-id - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - OvpnPeerContext* peerCtx = OvpnFindPeer(device, peer->PeerId); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); - if (peerCtx != NULL) { + peerCtx = OvpnFindPeer(device, peer->PeerId); + if (peerCtx != nullptr) { status = STATUS_OBJECTID_EXISTS; goto done; } @@ -283,29 +292,23 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) peerCtx->PeerId = peer->PeerId; // create peer-specific timer - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); - if (status != STATUS_SUCCESS) { - OvpnPeerCtxFree(peerCtx); - goto done; - } - - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); - if (status == STATUS_SUCCESS) { - if (peer->VpnAddr4.S_un.S_addr != INADDR_ANY) { - LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerVpn4(device, peerCtx)); - } + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); - if (RtlCompareMemory(&peer->VpnAddr6, &ovpn_in6addr_any, sizeof(IN6_ADDR)) != sizeof(IN6_ADDR)) { - LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerVpn6(device, peerCtx)); - } + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnAddPeerToTable(device, &device->Peers, peerCtx)); + + if (peer->VpnAddr4.S_un.S_addr != INADDR_ANY) { + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerToTable(device, &device->PeersByVpn4, peerCtx)); } - else { - OvpnPeerCtxFree(peerCtx); + + if (RtlCompareMemory(&peer->VpnAddr6, &ovpn_in6addr_any, sizeof(IN6_ADDR)) != sizeof(IN6_ADDR)) { + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerToTable(device, &device->PeersByVpn6, peerCtx)); } - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); done: + if (peerCtx != nullptr) { + OvpnPeerCtxRelease(peerCtx); + } + LOG_EXIT(); return status; @@ -313,6 +316,8 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) VOID OvpnPeerSetDoWork(OvpnPeerContext *peer, LONG keepaliveInterval, LONG keepaliveTimeout, LONG mss) { + auto irql = ExAcquireSpinLockExclusive(&peer->SpinLock); + if (mss != -1) { peer->MSS = (UINT16)mss; } @@ -330,6 +335,8 @@ VOID OvpnPeerSetDoWork(OvpnPeerContext *peer, LONG keepaliveInterval, LONG keepa // keepalive recv timer, detects keepalive timeout OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout); } + + ExReleaseSpinLockExclusive(&peer->SpinLock, irql); } _Use_decl_annotations_ @@ -339,9 +346,8 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status = STATUS_SUCCESS; - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - - if (peer == NULL) { + OvpnPeerContext* peer = OvpnGetFirstPeer(device); + if (peer == nullptr) { LOG_ERROR("Peer not added"); status = STATUS_INVALID_DEVICE_REQUEST; goto done; @@ -357,6 +363,10 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) OvpnPeerSetDoWork(peer, set_peer->KeepaliveInterval, set_peer->KeepaliveTimeout, set_peer->MSS); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; } @@ -368,6 +378,8 @@ NTSTATUS OvpnMPPeerSet(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status = STATUS_SUCCESS; + OvpnPeerContext* peer = nullptr; + POVPN_MP_SET_PEER set_peer = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_SET_PEER), (PVOID*)&set_peer, nullptr)); @@ -376,8 +388,8 @@ NTSTATUS OvpnMPPeerSet(POVPN_DEVICE device, WDFREQUEST request) TraceLoggingValue(set_peer->KeepaliveTimeout, "timeout"), TraceLoggingValue(set_peer->MSS, "MSS")); - OvpnPeerContext* peer = OvpnFindPeer(device, set_peer->PeerId); - if (peer == NULL) { + peer = OvpnFindPeer(device, set_peer->PeerId); + if (peer == nullptr) { LOG_ERROR("Peer not found", TraceLoggingValue(set_peer->PeerId, "peer-id")); status = STATUS_INVALID_DEVICE_REQUEST; goto done; @@ -386,6 +398,10 @@ NTSTATUS OvpnMPPeerSet(POVPN_DEVICE device, WDFREQUEST request) OvpnPeerSetDoWork(peer, set_peer->KeepaliveInterval, set_peer->KeepaliveTimeout, set_peer->MSS); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; } @@ -396,13 +412,6 @@ OvpnPeerGetStats(POVPN_DEVICE device, WDFREQUEST request, ULONG_PTR* bytesReturn { NTSTATUS status = STATUS_SUCCESS; - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - if (peer == NULL) { - LOG_ERROR("Peer not added"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } - POVPN_STATS stats = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_STATS), (PVOID*)&stats, NULL)); @@ -433,7 +442,8 @@ OvpnPeerStartVPN(POVPN_DEVICE device) NTSTATUS status = STATUS_SUCCESS; - if (!OvpnHasPeers(device)) { + auto peer = OvpnGetFirstPeer(device); + if (peer == nullptr) { LOG_ERROR("Peer not added"); status = STATUS_INVALID_DEVICE_REQUEST; goto done; @@ -442,6 +452,10 @@ OvpnPeerStartVPN(POVPN_DEVICE device) OvpnAdapterSetLinkState(OvpnGetAdapterContext(device->Adapter), MediaConnectStateConnected); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; @@ -482,20 +496,15 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) POVPN_CRYPTO_DATA cryptoData = NULL; OVPN_CRYPTO_DATA_V2 cryptoDataV2{}; - - if (!OvpnHasPeers(device)) { - LOG_ERROR("Peer not added"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } + OvpnPeerContext* peer = nullptr; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA), (PVOID*)&cryptoData, nullptr)); BCRYPT_ALG_HANDLE algHandle = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoData->CipherAlg, algHandle)); - OvpnPeerContext* peer = OvpnFindPeer(device, cryptoData->PeerId); - if (peer == NULL) { + peer = OvpnFindPeer(device, cryptoData->PeerId); + if (peer == nullptr) { status = STATUS_OBJECTID_NOT_FOUND; goto done; } @@ -504,6 +513,10 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle)); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; @@ -516,29 +529,30 @@ OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request) LOG_ENTER(); NTSTATUS status = STATUS_SUCCESS; + OvpnPeerContext* peer = nullptr; POVPN_CRYPTO_DATA_V2 cryptoDataV2 = NULL; - - if (!OvpnHasPeers(device)) { - LOG_ERROR("Peer not added"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } - GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA_V2), (PVOID*)&cryptoDataV2, nullptr)); BCRYPT_ALG_HANDLE algHandle = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoDataV2->V1.CipherAlg, algHandle)); - OvpnPeerContext* peer = OvpnFindPeer(device, cryptoDataV2->V1.PeerId); - if (peer == NULL) { + peer = OvpnFindPeer(device, cryptoDataV2->V1.PeerId); + + if (peer == nullptr) { status = STATUS_OBJECTID_NOT_FOUND; goto done; } - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle)); + KIRQL irql = ExAcquireSpinLockExclusive(&peer->SpinLock); + LOG_IF_NOT_NT_SUCCESS(status = OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle)); + ExReleaseSpinLockExclusive(&peer->SpinLock, irql); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; @@ -552,22 +566,22 @@ OvpnPeerSwapKeys(POVPN_DEVICE device) NTSTATUS status = STATUS_SUCCESS; - if (!OvpnHasPeers(device)) { - LOG_ERROR("Peer not added"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } - - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - if (peer == NULL) { + OvpnPeerContext* peer = OvpnGetFirstPeer(device); + if (peer == nullptr) { LOG_ERROR("Peer not found"); status = STATUS_INVALID_DEVICE_REQUEST; goto done; } + KIRQL irql = ExAcquireSpinLockExclusive(&peer->SpinLock); OvpnCryptoSwapKeys(&peer->CryptoContext); + ExReleaseSpinLockExclusive(&peer->SpinLock, irql); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; diff --git a/peer.h b/peer.h index 00a1585..800649a 100644 --- a/peer.h +++ b/peer.h @@ -29,6 +29,8 @@ struct OvpnPeerContext { + EX_SPIN_LOCK SpinLock; + OvpnCryptoContext CryptoContext; INT32 PeerId; @@ -72,6 +74,9 @@ OvpnPeerCtxAlloc(); VOID OvpnPeerCtxFree(_In_ OvpnPeerContext*); +VOID +OvpnPeerCtxRelease(_In_ OvpnPeerContext*); + RTL_GENERIC_ALLOCATE_ROUTINE OvpnPeerAllocateRoutine; RTL_GENERIC_FREE_ROUTINE OvpnPeerFreeRoutine; RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByPeerIdRoutine; diff --git a/socket.cpp b/socket.cpp index 5281207..9c2f515 100644 --- a/socket.cpp +++ b/socket.cpp @@ -95,7 +95,6 @@ OvpnSocketSyncOp(_In_z_ CHAR* opName, OP op, SUCCESS success) } static -_Requires_shared_lock_held_(device->SpinLock) VOID OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR buf, SIZE_T len, _In_opt_ PSOCKADDR remote) { @@ -188,13 +187,12 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR } static -_Requires_shared_lock_held_(device->SpinLock) -VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 peerId, _In_reads_(len) PUCHAR cipherTextBuf, SIZE_T len) +VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 peerId, _In_reads_(len) PUCHAR cipherTextBuf, SIZE_T len, BOOLEAN irqlDispatch) { InterlockedExchangeAddNoFence64(&device->Stats.TransportBytesReceived, len); OvpnPeerContext* peer = OvpnFindPeer(device, peerId); - if (peer == NULL) { + if (peer == nullptr) { LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peerId")); InterlockedIncrementNoFence(&device->Stats.LostInDataPackets); return; @@ -207,9 +205,20 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee if (!NT_SUCCESS(status)) { LOG_ERROR("RxBufferPool exhausted"); InterlockedIncrementNoFence(&device->Stats.LostInDataPackets); + OvpnPeerCtxRelease(peer); return; } + // If we're at dispatch level, we can use a small optimization and use function + // which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock + KIRQL kirql = 0; + if (irqlDispatch) { + ExAcquireSpinLockSharedAtDpcLevel(&peer->SpinLock); + } + else { + kirql = ExAcquireSpinLockShared(&peer->SpinLock); + } + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; if (cryptoContext->Decrypt) { @@ -245,63 +254,60 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee // LOG_WARN("CryptoContext not yet initialized"); } - if (!NT_SUCCESS(status)) { + if (NT_SUCCESS(status)) { + OvpnTimerResetRecv(peer->Timer); + } + else { OvpnRxBufferPoolPut(buffer); - return; } - OvpnTimerResetRecv(peer->Timer); - - // ping packet? - if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { - LOG_INFO("Ping received", TraceLoggingValue(peer->PeerId, "peer-id")); + auto mss = peer->MSS; - // no need to inject ping packet into OS, return buffer to the pool - OvpnRxBufferPoolPut(buffer); + // don't forget to release spinlock + if (irqlDispatch) { + ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock); } else { - if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { - OvpnMssDoIPv4(buffer->Data, buffer->Len, peer->MSS); - } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { - OvpnMssDoIPv6(buffer->Data, buffer->Len, peer->MSS); + ExReleaseSpinLockShared(&peer->SpinLock, kirql); + } + + OvpnPeerCtxRelease(peer); + + if (NT_SUCCESS(status)) { + // ping packet? + if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { + LOG_INFO("Ping received", TraceLoggingValue(peerId, "peer-id")); + + // no need to inject ping packet into OS, return buffer to the pool + OvpnRxBufferPoolPut(buffer); } + else { + if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { + OvpnMssDoIPv4(buffer->Data, buffer->Len, mss); + } + else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { + OvpnMssDoIPv6(buffer->Data, buffer->Len, mss); + } - // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath - OvpnBufferQueueEnqueue(device->DataRxBufferQueue, &buffer->QueueListEntry); + // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath + OvpnBufferQueueEnqueue(device->DataRxBufferQueue, &buffer->QueueListEntry); - OvpnAdapterNotifyRx(device->Adapter); + OvpnAdapterNotifyRx(device->Adapter); + } } } VOID OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLength) PUCHAR buf, SIZE_T packetLength, BOOLEAN irqlDispatch, _In_opt_ PSOCKADDR remoteAddr) { - // If we're at dispatch level, we can use a small optimization and use function - // which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock - KIRQL kirql = 0; - if (irqlDispatch) { - ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock); - } - else { - kirql = ExAcquireSpinLockShared(&device->SpinLock); - } - UCHAR op = RtlUlongByteSwap(*(ULONG*)(buf)) >> 24; if (OvpnCryptoOpcodeExtract(op) == OVPN_OP_DATA_V2) { UINT32 peerId = RtlUlongByteSwap(*(ULONG*)(buf)) & OVPN_PEER_ID_MASK; - OvpnSocketDataPacketReceived(device, op, peerId, buf, packetLength); + OvpnSocketDataPacketReceived(device, op, peerId, buf, packetLength, irqlDispatch); } else { OvpnSocketControlPacketReceived(device, buf, packetLength, remoteAddr); } - - // don't forget to release spinlock - if (irqlDispatch) { - ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock); - } - else { - ExReleaseSpinLockShared(&device->SpinLock, kirql); - } } _Must_inspect_result_ @@ -550,6 +556,11 @@ OvpnSocketInit(WSK_PROVIDER_NPI* wskProviderNpi, WSK_REGISTRATION* wskRegistrati }, [](PIRP) {})); // connect will be done later + + BOOLEAN tcpNoDelay = TRUE; + SIZE_T outputSizeReturned = 0; + GOTO_IF_NOT_NT_SUCCESS(error, status, connectionDispatch->Basic.WskControlSocket(*socket, WskSetOption, TCP_NODELAY, IPPROTO_TCP, + sizeof(tcpNoDelay), &tcpNoDelay, 0, NULL, &outputSizeReturned, NULL)); } else { // bind diff --git a/timer.cpp b/timer.cpp index 29034a0..e1839e5 100644 --- a/timer.cpp +++ b/timer.cpp @@ -72,7 +72,12 @@ static VOID OvpnTimerXmit(WDFTIMER timer) RtlCopyMemory(OvpnBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); OvpnPeerContext* peer = timerCtx->Peer; - KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); + + ExAcquireSpinLockSharedAtDpcLevel(&peer->SpinLock); + + auto peerId = peer->PeerId; + auto addr = peer->TransportAddrs.Remote; + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; if (cryptoContext->Encrypt) { // make space to crypto overhead @@ -88,19 +93,19 @@ static VOID OvpnTimerXmit(WDFTIMER timer) status = STATUS_INVALID_DEVICE_STATE; // LOG_WARN("CryptoContext not initialized"); } + ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock); if (NT_SUCCESS(status)) { // start async send, completion handler will return ciphertext buffer to the pool - SOCKADDR* sa = (SOCKADDR*)&(peer->TransportAddrs.Remote); + SOCKADDR* sa = (SOCKADDR*)&(addr); LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, sa)); if (NT_SUCCESS(status)) { - LOG_INFO("Ping sent", TraceLoggingValue(peer->PeerId, "peer-id")); + LOG_INFO("Ping sent", TraceLoggingValue(peerId, "peer-id")); } } else { OvpnTxBufferPoolPut(buffer); } - ExReleaseSpinLockShared(&device->SpinLock, kiqrl); } static BOOLEAN OvpnTimerRecv(WDFTIMER timer) @@ -168,6 +173,7 @@ static VOID OvpnTimerTick(WDFTIMER timer) KeQuerySystemTime(&now); POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + if ((timerCtx->xmitInterval > 0) && (((now.QuadPart - timerCtx->lastXmit.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->xmitInterval)) { OvpnTimerXmit(timer); diff --git a/txqueue.cpp b/txqueue.cpp index dd1b063..e1bc3a2 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -36,15 +36,57 @@ #include "socket.h" #include "peer.h" +template +static +VOID +OvpnTxCopyRemoteToSockaddr(T& remote, SOCKADDR* sockaddr) { + // Copy the appropriate address based on the family + if (remote.IPv4.sin_family == AF_INET) { + RtlCopyMemory(sockaddr, &remote.IPv4, sizeof(SOCKADDR_IN)); + } + else if (remote.IPv6.sin6_family == AF_INET6) { + RtlCopyMemory(sockaddr, &remote.IPv6, sizeof(SOCKADDR_IN6)); + } +} + +static +BOOLEAN +OvpnTxAreSockaddrEqual(const SOCKADDR* addr1, const SOCKADDR* addr2) { + // First, check if the address families are the same + if (addr1->sa_family != addr2->sa_family) { + return 0; // Not equal if the families are different + } + + if (addr1->sa_family == AF_INET) { + // Compare IPv4 addresses + SOCKADDR_IN* ipv4_1 = (SOCKADDR_IN*)addr1; + SOCKADDR_IN* ipv4_2 = (SOCKADDR_IN*)addr2; + return (ipv4_1->sin_addr.s_addr == ipv4_2->sin_addr.s_addr && + ipv4_1->sin_port == ipv4_2->sin_port); + } + else if (addr1->sa_family == AF_INET6) { + // Compare IPv6 addresses + SOCKADDR_IN6* ipv6_1 = (SOCKADDR_IN6*)addr1; + SOCKADDR_IN6* ipv6_2 = (SOCKADDR_IN6*)addr2; + SIZE_T result = RtlCompareMemory(&ipv6_1->sin6_addr, &ipv6_2->sin6_addr, sizeof(ipv6_1->sin6_addr)); + return (result == sizeof(ipv6_1->sin6_addr) && + ipv6_1->sin6_port == ipv6_2->sin6_port); + } + + // If the address family is neither AF_INET nor AF_INET6, return not equal + return 0; +} + _Must_inspect_result_ -_Requires_shared_lock_held_(device->SpinLock) static NTSTATUS OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET_RING_PACKET_ITERATOR *pi, - _Inout_ OVPN_TX_BUFFER **head, _Inout_ OVPN_TX_BUFFER** tail, _Inout_ SOCKADDR **headSockaddr) + _Inout_ OVPN_TX_BUFFER **head, _Inout_ OVPN_TX_BUFFER** tail, _Inout_ SOCKADDR *headSockaddr) { NET_RING_FRAGMENT_ITERATOR fi = NetPacketIteratorGetFragments(pi); + OvpnPeerContext* peer = NULL; + // get buffer into which we gather plaintext fragments and do in-place encryption OVPN_TX_BUFFER* buffer; NTSTATUS status; @@ -77,23 +119,23 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET NetFragmentIteratorAdvance(&fi); } - OvpnPeerContext* peer = NULL; - if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { auto addr = ((IPV4_HEADER*)buffer->Data)->DestinationAddress; + peer = OvpnFindPeerVPN4(device, addr); - if (peer != NULL) { + if (peer != nullptr) { OvpnMssDoIPv4(buffer->Data, buffer->Len, peer->MSS); } } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { auto addr = ((IPV6_HEADER*)buffer->Data)->DestinationAddress; + peer = OvpnFindPeerVPN6(device, addr); - if (peer != NULL) { + if (peer != nullptr) { OvpnMssDoIPv6(buffer->Data, buffer->Len, peer->MSS); } } - if (peer == NULL) { + if (peer == nullptr) { status = STATUS_ADDRESS_NOT_ASSOCIATED; OvpnTxBufferPoolPut(buffer); goto out; @@ -101,7 +143,11 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET InterlockedExchangeAddNoFence64(&device->Stats.TunBytesSent, buffer->Len); + auto irql = ExAcquireSpinLockShared(&peer->SpinLock); + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + auto remoteAddr = peer->TransportAddrs.Remote; + auto timer = peer->Timer; if (cryptoContext->Encrypt) { auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; @@ -121,6 +167,11 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET status = STATUS_INVALID_DEVICE_STATE; // LOG_WARN("CryptoContext not initialized"); } + ExReleaseSpinLockShared(&peer->SpinLock, irql); + + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } if (NT_SUCCESS(status)) { // start async send, this will return ciphertext buffer to the pool @@ -138,17 +189,16 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET // If this peer is different (head sockaddr != peer sockaddr) to the previous buffer chain peers, // then flush those and restart with a new buffer list. - if ((*head != NULL) && *headSockaddr != (SOCKADDR*)&(peer->TransportAddrs.Remote)) + if ((*head != NULL) && !(OvpnTxAreSockaddrEqual(headSockaddr, (const SOCKADDR*)&remoteAddr))) { - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, *head, *headSockaddr)); + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, *head, headSockaddr)); *head = buffer; *tail = buffer; - *headSockaddr = (SOCKADDR*)&(peer->TransportAddrs.Remote); - } - else { + OvpnTxCopyRemoteToSockaddr(remoteAddr, headSockaddr); + } else { if (*head == NULL) { *head = buffer; - *headSockaddr = (SOCKADDR*)&(peer->TransportAddrs.Remote); + OvpnTxCopyRemoteToSockaddr(remoteAddr, headSockaddr); } else { (*tail)->WskBufList.Next = &buffer->WskBufList; @@ -158,7 +208,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET } } - OvpnTimerResetXmit(peer->Timer); + OvpnTimerResetXmit(timer); } else { OvpnTxBufferPoolPut(buffer); @@ -184,11 +234,9 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_DEVICE device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); BOOLEAN packetSent = false; - KIRQL kirql = ExAcquireSpinLockShared(&device->SpinLock); - OVPN_TX_BUFFER* txBufferHead = NULL; OVPN_TX_BUFFER* txBufferTail = NULL; - SOCKADDR* headSockaddr = NULL; + SOCKADDR headSockaddr = {0}; while (NetPacketIteratorHasAny(&pi)) { NET_PACKET* packet = NetPacketIteratorGetPacket(&pi); @@ -213,11 +261,9 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) if (packetSent) { if (!device->Socket.Tcp) { // this will use WskSendMessages to send buffers list which we constructed before - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, headSockaddr)); + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, &headSockaddr)); } } - - ExReleaseSpinLockShared(&device->SpinLock, kirql); } _Use_decl_annotations_