diff --git a/Driver.cpp b/Driver.cpp index a10e28c..8540e5d 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -329,12 +329,12 @@ OvpnStopVPN(_In_ POVPN_DEVICE device) { LOG_ENTER(); + OvpnFlushPeers(device); + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); PWSK_SOCKET socket = device->Socket.Socket; device->Socket.Socket = NULL; - OvpnFlushPeers(device); - device->Mode = OVPN_MODE_P2P; RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); @@ -481,12 +481,9 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe return; } - KIRQL kirql = 0; switch ((long)ioControlCode) { case OVPN_IOCTL_GET_STATS: - kirql = ExAcquireSpinLockShared(&device->SpinLock); status = OvpnPeerGetStats(device, request, &bytesReturned); - ExReleaseSpinLockShared(&device->SpinLock, kirql); break; case OVPN_IOCTL_NEW_PEER: @@ -502,27 +499,19 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe break; case OVPN_IOCTL_NEW_KEY: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerNewKey(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_NEW_KEY_V2: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerNewKeyV2(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_SWAP_KEYS: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerSwapKeys(device); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_SET_PEER: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerSet(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_GET_VERSION: @@ -530,9 +519,7 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe break; case OVPN_IOCTL_SET_MODE: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnSetMode(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_MP_START_VPN: @@ -544,9 +531,7 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe break; case OVPN_IOCTL_MP_SET_PEER: - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnMPPeerSet(device, request); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; case OVPN_IOCTL_NOTIFY_EVENT: @@ -778,12 +763,15 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { return status; } +_Use_decl_annotations_ NTSTATUS -OvpnAddPeerToTable(_In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer) +OvpnAddPeerToTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* table, OvpnPeerContext* peer) { NTSTATUS status; BOOLEAN newElem; + auto irql = ExAcquireSpinLockExclusive(&device->SpinLock); + RtlInsertElementGenericTable(table, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); if (newElem) { @@ -794,116 +782,154 @@ OvpnAddPeerToTable(_In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer) LOG_ERROR("Unable to add new peer"); status = STATUS_NO_MEMORY; } - return status; -} - -_Use_decl_annotations_ -NTSTATUS -OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer) -{ - return OvpnAddPeerToTable(&device->Peers, peer); -} - -_Use_decl_annotations_ -NTSTATUS -OvpnAddPeerVpn4(POVPN_DEVICE device, OvpnPeerContext* peer) -{ - return OvpnAddPeerToTable(&device->PeersByVpn4, peer); -} + ExReleaseSpinLockExclusive(&device->SpinLock, irql); -_Use_decl_annotations_ -NTSTATUS -OvpnAddPeerVpn6(POVPN_DEVICE device, OvpnPeerContext* peer) -{ - return OvpnAddPeerToTable(&device->PeersByVpn6, peer); + return status; } _Use_decl_annotations_ VOID OvpnFlushPeers(POVPN_DEVICE device) { - OvpnCleanupPeerTable(&device->PeersByVpn6); - OvpnCleanupPeerTable(&device->PeersByVpn4); - OvpnCleanupPeerTable(&device->Peers); + OvpnCleanupPeerTable(device, &device->PeersByVpn6); + OvpnCleanupPeerTable(device, &device->PeersByVpn4); + OvpnCleanupPeerTable(device, &device->Peers); } _Use_decl_annotations_ VOID -OvpnCleanupPeerTable(RTL_GENERIC_TABLE* peers) +OvpnCleanupPeerTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* peers) { + auto irql = ExAcquireSpinLockExclusive(&device->SpinLock); + while (!RtlIsGenericTableEmpty(peers)) { PVOID ptr = RtlGetElementGenericTable(peers, 0); OvpnPeerContext* peer = *(OvpnPeerContext**)ptr; RtlDeleteElementGenericTable(peers, ptr); - if (InterlockedDecrement(&peer->RefCounter) == 0) { - OvpnPeerCtxFree(peer); - } + OvpnPeerCtxRelease(peer); } + + ExReleaseSpinLockExclusive(&device->SpinLock, irql); } _Use_decl_annotations_ OvpnPeerContext* -OvpnGetFirstPeer(RTL_GENERIC_TABLE* peers) +OvpnGetFirstPeer(POVPN_DEVICE device) { - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(peers, 0); - return ptr ? (OvpnPeerContext*)*ptr : NULL; + auto irql = ExAcquireSpinLockShared(&device->SpinLock); + + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + OvpnPeerContext* peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + + if (peer != nullptr) { + InterlockedIncrement(&peer->RefCounter); + } + + ExReleaseSpinLockShared(&device->SpinLock, irql); + + return peer; } _Use_decl_annotations_ OvpnPeerContext* OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId) { + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + if (device->Mode == OVPN_MODE_P2P) { - return OvpnGetFirstPeer(&device->Peers); + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.PeerId = PeerId; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp); + } + + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + + if (peer) { + InterlockedIncrement(&peer->RefCounter); } - OvpnPeerContext p {}; - p.PeerId = PeerId; + ExReleaseSpinLockShared(&device->SpinLock, kirql); - auto* pp = &p; - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp); - return ptr ? (OvpnPeerContext*)*ptr : NULL; + return peer; } _Use_decl_annotations_ OvpnPeerContext* OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr) { + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + if (device->Mode == OVPN_MODE_P2P) { - return OvpnGetFirstPeer(&device->Peers); + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.VpnAddrs.IPv4 = addr; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp); } - OvpnPeerContext p{}; - p.VpnAddrs.IPv4 = addr; + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + if (peer) { + InterlockedIncrement(&peer->RefCounter); + } + + ExReleaseSpinLockShared(&device->SpinLock, kirql); - auto* pp = &p; - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp); - return ptr ? (OvpnPeerContext*)*ptr : NULL; + return peer; } _Use_decl_annotations_ OvpnPeerContext* OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr) { + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + if (device->Mode == OVPN_MODE_P2P) { - return OvpnGetFirstPeer(&device->Peers); + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.VpnAddrs.IPv6 = addr; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); + } + + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + if (peer) { + InterlockedIncrement(&peer->RefCounter); } - OvpnPeerContext p{}; - p.VpnAddrs.IPv6 = addr; + ExReleaseSpinLockShared(&device->SpinLock, kirql); - auto* pp = &p; - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); - return ptr ? (OvpnPeerContext*)*ptr : NULL; + return peer; } VOID -OvpnDeletePeerFromTable(RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName) +OvpnDeletePeerFromTable(POVPN_DEVICE device, RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName) { auto peerId = peer->PeerId; auto pp = &peer; + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + if (RtlDeleteElementGenericTable(table, pp)) { LOG_INFO("Peer deleted", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); @@ -915,31 +941,35 @@ OvpnDeletePeerFromTable(RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* t else { LOG_INFO("Peer not found", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); } + + ExReleaseSpinLockShared(&device->SpinLock, kirql); } +_Use_decl_annotations_ NTSTATUS OvpnDeletePeer(POVPN_DEVICE device, INT32 peerId) { NTSTATUS status = STATUS_SUCCESS; - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - LOG_INFO("Deleting peer", TraceLoggingValue(peerId, "peer-id")); // get peer from main table - OvpnPeerContext* peerCtx = OvpnFindPeer(device, peerId); - if (peerCtx == NULL) { + OvpnPeerContext* peer = OvpnFindPeer(device, peerId); + if (peer == NULL) { status = STATUS_NOT_FOUND; LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peer-id")); } else { - OvpnDeletePeerFromTable(&device->PeersByVpn4, peerCtx, "vpn4"); - OvpnDeletePeerFromTable(&device->PeersByVpn6, peerCtx, "vpn6"); + auto irql = ExAcquireSpinLockShared(&peer->SpinLock); - OvpnDeletePeerFromTable(&device->Peers, peerCtx, "peers"); - } + OvpnDeletePeerFromTable(device, &device->PeersByVpn4, peer, "vpn4"); + OvpnDeletePeerFromTable(device, &device->PeersByVpn6, peer, "vpn6"); + OvpnDeletePeerFromTable(device, &device->Peers, peer, "peers"); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + ExReleaseSpinLockShared(&peer->SpinLock, irql); + + OvpnPeerCtxRelease(peer); + } return status; } diff --git a/Driver.h b/Driver.h index 3ac667e..881cb7e 100644 --- a/Driver.h +++ b/Driver.h @@ -108,36 +108,21 @@ typedef OVPN_DEVICE * POVPN_DEVICE; WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_DEVICE, OvpnGetDeviceContext) -static inline -BOOLEAN -OvpnHasPeers(_In_ POVPN_DEVICE device) -{ - return !RtlIsGenericTableEmpty(&device->Peers); -} - struct OvpnPeerContext; _Must_inspect_result_ NTSTATUS -OvpnAddPeer(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); - -_Must_inspect_result_ -NTSTATUS -OvpnAddPeerVpn4(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); - -_Must_inspect_result_ -NTSTATUS -OvpnAddPeerVpn6(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx); +OvpnAddPeerToTable(POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer); VOID OvpnFlushPeers(_In_ POVPN_DEVICE device); VOID -OvpnCleanupPeerTable(_In_ RTL_GENERIC_TABLE*); +OvpnCleanupPeerTable(_In_ POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE*); _Must_inspect_result_ OvpnPeerContext* -OvpnGetFirstPeer(_In_ RTL_GENERIC_TABLE*); +OvpnGetFirstPeer(_In_ POVPN_DEVICE device); _Must_inspect_result_ OvpnPeerContext* diff --git a/peer.cpp b/peer.cpp index 6ba9a8c..fddcb7c 100644 --- a/peer.cpp +++ b/peer.cpp @@ -35,16 +35,32 @@ OvpnPeerCtxAlloc() if (peer != NULL) { RtlZeroMemory(peer, sizeof(OvpnPeerContext)); } + InterlockedIncrement(&peer->RefCounter); return peer; } +_Use_decl_annotations_ +VOID +OvpnPeerCtxRelease(OvpnPeerContext* peer) +{ + if (InterlockedDecrement(&peer->RefCounter) == 0) { + auto peerId = peer->PeerId; + OvpnPeerCtxFree(peer); + LOG_INFO("Peer freed", TraceLoggingValue(peerId, "peer-id")); + } +} + _Use_decl_annotations_ VOID OvpnPeerCtxFree(OvpnPeerContext* peer) { + auto irql = ExAcquireSpinLockExclusive(&peer->SpinLock); + OvpnCryptoUninit(&peer->CryptoContext); OvpnTimerDestroy(&peer->Timer); + ExReleaseSpinLockExclusive(&peer->SpinLock, irql); + ExFreePoolWithTag(peer, 'ovpn'); } @@ -142,17 +158,16 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) POVPN_NEW_PEER peer = NULL; NTSTATUS status; - GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_NEW_PEER), (PVOID*)&peer, nullptr)); - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - const BOOLEAN peerExists = OvpnHasPeers(device); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); - if (peerExists) { + auto peerCtx = OvpnGetFirstPeer(device); + if (peerCtx != nullptr) { LOG_WARN("Peer already exists"); status = STATUS_OBJECTID_EXISTS; goto done; } + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_NEW_PEER), (PVOID*)&peer, nullptr)); + if ((peer->Remote.Addr4.sin_family != AF_INET) && (peer->Remote.Addr4.sin_family != AF_INET6)) { status = STATUS_INVALID_DEVICE_REQUEST; @@ -165,7 +180,7 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) BOOLEAN proto_tcp = peer->Proto == OVPN_PROTO_TCP; SIZE_T remoteAddrSize = peer->Remote.Addr4.sin_family == AF_INET ? sizeof(peer->Remote.Addr4) : sizeof(peer->Remote.Addr6); - OvpnPeerContext* peerCtx = OvpnPeerCtxAlloc(); + peerCtx = OvpnPeerCtxAlloc(); if (peerCtx == NULL) { status = STATUS_NO_MEMORY; goto done; @@ -185,33 +200,28 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) (PSOCKADDR)&peer->Remote, remoteAddrSize, device, &socket)); - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnAddPeerToTable(device, &device->Peers, peerCtx)); - LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); - if (status != STATUS_SUCCESS) { - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); - OvpnPeerCtxFree(peerCtx); - LOG_IF_NOT_NT_SUCCESS(OvpnSocketClose(socket)); - } - else { - device->Socket.Socket = socket; - device->Socket.Tcp = proto_tcp; - RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); - RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + device->Socket.Socket = socket; + device->Socket.Tcp = proto_tcp; + RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState)); + RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); - OvpnPeerZeroStats(&device->Stats); + OvpnPeerZeroStats(&device->Stats); - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); - if (proto_tcp) { - LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue)); - // start async connect - status = OvpnSocketTcpConnect(socket, device, (PSOCKADDR)&peer->Remote); - } + if (proto_tcp) { + LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue)); + // start async connect + status = OvpnSocketTcpConnect(socket, device, (PSOCKADDR)&peer->Remote); } done: + if (peerCtx != nullptr) { + OvpnPeerCtxRelease(peerCtx); + } + LOG_EXIT(); return status; @@ -228,14 +238,13 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status = STATUS_SUCCESS; POVPN_MP_NEW_PEER peer; + OvpnPeerContext* peerCtx = nullptr; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_NEW_PEER), (PVOID*)&peer, nullptr)); // check if we already have a peer with the same peer-id - KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - OvpnPeerContext* peerCtx = OvpnFindPeer(device, peer->PeerId); - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); - if (peerCtx != NULL) { + peerCtx = OvpnFindPeer(device, peer->PeerId); + if (peerCtx != nullptr) { status = STATUS_OBJECTID_EXISTS; goto done; } @@ -283,29 +292,23 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) peerCtx->PeerId = peer->PeerId; // create peer-specific timer - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); - if (status != STATUS_SUCCESS) { - OvpnPeerCtxFree(peerCtx); - goto done; - } - - kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeer(device, peerCtx)); - if (status == STATUS_SUCCESS) { - if (peer->VpnAddr4.S_un.S_addr != INADDR_ANY) { - LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerVpn4(device, peerCtx)); - } + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); - if (RtlCompareMemory(&peer->VpnAddr6, &ovpn_in6addr_any, sizeof(IN6_ADDR)) != sizeof(IN6_ADDR)) { - LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerVpn6(device, peerCtx)); - } + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnAddPeerToTable(device, &device->Peers, peerCtx)); + + if (peer->VpnAddr4.S_un.S_addr != INADDR_ANY) { + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerToTable(device, &device->PeersByVpn4, peerCtx)); } - else { - OvpnPeerCtxFree(peerCtx); + + if (RtlCompareMemory(&peer->VpnAddr6, &ovpn_in6addr_any, sizeof(IN6_ADDR)) != sizeof(IN6_ADDR)) { + LOG_IF_NOT_NT_SUCCESS(status = OvpnAddPeerToTable(device, &device->PeersByVpn6, peerCtx)); } - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); done: + if (peerCtx != nullptr) { + OvpnPeerCtxRelease(peerCtx); + } + LOG_EXIT(); return status; @@ -313,6 +316,8 @@ OvpnMPPeerNew(POVPN_DEVICE device, WDFREQUEST request) VOID OvpnPeerSetDoWork(OvpnPeerContext *peer, LONG keepaliveInterval, LONG keepaliveTimeout, LONG mss) { + auto irql = ExAcquireSpinLockExclusive(&peer->SpinLock); + if (mss != -1) { peer->MSS = (UINT16)mss; } @@ -330,6 +335,8 @@ VOID OvpnPeerSetDoWork(OvpnPeerContext *peer, LONG keepaliveInterval, LONG keepa // keepalive recv timer, detects keepalive timeout OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout); } + + ExReleaseSpinLockExclusive(&peer->SpinLock, irql); } _Use_decl_annotations_ @@ -339,9 +346,8 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status = STATUS_SUCCESS; - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - - if (peer == NULL) { + OvpnPeerContext* peer = OvpnGetFirstPeer(device); + if (peer == nullptr) { LOG_ERROR("Peer not added"); status = STATUS_INVALID_DEVICE_REQUEST; goto done; @@ -357,6 +363,10 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) OvpnPeerSetDoWork(peer, set_peer->KeepaliveInterval, set_peer->KeepaliveTimeout, set_peer->MSS); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; } @@ -368,6 +378,8 @@ NTSTATUS OvpnMPPeerSet(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status = STATUS_SUCCESS; + OvpnPeerContext* peer = nullptr; + POVPN_MP_SET_PEER set_peer = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_MP_SET_PEER), (PVOID*)&set_peer, nullptr)); @@ -376,8 +388,8 @@ NTSTATUS OvpnMPPeerSet(POVPN_DEVICE device, WDFREQUEST request) TraceLoggingValue(set_peer->KeepaliveTimeout, "timeout"), TraceLoggingValue(set_peer->MSS, "MSS")); - OvpnPeerContext* peer = OvpnFindPeer(device, set_peer->PeerId); - if (peer == NULL) { + peer = OvpnFindPeer(device, set_peer->PeerId); + if (peer == nullptr) { LOG_ERROR("Peer not found", TraceLoggingValue(set_peer->PeerId, "peer-id")); status = STATUS_INVALID_DEVICE_REQUEST; goto done; @@ -386,6 +398,10 @@ NTSTATUS OvpnMPPeerSet(POVPN_DEVICE device, WDFREQUEST request) OvpnPeerSetDoWork(peer, set_peer->KeepaliveInterval, set_peer->KeepaliveTimeout, set_peer->MSS); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; } @@ -396,13 +412,6 @@ OvpnPeerGetStats(POVPN_DEVICE device, WDFREQUEST request, ULONG_PTR* bytesReturn { NTSTATUS status = STATUS_SUCCESS; - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - if (peer == NULL) { - LOG_ERROR("Peer not added"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } - POVPN_STATS stats = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_STATS), (PVOID*)&stats, NULL)); @@ -433,7 +442,8 @@ OvpnPeerStartVPN(POVPN_DEVICE device) NTSTATUS status = STATUS_SUCCESS; - if (!OvpnHasPeers(device)) { + auto peer = OvpnGetFirstPeer(device); + if (peer == nullptr) { LOG_ERROR("Peer not added"); status = STATUS_INVALID_DEVICE_REQUEST; goto done; @@ -442,6 +452,10 @@ OvpnPeerStartVPN(POVPN_DEVICE device) OvpnAdapterSetLinkState(OvpnGetAdapterContext(device->Adapter), MediaConnectStateConnected); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; @@ -482,20 +496,15 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) POVPN_CRYPTO_DATA cryptoData = NULL; OVPN_CRYPTO_DATA_V2 cryptoDataV2{}; - - if (!OvpnHasPeers(device)) { - LOG_ERROR("Peer not added"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } + OvpnPeerContext* peer = nullptr; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA), (PVOID*)&cryptoData, nullptr)); BCRYPT_ALG_HANDLE algHandle = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoData->CipherAlg, algHandle)); - OvpnPeerContext* peer = OvpnFindPeer(device, cryptoData->PeerId); - if (peer == NULL) { + peer = OvpnFindPeer(device, cryptoData->PeerId); + if (peer == nullptr) { status = STATUS_OBJECTID_NOT_FOUND; goto done; } @@ -504,6 +513,10 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle)); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; @@ -516,29 +529,30 @@ OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request) LOG_ENTER(); NTSTATUS status = STATUS_SUCCESS; + OvpnPeerContext* peer = nullptr; POVPN_CRYPTO_DATA_V2 cryptoDataV2 = NULL; - - if (!OvpnHasPeers(device)) { - LOG_ERROR("Peer not added"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } - GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA_V2), (PVOID*)&cryptoDataV2, nullptr)); BCRYPT_ALG_HANDLE algHandle = NULL; GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoDataV2->V1.CipherAlg, algHandle)); - OvpnPeerContext* peer = OvpnFindPeer(device, cryptoDataV2->V1.PeerId); - if (peer == NULL) { + peer = OvpnFindPeer(device, cryptoDataV2->V1.PeerId); + + if (peer == nullptr) { status = STATUS_OBJECTID_NOT_FOUND; goto done; } - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle)); + KIRQL irql = ExAcquireSpinLockExclusive(&peer->SpinLock); + LOG_IF_NOT_NT_SUCCESS(status = OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle)); + ExReleaseSpinLockExclusive(&peer->SpinLock, irql); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; @@ -552,22 +566,22 @@ OvpnPeerSwapKeys(POVPN_DEVICE device) NTSTATUS status = STATUS_SUCCESS; - if (!OvpnHasPeers(device)) { - LOG_ERROR("Peer not added"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } - - OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); - if (peer == NULL) { + OvpnPeerContext* peer = OvpnGetFirstPeer(device); + if (peer == nullptr) { LOG_ERROR("Peer not found"); status = STATUS_INVALID_DEVICE_REQUEST; goto done; } + KIRQL irql = ExAcquireSpinLockExclusive(&peer->SpinLock); OvpnCryptoSwapKeys(&peer->CryptoContext); + ExReleaseSpinLockExclusive(&peer->SpinLock, irql); done: + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } + LOG_EXIT(); return status; diff --git a/peer.h b/peer.h index 00a1585..800649a 100644 --- a/peer.h +++ b/peer.h @@ -29,6 +29,8 @@ struct OvpnPeerContext { + EX_SPIN_LOCK SpinLock; + OvpnCryptoContext CryptoContext; INT32 PeerId; @@ -72,6 +74,9 @@ OvpnPeerCtxAlloc(); VOID OvpnPeerCtxFree(_In_ OvpnPeerContext*); +VOID +OvpnPeerCtxRelease(_In_ OvpnPeerContext*); + RTL_GENERIC_ALLOCATE_ROUTINE OvpnPeerAllocateRoutine; RTL_GENERIC_FREE_ROUTINE OvpnPeerFreeRoutine; RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByPeerIdRoutine; diff --git a/socket.cpp b/socket.cpp index 5281207..9c2f515 100644 --- a/socket.cpp +++ b/socket.cpp @@ -95,7 +95,6 @@ OvpnSocketSyncOp(_In_z_ CHAR* opName, OP op, SUCCESS success) } static -_Requires_shared_lock_held_(device->SpinLock) VOID OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR buf, SIZE_T len, _In_opt_ PSOCKADDR remote) { @@ -188,13 +187,12 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR } static -_Requires_shared_lock_held_(device->SpinLock) -VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 peerId, _In_reads_(len) PUCHAR cipherTextBuf, SIZE_T len) +VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 peerId, _In_reads_(len) PUCHAR cipherTextBuf, SIZE_T len, BOOLEAN irqlDispatch) { InterlockedExchangeAddNoFence64(&device->Stats.TransportBytesReceived, len); OvpnPeerContext* peer = OvpnFindPeer(device, peerId); - if (peer == NULL) { + if (peer == nullptr) { LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peerId")); InterlockedIncrementNoFence(&device->Stats.LostInDataPackets); return; @@ -207,9 +205,20 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee if (!NT_SUCCESS(status)) { LOG_ERROR("RxBufferPool exhausted"); InterlockedIncrementNoFence(&device->Stats.LostInDataPackets); + OvpnPeerCtxRelease(peer); return; } + // If we're at dispatch level, we can use a small optimization and use function + // which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock + KIRQL kirql = 0; + if (irqlDispatch) { + ExAcquireSpinLockSharedAtDpcLevel(&peer->SpinLock); + } + else { + kirql = ExAcquireSpinLockShared(&peer->SpinLock); + } + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; if (cryptoContext->Decrypt) { @@ -245,63 +254,60 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee // LOG_WARN("CryptoContext not yet initialized"); } - if (!NT_SUCCESS(status)) { + if (NT_SUCCESS(status)) { + OvpnTimerResetRecv(peer->Timer); + } + else { OvpnRxBufferPoolPut(buffer); - return; } - OvpnTimerResetRecv(peer->Timer); - - // ping packet? - if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { - LOG_INFO("Ping received", TraceLoggingValue(peer->PeerId, "peer-id")); + auto mss = peer->MSS; - // no need to inject ping packet into OS, return buffer to the pool - OvpnRxBufferPoolPut(buffer); + // don't forget to release spinlock + if (irqlDispatch) { + ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock); } else { - if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { - OvpnMssDoIPv4(buffer->Data, buffer->Len, peer->MSS); - } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { - OvpnMssDoIPv6(buffer->Data, buffer->Len, peer->MSS); + ExReleaseSpinLockShared(&peer->SpinLock, kirql); + } + + OvpnPeerCtxRelease(peer); + + if (NT_SUCCESS(status)) { + // ping packet? + if (OvpnTimerIsKeepaliveMessage(buffer->Data, buffer->Len)) { + LOG_INFO("Ping received", TraceLoggingValue(peerId, "peer-id")); + + // no need to inject ping packet into OS, return buffer to the pool + OvpnRxBufferPoolPut(buffer); } + else { + if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { + OvpnMssDoIPv4(buffer->Data, buffer->Len, mss); + } + else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { + OvpnMssDoIPv6(buffer->Data, buffer->Len, mss); + } - // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath - OvpnBufferQueueEnqueue(device->DataRxBufferQueue, &buffer->QueueListEntry); + // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath + OvpnBufferQueueEnqueue(device->DataRxBufferQueue, &buffer->QueueListEntry); - OvpnAdapterNotifyRx(device->Adapter); + OvpnAdapterNotifyRx(device->Adapter); + } } } VOID OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLength) PUCHAR buf, SIZE_T packetLength, BOOLEAN irqlDispatch, _In_opt_ PSOCKADDR remoteAddr) { - // If we're at dispatch level, we can use a small optimization and use function - // which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock - KIRQL kirql = 0; - if (irqlDispatch) { - ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock); - } - else { - kirql = ExAcquireSpinLockShared(&device->SpinLock); - } - UCHAR op = RtlUlongByteSwap(*(ULONG*)(buf)) >> 24; if (OvpnCryptoOpcodeExtract(op) == OVPN_OP_DATA_V2) { UINT32 peerId = RtlUlongByteSwap(*(ULONG*)(buf)) & OVPN_PEER_ID_MASK; - OvpnSocketDataPacketReceived(device, op, peerId, buf, packetLength); + OvpnSocketDataPacketReceived(device, op, peerId, buf, packetLength, irqlDispatch); } else { OvpnSocketControlPacketReceived(device, buf, packetLength, remoteAddr); } - - // don't forget to release spinlock - if (irqlDispatch) { - ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock); - } - else { - ExReleaseSpinLockShared(&device->SpinLock, kirql); - } } _Must_inspect_result_ @@ -550,6 +556,11 @@ OvpnSocketInit(WSK_PROVIDER_NPI* wskProviderNpi, WSK_REGISTRATION* wskRegistrati }, [](PIRP) {})); // connect will be done later + + BOOLEAN tcpNoDelay = TRUE; + SIZE_T outputSizeReturned = 0; + GOTO_IF_NOT_NT_SUCCESS(error, status, connectionDispatch->Basic.WskControlSocket(*socket, WskSetOption, TCP_NODELAY, IPPROTO_TCP, + sizeof(tcpNoDelay), &tcpNoDelay, 0, NULL, &outputSizeReturned, NULL)); } else { // bind diff --git a/timer.cpp b/timer.cpp index 29034a0..e1839e5 100644 --- a/timer.cpp +++ b/timer.cpp @@ -72,7 +72,12 @@ static VOID OvpnTimerXmit(WDFTIMER timer) RtlCopyMemory(OvpnBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); OvpnPeerContext* peer = timerCtx->Peer; - KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); + + ExAcquireSpinLockSharedAtDpcLevel(&peer->SpinLock); + + auto peerId = peer->PeerId; + auto addr = peer->TransportAddrs.Remote; + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; if (cryptoContext->Encrypt) { // make space to crypto overhead @@ -88,19 +93,19 @@ static VOID OvpnTimerXmit(WDFTIMER timer) status = STATUS_INVALID_DEVICE_STATE; // LOG_WARN("CryptoContext not initialized"); } + ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock); if (NT_SUCCESS(status)) { // start async send, completion handler will return ciphertext buffer to the pool - SOCKADDR* sa = (SOCKADDR*)&(peer->TransportAddrs.Remote); + SOCKADDR* sa = (SOCKADDR*)&(addr); LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, sa)); if (NT_SUCCESS(status)) { - LOG_INFO("Ping sent", TraceLoggingValue(peer->PeerId, "peer-id")); + LOG_INFO("Ping sent", TraceLoggingValue(peerId, "peer-id")); } } else { OvpnTxBufferPoolPut(buffer); } - ExReleaseSpinLockShared(&device->SpinLock, kiqrl); } static BOOLEAN OvpnTimerRecv(WDFTIMER timer) @@ -168,6 +173,7 @@ static VOID OvpnTimerTick(WDFTIMER timer) KeQuerySystemTime(&now); POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + if ((timerCtx->xmitInterval > 0) && (((now.QuadPart - timerCtx->lastXmit.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->xmitInterval)) { OvpnTimerXmit(timer); diff --git a/txqueue.cpp b/txqueue.cpp index dd1b063..e1bc3a2 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -36,15 +36,57 @@ #include "socket.h" #include "peer.h" +template +static +VOID +OvpnTxCopyRemoteToSockaddr(T& remote, SOCKADDR* sockaddr) { + // Copy the appropriate address based on the family + if (remote.IPv4.sin_family == AF_INET) { + RtlCopyMemory(sockaddr, &remote.IPv4, sizeof(SOCKADDR_IN)); + } + else if (remote.IPv6.sin6_family == AF_INET6) { + RtlCopyMemory(sockaddr, &remote.IPv6, sizeof(SOCKADDR_IN6)); + } +} + +static +BOOLEAN +OvpnTxAreSockaddrEqual(const SOCKADDR* addr1, const SOCKADDR* addr2) { + // First, check if the address families are the same + if (addr1->sa_family != addr2->sa_family) { + return 0; // Not equal if the families are different + } + + if (addr1->sa_family == AF_INET) { + // Compare IPv4 addresses + SOCKADDR_IN* ipv4_1 = (SOCKADDR_IN*)addr1; + SOCKADDR_IN* ipv4_2 = (SOCKADDR_IN*)addr2; + return (ipv4_1->sin_addr.s_addr == ipv4_2->sin_addr.s_addr && + ipv4_1->sin_port == ipv4_2->sin_port); + } + else if (addr1->sa_family == AF_INET6) { + // Compare IPv6 addresses + SOCKADDR_IN6* ipv6_1 = (SOCKADDR_IN6*)addr1; + SOCKADDR_IN6* ipv6_2 = (SOCKADDR_IN6*)addr2; + SIZE_T result = RtlCompareMemory(&ipv6_1->sin6_addr, &ipv6_2->sin6_addr, sizeof(ipv6_1->sin6_addr)); + return (result == sizeof(ipv6_1->sin6_addr) && + ipv6_1->sin6_port == ipv6_2->sin6_port); + } + + // If the address family is neither AF_INET nor AF_INET6, return not equal + return 0; +} + _Must_inspect_result_ -_Requires_shared_lock_held_(device->SpinLock) static NTSTATUS OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET_RING_PACKET_ITERATOR *pi, - _Inout_ OVPN_TX_BUFFER **head, _Inout_ OVPN_TX_BUFFER** tail, _Inout_ SOCKADDR **headSockaddr) + _Inout_ OVPN_TX_BUFFER **head, _Inout_ OVPN_TX_BUFFER** tail, _Inout_ SOCKADDR *headSockaddr) { NET_RING_FRAGMENT_ITERATOR fi = NetPacketIteratorGetFragments(pi); + OvpnPeerContext* peer = NULL; + // get buffer into which we gather plaintext fragments and do in-place encryption OVPN_TX_BUFFER* buffer; NTSTATUS status; @@ -77,23 +119,23 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET NetFragmentIteratorAdvance(&fi); } - OvpnPeerContext* peer = NULL; - if (OvpnMssIsIPv4(buffer->Data, buffer->Len)) { auto addr = ((IPV4_HEADER*)buffer->Data)->DestinationAddress; + peer = OvpnFindPeerVPN4(device, addr); - if (peer != NULL) { + if (peer != nullptr) { OvpnMssDoIPv4(buffer->Data, buffer->Len, peer->MSS); } } else if (OvpnMssIsIPv6(buffer->Data, buffer->Len)) { auto addr = ((IPV6_HEADER*)buffer->Data)->DestinationAddress; + peer = OvpnFindPeerVPN6(device, addr); - if (peer != NULL) { + if (peer != nullptr) { OvpnMssDoIPv6(buffer->Data, buffer->Len, peer->MSS); } } - if (peer == NULL) { + if (peer == nullptr) { status = STATUS_ADDRESS_NOT_ASSOCIATED; OvpnTxBufferPoolPut(buffer); goto out; @@ -101,7 +143,11 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET InterlockedExchangeAddNoFence64(&device->Stats.TunBytesSent, buffer->Len); + auto irql = ExAcquireSpinLockShared(&peer->SpinLock); + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + auto remoteAddr = peer->TransportAddrs.Remote; + auto timer = peer->Timer; if (cryptoContext->Encrypt) { auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; @@ -121,6 +167,11 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET status = STATUS_INVALID_DEVICE_STATE; // LOG_WARN("CryptoContext not initialized"); } + ExReleaseSpinLockShared(&peer->SpinLock, irql); + + if (peer != nullptr) { + OvpnPeerCtxRelease(peer); + } if (NT_SUCCESS(status)) { // start async send, this will return ciphertext buffer to the pool @@ -138,17 +189,16 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET // If this peer is different (head sockaddr != peer sockaddr) to the previous buffer chain peers, // then flush those and restart with a new buffer list. - if ((*head != NULL) && *headSockaddr != (SOCKADDR*)&(peer->TransportAddrs.Remote)) + if ((*head != NULL) && !(OvpnTxAreSockaddrEqual(headSockaddr, (const SOCKADDR*)&remoteAddr))) { - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, *head, *headSockaddr)); + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, *head, headSockaddr)); *head = buffer; *tail = buffer; - *headSockaddr = (SOCKADDR*)&(peer->TransportAddrs.Remote); - } - else { + OvpnTxCopyRemoteToSockaddr(remoteAddr, headSockaddr); + } else { if (*head == NULL) { *head = buffer; - *headSockaddr = (SOCKADDR*)&(peer->TransportAddrs.Remote); + OvpnTxCopyRemoteToSockaddr(remoteAddr, headSockaddr); } else { (*tail)->WskBufList.Next = &buffer->WskBufList; @@ -158,7 +208,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET } } - OvpnTimerResetXmit(peer->Timer); + OvpnTimerResetXmit(timer); } else { OvpnTxBufferPoolPut(buffer); @@ -184,11 +234,9 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_DEVICE device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); BOOLEAN packetSent = false; - KIRQL kirql = ExAcquireSpinLockShared(&device->SpinLock); - OVPN_TX_BUFFER* txBufferHead = NULL; OVPN_TX_BUFFER* txBufferTail = NULL; - SOCKADDR* headSockaddr = NULL; + SOCKADDR headSockaddr = {0}; while (NetPacketIteratorHasAny(&pi)) { NET_PACKET* packet = NetPacketIteratorGetPacket(&pi); @@ -213,11 +261,9 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) if (packetSent) { if (!device->Socket.Tcp) { // this will use WskSendMessages to send buffers list which we constructed before - LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, headSockaddr)); + LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, &headSockaddr)); } } - - ExReleaseSpinLockShared(&device->SpinLock, kirql); } _Use_decl_annotations_