diff --git a/Driver.cpp b/Driver.cpp
index 4c3ceec..4b4f0bc 100644
--- a/Driver.cpp
+++ b/Driver.cpp
@@ -269,6 +269,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe
ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
break;
+ case OVPN_IOCTL_NEW_KEY_V2:
+ kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
+ status = OvpnPeerNewKeyV2(device, request);
+ ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
+ break;
+
case OVPN_IOCTL_SWAP_KEYS:
kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
status = OvpnPeerSwapKeys(device);
diff --git a/Driver.h b/Driver.h
index 01df6c9..bf221ce 100644
--- a/Driver.h
+++ b/Driver.h
@@ -94,7 +94,6 @@ struct OVPN_DEVICE {
_Guarded_by_(SpinLock)
RTL_GENERIC_TABLE Peers;
- SIZE_T CryptoOverhead;
};
typedef OVPN_DEVICE * POVPN_DEVICE;
diff --git a/PropertySheet.props b/PropertySheet.props
index 61bdfe9..6056166 100644
--- a/PropertySheet.props
+++ b/PropertySheet.props
@@ -3,7 +3,7 @@
2
- 3
+ 4
0
diff --git a/bufferpool.h b/bufferpool.h
index fcde323..1ed0721 100644
--- a/bufferpool.h
+++ b/bufferpool.h
@@ -74,7 +74,6 @@ struct OVPN_RX_BUFFER
UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE];
};
-_Must_inspect_result_
UCHAR*
OvpnTxBufferPut(_In_ OVPN_TX_BUFFER* work, SIZE_T len);
diff --git a/crypto.cpp b/crypto.cpp
index 75005f4..0f915fa 100644
--- a/crypto.cpp
+++ b/crypto.cpp
@@ -48,11 +48,14 @@ OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId)
OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptNone;
_Use_decl_annotations_
-NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut)
+NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
{
UNREFERENCED_PARAMETER(keySlot);
- if (len < NONE_CRYPTO_OVERHEAD) {
+ BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ BOOLEAN cryptoOverhead = OVPN_DATA_V2_LEN + pktId64bit ? 8 : 4;
+
+ if (len < cryptoOverhead) {
LOG_WARN("Packet too short", TraceLoggingValue(len, "len"));
return STATUS_DATA_ERROR;
}
@@ -66,10 +69,11 @@ OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptNone;
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len)
+OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions)
{
UNREFERENCED_PARAMETER(keySlot);
UNREFERENCED_PARAMETER(len);
+ UNREFERENCED_PARAMETER(cryptoOptions);
// prepend with opcode, key-id and peer-id
UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, 0, 0);
@@ -121,74 +125,116 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL
static
NTSTATUS
-OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut)
+OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
{
/*
AEAD Nonce :
[Packet ID] [HMAC keying material]
- [4 bytes ] [8 bytes ]
+ [4/8 bytes] [8/4 bytes ]
[AEAD nonce total : 12 bytes ]
TLS wire protocol :
+ Packet ID is 8 bytes long with CRYPTO_OPTIONS_64BIT_PKTID.
+
[DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext]
- [4 bytes ] [4 bytes ] [16 bytes ]
+ [4 bytes ] [4/8 bytes] [16 bytes ]
+ [AEAD additional data(AD) ]
+
+ With CRYPTO_OPTIONS_AEAD_TAG_END AEAD Auth tag is placed after ciphertext:
+
+ [DATA_V2 opcode] [Packet ID] [ciphertext] [AEAD Auth tag]
+ [4 bytes ] [4/8 bytes] [16 bytes ]
[AEAD additional data(AD) ]
*/
NTSTATUS status = STATUS_SUCCESS;
- if (len < AEAD_CRYPTO_OVERHEAD) {
+ BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+
+ SIZE_T cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4);
+
+ if (len < cryptoOverhead) {
LOG_WARN("Packet too short", TraceLoggingValue(len, "len"));
return STATUS_DATA_ERROR;
}
- UCHAR nonce[OVPN_PKTID_LEN + OVPN_NONCE_TAIL_LEN];
+ UCHAR nonce[12];
if (encrypt) {
// prepend with opcode, key-id and peer-id
UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId);
op = RtlUlongByteSwap(op);
- *(UINT32*)(bufOut) = op;
+ *reinterpret_cast(bufOut) = op;
- // calculate pktid
- UINT32 pktid;
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid));
- ULONG pktidNetwork = RtlUlongByteSwap(pktid);
+ if (pktId64bit)
+ {
+ // calculate pktid
+ UINT64 pktid;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, true));
+ ULONG64 pktidNetwork = RtlUlonglongByteSwap(pktid);
+
+ // calculate nonce, which is pktid + nonce_tail
+ RtlCopyMemory(nonce, &pktidNetwork, 8);
+ RtlCopyMemory(nonce + 8, keySlot->EncNonceTail, 4);
+
+ // prepend with pktid
+ *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork;
+ }
+ else
+ {
+ // calculate pktid
+ UINT32 pktid;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, false));
+ ULONG pktidNetwork = RtlUlongByteSwap(pktid);
- // calculate nonce, which is pktid + nonce_tail
- RtlCopyMemory(nonce, &pktidNetwork, OVPN_PKTID_LEN);
- RtlCopyMemory(nonce + OVPN_PKTID_LEN, keySlot->EncNonceTail, OVPN_NONCE_TAIL_LEN);
+ // calculate nonce, which is pktid + nonce_tail
+ RtlCopyMemory(nonce, &pktidNetwork, 4);
+ RtlCopyMemory(nonce + 4, keySlot->EncNonceTail, 8);
- // prepend with pktid
- *(UINT32*)(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork;
+ // prepend with pktid
+ *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork;
+ }
}
else {
- RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, OVPN_PKTID_LEN);
- RtlCopyMemory(nonce + OVPN_PKTID_LEN, &keySlot->DecNonceTail, sizeof(keySlot->DecNonceTail));
+ ULONG64 pktId;
+
+ RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, pktId64bit ? 8 : 4);
+ RtlCopyMemory(nonce + (pktId64bit ? 8 : 4), &keySlot->DecNonceTail, pktId64bit ? 4 : 8);
+ if (pktId64bit)
+ {
+ pktId = RtlUlonglongByteSwap(*reinterpret_cast(nonce));
+ }
+ else
+ {
+ pktId = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce)));
+ }
- UINT32 pktId = RtlUlongByteSwap(*(UINT32*)nonce);
status = OvpnPktidRecvVerify(&keySlot->PktidRecv, pktId);
if (!NT_SUCCESS(status)) {
- LOG_ERROR("Invalid pktId", TraceLoggingUInt32(pktId, "pktId"));
+ LOG_ERROR("Invalid pktId", TraceLoggingUInt64(pktId, "pktId"));
return STATUS_DATA_ERROR;
}
}
+ // we prepended buf with crypto overhead
+ len -= cryptoOverhead;
+
+ BOOLEAN aeadTagEnd = cryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+
BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo;
BCRYPT_INIT_AUTH_MODE_INFO(authInfo);
authInfo.pbNonce = nonce;
authInfo.cbNonce = sizeof(nonce);
- authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + OVPN_PKTID_LEN;
+ authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? len : 0);
authInfo.cbTag = AEAD_AUTH_TAG_LEN;
authInfo.pbAuthData = (encrypt ? bufOut : bufIn);
- authInfo.cbAuthData = OVPN_DATA_V2_LEN + OVPN_PKTID_LEN;
-
- bufOut += AEAD_CRYPTO_OVERHEAD;
- bufIn += AEAD_CRYPTO_OVERHEAD;
+ authInfo.cbAuthData = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4);
- len -= AEAD_CRYPTO_OVERHEAD;
+ auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
+ bufOut += payloadOffset;
+ bufIn += payloadOffset;
// non-chaining mode
ULONG bytesDone = 0;
@@ -205,27 +251,29 @@ OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD;
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut)
+OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
{
- return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut);
+ return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut, cryptoOptions);
}
OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD;
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len)
+OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions)
{
- return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf);
+ return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf, cryptoOptions);
}
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData, BCRYPT_ALG_HANDLE algHandle)
+OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2, BCRYPT_ALG_HANDLE algHandle)
{
OvpnCryptoKeySlot* keySlot = NULL;
NTSTATUS status = STATUS_SUCCESS;
+ POVPN_CRYPTO_DATA cryptoData = &cryptoDataV2->V1;
+
if (cryptoData->KeySlot == OVPN_KEY_SLOT::OVPN_KEY_SLOT_PRIMARY) {
keySlot = &cryptoContext->Primary;
}
@@ -237,6 +285,15 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData,
return STATUS_INVALID_DEVICE_REQUEST;
}
+ if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID)
+ {
+ cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_64BIT_PKTID;
+ }
+ if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END)
+ {
+ cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_AEAD_TAG_END;
+ }
+
if ((cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM) || (cryptoData->CipherAlg == OVPN_CIPHER_ALG_CHACHA20_POLY1305)) {
// destroy previous keys
if (keySlot->EncKey) {
diff --git a/crypto.h b/crypto.h
index ee35cba..3782d42 100644
--- a/crypto.h
+++ b/crypto.h
@@ -29,14 +29,8 @@
#include "uapi\ovpn-dco.h"
#include "socket.h"
-#define AEAD_CRYPTO_OVERHEAD 24 // 4 + 4 + 16 data_v2 + pktid + auth_tag
-#define NONE_CRYPTO_OVERHEAD 8 // 4 + 4 data_v2 + pktid
-#define OVPN_PKTID_LEN 4
-#define OVPN_NONCE_TAIL_LEN 8
#define OVPN_DATA_V2_LEN 4
#define AEAD_AUTH_TAG_LEN 16
-#define AES_BLOCK_SIZE 16
-#define AES_GCM_NONCE_LEN 12
// packet opcode (high 5 bits) and key-id (low 3 bits) are combined in one byte
#define OVPN_OP_DATA_V2 9
@@ -63,7 +57,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL)
_Must_inspect_result_
typedef
NTSTATUS
-OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len);
+OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ INT32 CryptoOptions);
typedef OVPN_CRYPTO_ENCRYPT* POVPN_CRYPTO_ENCRYPT;
_Function_class_(OVPN_CRYPTO_DECRYPT)
@@ -71,7 +65,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL)
_Must_inspect_result_
typedef
NTSTATUS
-OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut);
+OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ INT32 CryptoOptions);
typedef OVPN_CRYPTO_DECRYPT* POVPN_CRYPTO_DECRYPT;
struct OvpnCryptoContext
@@ -82,7 +76,7 @@ struct OvpnCryptoContext
POVPN_CRYPTO_ENCRYPT Encrypt;
POVPN_CRYPTO_DECRYPT Decrypt;
- SIZE_T CryptoOverhead;
+ INT32 CryptoOptions;
};
_Must_inspect_result_
@@ -99,7 +93,7 @@ OvpnCryptoUninit(_In_ OvpnCryptoContext* cryptoContext);
_Must_inspect_result_
NTSTATUS
-OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle);
+OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle);
_Must_inspect_result_
OvpnCryptoKeySlot*
@@ -119,4 +113,4 @@ static inline
UCHAR OvpnCryptoOpcodeExtract(UCHAR op)
{
return op >> OVPN_OPCODE_SHIFT;
-}
\ No newline at end of file
+}
diff --git a/peer.cpp b/peer.cpp
index 3314fee..a2fa688 100644
--- a/peer.cpp
+++ b/peer.cpp
@@ -303,6 +303,31 @@ OvpnPeerStartVPN(POVPN_DEVICE device)
return status;
}
+static NTSTATUS
+OvpnPeerGetAlgHandle(POVPN_DEVICE device, OVPN_CIPHER_ALG cipherAlg, BCRYPT_ALG_HANDLE& algHandle)
+{
+ NTSTATUS status = STATUS_SUCCESS;
+
+ switch (cipherAlg) {
+ case OVPN_CIPHER_ALG_AES_GCM:
+ algHandle = device->AesAlgHandle;
+ break;
+
+ case OVPN_CIPHER_ALG_CHACHA20_POLY1305:
+ algHandle = device->ChachaAlgHandle;
+ if (algHandle == NULL) {
+ LOG_ERROR("CHACHA20-POLY1305 is not available");
+ status = STATUS_INVALID_DEVICE_REQUEST;
+ }
+ break;
+
+ default:
+ break;
+ }
+
+ return status;
+}
+
_Use_decl_annotations_
NTSTATUS
OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request)
@@ -311,44 +336,63 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request)
NTSTATUS status = STATUS_SUCCESS;
+ POVPN_CRYPTO_DATA cryptoData = NULL;
+ OVPN_CRYPTO_DATA_V2 cryptoDataV2{};
+
if (!OvpnHasPeers(device)) {
LOG_ERROR("Peer not added");
status = STATUS_INVALID_DEVICE_REQUEST;
goto done;
}
- POVPN_CRYPTO_DATA cryptoData = NULL;
-
GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA), (PVOID*)&cryptoData, nullptr));
BCRYPT_ALG_HANDLE algHandle = NULL;
- switch (cryptoData->CipherAlg) {
- case OVPN_CIPHER_ALG_AES_GCM:
- algHandle = device->AesAlgHandle;
- device->CryptoOverhead = AEAD_CRYPTO_OVERHEAD;
- break;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoData->CipherAlg, algHandle));
- case OVPN_CIPHER_ALG_CHACHA20_POLY1305:
- algHandle = device->ChachaAlgHandle;
- if (algHandle == NULL) {
- LOG_ERROR("CHACHA20-POLY1305 is not available");
- status = STATUS_INVALID_DEVICE_REQUEST;
- goto done;
- }
- device->CryptoOverhead = AEAD_CRYPTO_OVERHEAD;
+ OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers);
+ if (peer == NULL) {
+ status = STATUS_OBJECTID_NOT_FOUND;
+ goto done;
+ }
- default:
- device->CryptoOverhead = NONE_CRYPTO_OVERHEAD;
- break;
+ RtlCopyMemory(&cryptoDataV2.V1, cryptoData, sizeof(OVPN_CRYPTO_DATA));
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle));
+
+done:
+ LOG_EXIT();
+
+ return status;
+}
+
+_Use_decl_annotations_
+NTSTATUS
+OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request)
+{
+ LOG_ENTER();
+
+ NTSTATUS status = STATUS_SUCCESS;
+
+ POVPN_CRYPTO_DATA_V2 cryptoDataV2 = NULL;
+
+ if (!OvpnHasPeers(device)) {
+ LOG_ERROR("Peer not added");
+ status = STATUS_INVALID_DEVICE_REQUEST;
+ goto done;
}
+ GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA_V2), (PVOID*)&cryptoDataV2, nullptr));
+
+ BCRYPT_ALG_HANDLE algHandle = NULL;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoDataV2->V1.CipherAlg, algHandle));
+
OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers);
if (peer == NULL) {
status = STATUS_OBJECTID_NOT_FOUND;
goto done;
}
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoData, algHandle));
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle));
done:
LOG_EXIT();
diff --git a/peer.h b/peer.h
index e1eb5d4..fc780f8 100644
--- a/peer.h
+++ b/peer.h
@@ -85,6 +85,11 @@ _Requires_exclusive_lock_held_(device->SpinLock)
NTSTATUS
OvpnPeerNewKey(_In_ POVPN_DEVICE device, WDFREQUEST request);
+_Must_inspect_result_
+_Requires_exclusive_lock_held_(device->SpinLock)
+NTSTATUS
+OvpnPeerNewKeyV2(_In_ POVPN_DEVICE device, WDFREQUEST request);
+
_Must_inspect_result_
_Requires_exclusive_lock_held_(device->SpinLock)
NTSTATUS
diff --git a/pktid.cpp b/pktid.cpp
index b7f365d..4d94ee3 100644
--- a/pktid.cpp
+++ b/pktid.cpp
@@ -28,24 +28,29 @@
#define PKTID_WRAP_WARN 0xf0000000ULL
_Use_decl_annotations_
-NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, UINT32* pktId)
+NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId, BOOLEAN pktId64bit)
{
ULONG64 seqNum = InterlockedIncrementNoFence64(&px->SeqNum);
- *pktId = (UINT32)seqNum;
- if (seqNum < PKTID_WRAP_WARN) {
- return STATUS_SUCCESS;
- }
- else {
- LOG_ERROR("Pktid wrapped");
- return STATUS_INTEGER_OVERFLOW;
- }
+ if (pktId64bit) {
+ *static_cast(pktId) = seqNum;
+ }
+ else
+ {
+ *static_cast(pktId) = static_cast(seqNum);
+ if (seqNum >= PKTID_WRAP_WARN) {
+ LOG_ERROR("Pktid wrapped");
+ return STATUS_INTEGER_OVERFLOW;
+ }
+ }
+
+ return STATUS_SUCCESS;
}
#define PKTID_RECV_EXPIRE ((30 * WDF_TIMEOUT_TO_SEC) / KeQueryTimeIncrement())
_Use_decl_annotations_
-NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId)
+NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT64 pktId)
{
LARGE_INTEGER now;
KeQueryTickCount(&now);
@@ -69,16 +74,16 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId)
}
else if (pktId > pr->Id) {
/* ID jumped forward by more than one */
- UINT32 delta = pktId - pr->Id;
+ const auto delta = pktId - pr->Id;
if (delta < REPLAY_WINDOW_SIZE) {
pr->Base = REPLAY_INDEX(pr->Base, -(INT32)delta);
pr->History[pr->Base / 8] |= (1 << (pr->Base % 8));
- pr->Extent += delta;
+ pr->Extent += static_cast(delta);
if (pr->Extent > REPLAY_WINDOW_SIZE)
pr->Extent = REPLAY_WINDOW_SIZE;
- for (UINT32 i = 1; i < delta; ++i) {
- unsigned int newb = REPLAY_INDEX(pr->Base, i);
+ for (auto i = 1; i < delta; ++i) {
+ const auto newb = REPLAY_INDEX(pr->Base, i);
pr->History[newb / 8] &= ~BIT(newb % 8);
}
@@ -93,10 +98,8 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId)
}
else {
/* ID backtrack */
- UINT32 delta = pr->Id - pktId;
+ const auto delta = pr->Id - pktId;
- if (delta > pr->MaxBacktrack)
- pr->MaxBacktrack = delta;
if (delta < pr->Extent) {
if (pktId > pr->IdFloor) {
UINT32 ri = REPLAY_INDEX(pr->Base, delta);
diff --git a/pktid.h b/pktid.h
index b0d2325..dcc4be8 100644
--- a/pktid.h
+++ b/pktid.h
@@ -50,17 +50,17 @@ struct OvpnPktidRecv
LARGE_INTEGER Expire;
/* highest sequence number received */
- UINT32 Id;
+ UINT64 Id;
/* we will only accept backtrack IDs > id_floor */
- UINT32 IdFloor;
- UINT32 MaxBacktrack;
+ UINT64 IdFloor;
};
/* Get the next packet ID for xmit */
-NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ UINT32* pktId);
+NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId, BOOLEAN pktId64bit);
+
/* Packet replay detection.
* Allows ID backtrack of up to REPLAY_WINDOW_SIZE - 1.
*/
-NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT32 pktId);
+NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT64 pktId);
diff --git a/rxqueue.cpp b/rxqueue.cpp
index ce7d71b..d7687c3 100644
--- a/rxqueue.cpp
+++ b/rxqueue.cpp
@@ -27,6 +27,7 @@
#include "driver.h"
#include "bufferpool.h"
+#include "peer.h"
#include "rxqueue.h"
#include "netringiterator.h"
#include "trace.h"
@@ -101,6 +102,16 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue)
POVPN_RXQUEUE queue = OvpnGetRxQueueContext(netPacketQueue);
OVPN_DEVICE* device = OvpnGetDeviceContext(queue->Adapter->WdfDevice);
+ OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers);
+ if (peer == NULL) {
+ LOG_WARN("No peer");
+ return;
+ }
+
+ BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+ auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
+
NET_RING_FRAGMENT_ITERATOR fi = NetRingGetAllFragments(queue->Rings);
NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings);
while (NetFragmentIteratorHasAny(&fi)) {
@@ -115,7 +126,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue)
fragment->ValidLength = buffer->Len;
fragment->Offset = 0;
NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi));
- RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + device->CryptoOverhead, buffer->Len);
+ RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + payloadOffset, buffer->Len);
InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len);
diff --git a/socket.cpp b/socket.cpp
index 1b3ca42..2fcf2d7 100644
--- a/socket.cpp
+++ b/socket.cpp
@@ -178,9 +178,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_
return;
}
- if (peer->CryptoContext.Decrypt) {
+ OvpnCryptoContext* cryptoContext = &peer->CryptoContext;
+
+ if (cryptoContext->Decrypt) {
UCHAR keyId = OvpnCryptoKeyIdExtract(op);
- OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(&peer->CryptoContext, keyId);
+ OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, keyId);
if (!keySlot) {
status = STATUS_INVALID_DEVICE_STATE;
@@ -188,8 +190,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_
}
else {
// decrypt into plaintext buffer
- status = peer->CryptoContext.Decrypt(keySlot, cipherTextBuf, len, buffer->Data);
- buffer->Len = len - device->CryptoOverhead;
+ status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions);
+
+ auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4);
+ buffer->Len = len - cryptoOverhead;
}
}
else {
@@ -206,20 +211,23 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_
OvpnTimerResetRecv(peer->Timer);
// points to the beginning of plaintext
- UCHAR* buf = buffer->Data + device->CryptoOverhead;
+ BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+ auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
+ UCHAR* plaintext = buffer->Data + payloadOffset;
// ping packet?
- if (OvpnTimerIsKeepaliveMessage(buf, buffer->Len)) {
+ if (OvpnTimerIsKeepaliveMessage(plaintext, buffer->Len)) {
LOG_INFO("Ping received");
// no need to inject ping packet into OS, return buffer to the pool
OvpnRxBufferPoolPut(buffer);
}
else {
- if (OvpnMssIsIPv4(buf, buffer->Len)) {
- OvpnMssDoIPv4(buf, buffer->Len, device->MSS);
- } else if (OvpnMssIsIPv6(buf, buffer->Len)) {
- OvpnMssDoIPv6(buf, buffer->Len, device->MSS);
+ if (OvpnMssIsIPv4(plaintext, buffer->Len)) {
+ OvpnMssDoIPv4(plaintext, buffer->Len, device->MSS);
+ } else if (OvpnMssIsIPv6(plaintext, buffer->Len)) {
+ OvpnMssDoIPv6(plaintext, buffer->Len, device->MSS);
}
// enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath
diff --git a/timer.cpp b/timer.cpp
index 6ce4c95..95eeaa1 100644
--- a/timer.cpp
+++ b/timer.cpp
@@ -73,12 +73,16 @@ static VOID OvpnTimerXmit(WDFTIMER timer)
OvpnPeerContext* peer = timerCtx->Peer;
KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock);
- if (peer->CryptoContext.Encrypt) {
+ OvpnCryptoContext* cryptoContext = &peer->CryptoContext;
+ if (cryptoContext->Encrypt) {
// make space to crypto overhead
- OvpnTxBufferPush(buffer, device->CryptoOverhead);
+ BOOLEAN pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+ BOOLEAN aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+
+ OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN));
// in-place encrypt, always with primary key
- status = peer->CryptoContext.Encrypt(&peer->CryptoContext.Primary, buffer->Data, buffer->Len);
+ status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions);
}
else {
status = STATUS_INVALID_DEVICE_STATE;
diff --git a/txqueue.cpp b/txqueue.cpp
index 7f49038..feec4c5 100644
--- a/txqueue.cpp
+++ b/txqueue.cpp
@@ -93,12 +93,21 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET
InterlockedExchangeAddNoFence64(&device->Stats.TunBytesSent, buffer->Len);
- if (peer->CryptoContext.Encrypt) {
+ OvpnCryptoContext* cryptoContext = &peer->CryptoContext;
+
+ if (cryptoContext->Encrypt) {
+ auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+ auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+
// make space to crypto overhead
- OvpnTxBufferPush(buffer, device->CryptoOverhead);
+ OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN));
+ if (aeadTagEnd)
+ {
+ OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN);
+ }
// in-place encrypt, always with primary key
- status = peer->CryptoContext.Encrypt(&peer->CryptoContext.Primary, buffer->Data, buffer->Len);
+ status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions);
}
else {
status = STATUS_INVALID_DEVICE_STATE;
@@ -152,7 +161,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue)
POVPN_TXQUEUE queue = OvpnGetTxQueueContext(netPacketQueue);
NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings);
POVPN_DEVICE device = OvpnGetDeviceContext(queue->Adapter->WdfDevice);
- bool packetSent = false;
+ BOOLEAN packetSent = false;
KIRQL kirql = ExAcquireSpinLockShared(&device->SpinLock);
diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h
index ea2a733..9e437f0 100644
--- a/uapi/ovpn-dco.h
+++ b/uapi/ovpn-dco.h
@@ -94,6 +94,14 @@ typedef struct _OVPN_CRYPTO_DATA {
int PeerId;
} OVPN_CRYPTO_DATA, * POVPN_CRYPTO_DATA;
+#define CRYPTO_OPTIONS_AEAD_TAG_END (1<<1)
+#define CRYPTO_OPTIONS_64BIT_PKTID (1<<2)
+
+typedef struct _OVPN_CRYPTO_DATA_V2 {
+ OVPN_CRYPTO_DATA V1;
+ UINT32 CryptoOptions;
+} OVPN_CRYPTO_DATA_V2, * POVPN_CRYPTO_DATA_V2;
+
typedef struct _OVPN_SET_PEER {
LONG KeepaliveInterval;
LONG KeepaliveTimeout;
@@ -114,3 +122,4 @@ typedef struct _OVPN_VERSION {
#define OVPN_IOCTL_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 6, METHOD_BUFFERED, FILE_ANY_ACCESS)
#define OVPN_IOCTL_DEL_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 7, METHOD_BUFFERED, FILE_ANY_ACCESS)
#define OVPN_IOCTL_GET_VERSION CTL_CODE(FILE_DEVICE_UNKNOWN, 8, METHOD_BUFFERED, FILE_ANY_ACCESS)
+#define OVPN_IOCTL_NEW_KEY_V2 CTL_CODE(FILE_DEVICE_UNKNOWN, 9, METHOD_BUFFERED, FILE_ANY_ACCESS)