From c3c6b1b0d98403d22fa6e3176078e6e96b45fd80 Mon Sep 17 00:00:00 2001 From: Lev Stipakov Date: Tue, 8 Oct 2024 16:26:53 +0300 Subject: [PATCH] Implement OVPN_IOCTL_MP_NEW_PEER Implement adding peer in MP mode. Also add this to gui client. GitHub: https://github.com/OpenVPN/ovpn-dco-win/issues/85 Signed-off-by: Lev Stipakov --- Driver.cpp | 91 ++++++++++++++++++++++++++++++-- Driver.h | 26 +++++++++ crypto.h | 1 + gui/gui.cpp | 60 +++++++++++++++++++-- peer.cpp | 136 ++++++++++++++++++++++++++++++++++++++++++++++-- peer.h | 27 ++++++++++ socket.cpp | 9 ++-- txqueue.cpp | 4 +- uapi/ovpn-dco.h | 18 +++++++ 9 files changed, 355 insertions(+), 17 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index a7bc4a1..91e1c0c 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -317,6 +317,7 @@ OvpnDeviceCheckMode(OVPN_MODE mode, ULONG code) switch (code) { // those IOCTLs are for MP mode case OVPN_IOCTL_MP_START_VPN: + case OVPN_IOCTL_MP_NEW_PEER: return FALSE; } } @@ -500,6 +501,10 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe status = OvpnMPStartVPN(device, request, &bytesReturned); break; + case OVPN_IOCTL_MP_NEW_PEER: + status = OvpnMPPeerNew(device, request); + break; + default: LOG_WARN("Unknown ", TraceLoggingValue(ioControlCode, "ioControlCode")); status = STATUS_INVALID_DEVICE_REQUEST; @@ -705,8 +710,10 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoInitAlgHandles(&device->AesAlgHandle, &device->ChachaAlgHandle)); - // Initialize peers tree + // Initialize peers tables RtlInitializeGenericTable(&device->Peers, OvpnPeerCompareByPeerIdRoutine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); + RtlInitializeGenericTable(&device->PeersByVpn4, OvpnPeerCompareByVPN4Routine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); + RtlInitializeGenericTable(&device->PeersByVpn6, OvpnPeerCompareByVPN6Routine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); LOG_IF_NOT_NT_SUCCESS(status = OvpnAdapterCreate(device)); @@ -716,17 +723,17 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { return status; } -_Use_decl_annotations_ NTSTATUS -OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) +OvpnAddPeerToTable(_In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer) { NTSTATUS status; BOOLEAN newElem; - RtlInsertElementGenericTable(&device->Peers, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); + RtlInsertElementGenericTable(table, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); if (newElem) { status = STATUS_SUCCESS; + InterlockedIncrement(&peer->RefCounter); } else { LOG_ERROR("Unable to add new peer"); @@ -735,9 +742,33 @@ OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) return status; } + +_Use_decl_annotations_ +NTSTATUS +OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) +{ + return OvpnAddPeerToTable(&device->Peers, peer); +} + +_Use_decl_annotations_ +NTSTATUS +OvpnAddPeerVpn4(POVPN_DEVICE device, OvpnPeerContext* peer) +{ + return OvpnAddPeerToTable(&device->PeersByVpn4, peer); +} + +_Use_decl_annotations_ +NTSTATUS +OvpnAddPeerVpn6(POVPN_DEVICE device, OvpnPeerContext* peer) +{ + return OvpnAddPeerToTable(&device->PeersByVpn6, peer); +} + _Use_decl_annotations_ VOID OvpnFlushPeers(POVPN_DEVICE device) { + OvpnCleanupPeerTable(&device->PeersByVpn6); + OvpnCleanupPeerTable(&device->PeersByVpn4); OvpnCleanupPeerTable(&device->Peers); } @@ -750,7 +781,9 @@ OvpnCleanupPeerTable(RTL_GENERIC_TABLE* peers) OvpnPeerContext* peer = *(OvpnPeerContext**)ptr; RtlDeleteElementGenericTable(peers, ptr); - OvpnPeerCtxFree(peer); + if (InterlockedDecrement(&peer->RefCounter) == 0) { + OvpnPeerCtxFree(peer); + } } } @@ -761,3 +794,51 @@ OvpnGetFirstPeer(RTL_GENERIC_TABLE* peers) OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(peers, 0); return ptr ? (OvpnPeerContext*)*ptr : NULL; } + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId) +{ + if (device->Mode == OVPN_MODE_P2P) { + return OvpnGetFirstPeer(&device->Peers); + } + + OvpnPeerContext p {}; + p.PeerId = PeerId; + + auto* pp = &p; + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp); + return ptr ? (OvpnPeerContext*)*ptr : NULL; +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr) +{ + if (device->Mode == OVPN_MODE_P2P) { + return OvpnGetFirstPeer(&device->Peers); + } + + OvpnPeerContext p{}; + p.VpnAddrs.IPv4 = addr; + + auto* pp = &p; + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp); + return ptr ? (OvpnPeerContext*)*ptr : NULL; +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr) +{ + if (device->Mode == OVPN_MODE_P2P) { + return OvpnGetFirstPeer(&device->Peers); + } + + OvpnPeerContext p{}; + p.VpnAddrs.IPv6 = addr; + + auto* pp = &p; + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); + return ptr ? (OvpnPeerContext*)*ptr : NULL; +} diff --git a/Driver.h b/Driver.h index 0045871..745854c 100644 --- a/Driver.h +++ b/Driver.h @@ -94,6 +94,12 @@ struct OVPN_DEVICE { _Guarded_by_(SpinLock) RTL_GENERIC_TABLE Peers; + _Guarded_by_(SpinLock) + RTL_GENERIC_TABLE PeersByVpn4; + + _Guarded_by_(SpinLock) + RTL_GENERIC_TABLE PeersByVpn6; + OVPN_MODE Mode; }; @@ -114,6 +120,14 @@ _Must_inspect_result_ NTSTATUS OvpnAddPeer(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); +_Must_inspect_result_ +NTSTATUS +OvpnAddPeerVpn4(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); + +_Must_inspect_result_ +NTSTATUS +OvpnAddPeerVpn6(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); + VOID OvpnFlushPeers(_In_ POVPN_DEVICE device); @@ -123,3 +137,15 @@ OvpnCleanupPeerTable(_In_ RTL_GENERIC_TABLE*); _Must_inspect_result_ OvpnPeerContext* OvpnGetFirstPeer(_In_ RTL_GENERIC_TABLE*); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeer(_In_ POVPN_DEVICE device, INT32 PeerId); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeerVPN4(_In_ POVPN_DEVICE device, _In_ IN_ADDR addr); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeerVPN6(_In_ POVPN_DEVICE device, _In_ IN6_ADDR addr); diff --git a/crypto.h b/crypto.h index 3782d42..73d286b 100644 --- a/crypto.h +++ b/crypto.h @@ -36,6 +36,7 @@ #define OVPN_OP_DATA_V2 9 #define OVPN_KEY_ID_MASK 0x07 #define OVPN_OPCODE_SHIFT 3 +#define OVPN_PEER_ID_MASK 0x00FFFFFF struct OvpnCryptoKeySlot { diff --git a/gui/gui.cpp b/gui/gui.cpp index ce23a26..3eeeb75 100644 --- a/gui/gui.cpp +++ b/gui/gui.cpp @@ -20,7 +20,8 @@ LRESULT CALLBACK WindowProcedure(HWND, UINT, WPARAM, LPARAM); HWND hMPListenAddress, hMPListenPort, hP2PLocalAddress, hP2PLocalPort, hP2PRemoteAddress, hP2PRemotePort, - hCCMessage, hCCRemoteAddress, hCCRemotePort; + hCCMessage, hCCRemoteAddress, hCCRemotePort, + hMPNewPeerLocalIP, hMPNewPeerLocalPort, hMPNewPeerRemoteIP, hMPNewPeerRemotePort, hMPNewPeerVPNIP, hMPNewPeerPeerId; HWND hLogArea; std::unordered_map buttons = { @@ -34,7 +35,8 @@ std::unordered_map buttons = { {OVPN_IOCTL_GET_VERSION, L"Get Version"}, {OVPN_IOCTL_NEW_KEY_V2, L"New Key V2"}, {OVPN_IOCTL_SET_MODE, L"Set Mode"}, - {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"} + {OVPN_IOCTL_MP_START_VPN, L"MP Start VPN"}, + {OVPN_IOCTL_MP_NEW_PEER, L"MP New Peer"} }; #define MIN_FUNCTION_CODE 1 @@ -102,7 +104,7 @@ int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrevInst, LPSTR args, int ncmdsho // Create the Window HWND hwnd = CreateWindowW(L"myWindowClass", L"ovpn-dco-win GUI", WS_OVERLAPPEDWINDOW | WS_VISIBLE, - 100, 100, 800, 600, NULL, NULL, NULL, NULL); + 100, 100, 900, 600, NULL, NULL, NULL, NULL); HANDLE hEvRead = CreateEventW(NULL, FALSE, FALSE, NULL); ovRead.hEvent = hEvRead; @@ -324,6 +326,47 @@ void P2PStartVPN() } } +void MPNewPeer() +{ + wchar_t localIP[16], localPort[6]; + wchar_t remoteIP[16], remotePort[6], vpnIP[16]; + wchar_t peerId[6]; + + GetWindowText(hMPNewPeerLocalIP, localIP, 16); + GetWindowText(hMPNewPeerLocalPort, localPort, 6); + GetWindowText(hMPNewPeerRemoteIP, remoteIP, 16); + GetWindowText(hMPNewPeerRemotePort, remotePort, 6); + GetWindowText(hMPNewPeerVPNIP, vpnIP, 16); + GetWindowText(hMPNewPeerPeerId, peerId, 6); + + sockaddr_in localAddr = {}; + localAddr.sin_family = AF_INET; + InetPtonW(AF_INET, localIP, &(localAddr.sin_addr)); + localAddr.sin_port = htons(_wtoi(localPort)); + + sockaddr_in remoteAddr = {}; + remoteAddr.sin_family = AF_INET; + InetPtonW(AF_INET, remoteIP, &(remoteAddr.sin_addr)); + remoteAddr.sin_port = htons(_wtoi(remotePort)); + + in_addr vpnAddress; + InetPtonW(AF_INET, vpnIP, &vpnAddress); + + OVPN_MP_NEW_PEER newPeer = {}; + newPeer.Local.Addr4 = localAddr; + newPeer.Remote.Addr4 = remoteAddr; + newPeer.VpnAddr4 = vpnAddress; + newPeer.PeerId = _wtoi(peerId); + + DWORD bytesReturned; + if (!DeviceIoControl(hDev, OVPN_IOCTL_MP_NEW_PEER, &newPeer, sizeof(newPeer), NULL, 0, &bytesReturned, NULL)) { + Log("DeviceIoControl(OVPN_IOCTL_MP_NEW_PEER) failed with code ", GetLastError()); + } + else { + Log("MP peer added"); + } +} + void CreatePushButton(HWND hWnd, DWORD ioctl, int x, int y) { @@ -412,6 +455,14 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) hCCRemoteAddress = CreateEditBox(hwnd, L"192.168.100.1", 290, 160, 120); hCCRemotePort = CreateEditBox(hwnd, L"1194", 430, 160, 60); + CreatePushButton(hwnd, OVPN_IOCTL_MP_NEW_PEER, 10, 210); + hMPNewPeerLocalIP = CreateEditBox(hwnd, L"192.168.100.2", 150, 210, 120); + hMPNewPeerLocalPort = CreateEditBox(hwnd, L"1194", 290, 210, 60); + hMPNewPeerRemoteIP = CreateEditBox(hwnd, L"192.168.100.1", 400, 210, 120); + hMPNewPeerRemotePort = CreateEditBox(hwnd, L"1194", 540, 210, 60); + hMPNewPeerVPNIP = CreateEditBox(hwnd, L"10.8.0.6", 650, 210, 120); + hMPNewPeerPeerId = CreateEditBox(hwnd, L"1", 790, 210, 60); + SendMessage(hModes[0], BM_SETCHECK, BST_CHECKED, 0); // log area @@ -450,6 +501,9 @@ LRESULT CALLBACK WindowProcedure(HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) case OVPN_IOCTL_START_VPN: P2PStartVPN(); break; + + case OVPN_IOCTL_MP_NEW_PEER: + MPNewPeer(); } } else if ((ULONG)wp == BTN_SEND_CC) { diff --git a/peer.cpp b/peer.cpp index 4f36386..57bccff 100644 --- a/peer.cpp +++ b/peer.cpp @@ -66,13 +66,54 @@ OvpnPeerFreeRoutine(_RTL_GENERIC_TABLE* table, PVOID buffer) ExFreePoolWithTag(buffer, 'ovpn'); } -RTL_GENERIC_COMPARE_RESULTS OvpnPeerCompareByPeerIdRoutine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +RTL_GENERIC_COMPARE_RESULTS +OvpnPeerCompareByPeerIdRoutine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) { UNREFERENCED_PARAMETER(table); - UNREFERENCED_PARAMETER(first); - UNREFERENCED_PARAMETER(second); - return GenericEqual; + OvpnPeerContext* peer1 = *(OvpnPeerContext**)first; + OvpnPeerContext* peer2 = *(OvpnPeerContext**)second; + + if (peer1->PeerId == peer2->PeerId) + return GenericEqual; + else if (peer1->PeerId < peer2->PeerId) + return GenericLessThan; + else + return GenericGreaterThan; +} + +RTL_GENERIC_COMPARE_RESULTS +OvpnPeerCompareByVPN4Routine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +{ + UNREFERENCED_PARAMETER(table); + + OvpnPeerContext* peer1 = *(OvpnPeerContext**)first; + OvpnPeerContext* peer2 = *(OvpnPeerContext**)second; + + int n = memcmp(&peer1->VpnAddrs.IPv4, &peer2->VpnAddrs.IPv4, sizeof(IN_ADDR)); + if (n == 0) + return GenericEqual; + else if (n < 0) + return GenericLessThan; + else + return GenericGreaterThan; +} + +RTL_GENERIC_COMPARE_RESULTS +OvpnPeerCompareByVPN6Routine(_RTL_GENERIC_TABLE* table, PVOID first, PVOID second) +{ + UNREFERENCED_PARAMETER(table); + + OvpnPeerContext* peer1 = *(OvpnPeerContext**)first; + OvpnPeerContext* peer2 = *(OvpnPeerContext**)second; + + int n = memcmp(&peer1->VpnAddrs.IPv6, &peer2->VpnAddrs.IPv6, sizeof(IN6_ADDR)); + if (n == 0) + return GenericEqual; + else if (n < 0) + return GenericLessThan; + else + return GenericGreaterThan; } static @@ -161,6 +202,93 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) return status; } +_Use_decl_annotations_ +NTSTATUS +OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) +{ + LOG_ENTER(); + + const struct in6_addr ovpn_in6addr_any = { { 0 } }; + + NTSTATUS status = STATUS_SUCCESS; + + POVPN_MP_NEW_PEER peer; + + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_NEW_PEER), (PVOID*)&peer, nullptr)); + + // check if we already have a peer with the same peer-id + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + OvpnPeerContext* peerCtx = OvpnFindPeer(device, peer->PeerId); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + if (peerCtx != NULL) { + status = STATUS_OBJECTID_EXISTS; + goto done; + } + + // ensure local/remote address is AF_INET or AF_INET6 + if ((peer->Local.Addr4.sin_family != AF_INET) && (peer->Local.Addr4.sin_family != AF_INET6)) + { + status = STATUS_INVALID_DEVICE_REQUEST; + LOG_ERROR("Unknown address family in peer->Local", TraceLoggingValue(peer->Local.Addr4.sin_family, "AF")); + goto done; + } + if ((peer->Remote.Addr4.sin_family != AF_INET) && (peer->Remote.Addr4.sin_family != AF_INET6)) + { + status = STATUS_INVALID_DEVICE_REQUEST; + LOG_ERROR("Unknown address family in peer->Remote", TraceLoggingValue(peer->Remote.Addr4.sin_family, "AF")); + goto done; + } + + // allocate peer + peerCtx = OvpnPeerCtxAlloc(); + if (peerCtx == NULL) { + status = STATUS_NO_MEMORY; + goto done; + } + + // assign local transport address + if (peer->Local.Addr4.sin_family == AF_INET) { + peerCtx->TransportAddrs.Local.IPv4 = peer->Local.Addr4.sin_addr; + } + else { + peerCtx->TransportAddrs.Local.IPv6 = peer->Local.Addr6.sin6_addr; + } + + // assign remote transport address + if (peer->Remote.Addr4.sin_family == AF_INET) { + peerCtx->TransportAddrs.Remote.IPv4 = peer->Remote.Addr4; + } + else { + peerCtx->TransportAddrs.Remote.IPv6 = peer->Remote.Addr6; + } + + peerCtx->VpnAddrs.IPv4 = peer->VpnAddr4; + peerCtx->VpnAddrs.IPv6 = peer->VpnAddr6; + + peerCtx->PeerId = peer->PeerId; + + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); + if (status == STATUS_SUCCESS) { + if (peer->VpnAddr4.S_un.S_addr != INADDR_ANY) { + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerVpn4(device, peerCtx)); + } + + if (RtlCompareMemory(&peer->VpnAddr6, &ovpn_in6addr_any, sizeof(IN6_ADDR)) != sizeof(IN6_ADDR)) { + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerVpn6(device, peerCtx)); + } + } + else { + OvpnPeerCtxFree(peerCtx); + } + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + +done: + LOG_EXIT(); + + return status; +} + _Use_decl_annotations_ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) { diff --git a/peer.h b/peer.h index 11a80f8..907d44b 100644 --- a/peer.h +++ b/peer.h @@ -41,6 +41,26 @@ struct OvpnPeerContext // 1-sec timer which handles ping intervals and keepalive timeouts WDFTIMER Timer; + + struct { + IN_ADDR IPv4; + IN6_ADDR IPv6; + } VpnAddrs; + + struct { + union { + IN_ADDR IPv4; + IN6_ADDR IPv6; + } Local; + + union { + SOCKADDR_IN IPv4; + SOCKADDR_IN6 IPv6; + } Remote; + + } TransportAddrs; + + LONG RefCounter; }; _Must_inspect_result_ @@ -53,12 +73,19 @@ OvpnPeerCtxFree(_In_ OvpnPeerContext*); RTL_GENERIC_ALLOCATE_ROUTINE OvpnPeerAllocateRoutine; RTL_GENERIC_FREE_ROUTINE OvpnPeerFreeRoutine; RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByPeerIdRoutine; +RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByVPN4Routine; +RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByVPN6Routine; _Must_inspect_result_ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS OvpnPeerNew(_In_ POVPN_DEVICE device, WDFREQUEST request); +_Must_inspect_result_ +_IRQL_requires_(PASSIVE_LEVEL) +NTSTATUS +OvpnMPPeerNew(_In_ POVPN_DEVICE device, WDFREQUEST request); + _Must_inspect_result_ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS diff --git a/socket.cpp b/socket.cpp index 533c0c2..2b8653c 100644 --- a/socket.cpp +++ b/socket.cpp @@ -189,13 +189,13 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR static _Requires_shared_lock_held_(device->SpinLock) -VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_(len) PUCHAR cipherTextBuf, SIZE_T len) +VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 peerId, _In_reads_(len) PUCHAR cipherTextBuf, SIZE_T len) { InterlockedExchangeAddNoFence64(&device->Stats.TransportBytesReceived, len); - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + OvpnPeerContext* peer = OvpnFindPeer(device, peerId); if (peer == NULL) { - LOG_WARN("No peer"); + LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peerId")); InterlockedIncrementNoFence(&device->Stats.LostInDataPackets); return; } @@ -288,7 +288,8 @@ OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLengt UCHAR op = RtlUlongByteSwap(*(ULONG*)(buf)) >> 24; if (OvpnCryptoOpcodeExtract(op) == OVPN_OP_DATA_V2) { - OvpnSocketDataPacketReceived(device, op, buf, packetLength); + UINT32 peerId = RtlUlongByteSwap(*(ULONG*)(buf)) & OVPN_PEER_ID_MASK; + OvpnSocketDataPacketReceived(device, op, peerId, buf, packetLength); } else { OvpnSocketControlPacketReceived(device, buf, packetLength, remoteAddr); diff --git a/txqueue.cpp b/txqueue.cpp index 4f7ba30..94a86f3 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -77,13 +77,15 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET NetFragmentIteratorAdvance(&fi); } + OvpnPeerContext* peer = NULL; if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { OvpnMssDoIPv4(buffer->Data, buffer->Len, device->MSS); + peer = OvpnFindPeerVPN4(device, ((IPV4_HEADER*)buffer->Data)->DestinationAddress); } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { OvpnMssDoIPv6(buffer->Data, buffer->Len, device->MSS); + peer = OvpnFindPeerVPN6(device, ((IPV6_HEADER*)buffer->Data)->DestinationAddress); } - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); if (peer == NULL) { status = STATUS_ADDRESS_NOT_ASSOCIATED; OvpnTxBufferPoolPut(buffer); diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index 095e008..60e4a53 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -47,6 +47,23 @@ typedef struct _OVPN_NEW_PEER { OVPN_PROTO Proto; } OVPN_NEW_PEER, * POVPN_NEW_PEER; +typedef struct _OVPN_MP_NEW_PEER { + union { + SOCKADDR_IN Addr4; + SOCKADDR_IN6 Addr6; + } Local; + + union { + SOCKADDR_IN Addr4; + SOCKADDR_IN6 Addr6; + } Remote; + + IN_ADDR VpnAddr4; + IN6_ADDR VpnAddr6; + + int PeerId; +} OVPN_MP_NEW_PEER, * POVPN_MP_NEW_PEER; + typedef struct _OVPN_STATS { LONG LostInControlPackets; LONG LostOutControlPackets; @@ -142,3 +159,4 @@ typedef struct _OVPN_MP_START_VPN { #define OVPN_IOCTL_SET_MODE CTL_CODE(FILE_DEVICE_UNKNOWN, 10, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_MP_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 11, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define OVPN_IOCTL_MP_NEW_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 12, METHOD_BUFFERED, FILE_ANY_ACCESS)