Skip to content

Commit

Permalink
tmp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
wgtdkp committed Mar 31, 2024
1 parent bc818e1 commit 2d096f1
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 29 deletions.
11 changes: 11 additions & 0 deletions include/commissioner/error.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define OT_COMM_ERROR_HPP_

#include <memory>
#include <ostream>
#include <string>

#include <commissioner/defines.hpp>
Expand Down Expand Up @@ -272,6 +273,16 @@ inline bool operator!=(const ErrorCode &aErrorCode, const Error &aError)
return !(aErrorCode == aError);
}

/**
* Allows pretty-print in unit tests.
*
* See https://google.github.io/googletest/advanced.html#teaching-googletest-how-to-print-your-values
*/
inline void PrintTo(const Error &error, std::ostream *os)
{
*os << error.ToString();
}

} // namespace commissioner

} // namespace ot
Expand Down
72 changes: 66 additions & 6 deletions src/library/coap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ Error Message::AppendOption(OptionType aNumber, const OptionValue &aValue)
{
Error error;

VerifyOrExit(IsValidOption(aNumber, aValue),
error = ERROR_INVALID_ARGS("invalid CoAP option (number={})", aNumber));
VerifyOrExit(IsValidOption(aNumber, aValue), error = ERROR_INVALID_ARGS("invalid CoAP option {}", aNumber));

if (aNumber == OptionType::kUriPath)
{
Expand All @@ -219,7 +218,7 @@ Error Message::GetOption(std::string &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetStringValue();

Expand All @@ -232,7 +231,7 @@ Error Message::GetOption(uint32_t &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetUint32Value();

Expand All @@ -245,7 +244,7 @@ Error Message::GetOption(ByteArray &aValue, OptionType aNumber) const
Error error;
auto option = GetOption(aNumber);

VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option (number={}) not found", aNumber));
VerifyOrExit(option != nullptr, error = ERROR_NOT_FOUND("CoAP option {} not found", aNumber));

aValue = option->GetOpaqueValue();

Expand Down Expand Up @@ -355,7 +354,7 @@ Error Message::Serialize(OptionType aOptionNumber,
VerifyOrDie(utils::to_underlying(aOptionNumber) >= aLastOptionNumber);

VerifyOrExit(IsValidOption(aOptionNumber, aOptionValue),
error = ERROR_INVALID_ARGS("option (number={}) is not valid", aOptionNumber));
error = ERROR_INVALID_ARGS("option {} is not valid", aOptionNumber));

length = 1;
length += delta < kOption1ByteExtension ? 0 : (delta < kOption2ByteExtension ? 1 : 2);
Expand Down Expand Up @@ -1287,3 +1286,64 @@ std::string Message::GetRequestUri(void) const
} // namespace commissioner

} // namespace ot

auto fmt::formatter<ot::commissioner::coap::OptionType>::format(ot::commissioner::coap::OptionType optionType,
format_context &ctx) -> decltype(ctx.out())
{
using ot::commissioner::coap::OptionType;
string_view name;
switch (optionType)
{
case OptionType::kIfMatch:
name = "kIfMatch";
break;
case OptionType::kUriHost:
name = "kUriHost";
break;
case OptionType::kETag:
name = "kETag";
break;
case OptionType::kIfNonMatch:
name = "kIfNonMatch";
break;
case OptionType::kObserve:
name = "kObserve";
break;
case OptionType::kUriPort:
name = "kUriPort";
break;
case OptionType::kLocationPath:
name = "kLocationPath";
break;
case OptionType::kUriPath:
name = "kUriPath";
break;
case OptionType::kContentFormat:
name = "kContentFormat";
break;
case OptionType::kMaxAge:
name = "kMaxAge";
break;
case OptionType::kUriQuery:
name = "kUriQuery";
break;
case OptionType::kAccept:
name = "kAccept";
break;
case OptionType::kLocationQuery:
name = "kLocationQuery";
break;
case OptionType::kProxyUri:
name = "kProxyUri";
break;
case OptionType::kProxyScheme:
name = "kProxyScheme";
break;
case OptionType::kSize1:
name = "kSize1";
break;
default:
name = "unknown";
}
return formatter<string_view>::format(name, ctx);
}
8 changes: 8 additions & 0 deletions src/library/coap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include <queue>
#include <set>

#include <fmt/format.h>

#include <commissioner/defines.hpp>
#include <commissioner/error.hpp>

Expand Down Expand Up @@ -781,4 +783,10 @@ class Coap

} // namespace ot

/** Makes `OptionType` formattable as if the underlying integer. */
template <> struct fmt::formatter<ot::commissioner::coap::OptionType> : formatter<string_view>
{
auto format(ot::commissioner::coap::OptionType optionType, format_context &ctx) -> decltype(ctx.out());
};

#endif // OT_COMM_LIBRARY_COAP_HPP_
85 changes: 69 additions & 16 deletions src/library/dtls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace ot {
namespace commissioner {

static const int kAuthMode = MBEDTLS_SSL_VERIFY_REQUIRED;
static const size_t kMaxContentLength = MBEDTLS_SSL_MAX_CONTENT_LEN;
static const size_t kMaxContentLength = MBEDTLS_SSL_IN_CONTENT_LEN;
static const size_t KMaxFragmentLengthCode = MBEDTLS_SSL_MAX_FRAG_LEN_1024;
static const size_t kMaxTransmissionUnit = 1280;

Expand Down Expand Up @@ -110,8 +110,9 @@ void DtlsSession::InitMbedtls()
{
mbedtls_ssl_config_init(&mConfig);
mbedtls_ssl_cookie_init(&mCookie);
mbedtls_ctr_drbg_init(&mCtrDrbg);
mbedtls_entropy_init(&mEntropy);
mbedtls_ctr_drbg_init(&mCtrDrbg);
mbedtls_ctr_drbg_seed(&mCtrDrbg, mbedtls_entropy_func, &mEntropy, nullptr, 0);
mbedtls_ssl_init(&mSsl);

mbedtls_x509_crt_init(&mCaChain);
Expand Down Expand Up @@ -175,7 +176,12 @@ Error DtlsSession::Init(const DtlsConfig &aConfig)
{
ExitNow(error = ERROR_INVALID_ARGS("bad certificate; {}", ErrorFromMbedtlsError(fail).GetMessage()));
}
#if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
if (int fail = mbedtls_pk_parse_key(&mOwnKey, &aConfig.mOwnKey[0], aConfig.mOwnKey.size(), nullptr, 0,
mbedtls_ctr_drbg_random, &mCtrDrbg))
#else
if (int fail = mbedtls_pk_parse_key(&mOwnKey, &aConfig.mOwnKey[0], aConfig.mOwnKey.size(), nullptr, 0))
#endif
{
ExitNow(error = ERROR_INVALID_ARGS("bad private key; {}", ErrorFromMbedtlsError(fail).GetMessage()));
}
Expand All @@ -191,14 +197,11 @@ Error DtlsSession::Init(const DtlsConfig &aConfig)

mCipherSuites.push_back(0);
mbedtls_ssl_conf_ciphersuites(&mConfig, &mCipherSuites[0]);

#if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
mbedtls_ssl_set_export_keys_cb(&mSsl, HandleMbedtlsExportKeys, this);
#else
mbedtls_ssl_conf_export_keys_cb(&mConfig, HandleMbedtlsExportKeys, this);

// RNG & Entropy
if (int fail = mbedtls_ctr_drbg_seed(&mCtrDrbg, mbedtls_entropy_func, &mEntropy, nullptr, 0))
{
ExitNow(error = ErrorFromMbedtlsError(fail));
}
#endif
mbedtls_ssl_conf_rng(&mConfig, mbedtls_ctr_drbg_random, &mCtrDrbg);

// Cookie
Expand Down Expand Up @@ -335,22 +338,71 @@ std::string DtlsSession::GetStateString() const
return stateString;
}

int DtlsSession::HandleMbedtlsExportKeys(void *aDtlsSession,
#if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
void DtlsSession::HandleMbedtlsExportKeys(void *aDtlsSession,
mbedtls_ssl_key_export_type aType,
const unsigned char *aMasterSecret,
size_t aMasterSecretLen,
const unsigned char aClientRandom[32],
const unsigned char aServerRandom[32],
mbedtls_tls_prf_types aTlsPrfType)
{
auto dtlsSession = reinterpret_cast<DtlsSession *>(aDtlsSession);
dtlsSession->HandleMbedtlsExportKeys(aType, aMasterSecret, aMasterSecretLen, aClientRandom, aServerRandom,
aTlsPrfType);
}

void DtlsSession::HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,
const unsigned char *aMasterSecret,
size_t aMasterSecretLen,
const unsigned char aClientRandom[32],
const unsigned char aServerRandom[32],
mbedtls_tls_prf_types aTlsPrfType)
{
Sha256 sha256;
unsigned char keyBlock[kKeyBlockSize];
unsigned char randBytes[2 * kRandomBufferSize];

VerifyOrExit(!mPSK.empty());
VerifyOrExit(aType == MBEDTLS_SSL_KEY_EXPORT_TLS12_MASTER_SECRET);

memcpy(randBytes, aServerRandom, kRandomBufferSize);
memcpy(randBytes + kRandomBufferSize, aClientRandom, kRandomBufferSize);

// Retrieve the Key block from Master secret
mbedtls_ssl_tls_prf(aTlsPrfType, aMasterSecret, aMasterSecretLen, "key expansion", randBytes, sizeof(randBytes),
keyBlock, sizeof(keyBlock));

sha256.Start();
sha256.Update(keyBlock, kKeyBlockSize);

mKek.resize(Sha256::kHashSize);
sha256.Finish(mKek.data());

static_assert(Sha256::kHashSize >= kJoinerRouterKekLength, "Sha256::kHashSize >= kJoinerRouterKekLength");
mKek.resize(kJoinerRouterKekLength);

LOG_DEBUG(LOG_REGION_DTLS, "Generated KEK");

exit:;
}
#else
int DtlsSession::HandleMbedtlsExportKeys(void *aDtlsSession,
const unsigned char *aMasterSecret,
const unsigned char *aKeyBlock,
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength)
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength)
{
auto dtlsSession = reinterpret_cast<DtlsSession *>(aDtlsSession);
return dtlsSession->HandleMbedtlsExportKeys(aMasterSecret, aKeyBlock, aMacLength, aKeyLength, aIvLength);
}

int DtlsSession::HandleMbedtlsExportKeys(const unsigned char *,
const unsigned char *aKeyBlock,
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength)
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength)
{
Sha256 sha256;

Expand All @@ -365,6 +417,7 @@ int DtlsSession::HandleMbedtlsExportKeys(const unsigned char *,

return 0;
}
#endif // (MBEDTLS_VERSION_NUMBER >= 0x03000000)

void DtlsSession::HandleEvent(short aFlags)
{
Expand Down
34 changes: 30 additions & 4 deletions src/library/dtls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,23 @@ class DtlsSession : public Endpoint
uint16_t GetPeerPort() const override { return mSocket->GetPeerPort(); }

uint16_t GetLocalPort() const { return mSocket->GetLocalPort(); }

const mbedtls_x509_crt *GetPeerCertificate() const { return mSsl.session ? mSsl.session->peer_cert : nullptr; }

const ByteArray &GetKek() const { return mKek; }
#if OT_COMM_CONFIG_CCM_ENABLE
const mbedtls_x509_crt *GetPeerCertificate() const
{
return mSsl.session ? mSsl.session->peer_cert : nullptr;
}
#endif
const ByteArray &GetKek() const
{
return mKek;
}

void HandleEvent(short aFlags);

private:
static constexpr size_t kKeyBlockSize = 40;
static constexpr size_t kRandomBufferSize = 32;

class DtlsTimer : public Timer
{
public:
Expand Down Expand Up @@ -170,6 +179,22 @@ class DtlsSession : public Endpoint
// Decide if we should stop processing this session by given error.
static bool ShouldStop(Error aError);

#if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
static void HandleMbedtlsExportKeys(void *aDtlsSession,
mbedtls_ssl_key_export_type aType,
const unsigned char *aMasterSecret,
size_t aMasterSecretLen,
const unsigned char aClientRandom[32],
const unsigned char aServerRandom[32],
mbedtls_tls_prf_types aTlsPrfType);

void HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,
const unsigned char *aMasterSecret,
size_t aMasterSecretLen,
const unsigned char aClientRandom[32],
const unsigned char aServerRandom[32],
mbedtls_tls_prf_types aTlsPrfType);
#else
static int HandleMbedtlsExportKeys(void *aDtlsSession,
const unsigned char *aMasterSecret,
const unsigned char *aKeyBlock,
Expand All @@ -182,6 +207,7 @@ class DtlsSession : public Endpoint
size_t aMacLength,
size_t aKeyLength,
size_t aIvLength);
#endif

SocketPtr mSocket;
DtlsTimer mHandshakeTimer;
Expand Down
6 changes: 3 additions & 3 deletions third_party/mbedtls/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ add_subdirectory(repo)

## Include the user config file by absolute path to avoid exposing current directory.
target_compile_definitions(mbedtls
PUBLIC MBEDTLS_USER_CONFIG_FILE="${CMAKE_CURRENT_SOURCE_DIR}/mbedtls_user_config.h"
PUBLIC MBEDTLS_CONFIG_FILE="${CMAKE_CURRENT_SOURCE_DIR}/mbedtls-config.h"
)
target_include_directories(mbedtls
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/repo/include
)

target_compile_definitions(mbedx509
PUBLIC MBEDTLS_USER_CONFIG_FILE="${CMAKE_CURRENT_SOURCE_DIR}/mbedtls_user_config.h"
PUBLIC MBEDTLS_CONFIG_FILE="${CMAKE_CURRENT_SOURCE_DIR}/mbedtls-config.h"
)
target_include_directories(mbedx509
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/repo/include
)

target_compile_definitions(mbedcrypto
PUBLIC MBEDTLS_USER_CONFIG_FILE="${CMAKE_CURRENT_SOURCE_DIR}/mbedtls_user_config.h"
PUBLIC MBEDTLS_CONFIG_FILE="${CMAKE_CURRENT_SOURCE_DIR}/mbedtls-config.h"
)
target_include_directories(mbedcrypto
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/repo/include
Expand Down
Loading

0 comments on commit 2d096f1

Please sign in to comment.