Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[build] support mbedtls 3.x #249

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/commissioner/error.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define OT_COMM_ERROR_HPP_

#include <memory>
#include <ostream>
#include <string>

#include <commissioner/defines.hpp>
Expand Down Expand Up @@ -272,6 +273,16 @@ inline bool operator!=(const ErrorCode &aErrorCode, const Error &aError)
return !(aErrorCode == aError);
}

/**
* Allows pretty-print in unit tests.
*
* See https://google.github.io/googletest/advanced.html#teaching-googletest-how-to-print-your-values
*/
inline void PrintTo(const Error &error, std::ostream *os)
{
*os << error.ToString();
}

} // namespace commissioner

} // namespace ot
Expand Down
72 changes: 66 additions & 6 deletions src/library/coap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ Error Message::AppendOption(OptionType aNumber, const OptionValue &aValue)
{
Error error;

VerifyOrExit(IsValidOption(aNumber, aValue),
error = ERROR_INVALID_ARGS("invalid CoAP option (number={})", aNumber));
VerifyOrExit(IsValidOption(aNumber, aValue), error = ERROR_INVALID_ARGS("invalid CoAP option {}", aNumber));

if (aNumber == OptionType::kUriPath)
{
Expand All @@ -219,7 +218,7 @@ Error Message::GetOption(std::string &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetStringValue();

Expand All @@ -232,7 +231,7 @@ Error Message::GetOption(uint32_t &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetUint32Value();

Expand All @@ -245,7 +244,7 @@ Error Message::GetOption(ByteArray &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetOpaqueValue();

Expand Down Expand Up @@ -355,7 +354,7 @@ Error Message::Serialize(OptionType aOptionNumber,
VerifyOrDie(utils::to_underlying(aOptionNumber) >= aLastOptionNumber);

VerifyOrExit(IsValidOption(aOptionNumber, aOptionValue),
error = ERROR_INVALID_ARGS("option (number={}) is not valid", aOptionNumber));
error = ERROR_INVALID_ARGS("option {} is not valid", aOptionNumber));

length = 1;
length += delta < kOption1ByteExtension ? 0 : (delta < kOption2ByteExtension ? 1 : 2);
Expand Down Expand Up @@ -1287,3 +1286,64 @@ std::string Message::GetRequestUri(void) const
} // namespace commissioner

} // namespace ot

auto fmt::formatter<ot::commissioner::coap::OptionType>::format(ot::commissioner::coap::OptionType optionType,
format_context &ctx) -> decltype(ctx.out())
{
using ot::commissioner::coap::OptionType;
string_view name;
switch (optionType)
{
case OptionType::kIfMatch:
name = "kIfMatch";
break;
case OptionType::kUriHost:
name = "kUriHost";
break;
case OptionType::kETag:
name = "kETag";
break;
case OptionType::kIfNonMatch:
name = "kIfNonMatch";
break;
case OptionType::kObserve:
name = "kObserve";
break;
case OptionType::kUriPort:
name = "kUriPort";
break;
case OptionType::kLocationPath:
name = "kLocationPath";
break;
case OptionType::kUriPath:
name = "kUriPath";
break;
case OptionType::kContentFormat:
name = "kContentFormat";
break;
case OptionType::kMaxAge:
name = "kMaxAge";
break;
case OptionType::kUriQuery:
name = "kUriQuery";
break;
case OptionType::kAccept:
name = "kAccept";
break;
case OptionType::kLocationQuery:
name = "kLocationQuery";
break;
case OptionType::kProxyUri:
name = "kProxyUri";
break;
case OptionType::kProxyScheme:
name = "kProxyScheme";
break;
case OptionType::kSize1:
name = "kSize1";
break;
default:
name = "unknown";
}
return formatter<string_view>::format(name, ctx);
}
8 changes: 8 additions & 0 deletions src/library/coap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include <queue>
#include <set>

#include <fmt/format.h>

#include <commissioner/defines.hpp>
#include <commissioner/error.hpp>

Expand Down Expand Up @@ -781,4 +783,10 @@ class Coap

} // namespace ot

/** Makes `OptionType` formattable as a string. */
template <> struct fmt::formatter<ot::commissioner::coap::OptionType> : formatter<string_view>
{
auto format(ot::commissioner::coap::OptionType optionType, format_context &ctx) -> decltype(ctx.out());
};

#endif // OT_COMM_LIBRARY_COAP_HPP_
5 changes: 5 additions & 0 deletions src/library/cose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
*/

#if OT_COMM_CONFIG_CCM_ENABLE
// Private fields in mbedtls structs are not stable API after 3.x.
// Enabling this flag to allow us continue accessing those fields.
#ifndef MBEDTLS_ALLOW_PRIVATE_ACCESS
#define MBEDTLS_ALLOW_PRIVATE_ACCESS
#endif

#include "library/cose.hpp"

Expand Down
29 changes: 20 additions & 9 deletions src/library/cose_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ static const char kPrivateKey[] = "-----BEGIN PRIVATE KEY-----\r\n"
"Xvr27euqi54WCMXJEMk6IIaPyFBNNw8bJvqXWfZ5g7t4hj7amsvqUST2\r\n"
"-----END PRIVATE KEY-----\r\n";

static Error ParsePublicKey(mbedtls_pk_context &aPublicKey, const ByteArray &aCert)
{
TokenManager tokenManager{event_base_new()};

return tokenManager.ParsePublicKey(aPublicKey, aCert);
}

static Error ParsePrivateKey(mbedtls_pk_context &aPrivateKey, const ByteArray &aPrivateKeyRaw)
{
TokenManager tokenManager{event_base_new()};

return tokenManager.ParsePrivateKey(aPrivateKey, aPrivateKeyRaw);
}

TEST(CoseTest, CoseSignAndVerify_SignWithoutExternalData)
{
ByteArray content{1, 2, 3, 4, 5, 6};
Expand All @@ -76,10 +90,9 @@ TEST(CoseTest, CoseSignAndVerify_SignWithoutExternalData)
mbedtls_pk_init(&publicKey);
mbedtls_pk_init(&privateKey);

EXPECT_EQ(TokenManager::ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(TokenManager::ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}),
EXPECT_EQ(ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}), ErrorCode::kNone);

ByteArray signature;
Sign1Message msg;
Expand Down Expand Up @@ -107,10 +120,9 @@ TEST(CoseTest, CoseSignAndVerify_SignWithExternalData)
mbedtls_pk_init(&publicKey);
mbedtls_pk_init(&privateKey);

EXPECT_EQ(TokenManager::ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(TokenManager::ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}),
EXPECT_EQ(ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}), ErrorCode::kNone);

ByteArray signature;
Sign1Message msg;
Expand Down Expand Up @@ -140,10 +152,9 @@ TEST(CoseTest, CoseSignAndVerify_KeyConstruction)
mbedtls_pk_init(&publicKey);
mbedtls_pk_init(&privateKey);

EXPECT_EQ(TokenManager::ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(TokenManager::ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}),
EXPECT_EQ(ParsePublicKey(publicKey, ByteArray{kCertificate, kCertificate + sizeof(kCertificate)}),
ErrorCode::kNone);
EXPECT_EQ(ParsePrivateKey(privateKey, ByteArray{kPrivateKey, kPrivateKey + sizeof(kPrivateKey)}), ErrorCode::kNone);

ByteArray keyId = {};
ByteArray encodedCoseKey;
Expand Down
78 changes: 52 additions & 26 deletions src/library/dtls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
* This file implements wrapper of mbedtls.
*/

// Private fields in mbedtls structs are not stable API after 3.x.
// Enabling this flag to allow us continue accessing those fields.
#ifndef MBEDTLS_ALLOW_PRIVATE_ACCESS
#define MBEDTLS_ALLOW_PRIVATE_ACCESS
#endif

#include "library/dtls.hpp"

#include <mbedtls/debug.h>
Expand All @@ -48,7 +54,7 @@ namespace ot {
namespace commissioner {

static const int kAuthMode = MBEDTLS_SSL_VERIFY_REQUIRED;
static const size_t kMaxContentLength = MBEDTLS_SSL_MAX_CONTENT_LEN;
static const size_t kMaxContentLength = MBEDTLS_SSL_IN_CONTENT_LEN;
static const size_t KMaxFragmentLengthCode = MBEDTLS_SSL_MAX_FRAG_LEN_1024;
static const size_t kMaxTransmissionUnit = 1280;

Expand Down Expand Up @@ -110,8 +116,9 @@ void DtlsSession::InitMbedtls()
{
mbedtls_ssl_config_init(&mConfig);
mbedtls_ssl_cookie_init(&mCookie);
mbedtls_ctr_drbg_init(&mCtrDrbg);
mbedtls_entropy_init(&mEntropy);
mbedtls_ctr_drbg_init(&mCtrDrbg);
mbedtls_ctr_drbg_seed(&mCtrDrbg, mbedtls_entropy_func, &mEntropy, nullptr, 0);
mbedtls_ssl_init(&mSsl);

mbedtls_x509_crt_init(&mCaChain);
Expand Down Expand Up @@ -175,7 +182,8 @@ Error DtlsSession::Init(const DtlsConfig &aConfig)
{
ExitNow(error = ERROR_INVALID_ARGS("bad certificate; {}", ErrorFromMbedtlsError(fail).GetMessage()));
}
if (int fail = mbedtls_pk_parse_key(&mOwnKey, &aConfig.mOwnKey[0], aConfig.mOwnKey.size(), nullptr, 0))
if (int fail = mbedtls_pk_parse_key(&mOwnKey, &aConfig.mOwnKey[0], aConfig.mOwnKey.size(), nullptr, 0,
mbedtls_ctr_drbg_random, &mCtrDrbg))
{
ExitNow(error = ERROR_INVALID_ARGS("bad private key; {}", ErrorFromMbedtlsError(fail).GetMessage()));
}
Expand All @@ -192,13 +200,7 @@ Error DtlsSession::Init(const DtlsConfig &aConfig)
mCipherSuites.push_back(0);
mbedtls_ssl_conf_ciphersuites(&mConfig, &mCipherSuites[0]);

mbedtls_ssl_conf_export_keys_cb(&mConfig, HandleMbedtlsExportKeys, this);

// RNG & Entropy
if (int fail = mbedtls_ctr_drbg_seed(&mCtrDrbg, mbedtls_entropy_func, &mEntropy, nullptr, 0))
{
ExitNow(error = ErrorFromMbedtlsError(fail));
}
mbedtls_ssl_set_export_keys_cb(&mSsl, HandleMbedtlsExportKeys, this);
mbedtls_ssl_conf_rng(&mConfig, mbedtls_ctr_drbg_random, &mCtrDrbg);

// Cookie
Expand Down Expand Up @@ -335,35 +337,59 @@ std::string DtlsSession::GetStateString() const
return stateString;
}

int DtlsSession::HandleMbedtlsExportKeys(void *aDtlsSession,
const unsigned char *aMasterSecret,
const unsigned char *aKeyBlock,
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength)
#if OT_COMM_CONFIG_CCM_ENABLE
const mbedtls_x509_crt *DtlsSession::GetPeerCertificate() const
{
return mSsl.session ? mSsl.session->peer_cert : nullptr;
}
#endif

void DtlsSession::HandleMbedtlsExportKeys(void *aDtlsSession,
mbedtls_ssl_key_export_type aType,
const unsigned char *aMasterSecret,
size_t aMasterSecretLen,
const unsigned char aClientRandom[32],
const unsigned char aServerRandom[32],
mbedtls_tls_prf_types aTlsPrfType)
{
auto dtlsSession = reinterpret_cast<DtlsSession *>(aDtlsSession);
return dtlsSession->HandleMbedtlsExportKeys(aMasterSecret, aKeyBlock, aMacLength, aKeyLength, aIvLength);
dtlsSession->HandleMbedtlsExportKeys(aType, aMasterSecret, aMasterSecretLen, aClientRandom, aServerRandom,
aTlsPrfType);
}

int DtlsSession::HandleMbedtlsExportKeys(const unsigned char *,
const unsigned char *aKeyBlock,
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength)
void DtlsSession::HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,
const unsigned char *aMasterSecret,
size_t aMasterSecretLen,
const unsigned char aClientRandom[32],
const unsigned char aServerRandom[32],
mbedtls_tls_prf_types aTlsPrfType)
{
Sha256 sha256;
Sha256 sha256;
unsigned char keyBlock[kKeyBlockSize];
unsigned char randBytes[2 * kRandomBufferSize];

VerifyOrExit(!mPSK.empty());
VerifyOrExit(aType == MBEDTLS_SSL_KEY_EXPORT_TLS12_MASTER_SECRET);

memcpy(randBytes, aServerRandom, kRandomBufferSize);
memcpy(randBytes + kRandomBufferSize, aClientRandom, kRandomBufferSize);

// Retrieve the Key block from Master secret
mbedtls_ssl_tls_prf(aTlsPrfType, aMasterSecret, aMasterSecretLen, "key expansion", randBytes, sizeof(randBytes),
keyBlock, sizeof(keyBlock));

sha256.Start();
sha256.Update(aKeyBlock, 2 * static_cast<uint16_t>(aMacLength + aKeyLength + aIvLength));
sha256.Update(keyBlock, kKeyBlockSize);

mKek.resize(Sha256::kHashSize);
sha256.Finish(&mKek[0]);
sha256.Finish(mKek.data());

static_assert(Sha256::kHashSize >= kJoinerRouterKekLength, "Sha256::kHashSize >= kJoinerRouterKekLength");
mKek.resize(kJoinerRouterKekLength);

return 0;
LOG_DEBUG(LOG_REGION_DTLS, "Generated KEK");

exit:;
}

void DtlsSession::HandleEvent(short aFlags)
Expand Down
Loading
Loading