Skip to content

Commit

Permalink
Improve locking and peer lifetime management
Browse files Browse the repository at this point in the history
The current locking mechanism (global read-write spinlock)
works fine for client (since almost all operations are read-only,
such as encrypt/decrypt) but doesn't scale well for multipeer,
since amount of state-modifying operations (add/remove peer,
modify keys) increases a lot. Instead of a single lock:

 - Use a global spinlock when accessing peers tables
 (by peer-id, by vpn4 and vpn6 addresses). Acquire for
 shared access for lookup and for exclusive access
 when modifying the tables. After lookup, increment
 peer reference counter and return a pointer to a peer context.

 - When doing read-only operation on a peer, acquire
 peer-specific spinlock for shared access. For state-modifying
 operation, acquire peer-specific spinlock for exclusive access.

 - After finishing operations on a peer, decrement peer
 reference counter. If after decrement peer reference counter
 is zero, free peer context.

This approach reduces contention and ensures that peer context
won't be deleted during peer-specific operation.

GitHub: #89

Signed-off-by: Lev Stipakov <[email protected]>
  • Loading branch information
lstipakov committed Oct 31, 2024
1 parent c0a1dcc commit 66bda50
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 245 deletions.
184 changes: 107 additions & 77 deletions Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,12 @@ OvpnStopVPN(_In_ POVPN_DEVICE device)
{
LOG_ENTER();

OvpnFlushPeers(device);

KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
PWSK_SOCKET socket = device->Socket.Socket;
device->Socket.Socket = NULL;

OvpnFlushPeers(device);

device->Mode = OVPN_MODE_P2P;

RtlZeroMemory(&device->Socket.TcpState, sizeof(OvpnSocketTcpState));
Expand Down Expand Up @@ -481,12 +481,9 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe
return;
}

KIRQL kirql = 0;
switch ((long)ioControlCode) {
case OVPN_IOCTL_GET_STATS:
kirql = ExAcquireSpinLockShared(&device->SpinLock);
status = OvpnPeerGetStats(device, request, &bytesReturned);
ExReleaseSpinLockShared(&device->SpinLock, kirql);
break;

case OVPN_IOCTL_NEW_PEER:
Expand All @@ -502,37 +499,27 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe
break;

case OVPN_IOCTL_NEW_KEY:
kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
status = OvpnPeerNewKey(device, request);
ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
break;

case OVPN_IOCTL_NEW_KEY_V2:
kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
status = OvpnPeerNewKeyV2(device, request);
ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
break;

case OVPN_IOCTL_SWAP_KEYS:
kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
status = OvpnPeerSwapKeys(device);
ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
break;

case OVPN_IOCTL_SET_PEER:
kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
status = OvpnPeerSet(device, request);
ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
break;

case OVPN_IOCTL_GET_VERSION:
status = OvpnGetVersion(request, &bytesReturned);
break;

case OVPN_IOCTL_SET_MODE:
kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
status = OvpnSetMode(device, request);
ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
break;

case OVPN_IOCTL_MP_START_VPN:
Expand All @@ -544,9 +531,7 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe
break;

case OVPN_IOCTL_MP_SET_PEER:
kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
status = OvpnMPPeerSet(device, request);
ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
break;

case OVPN_IOCTL_NOTIFY_EVENT:
Expand Down Expand Up @@ -778,12 +763,15 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) {
return status;
}

_Use_decl_annotations_
NTSTATUS
OvpnAddPeerToTable(_In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer)
OvpnAddPeerToTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* table, OvpnPeerContext* peer)
{
NTSTATUS status;
BOOLEAN newElem;

auto irql = ExAcquireSpinLockExclusive(&device->SpinLock);

RtlInsertElementGenericTable(table, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem);

if (newElem) {
Expand All @@ -794,116 +782,154 @@ OvpnAddPeerToTable(_In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer)
LOG_ERROR("Unable to add new peer");
status = STATUS_NO_MEMORY;
}
return status;
}


_Use_decl_annotations_
NTSTATUS
OvpnAddPeer(POVPN_DEVICE device, OvpnPeerContext* peer)
{
return OvpnAddPeerToTable(&device->Peers, peer);
}

_Use_decl_annotations_
NTSTATUS
OvpnAddPeerVpn4(POVPN_DEVICE device, OvpnPeerContext* peer)
{
return OvpnAddPeerToTable(&device->PeersByVpn4, peer);
}
ExReleaseSpinLockExclusive(&device->SpinLock, irql);

_Use_decl_annotations_
NTSTATUS
OvpnAddPeerVpn6(POVPN_DEVICE device, OvpnPeerContext* peer)
{
return OvpnAddPeerToTable(&device->PeersByVpn6, peer);
return status;
}

_Use_decl_annotations_
VOID
OvpnFlushPeers(POVPN_DEVICE device) {
OvpnCleanupPeerTable(&device->PeersByVpn6);
OvpnCleanupPeerTable(&device->PeersByVpn4);
OvpnCleanupPeerTable(&device->Peers);
OvpnCleanupPeerTable(device, &device->PeersByVpn6);
OvpnCleanupPeerTable(device, &device->PeersByVpn4);
OvpnCleanupPeerTable(device, &device->Peers);
}

_Use_decl_annotations_
VOID
OvpnCleanupPeerTable(RTL_GENERIC_TABLE* peers)
OvpnCleanupPeerTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* peers)
{
auto irql = ExAcquireSpinLockExclusive(&device->SpinLock);

while (!RtlIsGenericTableEmpty(peers)) {
PVOID ptr = RtlGetElementGenericTable(peers, 0);
OvpnPeerContext* peer = *(OvpnPeerContext**)ptr;
RtlDeleteElementGenericTable(peers, ptr);

if (InterlockedDecrement(&peer->RefCounter) == 0) {
OvpnPeerCtxFree(peer);
}
OvpnPeerCtxRelease(peer);
}

ExReleaseSpinLockExclusive(&device->SpinLock, irql);
}

_Use_decl_annotations_
OvpnPeerContext*
OvpnGetFirstPeer(RTL_GENERIC_TABLE* peers)
OvpnGetFirstPeer(POVPN_DEVICE device)
{
OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(peers, 0);
return ptr ? (OvpnPeerContext*)*ptr : NULL;
auto irql = ExAcquireSpinLockShared(&device->SpinLock);

OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
OvpnPeerContext* peer = ptr ? (OvpnPeerContext*)*ptr : nullptr;

if (peer != nullptr) {
InterlockedIncrement(&peer->RefCounter);
}

ExReleaseSpinLockShared(&device->SpinLock, irql);

return peer;
}

_Use_decl_annotations_
OvpnPeerContext*
OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId)
{
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;

auto kirql = ExAcquireSpinLockShared(&device->SpinLock);

if (device->Mode == OVPN_MODE_P2P) {
return OvpnGetFirstPeer(&device->Peers);
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
}
else {
OvpnPeerContext p{};
p.PeerId = PeerId;

auto* pp = &p;
ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp);
}

peer = ptr ? (OvpnPeerContext*)*ptr : nullptr;

if (peer) {
InterlockedIncrement(&peer->RefCounter);
}

OvpnPeerContext p {};
p.PeerId = PeerId;
ExReleaseSpinLockShared(&device->SpinLock, kirql);

auto* pp = &p;
OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp);
return ptr ? (OvpnPeerContext*)*ptr : NULL;
return peer;
}

_Use_decl_annotations_
OvpnPeerContext*
OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr)
{
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;

auto kirql = ExAcquireSpinLockShared(&device->SpinLock);

if (device->Mode == OVPN_MODE_P2P) {
return OvpnGetFirstPeer(&device->Peers);
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
}
else {
OvpnPeerContext p{};
p.VpnAddrs.IPv4 = addr;

auto* pp = &p;
ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp);
}

OvpnPeerContext p{};
p.VpnAddrs.IPv4 = addr;
peer = ptr ? (OvpnPeerContext*)*ptr : nullptr;
if (peer) {
InterlockedIncrement(&peer->RefCounter);
}

ExReleaseSpinLockShared(&device->SpinLock, kirql);

auto* pp = &p;
OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp);
return ptr ? (OvpnPeerContext*)*ptr : NULL;
return peer;
}

_Use_decl_annotations_
OvpnPeerContext*
OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr)
{
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;

auto kirql = ExAcquireSpinLockShared(&device->SpinLock);

if (device->Mode == OVPN_MODE_P2P) {
return OvpnGetFirstPeer(&device->Peers);
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
}
else {
OvpnPeerContext p{};
p.VpnAddrs.IPv6 = addr;

auto* pp = &p;
ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp);
}

peer = ptr ? (OvpnPeerContext*)*ptr : nullptr;
if (peer) {
InterlockedIncrement(&peer->RefCounter);
}

OvpnPeerContext p{};
p.VpnAddrs.IPv6 = addr;
ExReleaseSpinLockShared(&device->SpinLock, kirql);

auto* pp = &p;
OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp);
return ptr ? (OvpnPeerContext*)*ptr : NULL;
return peer;
}

VOID
OvpnDeletePeerFromTable(RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName)
OvpnDeletePeerFromTable(POVPN_DEVICE device, RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName)
{
auto peerId = peer->PeerId;
auto pp = &peer;

auto kirql = ExAcquireSpinLockShared(&device->SpinLock);

if (RtlDeleteElementGenericTable(table, pp)) {
LOG_INFO("Peer deleted", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id"));

Expand All @@ -915,31 +941,35 @@ OvpnDeletePeerFromTable(RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* t
else {
LOG_INFO("Peer not found", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id"));
}

ExReleaseSpinLockShared(&device->SpinLock, kirql);
}

_Use_decl_annotations_
NTSTATUS
OvpnDeletePeer(POVPN_DEVICE device, INT32 peerId)
{
NTSTATUS status = STATUS_SUCCESS;

KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock);

LOG_INFO("Deleting peer", TraceLoggingValue(peerId, "peer-id"));

// get peer from main table
OvpnPeerContext* peerCtx = OvpnFindPeer(device, peerId);
if (peerCtx == NULL) {
OvpnPeerContext* peer = OvpnFindPeer(device, peerId);
if (peer == NULL) {
status = STATUS_NOT_FOUND;
LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peer-id"));
}
else {
OvpnDeletePeerFromTable(&device->PeersByVpn4, peerCtx, "vpn4");
OvpnDeletePeerFromTable(&device->PeersByVpn6, peerCtx, "vpn6");
auto irql = ExAcquireSpinLockShared(&peer->SpinLock);

OvpnDeletePeerFromTable(&device->Peers, peerCtx, "peers");
}
OvpnDeletePeerFromTable(device, &device->PeersByVpn4, peer, "vpn4");
OvpnDeletePeerFromTable(device, &device->PeersByVpn6, peer, "vpn6");
OvpnDeletePeerFromTable(device, &device->Peers, peer, "peers");

ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
ExReleaseSpinLockShared(&peer->SpinLock, irql);

OvpnPeerCtxRelease(peer);
}

return status;
}
21 changes: 3 additions & 18 deletions Driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,36 +108,21 @@ typedef OVPN_DEVICE * POVPN_DEVICE;

WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_DEVICE, OvpnGetDeviceContext)

static inline
BOOLEAN
OvpnHasPeers(_In_ POVPN_DEVICE device)
{
return !RtlIsGenericTableEmpty(&device->Peers);
}

struct OvpnPeerContext;

_Must_inspect_result_
NTSTATUS
OvpnAddPeer(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx);

_Must_inspect_result_
NTSTATUS
OvpnAddPeerVpn4(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx);

_Must_inspect_result_
NTSTATUS
OvpnAddPeerVpn6(_In_ POVPN_DEVICE device, _In_ OvpnPeerContext* PeerCtx);
OvpnAddPeerToTable(POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer);

VOID
OvpnFlushPeers(_In_ POVPN_DEVICE device);

VOID
OvpnCleanupPeerTable(_In_ RTL_GENERIC_TABLE*);
OvpnCleanupPeerTable(_In_ POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE*);

_Must_inspect_result_
OvpnPeerContext*
OvpnGetFirstPeer(_In_ RTL_GENERIC_TABLE*);
OvpnGetFirstPeer(_In_ POVPN_DEVICE device);

_Must_inspect_result_
OvpnPeerContext*
Expand Down
Loading

0 comments on commit 66bda50

Please sign in to comment.