Skip to content

Commit

Permalink
[build] support mbedtls 3.x
Browse files Browse the repository at this point in the history
Support building with mbedtls 3.x which is required for different
platforms.
  • Loading branch information
wgtdkp committed Apr 2, 2024
1 parent f63ec73 commit dab686d
Show file tree
Hide file tree
Showing 15 changed files with 331 additions and 90 deletions.
11 changes: 11 additions & 0 deletions include/commissioner/error.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define OT_COMM_ERROR_HPP_

#include <memory>
#include <ostream>
#include <string>

#include <commissioner/defines.hpp>
Expand Down Expand Up @@ -272,6 +273,16 @@ inline bool operator!=(const ErrorCode &aErrorCode, const Error &aError)
return !(aErrorCode == aError);
}

/**
* Allows pretty-print in unit tests.
*
* See https://google.github.io/googletest/advanced.html#teaching-googletest-how-to-print-your-values
*/
inline void PrintTo(const Error &error, std::ostream *os)
{
*os << error.ToString();
}

} // namespace commissioner

} // namespace ot
Expand Down
72 changes: 66 additions & 6 deletions src/library/coap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ Error Message::AppendOption(OptionType aNumber, const OptionValue &aValue)
{
Error error;

VerifyOrExit(IsValidOption(aNumber, aValue),
error = ERROR_INVALID_ARGS("invalid CoAP option (number={})", aNumber));
VerifyOrExit(IsValidOption(aNumber, aValue), error = ERROR_INVALID_ARGS("invalid CoAP option {}", aNumber));

if (aNumber == OptionType::kUriPath)
{
Expand All @@ -219,7 +218,7 @@ Error Message::GetOption(std::string &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetStringValue();

Expand All @@ -232,7 +231,7 @@ Error Message::GetOption(uint32_t &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetUint32Value();

Expand All @@ -245,7 +244,7 @@ Error Message::GetOption(ByteArray &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetOpaqueValue();

Expand Down Expand Up @@ -355,7 +354,7 @@ Error Message::Serialize(OptionType aOptionNumber,
VerifyOrDie(utils::to_underlying(aOptionNumber) >= aLastOptionNumber);

VerifyOrExit(IsValidOption(aOptionNumber, aOptionValue),
error = ERROR_INVALID_ARGS("option (number={}) is not valid", aOptionNumber));
error = ERROR_INVALID_ARGS("option {} is not valid", aOptionNumber));

length = 1;
length += delta < kOption1ByteExtension ? 0 : (delta < kOption2ByteExtension ? 1 : 2);
Expand Down Expand Up @@ -1287,3 +1286,64 @@ std::string Message::GetRequestUri(void) const
} // namespace commissioner

} // namespace ot

auto fmt::formatter<ot::commissioner::coap::OptionType>::format(ot::commissioner::coap::OptionType optionType,
format_context &ctx) -> decltype(ctx.out())
{
using ot::commissioner::coap::OptionType;
string_view name;
switch (optionType)
{
case OptionType::kIfMatch:
name = "kIfMatch";
break;
case OptionType::kUriHost:
name = "kUriHost";
break;
case OptionType::kETag:
name = "kETag";
break;
case OptionType::kIfNonMatch:
name = "kIfNonMatch";
break;
case OptionType::kObserve:
name = "kObserve";
break;
case OptionType::kUriPort:
name = "kUriPort";
break;
case OptionType::kLocationPath:
name = "kLocationPath";
break;
case OptionType::kUriPath:
name = "kUriPath";
break;
case OptionType::kContentFormat:
name = "kContentFormat";
break;
case OptionType::kMaxAge:
name = "kMaxAge";
break;
case OptionType::kUriQuery:
name = "kUriQuery";
break;
case OptionType::kAccept:
name = "kAccept";
break;
case OptionType::kLocationQuery:
name = "kLocationQuery";
break;
case OptionType::kProxyUri:
name = "kProxyUri";
break;
case OptionType::kProxyScheme:
name = "kProxyScheme";
break;
case OptionType::kSize1:
name = "kSize1";
break;
default:
name = "unknown";
}
return formatter<string_view>::format(name, ctx);
}
8 changes: 8 additions & 0 deletions src/library/coap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include <queue>
#include <set>

#include <fmt/format.h>

#include <commissioner/defines.hpp>
#include <commissioner/error.hpp>

Expand Down Expand Up @@ -781,4 +783,10 @@ class Coap

} // namespace ot

/** Makes `OptionType` formattable as a string. */
template <> struct fmt::formatter<ot::commissioner::coap::OptionType> : formatter<string_view>
{
auto format(ot::commissioner::coap::OptionType optionType, format_context &ctx) -> decltype(ctx.out());
};

#endif // OT_COMM_LIBRARY_COAP_HPP_
1 change: 1 addition & 0 deletions src/library/cose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
*/

#if OT_COMM_CONFIG_CCM_ENABLE
#define MBEDTLS_ALLOW_PRIVATE_ACCESS

#include "library/cose.hpp"

Expand Down
29 changes: 20 additions & 9 deletions src/library/cose_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ static const char kPrivateKey[] = "-----BEGIN PRIVATE KEY-----\r\n"
"Xvr27euqi54WCMXJEMk6IIaPyFBNNw8bJvqXWfZ5g7t4hj7amsvqUST2\r\n"
"-----END PRIVATE KEY-----\r\n";

static Error ParsePublicKey(mbedtls_pk_context &aPublicKey, const ByteArray &aCert)
{
TokenManager tokenManager{event_base_new()};

return tokenManager.ParsePublicKey(aPublicKey, aCert);
}

static Error ParsePrivateKey(mbedtls_pk_context &aPrivateKey, const ByteArray &aPrivateKeyRaw)
{
TokenManager tokenManager{event_base_new()};

return tokenManager.ParsePrivateKey(aPrivateKey, aPrivateKeyRaw);
}

TEST(CoseTest, CoseSignAndVerify_SignWithoutExternalData)
{
ByteArray content{1, 2, 3, 4, 5, 6};
Expand All @@ -76,10 +90,9 @@ TEST(CoseTest, CoseSignAndVerify_SignWithoutExternalData)
mbedtls_pk_init(&publicKey);
mbedtls_pk_init(&privateKey);

EXPECT_EQ(TokenManager::ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(TokenManager::ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}),
EXPECT_EQ(ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}), ErrorCode::kNone);

ByteArray signature;
Sign1Message msg;
Expand Down Expand Up @@ -107,10 +120,9 @@ TEST(CoseTest, CoseSignAndVerify_SignWithExternalData)
mbedtls_pk_init(&publicKey);
mbedtls_pk_init(&privateKey);

EXPECT_EQ(TokenManager::ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(TokenManager::ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}),
EXPECT_EQ(ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}), ErrorCode::kNone);

ByteArray signature;
Sign1Message msg;
Expand Down Expand Up @@ -140,10 +152,9 @@ TEST(CoseTest, CoseSignAndVerify_KeyConstruction)
mbedtls_pk_init(&publicKey);
mbedtls_pk_init(&privateKey);

EXPECT_EQ(TokenManager::ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(TokenManager::ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}),
EXPECT_EQ(ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}), ErrorCode::kNone);

ByteArray keyId = {};
ByteArray encodedCoseKey;
Expand Down
74 changes: 48 additions & 26 deletions src/library/dtls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
* This file implements wrapper of mbedtls.
*/

#define MBEDTLS_ALLOW_PRIVATE_ACCESS

#include "library/dtls.hpp"

#include <mbedtls/debug.h>
Expand All @@ -48,7 +50,7 @@ namespace ot {
namespace commissioner {

static const int kAuthMode = MBEDTLS_SSL_VERIFY_REQUIRED;
static const size_t kMaxContentLength = MBEDTLS_SSL_MAX_CONTENT_LEN;
static const size_t kMaxContentLength = MBEDTLS_SSL_IN_CONTENT_LEN;
static const size_t KMaxFragmentLengthCode = MBEDTLS_SSL_MAX_FRAG_LEN_1024;
static const size_t kMaxTransmissionUnit = 1280;

Expand Down Expand Up @@ -110,8 +112,9 @@ void DtlsSession::InitMbedtls()
{
mbedtls_ssl_config_init(&mConfig);
mbedtls_ssl_cookie_init(&mCookie);
mbedtls_ctr_drbg_init(&mCtrDrbg);
mbedtls_entropy_init(&mEntropy);
mbedtls_ctr_drbg_init(&mCtrDrbg);
mbedtls_ctr_drbg_seed(&mCtrDrbg, mbedtls_entropy_func, &mEntropy, nullptr, 0);
mbedtls_ssl_init(&mSsl);

mbedtls_x509_crt_init(&mCaChain);
Expand Down Expand Up @@ -175,7 +178,8 @@ Error DtlsSession::Init(const DtlsConfig &aConfig)
{
ExitNow(error = ERROR_INVALID_ARGS("bad certificate; {}", ErrorFromMbedtlsError(fail).GetMessage()));
}
if (int fail = mbedtls_pk_parse_key(&mOwnKey, &aConfig.mOwnKey[0], aConfig.mOwnKey.size(), nullptr, 0))
if (int fail = mbedtls_pk_parse_key(&mOwnKey, &aConfig.mOwnKey[0], aConfig.mOwnKey.size(), nullptr, 0,
mbedtls_ctr_drbg_random, &mCtrDrbg))
{
ExitNow(error = ERROR_INVALID_ARGS("bad private key; {}", ErrorFromMbedtlsError(fail).GetMessage()));
}
Expand All @@ -192,13 +196,7 @@ Error DtlsSession::Init(const DtlsConfig &aConfig)
mCipherSuites.push_back(0);
mbedtls_ssl_conf_ciphersuites(&mConfig, &mCipherSuites[0]);

mbedtls_ssl_conf_export_keys_cb(&mConfig, HandleMbedtlsExportKeys, this);

// RNG & Entropy
if (int fail = mbedtls_ctr_drbg_seed(&mCtrDrbg, mbedtls_entropy_func, &mEntropy, nullptr, 0))
{
ExitNow(error = ErrorFromMbedtlsError(fail));
}
mbedtls_ssl_set_export_keys_cb(&mSsl, HandleMbedtlsExportKeys, this);
mbedtls_ssl_conf_rng(&mConfig, mbedtls_ctr_drbg_random, &mCtrDrbg);

// Cookie
Expand Down Expand Up @@ -335,35 +333,59 @@ std::string DtlsSession::GetStateString() const
return stateString;
}

int DtlsSession::HandleMbedtlsExportKeys(void *aDtlsSession,
const unsigned char *aMasterSecret,
const unsigned char *aKeyBlock,
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength)
#if OT_COMM_CONFIG_CCM_ENABLE
const mbedtls_x509_crt *DtlsSession::GetPeerCertificate() const
{
return mSsl.session ? mSsl.session->peer_cert : nullptr;
}
#endif

void DtlsSession::HandleMbedtlsExportKeys(void *aDtlsSession,
mbedtls_ssl_key_export_type aType,
const unsigned char *aMasterSecret,
size_t aMasterSecretLen,
const unsigned char aClientRandom[32],
const unsigned char aServerRandom[32],
mbedtls_tls_prf_types aTlsPrfType)
{
auto dtlsSession = reinterpret_cast<DtlsSession *>(aDtlsSession);
return dtlsSession->HandleMbedtlsExportKeys(aMasterSecret, aKeyBlock, aMacLength, aKeyLength, aIvLength);
dtlsSession->HandleMbedtlsExportKeys(aType, aMasterSecret, aMasterSecretLen, aClientRandom, aServerRandom,
aTlsPrfType);
}

int DtlsSession::HandleMbedtlsExportKeys(const unsigned char *,
const unsigned char *aKeyBlock,
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength)
void DtlsSession::HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,
const unsigned char *aMasterSecret,
size_t aMasterSecretLen,
const unsigned char aClientRandom[32],
const unsigned char aServerRandom[32],
mbedtls_tls_prf_types aTlsPrfType)
{
Sha256 sha256;
Sha256 sha256;
unsigned char keyBlock[kKeyBlockSize];
unsigned char randBytes[2 * kRandomBufferSize];

VerifyOrExit(!mPSK.empty());
VerifyOrExit(aType == MBEDTLS_SSL_KEY_EXPORT_TLS12_MASTER_SECRET);

memcpy(randBytes, aServerRandom, kRandomBufferSize);
memcpy(randBytes + kRandomBufferSize, aClientRandom, kRandomBufferSize);

// Retrieve the Key block from Master secret
mbedtls_ssl_tls_prf(aTlsPrfType, aMasterSecret, aMasterSecretLen, "key expansion", randBytes, sizeof(randBytes),
keyBlock, sizeof(keyBlock));

sha256.Start();
sha256.Update(aKeyBlock, 2 * static_cast<uint16_t>(aMacLength + aKeyLength + aIvLength));
sha256.Update(keyBlock, kKeyBlockSize);

mKek.resize(Sha256::kHashSize);
sha256.Finish(&mKek[0]);
sha256.Finish(mKek.data());

static_assert(Sha256::kHashSize >= kJoinerRouterKekLength, "Sha256::kHashSize >= kJoinerRouterKekLength");
mKek.resize(kJoinerRouterKekLength);

return 0;
LOG_DEBUG(LOG_REGION_DTLS, "Generated KEK");

exit:;
}

void DtlsSession::HandleEvent(short aFlags)
Expand Down
Loading

0 comments on commit dab686d

Please sign in to comment.