diff --git a/Driver.cpp b/Driver.cpp
index 803fa62..5a6b48f 100644
--- a/Driver.cpp
+++ b/Driver.cpp
@@ -268,6 +268,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe
ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
break;
+ case OVPN_IOCTL_NEW_KEY_V2:
+ kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
+ status = OvpnPeerNewKeyV2(device, request);
+ ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
+ break;
+
case OVPN_IOCTL_SWAP_KEYS:
kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
status = OvpnPeerSwapKeys(device);
diff --git a/PropertySheet.props b/PropertySheet.props
index f839915..76c4a62 100644
--- a/PropertySheet.props
+++ b/PropertySheet.props
@@ -3,7 +3,7 @@
1
- 3
+ 4
0
diff --git a/bufferpool.h b/bufferpool.h
index fcde323..1ed0721 100644
--- a/bufferpool.h
+++ b/bufferpool.h
@@ -74,7 +74,6 @@ struct OVPN_RX_BUFFER
UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE];
};
-_Must_inspect_result_
UCHAR*
OvpnTxBufferPut(_In_ OVPN_TX_BUFFER* work, SIZE_T len);
diff --git a/crypto.cpp b/crypto.cpp
index b6a881a..18f837b 100644
--- a/crypto.cpp
+++ b/crypto.cpp
@@ -48,11 +48,14 @@ OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId)
OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptNone;
_Use_decl_annotations_
-NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut)
+NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
{
UNREFERENCED_PARAMETER(keySlot);
- if (len < NONE_CRYPTO_OVERHEAD) {
+ BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ BOOLEAN cryptoOverhead = OVPN_DATA_V2_LEN + pktId64bit ? 8 : 4;
+
+ if (len < cryptoOverhead) {
LOG_WARN("Packet too short", TraceLoggingValue(len, "len"));
return STATUS_DATA_ERROR;
}
@@ -66,10 +69,11 @@ OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptNone;
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len)
+OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions)
{
UNREFERENCED_PARAMETER(keySlot);
UNREFERENCED_PARAMETER(len);
+ UNREFERENCED_PARAMETER(cryptoOptions);
// prepend with opcode, key-id and peer-id
UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, 0, 0);
@@ -121,74 +125,116 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL
static
NTSTATUS
-OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut)
+OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
{
/*
AEAD Nonce :
[Packet ID] [HMAC keying material]
- [4 bytes ] [8 bytes ]
+ [4/8 bytes] [8/4 bytes ]
[AEAD nonce total : 12 bytes ]
TLS wire protocol :
+ Packet ID is 8 bytes long with CRYPTO_OPTIONS_64BIT_PKTID.
+
[DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext]
- [4 bytes ] [4 bytes ] [16 bytes ]
+ [4 bytes ] [4/8 bytes] [16 bytes ]
+ [AEAD additional data(AD) ]
+
+ With CRYPTO_OPTIONS_AEAD_TAG_END AEAD Auth tag is placed after ciphertext:
+
+ [DATA_V2 opcode] [Packet ID] [ciphertext] [AEAD Auth tag]
+ [4 bytes ] [4/8 bytes] [16 bytes ]
[AEAD additional data(AD) ]
*/
NTSTATUS status = STATUS_SUCCESS;
- if (len < AEAD_CRYPTO_OVERHEAD) {
+ BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+
+ auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4);
+
+ if (len < cryptoOverhead) {
LOG_WARN("Packet too short", TraceLoggingValue(len, "len"));
return STATUS_DATA_ERROR;
}
- UCHAR nonce[OVPN_PKTID_LEN + OVPN_NONCE_TAIL_LEN];
+ UCHAR nonce[12];
if (encrypt) {
// prepend with opcode, key-id and peer-id
UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId);
op = RtlUlongByteSwap(op);
- *(UINT32*)(bufOut) = op;
+ *reinterpret_cast(bufOut) = op;
- // calculate pktid
- UINT32 pktid;
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid));
- ULONG pktidNetwork = RtlUlongByteSwap(pktid);
+ if (pktId64bit)
+ {
+ // calculate pktid
+ UINT64 pktid;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, true));
+ ULONG64 pktidNetwork = RtlUlonglongByteSwap(pktid);
+
+ // calculate nonce, which is pktid + nonce_tail
+ RtlCopyMemory(nonce, &pktidNetwork, 8);
+ RtlCopyMemory(nonce + 8, keySlot->EncNonceTail, 4);
- // calculate nonce, which is pktid + nonce_tail
- RtlCopyMemory(nonce, &pktidNetwork, OVPN_PKTID_LEN);
- RtlCopyMemory(nonce + OVPN_PKTID_LEN, keySlot->EncNonceTail, OVPN_NONCE_TAIL_LEN);
+ // prepend with pktid
+ *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork;
+ }
+ else
+ {
+ // calculate pktid
+ UINT32 pktid;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, false));
+ ULONG pktidNetwork = RtlUlongByteSwap(pktid);
- // prepend with pktid
- *(UINT32*)(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork;
+ // calculate nonce, which is pktid + nonce_tail
+ RtlCopyMemory(nonce, &pktidNetwork, 4);
+ RtlCopyMemory(nonce + 4, keySlot->EncNonceTail, 8);
+
+ // prepend with pktid
+ *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork;
+ }
}
else {
- RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, OVPN_PKTID_LEN);
- RtlCopyMemory(nonce + OVPN_PKTID_LEN, &keySlot->DecNonceTail, sizeof(keySlot->DecNonceTail));
+ ULONG64 pktId;
+
+ RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, pktId64bit ? 8 : 4);
+ RtlCopyMemory(nonce + (pktId64bit ? 8 : 4), &keySlot->DecNonceTail, pktId64bit ? 4 : 8);
+ if (pktId64bit)
+ {
+ pktId = RtlUlonglongByteSwap(*reinterpret_cast(nonce));
+ }
+ else
+ {
+ pktId = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce)));
+ }
- UINT32 pktId = RtlUlongByteSwap(*(UINT32*)nonce);
status = OvpnPktidRecvVerify(&keySlot->PktidRecv, pktId);
if (!NT_SUCCESS(status)) {
- LOG_ERROR("Invalid pktId", TraceLoggingUInt32(pktId, "pktId"));
+ LOG_ERROR("Invalid pktId", TraceLoggingUInt64(pktId, "pktId"));
return STATUS_DATA_ERROR;
}
}
+ // we prepended buf with crypto overhead
+ len -= cryptoOverhead;
+
+ BOOLEAN aeadTagEnd = cryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+
BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo;
BCRYPT_INIT_AUTH_MODE_INFO(authInfo);
authInfo.pbNonce = nonce;
authInfo.cbNonce = sizeof(nonce);
- authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + OVPN_PKTID_LEN;
+ authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? len : 0);
authInfo.cbTag = AEAD_AUTH_TAG_LEN;
authInfo.pbAuthData = (encrypt ? bufOut : bufIn);
- authInfo.cbAuthData = OVPN_DATA_V2_LEN + OVPN_PKTID_LEN;
-
- bufOut += AEAD_CRYPTO_OVERHEAD;
- bufIn += AEAD_CRYPTO_OVERHEAD;
+ authInfo.cbAuthData = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4);
- len -= AEAD_CRYPTO_OVERHEAD;
+ auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
+ bufOut += payloadOffset;
+ bufIn += payloadOffset;
// non-chaining mode
ULONG bytesDone = 0;
@@ -205,27 +251,29 @@ OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD;
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut)
+OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
{
- return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut);
+ return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut, cryptoOptions);
}
OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD;
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len)
+OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions)
{
- return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf);
+ return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf, cryptoOptions);
}
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData)
+OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2)
{
OvpnCryptoKeySlot* keySlot = NULL;
NTSTATUS status = STATUS_SUCCESS;
+ POVPN_CRYPTO_DATA cryptoData = &cryptoDataV2->V1;
+
if (cryptoData->KeySlot == OVPN_KEY_SLOT::OVPN_KEY_SLOT_PRIMARY) {
keySlot = &cryptoContext->Primary;
}
@@ -237,6 +285,15 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData)
return STATUS_INVALID_DEVICE_REQUEST;
}
+ if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID)
+ {
+ cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_64BIT_PKTID;
+ }
+ if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END)
+ {
+ cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_AEAD_TAG_END;
+ }
+
if ((cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM) || (cryptoData->CipherAlg == OVPN_CIPHER_ALG_CHACHA20_POLY1305)) {
// destroy previous keys
if (keySlot->EncKey) {
@@ -284,8 +341,6 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData)
keySlot->KeyId = cryptoData->KeyId;
keySlot->PeerId = cryptoData->PeerId;
- cryptoContext->CryptoOverhead = AEAD_CRYPTO_OVERHEAD;
-
LOG_INFO("New key", TraceLoggingValue(cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM ? "aes-gcm" : "chacha20-poly1305", "alg"),
TraceLoggingValue(cryptoData->KeyId, "KeyId"), TraceLoggingValue(cryptoData->KeyId, "PeerId"));
}
@@ -293,8 +348,6 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData)
cryptoContext->Encrypt = OvpnCryptoEncryptNone;
cryptoContext->Decrypt = OvpnCryptoDecryptNone;
- cryptoContext->CryptoOverhead = NONE_CRYPTO_OVERHEAD;
-
LOG_INFO("Using cipher none");
}
else {
diff --git a/crypto.h b/crypto.h
index dfbc688..def22d7 100644
--- a/crypto.h
+++ b/crypto.h
@@ -28,14 +28,8 @@
#include "uapi\ovpn-dco.h"
#include "socket.h"
-#define AEAD_CRYPTO_OVERHEAD 24 // 4 + 4 + 16 data_v2 + pktid + auth_tag
-#define NONE_CRYPTO_OVERHEAD 8 // 4 + 4 data_v2 + pktid
-#define OVPN_PKTID_LEN 4
-#define OVPN_NONCE_TAIL_LEN 8
#define OVPN_DATA_V2_LEN 4
#define AEAD_AUTH_TAG_LEN 16
-#define AES_BLOCK_SIZE 16
-#define AES_GCM_NONCE_LEN 12
// packet opcode (high 5 bits) and key-id (low 3 bits) are combined in one byte
#define OVPN_OP_DATA_V2 9
@@ -62,7 +56,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL)
_Must_inspect_result_
typedef
NTSTATUS
-OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len);
+OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ INT32 CryptoOptions);
typedef OVPN_CRYPTO_ENCRYPT* POVPN_CRYPTO_ENCRYPT;
_Function_class_(OVPN_CRYPTO_DECRYPT)
@@ -70,7 +64,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL)
_Must_inspect_result_
typedef
NTSTATUS
-OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut);
+OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ INT32 CryptoOptions);
typedef OVPN_CRYPTO_DECRYPT* POVPN_CRYPTO_DECRYPT;
struct OvpnCryptoContext
@@ -84,7 +78,7 @@ struct OvpnCryptoContext
POVPN_CRYPTO_ENCRYPT Encrypt;
POVPN_CRYPTO_DECRYPT Decrypt;
- SIZE_T CryptoOverhead;
+ INT32 CryptoOptions;
};
_Must_inspect_result_
@@ -101,7 +95,7 @@ OvpnCryptoUninit(_In_ OvpnCryptoContext* cryptoContext);
_Must_inspect_result_
NTSTATUS
-OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA cryptoData);
+OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData);
_Must_inspect_result_
OvpnCryptoKeySlot*
diff --git a/peer.cpp b/peer.cpp
index e84612c..d8e0748 100644
--- a/peer.cpp
+++ b/peer.cpp
@@ -262,9 +262,35 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request)
}
POVPN_CRYPTO_DATA cryptoData = NULL;
+ OVPN_CRYPTO_DATA_V2 cryptoDataV2{};
NTSTATUS status;
GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA), (PVOID*)&cryptoData, nullptr));
+
+ RtlCopyMemory(&cryptoDataV2.V1, cryptoData, sizeof(OVPN_CRYPTO_DATA));
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&device->CryptoContext, &cryptoDataV2));
+
+done:
+ LOG_EXIT();
+
+ return status;
+}
+
+_Use_decl_annotations_
+NTSTATUS
+OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request)
+{
+ LOG_ENTER();
+
+ if (InterlockedCompareExchange(&device->UserspacePid, 0, 0) == 0) {
+ LOG_ERROR("Peer not added");
+ return STATUS_INVALID_DEVICE_REQUEST;
+ }
+
+ POVPN_CRYPTO_DATA_V2 cryptoData = NULL;
+ NTSTATUS status;
+
+ GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA_V2), (PVOID*)&cryptoData, nullptr));
GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&device->CryptoContext, cryptoData));
done:
diff --git a/peer.h b/peer.h
index 4eecca6..8906b0c 100644
--- a/peer.h
+++ b/peer.h
@@ -57,6 +57,11 @@ _Requires_exclusive_lock_held_(device->SpinLock)
NTSTATUS
OvpnPeerNewKey(_In_ POVPN_DEVICE device, WDFREQUEST request);
+_Must_inspect_result_
+_Requires_exclusive_lock_held_(device->SpinLock)
+NTSTATUS
+OvpnPeerNewKeyV2(_In_ POVPN_DEVICE device, WDFREQUEST request);
+
_Must_inspect_result_
_Requires_exclusive_lock_held_(device->SpinLock)
NTSTATUS
diff --git a/pktid.cpp b/pktid.cpp
index b7f365d..4d94ee3 100644
--- a/pktid.cpp
+++ b/pktid.cpp
@@ -28,24 +28,29 @@
#define PKTID_WRAP_WARN 0xf0000000ULL
_Use_decl_annotations_
-NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, UINT32* pktId)
+NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId, BOOLEAN pktId64bit)
{
ULONG64 seqNum = InterlockedIncrementNoFence64(&px->SeqNum);
- *pktId = (UINT32)seqNum;
- if (seqNum < PKTID_WRAP_WARN) {
- return STATUS_SUCCESS;
- }
- else {
- LOG_ERROR("Pktid wrapped");
- return STATUS_INTEGER_OVERFLOW;
- }
+ if (pktId64bit) {
+ *static_cast(pktId) = seqNum;
+ }
+ else
+ {
+ *static_cast(pktId) = static_cast(seqNum);
+ if (seqNum >= PKTID_WRAP_WARN) {
+ LOG_ERROR("Pktid wrapped");
+ return STATUS_INTEGER_OVERFLOW;
+ }
+ }
+
+ return STATUS_SUCCESS;
}
#define PKTID_RECV_EXPIRE ((30 * WDF_TIMEOUT_TO_SEC) / KeQueryTimeIncrement())
_Use_decl_annotations_
-NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId)
+NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT64 pktId)
{
LARGE_INTEGER now;
KeQueryTickCount(&now);
@@ -69,16 +74,16 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId)
}
else if (pktId > pr->Id) {
/* ID jumped forward by more than one */
- UINT32 delta = pktId - pr->Id;
+ const auto delta = pktId - pr->Id;
if (delta < REPLAY_WINDOW_SIZE) {
pr->Base = REPLAY_INDEX(pr->Base, -(INT32)delta);
pr->History[pr->Base / 8] |= (1 << (pr->Base % 8));
- pr->Extent += delta;
+ pr->Extent += static_cast(delta);
if (pr->Extent > REPLAY_WINDOW_SIZE)
pr->Extent = REPLAY_WINDOW_SIZE;
- for (UINT32 i = 1; i < delta; ++i) {
- unsigned int newb = REPLAY_INDEX(pr->Base, i);
+ for (auto i = 1; i < delta; ++i) {
+ const auto newb = REPLAY_INDEX(pr->Base, i);
pr->History[newb / 8] &= ~BIT(newb % 8);
}
@@ -93,10 +98,8 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId)
}
else {
/* ID backtrack */
- UINT32 delta = pr->Id - pktId;
+ const auto delta = pr->Id - pktId;
- if (delta > pr->MaxBacktrack)
- pr->MaxBacktrack = delta;
if (delta < pr->Extent) {
if (pktId > pr->IdFloor) {
UINT32 ri = REPLAY_INDEX(pr->Base, delta);
diff --git a/pktid.h b/pktid.h
index b0d2325..dcc4be8 100644
--- a/pktid.h
+++ b/pktid.h
@@ -50,17 +50,17 @@ struct OvpnPktidRecv
LARGE_INTEGER Expire;
/* highest sequence number received */
- UINT32 Id;
+ UINT64 Id;
/* we will only accept backtrack IDs > id_floor */
- UINT32 IdFloor;
- UINT32 MaxBacktrack;
+ UINT64 IdFloor;
};
/* Get the next packet ID for xmit */
-NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ UINT32* pktId);
+NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId, BOOLEAN pktId64bit);
+
/* Packet replay detection.
* Allows ID backtrack of up to REPLAY_WINDOW_SIZE - 1.
*/
-NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT32 pktId);
+NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT64 pktId);
diff --git a/rxqueue.cpp b/rxqueue.cpp
index bc0f8c0..794fa89 100644
--- a/rxqueue.cpp
+++ b/rxqueue.cpp
@@ -101,6 +101,10 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue)
POVPN_RXQUEUE queue = OvpnGetRxQueueContext(netPacketQueue);
OVPN_DEVICE* device = OvpnGetDeviceContext(queue->Adapter->WdfDevice);
+ BOOLEAN pktId64bit = device->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ BOOLEAN aeadTagEnd = device->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+ auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
+
NET_RING_FRAGMENT_ITERATOR fi = NetRingGetAllFragments(queue->Rings);
NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings);
while (NetFragmentIteratorHasAny(&fi)) {
@@ -115,7 +119,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue)
fragment->ValidLength = buffer->Len;
fragment->Offset = 0;
NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi));
- RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + device->CryptoContext.CryptoOverhead, buffer->Len);
+ RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + payloadOffset, buffer->Len);
InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len);
diff --git a/socket.cpp b/socket.cpp
index 765152f..a41c624 100644
--- a/socket.cpp
+++ b/socket.cpp
@@ -169,9 +169,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_
return;
}
- if (device->CryptoContext.Decrypt) {
+ OvpnCryptoContext* cryptoContext = &device->CryptoContext;
+
+ if (cryptoContext->Decrypt) {
UCHAR keyId = OvpnCryptoKeyIdExtract(op);
- OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(&device->CryptoContext, keyId);
+ OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, keyId);
if (!keySlot) {
status = STATUS_INVALID_DEVICE_STATE;
@@ -179,8 +181,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_
}
else {
// decrypt into plaintext buffer
- status = device->CryptoContext.Decrypt(keySlot, cipherTextBuf, len, buffer->Data);
- buffer->Len = len - device->CryptoContext.CryptoOverhead;
+ status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions);
+
+ auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4);
+ buffer->Len = len - cryptoOverhead;
}
}
else {
@@ -197,20 +202,23 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_
OvpnTimerResetRecv(device->Timer);
// points to the beginning of plaintext
- UCHAR* buf = buffer->Data + device->CryptoContext.CryptoOverhead;
+ BOOLEAN pktId64bit = device->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ BOOLEAN aeadTagEnd = device->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+ auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
+ UCHAR* plaintext = buffer->Data + payloadOffset;
// ping packet?
- if (OvpnTimerIsKeepaliveMessage(buf, buffer->Len)) {
+ if (OvpnTimerIsKeepaliveMessage(plaintext, buffer->Len)) {
LOG_INFO("Ping received");
// no need to inject ping packet into OS, return buffer to the pool
OvpnRxBufferPoolPut(buffer);
}
else {
- if (OvpnMssIsIPv4(buf, buffer->Len)) {
- OvpnMssDoIPv4(buf, buffer->Len, device->MSS);
- } else if (OvpnMssIsIPv6(buf, buffer->Len)) {
- OvpnMssDoIPv6(buf, buffer->Len, device->MSS);
+ if (OvpnMssIsIPv4(plaintext, buffer->Len)) {
+ OvpnMssDoIPv4(plaintext, buffer->Len, device->MSS);
+ } else if (OvpnMssIsIPv6(plaintext, buffer->Len)) {
+ OvpnMssDoIPv6(plaintext, buffer->Len, device->MSS);
}
// enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath
diff --git a/timer.cpp b/timer.cpp
index 215659a..2f5020a 100644
--- a/timer.cpp
+++ b/timer.cpp
@@ -63,12 +63,19 @@ static VOID OvpnTimerXmit(WDFTIMER timer)
RtlCopyMemory(OvpnTxBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage));
KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock);
- if (device->CryptoContext.Encrypt) {
+ OvpnCryptoContext* cryptoContext = &device->CryptoContext;
+ if (cryptoContext->Encrypt) {
// make space to crypto overhead
- OvpnTxBufferPush(buffer, device->CryptoContext.CryptoOverhead);
+ BOOLEAN pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ BOOLEAN aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+
+ OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN));
+ if (aeadTagEnd) {
+ OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN);
+ }
// in-place encrypt, always with primary key
- status = device->CryptoContext.Encrypt(&device->CryptoContext.Primary, buffer->Data, buffer->Len);
+ status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions);
}
else {
status = STATUS_INVALID_DEVICE_STATE;
diff --git a/txqueue.cpp b/txqueue.cpp
index d3887dd..d097cd7 100644
--- a/txqueue.cpp
+++ b/txqueue.cpp
@@ -83,12 +83,21 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET
InterlockedExchangeAddNoFence64(&device->Stats.TunBytesSent, buffer->Len);
- if (device->CryptoContext.Encrypt) {
+ OvpnCryptoContext* cryptoContext = &device->CryptoContext;
+
+ if (cryptoContext->Encrypt) {
+ auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+ auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+
// make space to crypto overhead
- OvpnTxBufferPush(buffer, device->CryptoContext.CryptoOverhead);
+ OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN));
+ if (aeadTagEnd)
+ {
+ OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN);
+ }
// in-place encrypt, always with primary key
- status = device->CryptoContext.Encrypt(&device->CryptoContext.Primary, buffer->Data, buffer->Len);
+ status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions);
}
else {
status = STATUS_INVALID_DEVICE_STATE;
@@ -140,7 +149,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue)
POVPN_TXQUEUE queue = OvpnGetTxQueueContext(netPacketQueue);
NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings);
POVPN_DEVICE device = OvpnGetDeviceContext(queue->Adapter->WdfDevice);
- bool packetSent = false;
+ BOOLEAN packetSent = false;
KIRQL kirql = ExAcquireSpinLockShared(&device->SpinLock);
diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h
index ea2a733..9e437f0 100644
--- a/uapi/ovpn-dco.h
+++ b/uapi/ovpn-dco.h
@@ -94,6 +94,14 @@ typedef struct _OVPN_CRYPTO_DATA {
int PeerId;
} OVPN_CRYPTO_DATA, * POVPN_CRYPTO_DATA;
+#define CRYPTO_OPTIONS_AEAD_TAG_END (1<<1)
+#define CRYPTO_OPTIONS_64BIT_PKTID (1<<2)
+
+typedef struct _OVPN_CRYPTO_DATA_V2 {
+ OVPN_CRYPTO_DATA V1;
+ UINT32 CryptoOptions;
+} OVPN_CRYPTO_DATA_V2, * POVPN_CRYPTO_DATA_V2;
+
typedef struct _OVPN_SET_PEER {
LONG KeepaliveInterval;
LONG KeepaliveTimeout;
@@ -114,3 +122,4 @@ typedef struct _OVPN_VERSION {
#define OVPN_IOCTL_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 6, METHOD_BUFFERED, FILE_ANY_ACCESS)
#define OVPN_IOCTL_DEL_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 7, METHOD_BUFFERED, FILE_ANY_ACCESS)
#define OVPN_IOCTL_GET_VERSION CTL_CODE(FILE_DEVICE_UNKNOWN, 8, METHOD_BUFFERED, FILE_ANY_ACCESS)
+#define OVPN_IOCTL_NEW_KEY_V2 CTL_CODE(FILE_DEVICE_UNKNOWN, 9, METHOD_BUFFERED, FILE_ANY_ACCESS)