From 83d2ad7c8564212c6362087ee6f9cba9066da75a Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Mon, 16 Sep 2024 14:28:18 +0300 Subject: [PATCH] Add OVPN_MP_START_VPN ioctl This starts listening on a provided local address. https://github.com/OpenVPN/ovpn-dco-win/issues/82 Co-authored-by: Leon Dang Signed-off-by: Leon Dang Signed-off-by: Lev Stipakov --- .gitignore | 7 +++++ Driver.cpp | 75 ++++++++++++++++++++++++++++++++++++++++++++++--- socket.cpp | 17 +++++++---- socket.h | 2 +- uapi/ovpn-dco.h | 9 ++++++ 5 files changed, 100 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 336c1bf..0c87796 100644 --- a/.gitignore +++ b/.gitignore @@ -53,6 +53,7 @@ Thumbs.db Debug Release x64 +x86 Win32 sdv ARM64 @@ -72,3 +73,9 @@ smvstats.txt *.cab *.vcxproj.user + +**/out +**/build +**/.vscode +**/signed +msm/tmp \ No newline at end of file diff --git a/Driver.cpp b/Driver.cpp index ae5d930..80d6368 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -264,10 +264,8 @@ OvpnSetMode(POVPN_DEVICE device, WDFREQUEST request) static BOOLEAN OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) { - if (mode == OVPN_MODE_MP) - { - switch (code) - { + if (mode == OVPN_MODE_MP) { + switch (code) { // all those IOCTLs are only for P2P mode case OVPN_IOCTL_NEW_PEER: case OVPN_IOCTL_DEL_PEER: @@ -275,6 +273,14 @@ OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) case OVPN_IOCTL_NEW_KEY_V2: case OVPN_IOCTL_SWAP_KEYS: case OVPN_IOCTL_SET_PEER: + case OVPN_IOCTL_START_VPN: + return FALSE; + } + } + else if (mode == OVPN_MODE_P2P) { + switch (code) { + // those IOCTLs are for MP mode + case OVPN_IOCTL_MP_START_VPN: return FALSE; } } @@ -323,6 +329,63 @@ OvpnStopVPN(_In_ POVPN_DEVICE device) return STATUS_SUCCESS; } +_Must_inspect_result_ +_IRQL_requires_(PASSIVE_LEVEL) +_IRQL_requires_same_ +NTSTATUS +OvpnMPStartVPN(POVPN_DEVICE device, WDFREQUEST request, ULONG_PTR* bytesReturned) +{ + NTSTATUS status = STATUS_SUCCESS; + + LOG_ENTER(); + + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + if (device->Socket.Socket != NULL) { + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + status = STATUS_ALREADY_INITIALIZED; + + goto done; + } + else { + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + POVPN_MP_START_VPN addrIn = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_START_VPN), (PVOID*)&addrIn, NULL)); + + PWSK_SOCKET socket = NULL; + POVPN_DRIVER driver = OvpnGetDriverContext(WdfGetDriver()); + + // Bind to the address provided + status = OvpnSocketInit(&driver->WskProviderNpi, &driver->WskRegistration, + addrIn->ListenAddress.Addr4.sin_family, false, + (PSOCKADDR)&addrIn->ListenAddress, NULL, + 0, device, &socket); + if (!NT_SUCCESS(status)) { + LOG_ERROR("Socket create failed", TraceLoggingValue((UINT32)status), + TraceLoggingHexUInt32(*(UINT32*)(&addrIn->ListenAddress.Addr4.sin_addr), "addr")); + goto done; + } + + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + device->Socket.Socket = socket; + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + // we might bind the socket to port 0 and we want to get actual port back to userspace + POVPN_MP_START_VPN addrOut = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_MP_START_VPN), (PVOID*)&addrOut, NULL)); + RtlCopyMemory(addrOut, addrIn, sizeof(OVPN_MP_START_VPN)); + *bytesReturned = sizeof(OVPN_MP_START_VPN); + } + + OvpnAdapterSetLinkState(OvpnGetAdapterContext(device->Adapter), MediaConnectStateConnected); + +done: + LOG_EXIT(); + + return status; +} + EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL OvpnEvtIoDeviceControl; _Use_decl_annotations_ @@ -397,6 +460,10 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; + case OVPN_IOCTL_MP_START_VPN: + status = OvpnMPStartVPN(device, request, &bytesReturned); + break; + default: LOG_WARN("Unknown ", TraceLoggingValue(ioControlCode, "ioControlCode")); status = STATUS_INVALID_DEVICE_REQUEST; diff --git a/socket.cpp b/socket.cpp index 2fcf2d7..70987b9 100644 --- a/socket.cpp +++ b/socket.cpp @@ -522,13 +522,20 @@ OvpnSocketInit(WSK_PROVIDER_NPI* wskProviderNpi, WSK_REGISTRATION* wskRegistrati return datagramDispatch->WskBind(*socket, localAddr, 0, irp); }, [](PIRP) {})); - // set remote - PWSK_PROVIDER_BASIC_DISPATCH basicDispatch = (PWSK_PROVIDER_BASIC_DISPATCH)(*socket)->Dispatch; - - GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnSocketSyncOp("SetRemote", [basicDispatch, socket, remoteAddrSize, remoteAddr](PIRP irp) { - return basicDispatch->WskControlSocket(*socket, WskIoctl, SIO_WSK_SET_REMOTE_ADDRESS, 0, remoteAddrSize, remoteAddr, 0, NULL, NULL, irp); + // get the locally bound address for the socket + GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnSocketSyncOp("GetLocalAddr", [datagramDispatch, socket, localAddr](PIRP irp) { + return datagramDispatch->WskGetLocalAddress(*socket, localAddr, irp); }, [](PIRP) {})); + if (remoteAddr != NULL) { + // set remote + PWSK_PROVIDER_BASIC_DISPATCH basicDispatch = (PWSK_PROVIDER_BASIC_DISPATCH)(*socket)->Dispatch; + + GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnSocketSyncOp("SetRemote", [basicDispatch, socket, remoteAddrSize, remoteAddr](PIRP irp) { + return basicDispatch->WskControlSocket(*socket, WskIoctl, SIO_WSK_SET_REMOTE_ADDRESS, 0, remoteAddrSize, remoteAddr, 0, NULL, NULL, irp); + }, [](PIRP) {})); + } + // enable ReceiveFrom event eventCallbackControl.NpiId = &NPI_WSK_INTERFACE_ID; eventCallbackControl.EventMask = WSK_EVENT_RECEIVE_FROM; diff --git a/socket.h b/socket.h index b6dee5a..f1df149 100644 --- a/socket.h +++ b/socket.h @@ -60,7 +60,7 @@ _Must_inspect_result_ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS OvpnSocketInit(_In_ WSK_PROVIDER_NPI* wskProviderNpi, _In_ WSK_REGISTRATION* wskRegistration, ADDRESS_FAMILY addrFamily, - BOOLEAN tcp, _In_ PSOCKADDR localAddr, _In_ PSOCKADDR remoteAddr, SIZE_T remoteAddrSize, + BOOLEAN tcp, _In_ PSOCKADDR localAddr, _In_opt_ PSOCKADDR remoteAddr, SIZE_T remoteAddrSize, _In_ PVOID deviceContext, _Out_ PWSK_SOCKET* socket); _Must_inspect_result_ diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index 8163b8b..095e008 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -123,6 +123,13 @@ typedef struct _OVPN_SET_MODE { OVPN_MODE Mode; } OVPN_SET_MODE, * POVPN_SET_MODE; +typedef struct _OVPN_MP_START_VPN { + union { + SOCKADDR_IN Addr4; + SOCKADDR_IN6 Addr6; + } ListenAddress; +} OVPN_MP_START_VPN, * POVPN_MP_START_VPN; + #define OVPN_IOCTL_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 1, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_GET_STATS CTL_CODE(FILE_DEVICE_UNKNOWN, 2, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_NEW_KEY CTL_CODE(FILE_DEVICE_UNKNOWN, 3, METHOD_BUFFERED, FILE_ANY_ACCESS) @@ -133,3 +140,5 @@ typedef struct _OVPN_SET_MODE { #define OVPN_IOCTL_GET_VERSION CTL_CODE(FILE_DEVICE_UNKNOWN, 8, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_NEW_KEY_V2 CTL_CODE(FILE_DEVICE_UNKNOWN, 9, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_SET_MODE CTL_CODE(FILE_DEVICE_UNKNOWN, 10, METHOD_BUFFERED, FILE_ANY_ACCESS) + +#define OVPN_IOCTL_MP_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 11, METHOD_BUFFERED, FILE_ANY_ACCESS)