diff --git a/PropertySheet.props b/PropertySheet.props
index f63952e..6bcce73 100644
--- a/PropertySheet.props
+++ b/PropertySheet.props
@@ -3,7 +3,7 @@
2
- 0
+ 2
1
diff --git a/peer.cpp b/peer.cpp
index a7cdd07..3314fee 100644
--- a/peer.cpp
+++ b/peer.cpp
@@ -43,8 +43,7 @@ VOID
OvpnPeerCtxFree(OvpnPeerContext* peer)
{
OvpnCryptoUninit(&peer->CryptoContext);
- OvpnTimerDestroy(&peer->KeepaliveXmitTimer);
- OvpnTimerDestroy(&peer->KeepaliveRecvTimer);
+ OvpnTimerDestroy(&peer->Timer);
ExFreePoolWithTag(peer, 'ovpn');
}
@@ -147,6 +146,8 @@ OvpnPeerNew(POVPN_DEVICE device, WDFREQUEST request)
OvpnPeerZeroStats(&device->Stats);
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerCreate(device->WdfDevice, peerCtx, &peerCtx->Timer));
+
if (proto_tcp) {
LOG_IF_NOT_NT_SUCCESS(status = WdfRequestForwardToIoQueue(request, device->PendingNewPeerQueue));
// start async connect
@@ -229,29 +230,15 @@ NTSTATUS OvpnPeerSet(POVPN_DEVICE device, WDFREQUEST request)
if (set_peer->KeepaliveInterval != -1) {
peer->KeepaliveInterval = set_peer->KeepaliveInterval;
- if (peer->KeepaliveInterval > 0) {
- // keepalive xmit timer, sends ping packets
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerXmitCreate(device->WdfDevice, peer, peer->KeepaliveInterval, &peer->KeepaliveXmitTimer));
- OvpnTimerReset(peer->KeepaliveXmitTimer, peer->KeepaliveInterval);
- }
- else {
- LOG_INFO("Destroy xmit timer");
- OvpnTimerDestroy(&peer->KeepaliveXmitTimer);
- }
+ // keepalive xmit timer, sends ping packets
+ OvpnTimerSetXmitInterval(peer->Timer, peer->KeepaliveInterval);
}
if (peer->KeepaliveTimeout != -1) {
peer->KeepaliveTimeout = set_peer->KeepaliveTimeout;
- if (peer->KeepaliveTimeout > 0) {
- // keepalive recv timer, detects keepalive timeout
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnTimerRecvCreate(device->WdfDevice, peer, &peer->KeepaliveRecvTimer));
- OvpnTimerReset(peer->KeepaliveRecvTimer, peer->KeepaliveTimeout);
- }
- else {
- LOG_INFO("Destroy recv timer");
- OvpnTimerDestroy(&peer->KeepaliveRecvTimer);
- }
+ // keepalive recv timer, detects keepalive timeout
+ OvpnTimerSetRecvTimeout(peer->Timer, peer->KeepaliveTimeout);
}
done:
diff --git a/peer.h b/peer.h
index 4495171..e1eb5d4 100644
--- a/peer.h
+++ b/peer.h
@@ -39,11 +39,8 @@ struct OvpnPeerContext
// keepalive timeout in seconds
LONG KeepaliveTimeout;
- // timer used to send periodic ping messages to the server if no data has been sent within the past KeepaliveInterval seconds
- WDFTIMER KeepaliveXmitTimer;
-
- // timer used to report keepalive timeout error to userspace when no data has been received for KeepaliveTimeout seconds
- WDFTIMER KeepaliveRecvTimer;
+ // 1-sec timer which handles ping intervals and keepalive timeouts
+ WDFTIMER Timer;
};
_Must_inspect_result_
diff --git a/socket.cpp b/socket.cpp
index ef57e37..1b3ca42 100644
--- a/socket.cpp
+++ b/socket.cpp
@@ -203,7 +203,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_
return;
}
- OvpnTimerReset(peer->KeepaliveRecvTimer, peer->KeepaliveTimeout);
+ OvpnTimerResetRecv(peer->Timer);
// points to the beginning of plaintext
UCHAR* buf = buffer->Data + device->CryptoOverhead;
diff --git a/timer.cpp b/timer.cpp
index bac0e9b..2aa9fbb 100644
--- a/timer.cpp
+++ b/timer.cpp
@@ -37,6 +37,13 @@ static const UCHAR OvpnKeepaliveMessage[] = {
// Context added to a timer's attributes
typedef struct _OVPN_PEER_TIMER_CONTEXT {
OvpnPeerContext* Peer;
+
+ LARGE_INTEGER lastXmit;
+ LARGE_INTEGER lastRecv;
+
+ // 0 means "not set"
+ LONG recvTimeout;
+ LONG xmitInterval;
} OVPN_PEER_TIMER_CONTEXT, * POVPN_PEER_TIMER_CONTEXT;
WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_PEER_TIMER_CONTEXT, OvpnGetPeerTimerContext);
@@ -47,11 +54,8 @@ BOOLEAN OvpnTimerIsKeepaliveMessage(const PUCHAR buf, SIZE_T len)
return RtlCompareMemory(buf, OvpnKeepaliveMessage, len) == sizeof(OvpnKeepaliveMessage);
}
-_Function_class_(EVT_WDF_TIMER)
static VOID OvpnTimerXmit(WDFTIMER timer)
{
- LOG_ENTER();
-
POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer));
POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer);
OVPN_TX_BUFFER* buffer;
@@ -92,30 +96,24 @@ static VOID OvpnTimerXmit(WDFTIMER timer)
OvpnTxBufferPoolPut(buffer);
}
ExReleaseSpinLockShared(&device->SpinLock, kiqrl);
-
- LOG_EXIT();
}
-_Function_class_(EVT_WDF_TIMER)
-static VOID OvpnTimerRecv(WDFTIMER timer)
+static BOOLEAN OvpnTimerRecv(WDFTIMER timer)
{
- LOG_ENTER();
-
- LOG_WARN("Keepalive timeout");
-
POVPN_DEVICE device = OvpnGetDeviceContext(WdfTimerGetParentObject(timer));
WDFREQUEST request;
NTSTATUS status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request);
if (!NT_SUCCESS(status)) {
LOG_WARN("No pending request for keepalive timeout notification");
+ return FALSE;
}
else {
+ LOG_INFO("Notify userspace about keepalive timeout");
ULONG_PTR bytesSent = 0;
WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent);
+ return TRUE;
}
-
- LOG_EXIT();
}
_Use_decl_annotations_
@@ -129,7 +127,31 @@ VOID OvpnTimerDestroy(WDFTIMER* timer)
}
}
-static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, PFN_WDF_TIMER func, _Inout_ WDFTIMER* timer)
+_Function_class_(EVT_WDF_TIMER)
+static VOID OvpnTimerTick(WDFTIMER timer)
+{
+ LARGE_INTEGER now;
+ KeQuerySystemTime(&now);
+
+ POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer);
+ if ((timerCtx->xmitInterval > 0) && (((now.QuadPart - timerCtx->lastXmit.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->xmitInterval))
+ {
+ OvpnTimerXmit(timer);
+ timerCtx->lastXmit = now;
+ }
+
+ if ((timerCtx->recvTimeout > 0) && (((now.QuadPart - timerCtx->lastRecv.QuadPart) / WDF_TIMEOUT_TO_SEC) > timerCtx->recvTimeout))
+ {
+ // have we have completed pending read request?
+ if (OvpnTimerRecv(timer))
+ {
+ timerCtx->recvTimeout = 0; // one-off timer
+ }
+ }
+}
+
+_Use_decl_annotations_
+NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer)
{
LOG_ENTER();
@@ -141,8 +163,9 @@ static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG p
}
WDF_TIMER_CONFIG timerConfig;
- WDF_TIMER_CONFIG_INIT(&timerConfig, func);
- timerConfig.Period = period * 1000;
+ WDF_TIMER_CONFIG_INIT(&timerConfig, OvpnTimerTick);
+ timerConfig.TolerableDelay = TolerableDelayUnlimited;
+ timerConfig.Period = 1000;
WDF_OBJECT_ATTRIBUTES timerAttributes;
WDF_OBJECT_ATTRIBUTES_INIT(&timerAttributes);
@@ -155,36 +178,49 @@ static NTSTATUS OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG p
if (NT_SUCCESS(status)) {
POVPN_PEER_TIMER_CONTEXT pTimerContext = OvpnGetPeerTimerContext(*timer);
pTimerContext->Peer = peer;
+ WdfTimerStart(*timer, WDF_REL_TIMEOUT_IN_SEC(1));
}
LOG_EXIT();
return status;
}
-_Use_decl_annotations_
-NTSTATUS OvpnTimerXmitCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, WDFTIMER* timer)
+#define CHECK_TIMER_HANDLE(timer) \
+ do { \
+ if ((timer) == WDF_NO_HANDLE) { \
+ LOG_ERROR("Timer handle is not initialized"); \
+ return; \
+ } \
+ } while (0)
+
+VOID OvpnTimerSetXmitInterval(WDFTIMER timer, LONG xmitInterval)
{
- NTSTATUS status;
- LOG_INFO("Create xmit timer", TraceLoggingValue(period, "period"));
- LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, peer, period, OvpnTimerXmit, timer));
+ CHECK_TIMER_HANDLE(timer);
- return status;
+ POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer);
+ timerCtx->xmitInterval = xmitInterval;
+ KeQuerySystemTime(&timerCtx->lastXmit);
}
-_Use_decl_annotations_
-NTSTATUS OvpnTimerRecvCreate(WDFOBJECT parent, OvpnPeerContext* peer, WDFTIMER* timer)
+VOID OvpnTimerSetRecvTimeout(WDFTIMER timer, LONG recvTimeout)
{
- NTSTATUS status;
- LOG_INFO("Create recv timer");
- LOG_IF_NOT_NT_SUCCESS(status = OvpnTimerCreate(parent, peer, 0, OvpnTimerRecv, timer));
+ CHECK_TIMER_HANDLE(timer);
- return status;
+ POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer);
+ timerCtx->recvTimeout = recvTimeout;
+ KeQuerySystemTime(&timerCtx->lastRecv);
}
-VOID OvpnTimerReset(WDFTIMER timer, ULONG dueTime)
+VOID OvpnTimerResetXmit(WDFTIMER timer)
{
- if (timer != WDF_NO_HANDLE) {
- // if timer has already been created this will reset "due time" value to the new one
- WdfTimerStart(timer, WDF_REL_TIMEOUT_IN_SEC(dueTime));
- }
+ CHECK_TIMER_HANDLE(timer);
+
+ POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer);
+ KeQuerySystemTime(&timerCtx->lastXmit);
+}
+
+VOID OvpnTimerResetRecv(WDFTIMER timer)
+{
+ POVPN_PEER_TIMER_CONTEXT timerCtx = OvpnGetPeerTimerContext(timer);
+ KeQuerySystemTime(&timerCtx->lastRecv);
}
diff --git a/timer.h b/timer.h
index b9f8bfc..88ee019 100644
--- a/timer.h
+++ b/timer.h
@@ -26,15 +26,20 @@
#include
VOID
-OvpnTimerReset(WDFTIMER timer, ULONG dueTime);
+OvpnTimerResetXmit(WDFTIMER timer);
-_Must_inspect_result_
-NTSTATUS
-OvpnTimerXmitCreate(WDFOBJECT parent, OvpnPeerContext* peer, ULONG period, _Inout_ WDFTIMER* timer);
+VOID
+OvpnTimerResetRecv(WDFTIMER timer);
_Must_inspect_result_
NTSTATUS
-OvpnTimerRecvCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer);
+OvpnTimerCreate(WDFOBJECT parent, OvpnPeerContext* peer, _Inout_ WDFTIMER* timer);
+
+VOID
+OvpnTimerSetXmitInterval(WDFTIMER timer, LONG xmitInterval);
+
+VOID
+OvpnTimerSetRecvTimeout(WDFTIMER timer, LONG recvTimeout);
VOID
OvpnTimerDestroy(_Inout_ WDFTIMER* timer);
diff --git a/txqueue.cpp b/txqueue.cpp
index a93a470..5325c87 100644
--- a/txqueue.cpp
+++ b/txqueue.cpp
@@ -128,7 +128,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET
*tail = buffer;
}
- OvpnTimerReset(peer->KeepaliveXmitTimer, peer->KeepaliveInterval);
+ OvpnTimerResetXmit(peer->Timer);
}
else {
OvpnTxBufferPoolPut(buffer);