diff --git a/PropertySheet.props b/PropertySheet.props index f63952e..6bcce73 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -3,7 +3,7 @@ 2 - 0 + 2 1 diff --git a/peer.cpp b/peer.cpp index a7cdd07..3314fee 100644 --- a/peer.cpp +++ b/peer.cpp @@ -43,8 +43,7 @@ VOID OvpnPeerCtxFree(OvpnPeerContext* peer) { OvpnCryptoUninit(&peer->CryptoContext); - OvpnTimerDestroy(&peer->KeepaliveXmitTimer); - OvpnTimerDestroy(&peer->KeepaliveRecvTimer); + OvpnTimerDestroy(&peer->Timer); ExFreePoolWithTag(peer, 'ovpn'); } @@ -147,6 +146,8 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request) OvpnPeerZeroStats(&device->Stats); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer)); + if (proto_tcp) { LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue)); // start async connect @@ -229,29 +230,15 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request) if (set_peer->KeepaliveInterval != -1) { peer->KeepaliveInterval = set_peer->KeepaliveInterval; - if (peer->KeepaliveInterval > 0) { - // keepalive xmit timer, sends ping packets - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerXmitCreate(device->WdfDevice, peer, peer->KeepaliveInterval, &peer->KeepaliveXmitTimer)); - OvpnTimerReset(peer->KeepaliveXmitTimer, peer->KeepaliveInterval); - } - else { - LOG_INFO("Destroy xmit timer"); - OvpnTimerDestroy(&peer->KeepaliveXmitTimer); - } + // keepalive xmit timer, sends ping packets + OvpnTimerSetXmitInterval(peer->Timer, peer->KeepaliveInterval); } if (peer->KeepaliveTimeout != -1) { peer->KeepaliveTimeout = set_peer->KeepaliveTimeout; - if (peer->KeepaliveTimeout > 0) { - // keepalive recv timer, detects keepalive timeout - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerRecvCreate(device->WdfDevice, peer, &peer->KeepaliveRecvTimer)); - OvpnTimerReset(peer->KeepaliveRecvTimer, peer->KeepaliveTimeout); - } - else { - LOG_INFO("Destroy recv timer"); - OvpnTimerDestroy(&peer->KeepaliveRecvTimer); - } + // keepalive recv timer, detects keepalive timeout + OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout); } done: diff --git a/peer.h b/peer.h index 4495171..e1eb5d4 100644 --- a/peer.h +++ b/peer.h @@ -39,11 +39,8 @@ struct OvpnPeerContext // keepalive timeout in seconds LONG KeepaliveTimeout; - // timer used to send periodic ping messages to the server if no data has been sent within the past KeepaliveInterval seconds - WDFTIMER KeepaliveXmitTimer; - - // timer used to report keepalive timeout error to userspace when no data has been received for KeepaliveTimeout seconds - WDFTIMER KeepaliveRecvTimer; + // 1-sec timer which handles ping intervals and keepalive timeouts + WDFTIMER Timer; }; _Must_inspect_result_ diff --git a/socket.cpp b/socket.cpp index ef57e37..1b3ca42 100644 --- a/socket.cpp +++ b/socket.cpp @@ -203,7 +203,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ return; } - OvpnTimerReset(peer->KeepaliveRecvTimer, peer->KeepaliveTimeout); + OvpnTimerResetRecv(peer->Timer); // points to the beginning of plaintext UCHAR* buf = buffer->Data + device->CryptoOverhead; diff --git a/timer.cpp b/timer.cpp index bac0e9b..2aa9fbb 100644 --- a/timer.cpp +++ b/timer.cpp @@ -37,6 +37,13 @@ static const UCHAR OvpnKeepaliveMessage[] = { // Context added to a timer's attributes typedef struct _OVPN_PEER_TIMER_CONTEXT { OvpnPeerContext* Peer; + + LARGE_INTEGER lastXmit; + LARGE_INTEGER lastRecv; + + // 0 means "not set" + LONG recvTimeout; + LONG xmitInterval; } OVPN_PEER_TIMER_CONTEXT, * POVPN_PEER_TIMER_CONTEXT; WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_PEER_TIMER_CONTEXT, OvpnGetPeerTimerContext); @@ -47,11 +54,8 @@ BOOLEAN OvpnTimerIsKeepaliveMessage(const PUCHAR buf, SIZE_T len) return RtlCompareMemory(buf, OvpnKeepaliveMessage, len) == sizeof(OvpnKeepaliveMessage); } -_Function_class_(EVT_WDF_TIMER) static VOID OvpnTimerXmit(WDFTIMER timer) { - LOG_ENTER(); - POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); OVPN_TX_BUFFER* buffer; @@ -92,30 +96,24 @@ static VOID OvpnTimerXmit(WDFTIMER timer) OvpnTxBufferPoolPut(buffer); } ExReleaseSpinLockShared(&device->SpinLock, kiqrl); - - LOG_EXIT(); } -_Function_class_(EVT_WDF_TIMER) -static VOID OvpnTimerRecv(WDFTIMER timer) +static BOOLEAN OvpnTimerRecv(WDFTIMER timer) { - LOG_ENTER(); - - LOG_WARN("Keepalive timeout"); - POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer)); WDFREQUEST request; NTSTATUS status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); if (!NT_SUCCESS(status)) { LOG_WARN("No pending request for keepalive timeout notification"); + return FALSE; } else { + LOG_INFO("Notify userspace about keepalive timeout"); ULONG_PTR bytesSent = 0; WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent); + return TRUE; } - - LOG_EXIT(); } _Use_decl_annotations_ @@ -129,7 +127,31 @@ VOID OvpnTimerDestroy(WDFTIMER* timer) } } -static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, PFN_WDF_TIMER func, _Inout_ WDFTIMER* timer) +_Function_class_(EVT_WDF_TIMER) +static VOID OvpnTimerTick(WDFTIMER timer) +{ + LARGE_INTEGER now; + KeQuerySystemTime(&now); + + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + if ((timerCtx->xmitInterval > 0) && (((now.QuadPart - timerCtx->lastXmit.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->xmitInterval)) + { + OvpnTimerXmit(timer); + timerCtx->lastXmit = now; + } + + if ((timerCtx->recvTimeout > 0) && (((now.QuadPart - timerCtx->lastRecv.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->recvTimeout)) + { + // have we have completed pending read request? + if (OvpnTimerRecv(timer)) + { + timerCtx->recvTimeout = 0; // one-off timer + } + } +} + +_Use_decl_annotations_ +NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer) { LOG_ENTER(); @@ -141,8 +163,9 @@ static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG p } WDF_TIMER_CONFIG timerConfig; - WDF_TIMER_CONFIG_INIT(&timerConfig, func); - timerConfig.Period = period * 1000; + WDF_TIMER_CONFIG_INIT(&timerConfig, OvpnTimerTick); + timerConfig.TolerableDelay = TolerableDelayUnlimited; + timerConfig.Period = 1000; WDF_OBJECT_ATTRIBUTES timerAttributes; WDF_OBJECT_ATTRIBUTES_INIT(&timerAttributes); @@ -155,36 +178,49 @@ static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG p if (NT_SUCCESS(status)) { POVPN_PEER_TIMER_CONTEXT pTimerContext = OvpnGetPeerTimerContext(*timer); pTimerContext->Peer = peer; + WdfTimerStart(*timer, WDF_REL_TIMEOUT_IN_SEC(1)); } LOG_EXIT(); return status; } -_Use_decl_annotations_ -NTSTATUS OvpnTimerXmitCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, WDFTIMER* timer) +#define CHECK_TIMER_HANDLE(timer) \ + do { \ + if ((timer) == WDF_NO_HANDLE) { \ + LOG_ERROR("Timer handle is not initialized"); \ + return; \ + } \ + } while (0) + +VOID OvpnTimerSetXmitInterval(WDFTIMER timer, LONG xmitInterval) { - NTSTATUS status; - LOG_INFO("Create xmit timer", TraceLoggingValue(period, "period")); - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, peer, period, OvpnTimerXmit, timer)); + CHECK_TIMER_HANDLE(timer); - return status; + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + timerCtx->xmitInterval = xmitInterval; + KeQuerySystemTime(&timerCtx->lastXmit); } -_Use_decl_annotations_ -NTSTATUS OvpnTimerRecvCreate(WDFOBJECT parent, OvpnPeerContext* peer, WDFTIMER* timer) +VOID OvpnTimerSetRecvTimeout(WDFTIMER timer, LONG recvTimeout) { - NTSTATUS status; - LOG_INFO("Create recv timer"); - LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, peer, 0, OvpnTimerRecv, timer)); + CHECK_TIMER_HANDLE(timer); - return status; + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + timerCtx->recvTimeout = recvTimeout; + KeQuerySystemTime(&timerCtx->lastRecv); } -VOID OvpnTimerReset(WDFTIMER timer, ULONG dueTime) +VOID OvpnTimerResetXmit(WDFTIMER timer) { - if (timer != WDF_NO_HANDLE) { - // if timer has already been created this will reset "due time" value to the new one - WdfTimerStart(timer, WDF_REL_TIMEOUT_IN_SEC(dueTime)); - } + CHECK_TIMER_HANDLE(timer); + + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + KeQuerySystemTime(&timerCtx->lastXmit); +} + +VOID OvpnTimerResetRecv(WDFTIMER timer) +{ + POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer); + KeQuerySystemTime(&timerCtx->lastRecv); } diff --git a/timer.h b/timer.h index b9f8bfc..88ee019 100644 --- a/timer.h +++ b/timer.h @@ -26,15 +26,20 @@ #include VOID -OvpnTimerReset(WDFTIMER timer, ULONG dueTime); +OvpnTimerResetXmit(WDFTIMER timer); -_Must_inspect_result_ -NTSTATUS -OvpnTimerXmitCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, _Inout_ WDFTIMER* timer); +VOID +OvpnTimerResetRecv(WDFTIMER timer); _Must_inspect_result_ NTSTATUS -OvpnTimerRecvCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer); +OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer); + +VOID +OvpnTimerSetXmitInterval(WDFTIMER timer, LONG xmitInterval); + +VOID +OvpnTimerSetRecvTimeout(WDFTIMER timer, LONG recvTimeout); VOID OvpnTimerDestroy(_Inout_ WDFTIMER* timer); diff --git a/txqueue.cpp b/txqueue.cpp index a93a470..5325c87 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -128,7 +128,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET *tail = buffer; } - OvpnTimerReset(peer->KeepaliveXmitTimer, peer->KeepaliveInterval); + OvpnTimerResetXmit(peer->Timer); } else { OvpnTxBufferPoolPut(buffer);