From 4865e5e689807e500bde212cceaf36e9524b6dde Mon Sep 17 00:00:00 2001 From: Amine Alami <43780877+Alami-Amine@users.noreply.github.com> Date: Mon, 23 Dec 2024 18:42:24 +0100 Subject: [PATCH] [CASESession] refactoring and improving testability of Sigma1 sending and handling, and Sigma2 Sending (#36679) * Removing networking bits from CASESession ParseSigma1 and creating EncodeSigma1 * integrating comments * Refactoring HandleSigma1_and_SendSigma2 and functions called within it * Adding Encrypted2 datalen check * renaming variables * Making Parsing and Encoding methods protected * Converting TLV tags to enum class * Making structs protected and making Unit Test inherit CASESession * integrating comments * Integrating comments in TestCASESession * integrate comments on CASESession * Integrating more comments * Using a Variant for the NextStep logic in HandleSigma1, to either send a Sigma message or a status report * fixing clang-tidy error * more comments * comments integration --- src/lib/support/CodeUtils.h | 27 + src/protocols/secure_channel/CASESession.cpp | 674 ++++++++++-------- src/protocols/secure_channel/CASESession.h | 144 +++- .../secure_channel/tests/TestCASESession.cpp | 381 ++++++++-- 4 files changed, 850 insertions(+), 376 deletions(-) diff --git a/src/lib/support/CodeUtils.h b/src/lib/support/CodeUtils.h index 01cc80bffcbe6c..cd07370fa92ec4 100644 --- a/src/lib/support/CodeUtils.h +++ b/src/lib/support/CodeUtils.h @@ -156,6 +156,33 @@ } \ } while (false) +/** + * @def ReturnErrorVariantOnFailure(expr) + * + * @brief + * This is for use when the calling function returns a Variant type. It returns a CHIP_ERROR variant with the corresponding error + * code if the expression returns an error. + * + * Example usage: + * + * @code + * ReturnErrorVariantOnFailure(NextStep, ParseSigma1(tlvReader, parsedSigma1)); + * @endcode + * + * @param[in] variantType The Variant type that the calling function returns. + * @param[in] expr An expression to be tested. + + */ +#define ReturnErrorVariantOnFailure(variantType, expr) \ + do \ + { \ + auto __err = (expr); \ + if (!::chip::ChipError::IsSuccess(__err)) \ + { \ + return variantType::Create<CHIP_ERROR>(__err); \ + } \ + } while (false) + /** * @def ReturnLogErrorOnFailure(expr) * diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index a012d12fcc56b4..6242cc2aa6199c 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -45,54 +45,71 @@ #include <protocols/secure_channel/SessionResumptionStorage.h> #include <protocols/secure_channel/StatusReport.h> #include <system/SystemClock.h> -#include <system/TLVPacketBufferBackingStore.h> #include <tracing/macros.h> #include <tracing/metric_event.h> #include <transport/SessionManager.h> namespace { +// TBEDataTags works for both sigma-2-tbedata and sigma-3-tbedata as they have the same tag numbers for the elements common between +// them. +enum class TBEDataTags : uint8_t +{ + kSenderNOC = 1, + kSenderICAC = 2, + kSignature = 3, + kResumptionID = 4, +}; -enum +// TBSDataTags works for both sigma-2-tbsdata and sigma-3-tbsdata as they have the same tag numbers for the elements common between +// them. +enum class TBSDataTags : uint8_t { - kTag_TBEData_SenderNOC = 1, - kTag_TBEData_SenderICAC = 2, - kTag_TBEData_Signature = 3, - kTag_TBEData_ResumptionID = 4, + kSenderNOC = 1, + kSenderICAC = 2, + kSenderPubKey = 3, + kReceiverPubKey = 4, }; -enum +enum class Sigma1Tags : uint8_t { - kTag_TBSData_SenderNOC = 1, - kTag_TBSData_SenderICAC = 2, - kTag_TBSData_SenderPubKey = 3, - kTag_TBSData_ReceiverPubKey = 4, + kInitiatorRandom = 1, + kInitiatorSessionId = 2, + kDestinationId = 3, + kInitiatorEphPubKey = 4, + kInitiatorSessionParams = 5, + kResumptionID = 6, + kResume1MIC = 7, }; -enum +enum class Sigma2Tags : uint8_t { - kTag_Sigma1_InitiatorRandom = 1, - kTag_Sigma1_InitiatorSessionId = 2, - kTag_Sigma1_DestinationId = 3, - kTag_Sigma1_InitiatorEphPubKey = 4, - kTag_Sigma1_InitiatorMRPParams = 5, - kTag_Sigma1_ResumptionID = 6, - kTag_Sigma1_InitiatorResumeMIC = 7, + kResponderRandom = 1, + kResponderSessionId = 2, + kResponderEphPubKey = 3, + kEncrypted2 = 4, + kResponderSessionParams = 5, }; -enum +enum class Sigma2ResumeTags : uint8_t { - kTag_Sigma2_ResponderRandom = 1, - kTag_Sigma2_ResponderSessionId = 2, - kTag_Sigma2_ResponderEphPubKey = 3, - kTag_Sigma2_Encrypted2 = 4, - kTag_Sigma2_ResponderMRPParams = 5, + kResumptionID = 1, + kSigma2ResumeMIC = 2, + kResponderSessionID = 3, + kResponderSessionParams = 4, }; -enum +enum class Sigma3Tags : uint8_t { - kTag_Sigma3_Encrypted3 = 1, + kEncrypted3 = 1, }; +// Utility to extract the underlying value of TLV Tag enum classes, used in TLV encoding and parsing. +template <typename Enum> +constexpr chip::TLV::Tag AsTlvContextTag(Enum e) +{ + return chip::TLV::ContextTag(chip::to_underlying(e)); +} + } // namespace namespace chip { @@ -103,6 +120,7 @@ using namespace Messaging; using namespace Encoding; using namespace Protocols::SecureChannel; using namespace Tracing; +using namespace TLV; constexpr uint8_t kKDFSR2Info[] = { 0x53, 0x69, 0x67, 0x6d, 0x61, 0x32 }; constexpr uint8_t kKDFSR3Info[] = { 0x53, 0x69, 0x67, 0x6d, 0x61, 0x33 }; @@ -778,42 +796,29 @@ void CASESession::HandleConnectionClosed(Transport::ActiveTCPConnectionState * c CHIP_ERROR CASESession::SendSigma1() { MATTER_TRACE_SCOPE("SendSigma1", "CASESession"); - size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom - sizeof(uint16_t), // initiatorSessionId, - kSHA256_Hash_Length, // destinationId - kP256_PublicKey_Length, // InitiatorEphPubKey, - SessionParameters::kEstimatedTLVSize, // initiatorSessionParams - SessionResumptionStorage::kResumptionIdSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES); - System::PacketBufferTLVWriter tlvWriter; - System::PacketBufferHandle msg_R1; - TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; uint8_t destinationIdentifier[kSHA256_Hash_Length] = { 0 }; + // Struct that will be used as input to EncodeSigma1() method + EncodeSigma1Inputs encodeSigma1Inputs; + // Lookup fabric info. const auto * fabricInfo = mFabricsTable->FindFabricWithIndex(mFabricIndex); VerifyOrReturnError(fabricInfo != nullptr, CHIP_ERROR_INCORRECT_STATE); // Validate that we have a session ID allocated. VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + encodeSigma1Inputs.initiatorSessionId = GetLocalSessionId().Value(); // Generate an ephemeral keypair mEphemeralKey = mFabricsTable->AllocateEphemeralKeypairForCASE(); VerifyOrReturnError(mEphemeralKey != nullptr, CHIP_ERROR_NO_MEMORY); ReturnErrorOnFailure(mEphemeralKey->Initialize(ECPKeyTarget::ECDH)); + encodeSigma1Inputs.initiatorEphPubKey = &mEphemeralKey->Pubkey(); // Fill in the random value ReturnErrorOnFailure(DRBG_get_bytes(mInitiatorRandom, sizeof(mInitiatorRandom))); - - // Construct Sigma1 Msg - msg_R1 = System::PacketBufferHandle::New(data_len); - VerifyOrReturnError(!msg_R1.IsNull(), CHIP_ERROR_NO_MEMORY); - - tlvWriter.Init(std::move(msg_R1)); - ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), ByteSpan(mInitiatorRandom))); - // Retrieve Session Identifier - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId().Value())); + encodeSigma1Inputs.initiatorRandom = ByteSpan(mInitiatorRandom); // Generate a Destination Identifier based on the node we are attempting to reach { @@ -827,19 +832,15 @@ CHIP_ERROR CASESession::SendSigma1() Credentials::P256PublicKeySpan rootPubKeySpan{ rootPubKey.ConstBytes() }; MutableByteSpan destinationIdSpan(destinationIdentifier); - ReturnErrorOnFailure(GenerateCaseDestinationId(ByteSpan(mIPK), ByteSpan(mInitiatorRandom), rootPubKeySpan, fabricId, + ReturnErrorOnFailure(GenerateCaseDestinationId(ByteSpan(mIPK), encodeSigma1Inputs.initiatorRandom, rootPubKeySpan, fabricId, mPeerNodeId, destinationIdSpan)); + encodeSigma1Inputs.destinationId = destinationIdSpan; } - ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(3), destinationIdentifier, sizeof(destinationIdentifier))); - - ReturnErrorOnFailure( - tlvWriter.PutBytes(TLV::ContextTag(4), mEphemeralKey->Pubkey(), static_cast<uint32_t>(mEphemeralKey->Pubkey().Length()))); VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE); - ReturnErrorOnFailure(EncodeSessionParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter)); + encodeSigma1Inputs.initiatorMrpConfig = &mLocalMRPConfig.Value(); // Try to find persistent session, and resume it. - bool resuming = false; if (mSessionResumptionStorage != nullptr) { CHIP_ERROR err = mSessionResumptionStorage->FindByScopedNodeId(fabricInfo->GetScopedNodeIdForNode(mPeerNodeId), @@ -847,28 +848,28 @@ CHIP_ERROR CASESession::SendSigma1() if (err == CHIP_NO_ERROR) { // Found valid resumption state, try to resume the session. - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(6), mResumeResumptionId)); - - uint8_t initiatorResume1MIC[CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES]; - MutableByteSpan resumeMICSpan(initiatorResume1MIC); - ReturnErrorOnFailure(GenerateSigmaResumeMIC(ByteSpan(mInitiatorRandom), ByteSpan(mResumeResumptionId), + encodeSigma1Inputs.resumptionId = mResumeResumptionId; + MutableByteSpan resumeMICSpan(encodeSigma1Inputs.initiatorResume1MICBuffer); + ReturnErrorOnFailure(GenerateSigmaResumeMIC(encodeSigma1Inputs.initiatorRandom, encodeSigma1Inputs.resumptionId, ByteSpan(kKDFS1RKeyInfo), ByteSpan(kResume1MIC_Nonce), resumeMICSpan)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(7), resumeMICSpan)); - resuming = true; + encodeSigma1Inputs.initiatorResumeMIC = resumeMICSpan; + encodeSigma1Inputs.sessionResumptionRequested = true; } } - ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); - ReturnErrorOnFailure(tlvWriter.Finalize(&msg_R1)); + System::PacketBufferHandle msgR1; + + // Encode Sigma1 in CHIP TLV Format + ReturnErrorOnFailure(EncodeSigma1(msgR1, encodeSigma1Inputs)); - ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R1->Start(), msg_R1->DataLength() })); + ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msgR1->Start(), msgR1->DataLength() })); // Call delegate to send the msg to peer - ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msg_R1), + ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msgR1), SendFlags(SendMessageFlags::kExpectResponse))); - if (resuming) + if (encodeSigma1Inputs.sessionResumptionRequested) { mState = State::kSentSigma1Resume; @@ -892,12 +893,109 @@ CHIP_ERROR CASESession::SendSigma1() return CHIP_NO_ERROR; } +CHIP_ERROR CASESession::EncodeSigma1(System::PacketBufferHandle & msg, EncodeSigma1Inputs & input) +{ + MATTER_TRACE_SCOPE("EncodeSigma1", "CASESession"); + + VerifyOrReturnError(input.initiatorEphPubKey != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + + size_t dataLen = EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom + sizeof(uint16_t), // initiatorSessionId, + kSHA256_Hash_Length, // destinationId + kP256_PublicKey_Length, // InitiatorEphPubKey, + SessionParameters::kEstimatedTLVSize, // initiatorSessionParams + SessionResumptionStorage::kResumptionIdSize, // resumptionId + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES // initiatorResumeMIC + ); + + msg = System::PacketBufferHandle::New(dataLen); + VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_NO_MEMORY); + + System::PacketBufferTLVWriter tlvWriter; + tlvWriter.Init(std::move(msg)); + + TLVType outerContainerType = kTLVType_NotSpecified; + ReturnErrorOnFailure(tlvWriter.StartContainer(AnonymousTag(), kTLVType_Structure, outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(Sigma1Tags::kInitiatorRandom), input.initiatorRandom)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(Sigma1Tags::kInitiatorSessionId), input.initiatorSessionId)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(Sigma1Tags::kDestinationId), input.destinationId)); + + ReturnErrorOnFailure(tlvWriter.PutBytes(AsTlvContextTag(Sigma1Tags::kInitiatorEphPubKey), *input.initiatorEphPubKey, + static_cast<uint32_t>(input.initiatorEphPubKey->Length()))); + + VerifyOrReturnError(input.initiatorMrpConfig != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + ReturnErrorOnFailure( + EncodeSessionParameters(AsTlvContextTag(Sigma1Tags::kInitiatorSessionParams), *input.initiatorMrpConfig, tlvWriter)); + + if (input.sessionResumptionRequested) + { + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(Sigma1Tags::kResumptionID), input.resumptionId)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(Sigma1Tags::kResume1MIC), input.initiatorResumeMIC)); + } + + ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Finalize(&msg)); + + return CHIP_NO_ERROR; +} + CHIP_ERROR CASESession::HandleSigma1_and_SendSigma2(System::PacketBufferHandle && msg) { MATTER_TRACE_SCOPE("HandleSigma1_and_SendSigma2", "CASESession"); - ReturnErrorOnFailure(HandleSigma1(std::move(msg))); - return CHIP_NO_ERROR; + CHIP_ERROR err = CHIP_NO_ERROR; + + // Parse and Validate Received Sigma1, and decide next step + NextStep nextStep = HandleSigma1(std::move(msg)); + VerifyOrExit(nextStep.Is<Step>(), err = nextStep.Get<CHIP_ERROR>()); + + switch (nextStep.Get<Step>()) + { + case Step::kSendSigma2: { + + System::PacketBufferHandle msgR2; + EncodeSigma2Inputs encodeSigma2; + + SuccessOrExit(err = PrepareSigma2(encodeSigma2)); + SuccessOrExit(err = EncodeSigma2(msgR2, encodeSigma2)); + + MATTER_LOG_METRIC_BEGIN(kMetricDeviceCASESessionSigma2); + SuccessOrExitAction(err = SendSigma2(std::move(msgR2)), MATTER_LOG_METRIC_END(kMetricDeviceCASESessionSigma2, err)); + + mDelegate->OnSessionEstablishmentStarted(); + break; + } + case Step::kSendSigma2Resume: { + + System::PacketBufferHandle msgR2Resume; + EncodeSigma2ResumeInputs encodeSigma2Resume; + + SuccessOrExit(err = PrepareSigma2Resume(encodeSigma2Resume)); + SuccessOrExit(err = EncodeSigma2Resume(msgR2Resume, encodeSigma2Resume)); + + MATTER_LOG_METRIC_BEGIN(kMetricDeviceCASESessionSigma2Resume); + SuccessOrExitAction(err = SendSigma2Resume(std::move(msgR2Resume)), + MATTER_LOG_METRIC_END(kMetricDeviceCASESessionSigma2Resume, err)); + + mDelegate->OnSessionEstablishmentStarted(); + break; + } + default: + break; + } + +exit: + if (err == CHIP_ERROR_KEY_NOT_FOUND) + { + SendStatusReport(mExchangeCtxt, kProtocolCodeNoSharedRoot); + mState = State::kInitialized; + } + else if (err != CHIP_NO_ERROR) + { + SendStatusReport(mExchangeCtxt, kProtocolCodeInvalidParam); + mState = State::kInitialized; + } + return err; } CHIP_ERROR CASESession::FindLocalNodeFromDestinationId(const ByteSpan & destinationId, const ByteSpan & initiatorRandom) @@ -931,7 +1029,7 @@ CHIP_ERROR CASESession::FindLocalNodeFromDestinationId(const ByteSpan & destinat MutableByteSpan candidateDestinationIdSpan(candidateDestinationId); ByteSpan candidateIpkSpan(ipkKeySet.epoch_keys[keyIdx].key); - err = GenerateCaseDestinationId(ByteSpan(candidateIpkSpan), ByteSpan(initiatorRandom), rootPubKeySpan, fabricId, nodeId, + err = GenerateCaseDestinationId(candidateIpkSpan, initiatorRandom, rootPubKeySpan, fabricId, nodeId, candidateDestinationIdSpan); if ((err == CHIP_NO_ERROR) && (candidateDestinationIdSpan.data_equal(destinationId))) { @@ -978,60 +1076,56 @@ CHIP_ERROR CASESession::TryResumeSession(SessionResumptionStorage::ConstResumpti return CHIP_NO_ERROR; } - -CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg) +CASESession::NextStep CASESession::HandleSigma1(System::PacketBufferHandle && msg) { MATTER_TRACE_SCOPE("HandleSigma1", "CASESession"); - CHIP_ERROR err = CHIP_NO_ERROR; - System::PacketBufferTLVReader tlvReader; - - uint16_t initiatorSessionId; - ByteSpan destinationIdentifier; - ByteSpan initiatorRandom; - ChipLogProgress(SecureChannel, "Received Sigma1 msg"); MATTER_TRACE_COUNTER("Sigma1"); - bool sessionResumptionRequested = false; - ByteSpan resumptionId; - ByteSpan resume1MIC; - ByteSpan initiatorPubKey; + VerifyOrReturnError(mFabricsTable != nullptr, NextStep::Create<CHIP_ERROR>(CHIP_ERROR_INCORRECT_STATE)); - SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ msg->Start(), msg->DataLength() })); + ReturnErrorVariantOnFailure(NextStep, mCommissioningHash.AddData(ByteSpan{ msg->Start(), msg->DataLength() })); + System::PacketBufferTLVReader tlvReader; tlvReader.Init(std::move(msg)); - SuccessOrExit(err = ParseSigma1(tlvReader, initiatorRandom, initiatorSessionId, destinationIdentifier, initiatorPubKey, - sessionResumptionRequested, resumptionId, resume1MIC)); - ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", initiatorSessionId); - SetPeerSessionId(initiatorSessionId); + // Struct that will serve as output in ParseSigma1 + ParsedSigma1 parsedSigma1; + + ReturnErrorVariantOnFailure(NextStep, ParseSigma1(tlvReader, parsedSigma1)); + + ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", parsedSigma1.initiatorSessionId); + SetPeerSessionId(parsedSigma1.initiatorSessionId); - VerifyOrExit(mFabricsTable != nullptr, err = CHIP_ERROR_INCORRECT_STATE); + // Set the MRP parameters provided in the Sigma1 message + if (parsedSigma1.initiatorMrpParamsPresent) + { + mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters( + GetRemoteSessionParameters()); + } - if (sessionResumptionRequested && resumptionId.size() == SessionResumptionStorage::kResumptionIdSize && + if (parsedSigma1.sessionResumptionRequested && + parsedSigma1.resumptionId.size() == SessionResumptionStorage::kResumptionIdSize && CHIP_NO_ERROR == - TryResumeSession(SessionResumptionStorage::ConstResumptionIdView(resumptionId.data()), resume1MIC, initiatorRandom)) + TryResumeSession(SessionResumptionStorage::ConstResumptionIdView(parsedSigma1.resumptionId.data()), + parsedSigma1.initiatorResumeMIC, parsedSigma1.initiatorRandom)) { - std::copy(initiatorRandom.begin(), initiatorRandom.end(), mInitiatorRandom); - std::copy(resumptionId.begin(), resumptionId.end(), mResumeResumptionId.begin()); + std::copy(parsedSigma1.initiatorRandom.begin(), parsedSigma1.initiatorRandom.end(), mInitiatorRandom); + std::copy(parsedSigma1.resumptionId.begin(), parsedSigma1.resumptionId.end(), mResumeResumptionId.begin()); - // Send Sigma2Resume message to the initiator - MATTER_LOG_METRIC_BEGIN(kMetricDeviceCASESessionSigma2Resume); - err = SendSigma2Resume(); - if (CHIP_NO_ERROR != err) - { - MATTER_LOG_METRIC_END(kMetricDeviceCASESessionSigma2Resume, err); - } - SuccessOrExit(err); + // Early returning here, since the next Step is known to be Sigma2Resume, and no further processing is needed for the + // Sigma1 message + return NextStep::Create<Step>(Step::kSendSigma2Resume); + } - mDelegate->OnSessionEstablishmentStarted(); + // ParseSigma1 ensures that: + // mRemotePubKey.Length() == initiatorPubKey.size() == kP256_PublicKey_Length. + memcpy(mRemotePubKey.Bytes(), parsedSigma1.initiatorEphPubKey.data(), mRemotePubKey.Length()); - // Early returning here, since we have sent Sigma2Resume, and no further processing is needed for the Sigma1 message - return CHIP_NO_ERROR; - } + CHIP_ERROR err = CHIP_NO_ERROR; // Attempt to match the initiator's desired destination based on local fabric table. - err = FindLocalNodeFromDestinationId(destinationIdentifier, initiatorRandom); + err = FindLocalNodeFromDestinationId(parsedSigma1.destinationId, parsedSigma1.initiatorRandom); if (err == CHIP_NO_ERROR) { ChipLogProgress(SecureChannel, "CASE matched destination ID: fabricIndex %u, NodeID 0x" ChipLogFormatX64, @@ -1039,87 +1133,82 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg) // Side-effect of FindLocalNodeFromDestinationId success was that mFabricIndex/mLocalNodeId are now // set to the local fabric and associated NodeId that was targeted by the initiator. - } - else - { - ChipLogError(SecureChannel, "CASE failed to match destination ID with local fabrics"); - ChipLogByteSpan(SecureChannel, destinationIdentifier); - } - SuccessOrExit(err); - - // ParseSigma1 ensures that: - // mRemotePubKey.Length() == initiatorPubKey.size() == kP256_PublicKey_Length. - memcpy(mRemotePubKey.Bytes(), initiatorPubKey.data(), mRemotePubKey.Length()); - MATTER_LOG_METRIC_BEGIN(kMetricDeviceCASESessionSigma2); - err = SendSigma2(); - if (CHIP_NO_ERROR != err) - { - MATTER_LOG_METRIC_END(kMetricDeviceCASESessionSigma2, err); + return NextStep::Create<Step>(Step::kSendSigma2); } - SuccessOrExit(err); - mDelegate->OnSessionEstablishmentStarted(); + ChipLogError(SecureChannel, "CASE failed to match destination ID with local fabrics"); + ChipLogByteSpan(SecureChannel, parsedSigma1.destinationId); -exit: + // FindLocalNodeFromDestinationId returns CHIP_ERROR_KEY_NOT_FOUND if Sigma1's DestinationId does not match any + // candidateDestinationId, this will trigger a status Report with ProtocolCode = NoSharedTrustRoots. - if (err == CHIP_ERROR_KEY_NOT_FOUND) - { - SendStatusReport(mExchangeCtxt, kProtocolCodeNoSharedRoot); - mState = State::kInitialized; - } - else if (err != CHIP_NO_ERROR) - { - SendStatusReport(mExchangeCtxt, kProtocolCodeInvalidParam); - mState = State::kInitialized; - } - return err; + // Returning a CHIP_ERROR variant that will trigger a corresponding Status Report. + return NextStep::Create<CHIP_ERROR>(err); } -CHIP_ERROR CASESession::SendSigma2Resume() +CHIP_ERROR CASESession::PrepareSigma2Resume(EncodeSigma2ResumeInputs & outSigma2ResData) { - MATTER_TRACE_SCOPE("SendSigma2Resume", "CASESession"); - size_t max_sigma2_resume_data_len = - TLV::EstimateStructOverhead(SessionResumptionStorage::kResumptionIdSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, - sizeof(uint16_t), SessionParameters::kEstimatedTLVSize); + MATTER_TRACE_SCOPE("PrepareSigma2Resume", "CASESession"); - System::PacketBufferTLVWriter tlvWriter; - System::PacketBufferHandle msg_R2_resume; - TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE); // Validate that we have a session ID allocated. VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); - - msg_R2_resume = System::PacketBufferHandle::New(max_sigma2_resume_data_len); - VerifyOrReturnError(!msg_R2_resume.IsNull(), CHIP_ERROR_NO_MEMORY); - - tlvWriter.Init(std::move(msg_R2_resume)); + outSigma2ResData.responderSessionId = GetLocalSessionId().Value(); // Generate a new resumption ID ReturnErrorOnFailure(DRBG_get_bytes(mNewResumptionId.data(), mNewResumptionId.size())); + outSigma2ResData.resumptionId = mNewResumptionId; - ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), mNewResumptionId)); - - uint8_t sigma2ResumeMIC[CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES]; - MutableByteSpan resumeMICSpan(sigma2ResumeMIC); ReturnErrorOnFailure(GenerateSigmaResumeMIC(ByteSpan(mInitiatorRandom), mNewResumptionId, ByteSpan(kKDFS2RKeyInfo), - ByteSpan(kResume2MIC_Nonce), resumeMICSpan)); + ByteSpan(kResume2MIC_Nonce), outSigma2ResData.resumeMIC)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), resumeMICSpan)); + outSigma2ResData.responderMrpConfig = &mLocalMRPConfig.Value(); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId().Value())); + return CHIP_NO_ERROR; +} - VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE); - ReturnErrorOnFailure(EncodeSessionParameters(TLV::ContextTag(4), mLocalMRPConfig.Value(), tlvWriter)); +CHIP_ERROR CASESession::EncodeSigma2Resume(System::PacketBufferHandle & msgR2Resume, EncodeSigma2ResumeInputs & input) +{ + MATTER_TRACE_SCOPE("EncodeSigma2Resume", "CASESession"); + + VerifyOrReturnError(input.responderMrpConfig != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + + size_t maxDatalLen = EstimateStructOverhead(SessionResumptionStorage::kResumptionIdSize, // resumptionID + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, // sigma2ResumeMIC + sizeof(uint16_t), // responderSessionID + SessionParameters::kEstimatedTLVSize // responderSessionParams + ); + + msgR2Resume = System::PacketBufferHandle::New(maxDatalLen); + VerifyOrReturnError(!msgR2Resume.IsNull(), CHIP_ERROR_NO_MEMORY); + + System::PacketBufferTLVWriter tlvWriter; + tlvWriter.Init(std::move(msgR2Resume)); + + TLVType outerContainerType = kTLVType_NotSpecified; + + ReturnErrorOnFailure(tlvWriter.StartContainer(AnonymousTag(), kTLVType_Structure, outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(Sigma2ResumeTags::kResumptionID), input.resumptionId)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(Sigma2ResumeTags::kSigma2ResumeMIC), input.resumeMIC)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(Sigma2ResumeTags::kResponderSessionID), input.responderSessionId)); + + ReturnErrorOnFailure( + EncodeSessionParameters(AsTlvContextTag(Sigma2ResumeTags::kResponderSessionParams), *input.responderMrpConfig, tlvWriter)); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); - ReturnErrorOnFailure(tlvWriter.Finalize(&msg_R2_resume)); + ReturnErrorOnFailure(tlvWriter.Finalize(&msgR2Resume)); + + return CHIP_NO_ERROR; +} + +CHIP_ERROR CASESession::SendSigma2Resume(System::PacketBufferHandle && msgR2Resume) +{ // Call delegate to send the msg to peer ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2Resume, - std::move(msg_R2_resume), - SendFlags(SendMessageFlags::kExpectResponse))); + std::move(msgR2Resume), SendFlags(SendMessageFlags::kExpectResponse))); mState = State::kSentSigma2Resume; @@ -1128,12 +1217,15 @@ CHIP_ERROR CASESession::SendSigma2Resume() return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::SendSigma2() +CHIP_ERROR CASESession::PrepareSigma2(EncodeSigma2Inputs & outSigma2Data) { - MATTER_TRACE_SCOPE("SendSigma2", "CASESession"); - VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + MATTER_TRACE_SCOPE("PrepareSigma2", "CASESession"); + VerifyOrReturnError(mFabricsTable != nullptr, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + outSigma2Data.responderSessionId = GetLocalSessionId().Value(); chip::Platform::ScopedMemoryBuffer<uint8_t> icacBuf; VerifyOrReturnError(icacBuf.Alloc(kMaxCHIPCertLength), CHIP_ERROR_NO_MEMORY); @@ -1148,57 +1240,65 @@ CHIP_ERROR CASESession::SendSigma2() ReturnErrorOnFailure(mFabricsTable->FetchNOCCert(mFabricIndex, nocCert)); // Fill in the random value - uint8_t msg_rand[kSigmaParamRandomNumberSize]; - ReturnErrorOnFailure(DRBG_get_bytes(&msg_rand[0], sizeof(msg_rand))); + ReturnErrorOnFailure(DRBG_get_bytes(&outSigma2Data.responderRandom[0], sizeof(outSigma2Data.responderRandom))); // Generate an ephemeral keypair mEphemeralKey = mFabricsTable->AllocateEphemeralKeypairForCASE(); VerifyOrReturnError(mEphemeralKey != nullptr, CHIP_ERROR_NO_MEMORY); ReturnErrorOnFailure(mEphemeralKey->Initialize(ECPKeyTarget::ECDH)); + outSigma2Data.responderEphPubKey = &mEphemeralKey->Pubkey(); // Generate a Shared Secret ReturnErrorOnFailure(mEphemeralKey->ECDH_derive_secret(mRemotePubKey, mSharedSecret)); - uint8_t msg_salt[kIPKSize + kSigmaParamRandomNumberSize + kP256_PublicKey_Length + kSHA256_Hash_Length]; + uint8_t msgSalt[kIPKSize + kSigmaParamRandomNumberSize + kP256_PublicKey_Length + kSHA256_Hash_Length]; - MutableByteSpan saltSpan(msg_salt); - ReturnErrorOnFailure(ConstructSaltSigma2(ByteSpan(msg_rand), mEphemeralKey->Pubkey(), ByteSpan(mIPK), saltSpan)); + MutableByteSpan saltSpan(msgSalt); + ReturnErrorOnFailure( + ConstructSaltSigma2(ByteSpan(outSigma2Data.responderRandom), mEphemeralKey->Pubkey(), ByteSpan(mIPK), saltSpan)); AutoReleaseSessionKey sr2k(*mSessionManager->GetSessionKeystore()); ReturnErrorOnFailure(DeriveSigmaKey(saltSpan, ByteSpan(kKDFSR2Info), sr2k)); // Construct Sigma2 TBS Data - size_t msg_r2_signed_len = - TLV::EstimateStructOverhead(kMaxCHIPCertLength, kMaxCHIPCertLength, kP256_PublicKey_Length, kP256_PublicKey_Length); + size_t msgR2SignedLen = EstimateStructOverhead(kMaxCHIPCertLength, // responderNoc + kMaxCHIPCertLength, // responderICAC + kP256_PublicKey_Length, // responderEphPubKey + kP256_PublicKey_Length // InitiatorEphPubKey + ); - chip::Platform::ScopedMemoryBuffer<uint8_t> msg_R2_Signed; - VerifyOrReturnError(msg_R2_Signed.Alloc(msg_r2_signed_len), CHIP_ERROR_NO_MEMORY); - - ReturnErrorOnFailure(ConstructTBSData(nocCert, icaCert, ByteSpan(mEphemeralKey->Pubkey(), mEphemeralKey->Pubkey().Length()), - ByteSpan(mRemotePubKey, mRemotePubKey.Length()), msg_R2_Signed.Get(), msg_r2_signed_len)); - - // Generate a Signature P256ECDSASignature tbsData2Signature; - ReturnErrorOnFailure( - mFabricsTable->SignWithOpKeypair(mFabricIndex, ByteSpan{ msg_R2_Signed.Get(), msg_r2_signed_len }, tbsData2Signature)); - msg_R2_Signed.Free(); + { + chip::Platform::ScopedMemoryBuffer<uint8_t> msgR2Signed; + VerifyOrReturnError(msgR2Signed.Alloc(msgR2SignedLen), CHIP_ERROR_NO_MEMORY); + + ReturnErrorOnFailure(ConstructTBSData(nocCert, icaCert, ByteSpan(mEphemeralKey->Pubkey(), mEphemeralKey->Pubkey().Length()), + ByteSpan(mRemotePubKey, mRemotePubKey.Length()), msgR2Signed.Get(), msgR2SignedLen)); + // Generate a Signature + ReturnErrorOnFailure( + mFabricsTable->SignWithOpKeypair(mFabricIndex, ByteSpan{ msgR2Signed.Get(), msgR2SignedLen }, tbsData2Signature)); + } // Construct Sigma2 TBE Data - size_t msg_r2_signed_enc_len = TLV::EstimateStructOverhead(nocCert.size(), icaCert.size(), tbsData2Signature.Length(), - SessionResumptionStorage::kResumptionIdSize); + size_t msgR2SignedEncLen = EstimateStructOverhead(nocCert.size(), // responderNoc + icaCert.size(), // responderICAC + tbsData2Signature.Length(), // signature + SessionResumptionStorage::kResumptionIdSize // resumptionID + ); - chip::Platform::ScopedMemoryBuffer<uint8_t> msg_R2_Encrypted; - VerifyOrReturnError(msg_R2_Encrypted.Alloc(msg_r2_signed_enc_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES), CHIP_ERROR_NO_MEMORY); + VerifyOrReturnError(outSigma2Data.msgR2Encrypted.Alloc(msgR2SignedEncLen + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES), + CHIP_ERROR_NO_MEMORY); - TLV::TLVWriter tlvWriter; - TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + TLVWriter tlvWriter; + tlvWriter.Init(outSigma2Data.msgR2Encrypted.Get(), msgR2SignedEncLen); - tlvWriter.Init(msg_R2_Encrypted.Get(), msg_r2_signed_enc_len); - ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), nocCert)); + TLVType outerContainerType = kTLVType_NotSpecified; + + ReturnErrorOnFailure(tlvWriter.StartContainer(AnonymousTag(), kTLVType_Structure, outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBEDataTags::kSenderNOC), nocCert)); if (!icaCert.empty()) { - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), icaCert)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBEDataTags::kSenderICAC), icaCert)); } // We are now done with ICAC and NOC certs so we can release the memory. @@ -1210,53 +1310,81 @@ CHIP_ERROR CASESession::SendSigma2() nocCert = MutableByteSpan{}; } - ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), tbsData2Signature.ConstBytes(), + ReturnErrorOnFailure(tlvWriter.PutBytes(AsTlvContextTag(TBEDataTags::kSignature), tbsData2Signature.ConstBytes(), static_cast<uint32_t>(tbsData2Signature.Length()))); // Generate a new resumption ID ReturnErrorOnFailure(DRBG_get_bytes(mNewResumptionId.data(), mNewResumptionId.size())); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_ResumptionID), mNewResumptionId)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBEDataTags::kResumptionID), mNewResumptionId)); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize()); - msg_r2_signed_enc_len = static_cast<size_t>(tlvWriter.GetLengthWritten()); - + msgR2SignedEncLen = static_cast<size_t>(tlvWriter.GetLengthWritten()); + outSigma2Data.encrypted2Length = msgR2SignedEncLen + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES; // Generate the encrypted data blob - ReturnErrorOnFailure(AES_CCM_encrypt(msg_R2_Encrypted.Get(), msg_r2_signed_enc_len, nullptr, 0, sr2k.KeyHandle(), - kTBEData2_Nonce, kTBEDataNonceLength, msg_R2_Encrypted.Get(), - msg_R2_Encrypted.Get() + msg_r2_signed_enc_len, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES)); + ReturnErrorOnFailure(AES_CCM_encrypt(outSigma2Data.msgR2Encrypted.Get(), msgR2SignedEncLen, nullptr, 0, sr2k.KeyHandle(), + kTBEData2_Nonce, kTBEDataNonceLength, outSigma2Data.msgR2Encrypted.Get(), + outSigma2Data.msgR2Encrypted.Get() + msgR2SignedEncLen, + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES)); + + outSigma2Data.responderMrpConfig = &mLocalMRPConfig.Value(); - // Construct Sigma2 Msg - size_t size_of_local_session_id = sizeof(uint16_t); - size_t data_len = - TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, size_of_local_session_id, kP256_PublicKey_Length, - msg_r2_signed_enc_len, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, SessionParameters::kEstimatedTLVSize); + return CHIP_NO_ERROR; +} - System::PacketBufferHandle msg_R2 = System::PacketBufferHandle::New(data_len); - VerifyOrReturnError(!msg_R2.IsNull(), CHIP_ERROR_NO_MEMORY); +CHIP_ERROR CASESession::EncodeSigma2(System::PacketBufferHandle & msgR2, EncodeSigma2Inputs & input) +{ + VerifyOrReturnError(input.responderEphPubKey != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(input.msgR2Encrypted, CHIP_ERROR_INCORRECT_STATE); + // Check if length of msgR2Encrypted is set and is at least larger than the MIC length + VerifyOrReturnError(input.encrypted2Length > CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(input.responderMrpConfig != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + + size_t dataLen = EstimateStructOverhead(kSigmaParamRandomNumberSize, // responderRandom + sizeof(uint16_t), // responderSessionId + kP256_PublicKey_Length, // responderEphPubKey + input.encrypted2Length, // encrypted2 + SessionParameters::kEstimatedTLVSize // responderSessionParams + ); + + msgR2 = System::PacketBufferHandle::New(dataLen); + VerifyOrReturnError(!msgR2.IsNull(), CHIP_ERROR_NO_MEMORY); System::PacketBufferTLVWriter tlvWriterMsg2; - outerContainerType = TLV::kTLVType_NotSpecified; + tlvWriterMsg2.Init(std::move(msgR2)); - tlvWriterMsg2.Init(std::move(msg_R2)); - ReturnErrorOnFailure(tlvWriterMsg2.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(TLV::ContextTag(1), &msg_rand[0], sizeof(msg_rand))); - ReturnErrorOnFailure(tlvWriterMsg2.Put(TLV::ContextTag(2), GetLocalSessionId().Value())); - ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(TLV::ContextTag(3), mEphemeralKey->Pubkey(), - static_cast<uint32_t>(mEphemeralKey->Pubkey().Length()))); - ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(TLV::ContextTag(4), msg_R2_Encrypted.Get(), - static_cast<uint32_t>(msg_r2_signed_enc_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES))); + TLVType outerContainerType = kTLVType_NotSpecified; - VerifyOrReturnError(mLocalMRPConfig.HasValue(), CHIP_ERROR_INCORRECT_STATE); - ReturnErrorOnFailure(EncodeSessionParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriterMsg2)); + ReturnErrorOnFailure(tlvWriterMsg2.StartContainer(AnonymousTag(), kTLVType_Structure, outerContainerType)); + + ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(AsTlvContextTag(Sigma2Tags::kResponderRandom), &input.responderRandom[0], + sizeof(input.responderRandom))); + ReturnErrorOnFailure(tlvWriterMsg2.Put(AsTlvContextTag(Sigma2Tags::kResponderSessionId), input.responderSessionId)); + + ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(AsTlvContextTag(Sigma2Tags::kResponderEphPubKey), *input.responderEphPubKey, + static_cast<uint32_t>(input.responderEphPubKey->Length()))); + + ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(AsTlvContextTag(Sigma2Tags::kEncrypted2), input.msgR2Encrypted.Get(), + static_cast<uint32_t>(input.encrypted2Length))); + input.msgR2Encrypted.Free(); + + ReturnErrorOnFailure( + EncodeSessionParameters(AsTlvContextTag(Sigma2Tags::kResponderSessionParams), *input.responderMrpConfig, tlvWriterMsg2)); ReturnErrorOnFailure(tlvWriterMsg2.EndContainer(outerContainerType)); - ReturnErrorOnFailure(tlvWriterMsg2.Finalize(&msg_R2)); + ReturnErrorOnFailure(tlvWriterMsg2.Finalize(&msgR2)); + + return CHIP_NO_ERROR; +} - ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R2->Start(), msg_R2->DataLength() })); +CHIP_ERROR CASESession::SendSigma2(System::PacketBufferHandle && msgR2) +{ + MATTER_TRACE_SCOPE("SendSigma2", "CASESession"); + + ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msgR2->Start(), msgR2->DataLength() })); // Call delegate to send the msg to peer - ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2, std::move(msg_R2), + ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2, std::move(msgR2), SendFlags(SendMessageFlags::kExpectResponse))); mState = State::kSentSigma2; @@ -1406,18 +1534,18 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) SuccessOrExit(err = tlvReader.EnterContainer(containerType)); // Retrieve Responder's Random value - SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_Sigma2_ResponderRandom))); + SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(Sigma2Tags::kResponderRandom))); SuccessOrExit(err = tlvReader.GetBytes(responderRandom, sizeof(responderRandom))); // Assign Session ID - SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_UnsignedInteger, TLV::ContextTag(kTag_Sigma2_ResponderSessionId))); + SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_UnsignedInteger, AsTlvContextTag(Sigma2Tags::kResponderSessionId))); SuccessOrExit(err = tlvReader.Get(responderSessionId)); ChipLogDetail(SecureChannel, "Peer assigned session session ID %d", responderSessionId); SetPeerSessionId(responderSessionId); // Retrieve Responder's Ephemeral Pubkey - SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_Sigma2_ResponderEphPubKey))); + SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(Sigma2Tags::kResponderEphPubKey))); SuccessOrExit(err = tlvReader.GetBytes(mRemotePubKey, static_cast<uint32_t>(mRemotePubKey.Length()))); // Generate a Shared Secret @@ -1433,7 +1561,7 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ buf, buflen })); // Generate decrypted data - SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_Sigma2_Encrypted2))); + SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(Sigma2Tags::kEncrypted2))); max_msg_r2_signed_enc_len = TLV::EstimateStructOverhead(Credentials::kMaxCHIPCertLength, Credentials::kMaxCHIPCertLength, tbsData2Signature.Length(), @@ -1457,15 +1585,15 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) SuccessOrExit(err = decryptedDataTlvReader.Next(containerType, TLV::AnonymousTag())); SuccessOrExit(err = decryptedDataTlvReader.EnterContainer(containerType)); - SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_SenderNOC))); + SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(TBEDataTags::kSenderNOC))); SuccessOrExit(err = decryptedDataTlvReader.Get(responderNOC)); SuccessOrExit(err = decryptedDataTlvReader.Next()); - if (TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == kTag_TBEData_SenderICAC) + if (decryptedDataTlvReader.GetTag() == AsTlvContextTag(TBEDataTags::kSenderICAC)) { VerifyOrExit(decryptedDataTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); SuccessOrExit(err = decryptedDataTlvReader.Get(responderICAC)); - SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_Signature))); + SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(TBEDataTags::kSignature))); } // Validate responder identity located in msg_r2_encrypted @@ -1492,7 +1620,7 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) ByteSpan(mEphemeralKey->Pubkey(), mEphemeralKey->Pubkey().Length()), msg_R2_Signed.Get(), msg_r2_signed_len)); - VerifyOrExit(TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == kTag_TBEData_Signature, err = CHIP_ERROR_INVALID_TLV_TAG); + VerifyOrExit(decryptedDataTlvReader.GetTag() == AsTlvContextTag(TBEDataTags::kSignature), err = CHIP_ERROR_INVALID_TLV_TAG); VerifyOrExit(tbsData2Signature.Capacity() >= decryptedDataTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); tbsData2Signature.SetLength(decryptedDataTlvReader.GetLength()); SuccessOrExit(err = decryptedDataTlvReader.GetBytes(tbsData2Signature.Bytes(), tbsData2Signature.Length())); @@ -1501,7 +1629,7 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) SuccessOrExit(err = responderPublicKey.ECDSA_validate_msg_signature(msg_R2_Signed.Get(), msg_r2_signed_len, tbsData2Signature)); // Retrieve session resumption ID - SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_ResumptionID))); + SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(TBEDataTags::kResumptionID))); SuccessOrExit(err = decryptedDataTlvReader.GetBytes(mNewResumptionId.data(), mNewResumptionId.size())); // Retrieve peer CASE Authenticated Tags (CATs) from peer's NOC. @@ -1510,7 +1638,7 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) // Retrieve responderMRPParams if present if (tlvReader.Next() != CHIP_END_OF_TLV) { - SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(kTag_Sigma2_ResponderMRPParams), tlvReader)); + SuccessOrExit(err = DecodeMRPParametersIfPresent(AsTlvContextTag(Sigma2Tags::kResponderSessionParams), tlvReader)); mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters( GetRemoteSessionParameters()); } @@ -1629,10 +1757,10 @@ CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel) tlvWriter.Init(data.msg_R3_Encrypted.Get(), data.msg_r3_encrypted_len); ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), data.nocCert)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBEDataTags::kSenderNOC), data.nocCert)); if (!data.icaCert.empty()) { - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), data.icaCert)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBEDataTags::kSenderICAC), data.icaCert)); } // We are now done with ICAC and NOC certs so we can release the memory. @@ -1644,7 +1772,7 @@ CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel) data.nocCert = MutableByteSpan{}; } - ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), data.tbsData3Signature.ConstBytes(), + ReturnErrorOnFailure(tlvWriter.PutBytes(AsTlvContextTag(TBEDataTags::kSignature), data.tbsData3Signature.ConstBytes(), static_cast<uint32_t>(data.tbsData3Signature.Length()))); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize()); @@ -1785,7 +1913,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) max_msg_r3_signed_enc_len = TLV::EstimateStructOverhead(Credentials::kMaxCHIPCertLength, Credentials::kMaxCHIPCertLength, data.tbsData3Signature.Length(), kCaseOverheadForFutureTbeData); - SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_Sigma3_Encrypted3))); + SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(Sigma3Tags::kEncrypted3))); msg_r3_encrypted_len_with_tag = tlvReader.GetLength(); @@ -1816,15 +1944,15 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) SuccessOrExit(err = decryptedDataTlvReader.Next(containerType, TLV::AnonymousTag())); SuccessOrExit(err = decryptedDataTlvReader.EnterContainer(containerType)); - SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_SenderNOC))); + SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(TBEDataTags::kSenderNOC))); SuccessOrExit(err = decryptedDataTlvReader.Get(data.initiatorNOC)); SuccessOrExit(err = decryptedDataTlvReader.Next()); - if (TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == kTag_TBEData_SenderICAC) + if (decryptedDataTlvReader.GetTag() == AsTlvContextTag(TBEDataTags::kSenderICAC)) { VerifyOrExit(decryptedDataTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); SuccessOrExit(err = decryptedDataTlvReader.Get(data.initiatorICAC)); - SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_Signature))); + SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(TBEDataTags::kSignature))); } // Step 4 - Construct Sigma3 TBS Data @@ -1837,8 +1965,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) ByteSpan(mEphemeralKey->Pubkey(), mEphemeralKey->Pubkey().Length()), data.msg_R3_Signed.Get(), data.msg_r3_signed_len)); - VerifyOrExit(TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == kTag_TBEData_Signature, - err = CHIP_ERROR_INVALID_TLV_TAG); + VerifyOrExit(decryptedDataTlvReader.GetTag() == AsTlvContextTag(TBEDataTags::kSignature), err = CHIP_ERROR_INVALID_TLV_TAG); VerifyOrExit(data.tbsData3Signature.Capacity() >= decryptedDataTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); data.tbsData3Signature.SetLength(decryptedDataTlvReader.GetLength()); SuccessOrExit(err = decryptedDataTlvReader.GetBytes(data.tbsData3Signature.Bytes(), data.tbsData3Signature.Length())); @@ -1864,12 +1991,12 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())); SuccessOrExit(err = signedDataTlvReader.EnterContainer(containerType)); - SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBSData_SenderNOC))); + SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(TBSDataTags::kSenderNOC))); SuccessOrExit(err = signedDataTlvReader.Get(data.initiatorNOC)); if (!data.initiatorICAC.empty()) { - SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBSData_SenderICAC))); + SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_ByteString, AsTlvContextTag(TBSDataTags::kSenderICAC))); SuccessOrExit(err = signedDataTlvReader.Get(data.initiatorICAC)); } } @@ -2060,13 +2187,13 @@ CHIP_ERROR CASESession::ConstructTBSData(const ByteSpan & senderNOC, const ByteS tlvWriter.Init(tbsData, tbsDataLen); ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBSData_SenderNOC), senderNOC)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBSDataTags::kSenderNOC), senderNOC)); if (!senderICAC.empty()) { - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBSData_SenderICAC), senderICAC)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBSDataTags::kSenderICAC), senderICAC)); } - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBSData_SenderPubKey), senderPubKey)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBSData_ReceiverPubKey), receiverPubKey)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBSDataTags::kSenderPubKey), senderPubKey)); + ReturnErrorOnFailure(tlvWriter.Put(AsTlvContextTag(TBSDataTags::kReceiverPubKey), receiverPubKey)); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize()); tbsDataLen = static_cast<size_t>(tlvWriter.GetLengthWritten()); @@ -2171,65 +2298,56 @@ CHIP_ERROR CASESession::OnFailureStatusReport(Protocols::SecureChannel::GeneralS return err; } -CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, ByteSpan & initiatorRandom, - uint16_t & initiatorSessionId, ByteSpan & destinationId, ByteSpan & initiatorEphPubKey, - bool & resumptionRequested, ByteSpan & resumptionId, ByteSpan & initiatorResumeMIC) +CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, ParsedSigma1 & outParsedSigma1) { - using namespace TLV; - - constexpr uint8_t kInitiatorRandomTag = 1; - constexpr uint8_t kInitiatorSessionIdTag = 2; - constexpr uint8_t kDestinationIdTag = 3; - constexpr uint8_t kInitiatorPubKeyTag = 4; - constexpr uint8_t kInitiatorMRPParamsTag = 5; - constexpr uint8_t kResumptionIDTag = 6; - constexpr uint8_t kResume1MICTag = 7; TLVType containerType = kTLVType_Structure; ReturnErrorOnFailure(tlvReader.Next(containerType, AnonymousTag())); ReturnErrorOnFailure(tlvReader.EnterContainer(containerType)); - ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorRandomTag))); - ReturnErrorOnFailure(tlvReader.GetByteView(initiatorRandom)); - VerifyOrReturnError(initiatorRandom.size() == kSigmaParamRandomNumberSize, CHIP_ERROR_INVALID_CASE_PARAMETER); + ReturnErrorOnFailure(tlvReader.Next(AsTlvContextTag(Sigma1Tags::kInitiatorRandom))); + ReturnErrorOnFailure(tlvReader.GetByteView(outParsedSigma1.initiatorRandom)); + VerifyOrReturnError(outParsedSigma1.initiatorRandom.size() == kSigmaParamRandomNumberSize, CHIP_ERROR_INVALID_CASE_PARAMETER); - ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorSessionIdTag))); - ReturnErrorOnFailure(tlvReader.Get(initiatorSessionId)); + ReturnErrorOnFailure(tlvReader.Next(AsTlvContextTag(Sigma1Tags::kInitiatorSessionId))); + ReturnErrorOnFailure(tlvReader.Get(outParsedSigma1.initiatorSessionId)); - ReturnErrorOnFailure(tlvReader.Next(ContextTag(kDestinationIdTag))); - ReturnErrorOnFailure(tlvReader.GetByteView(destinationId)); - VerifyOrReturnError(destinationId.size() == kSHA256_Hash_Length, CHIP_ERROR_INVALID_CASE_PARAMETER); + ReturnErrorOnFailure(tlvReader.Next(AsTlvContextTag(Sigma1Tags::kDestinationId))); + ReturnErrorOnFailure(tlvReader.GetByteView(outParsedSigma1.destinationId)); + VerifyOrReturnError(outParsedSigma1.destinationId.size() == kSHA256_Hash_Length, CHIP_ERROR_INVALID_CASE_PARAMETER); - ReturnErrorOnFailure(tlvReader.Next(ContextTag(kInitiatorPubKeyTag))); - ReturnErrorOnFailure(tlvReader.GetByteView(initiatorEphPubKey)); - VerifyOrReturnError(initiatorEphPubKey.size() == kP256_PublicKey_Length, CHIP_ERROR_INVALID_CASE_PARAMETER); + ReturnErrorOnFailure(tlvReader.Next(AsTlvContextTag(Sigma1Tags::kInitiatorEphPubKey))); + ReturnErrorOnFailure(tlvReader.GetByteView(outParsedSigma1.initiatorEphPubKey)); + VerifyOrReturnError(outParsedSigma1.initiatorEphPubKey.size() == kP256_PublicKey_Length, CHIP_ERROR_INVALID_CASE_PARAMETER); // Optional members start here. CHIP_ERROR err = tlvReader.Next(); - if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kInitiatorMRPParamsTag)) + if (err == CHIP_NO_ERROR && tlvReader.GetTag() == AsTlvContextTag(Sigma1Tags::kInitiatorSessionParams)) { - ReturnErrorOnFailure(DecodeMRPParametersIfPresent(TLV::ContextTag(kInitiatorMRPParamsTag), tlvReader)); - mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters( - GetRemoteSessionParameters()); + ReturnErrorOnFailure(DecodeMRPParametersIfPresent(AsTlvContextTag(Sigma1Tags::kInitiatorSessionParams), tlvReader)); + outParsedSigma1.initiatorMrpParamsPresent = true; + err = tlvReader.Next(); } bool resumptionIDTagFound = false; bool resume1MICTagFound = false; - if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kResumptionIDTag)) + if (err == CHIP_NO_ERROR && tlvReader.GetTag() == AsTlvContextTag(Sigma1Tags::kResumptionID)) { resumptionIDTagFound = true; - ReturnErrorOnFailure(tlvReader.GetByteView(resumptionId)); - VerifyOrReturnError(resumptionId.size() == SessionResumptionStorage::kResumptionIdSize, CHIP_ERROR_INVALID_CASE_PARAMETER); + ReturnErrorOnFailure(tlvReader.GetByteView(outParsedSigma1.resumptionId)); + VerifyOrReturnError(outParsedSigma1.resumptionId.size() == SessionResumptionStorage::kResumptionIdSize, + CHIP_ERROR_INVALID_CASE_PARAMETER); err = tlvReader.Next(); } - if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kResume1MICTag)) + if (err == CHIP_NO_ERROR && tlvReader.GetTag() == AsTlvContextTag(Sigma1Tags::kResume1MIC)) { resume1MICTagFound = true; - ReturnErrorOnFailure(tlvReader.GetByteView(initiatorResumeMIC)); - VerifyOrReturnError(initiatorResumeMIC.size() == CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, CHIP_ERROR_INVALID_CASE_PARAMETER); + ReturnErrorOnFailure(tlvReader.GetByteView(outParsedSigma1.initiatorResumeMIC)); + VerifyOrReturnError(outParsedSigma1.initiatorResumeMIC.size() == CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, + CHIP_ERROR_INVALID_CASE_PARAMETER); err = tlvReader.Next(); } @@ -2244,11 +2362,11 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, if (resumptionIDTagFound && resume1MICTagFound) { - resumptionRequested = true; + outParsedSigma1.sessionResumptionRequested = true; } else if (!resumptionIDTagFound && !resume1MICTagFound) { - resumptionRequested = false; + outParsedSigma1.sessionResumptionRequested = false; } else { diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 045d1982dd723c..a9d43194d0f56d 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -44,6 +44,7 @@ #include <protocols/secure_channel/SessionResumptionStorage.h> #include <system/SystemClock.h> #include <system/SystemPacketBuffer.h> +#include <system/TLVPacketBufferBackingStore.h> #include <transport/CryptoContext.h> #include <transport/raw/MessageHeader.h> #include <transport/raw/PeerAddress.h> @@ -117,28 +118,6 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, */ void SetGroupDataProvider(Credentials::GroupDataProvider * groupDataProvider) { mGroupDataProvider = groupDataProvider; } - /** - * Parse a sigma1 message. This function will return success only if the - * message passes schema checks. Specifically: - * * The tags come in order. - * * The required tags are present. - * * The values for the tags that are present satisfy schema requirements - * (e.g. constraints on octet string lengths) - * * Either resumptionID and initiatorResume1MIC are both present or both - * absent. - * - * On success, the initiatorRandom, initiatorSessionId, destinationId, - * initiatorEphPubKey outparams will be set to the corresponding values in - * the message. - * - * On success, either the resumptionRequested outparam will be set to true - * and the resumptionID and initiatorResumeMIC outparams will be set to - * valid values, or the resumptionRequested outparam will be set to false. - */ - CHIP_ERROR ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, ByteSpan & initiatorRandom, uint16_t & initiatorSessionId, - ByteSpan & destinationId, ByteSpan & initiatorEphPubKey, bool & resumptionRequested, - ByteSpan & resumptionId, ByteSpan & initiatorResumeMIC); - /** * @brief * Derive a secure session from the established session. The API will return error if called before session is established. @@ -211,6 +190,118 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, // If this function returns true, the CASE session has been reset and is ready for a new session establishment. bool InvokeBackgroundWorkWatchdog(); +protected: + // Helper Enum for use in HandleSigma1_and_SendSigma2 + enum class Step : uint8_t + { + kSendSigma2, + kSendSigma2Resume, + }; + // Making NextStep a Variant allows HandleSigma() to return either a Step value (indicating + // the next Sigma step to send) or a CHIP_ERROR (indicating a failure that will trigger + // a Status Report). + using NextStep = Variant<Step, CHIP_ERROR>; + // This struct only serves as a base struct for EncodeSigma1Inputs and ParsedSigma1 + struct Sigma1Param + { + ByteSpan initiatorRandom; + uint16_t initiatorSessionId; + ByteSpan destinationId; + bool sessionResumptionRequested = false; + ByteSpan resumptionId; + ByteSpan initiatorResumeMIC; + }; + + struct EncodeSigma1Inputs : Sigma1Param + { + const Crypto::P256PublicKey * initiatorEphPubKey = nullptr; + const ReliableMessageProtocolConfig * initiatorMrpConfig = nullptr; + uint8_t initiatorResume1MICBuffer[Crypto::CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES]; + }; + + struct ParsedSigma1 : Sigma1Param + { + ByteSpan initiatorEphPubKey; + bool initiatorMrpParamsPresent = false; + }; + + struct EncodeSigma2Inputs + { + uint8_t responderRandom[kSigmaParamRandomNumberSize]; + uint16_t responderSessionId; + const Crypto::P256PublicKey * responderEphPubKey = nullptr; + // ScopedMemoryBufferWithSize is not used for msgR2Encrypted since encrypted2Length might differ from the allocated buffer + // size + Platform::ScopedMemoryBuffer<uint8_t> msgR2Encrypted; + size_t encrypted2Length = 0; + const ReliableMessageProtocolConfig * responderMrpConfig; + }; + + struct EncodeSigma2ResumeInputs + { + ByteSpan resumptionId; + uint8_t sigma2ResumeMICBuffer[Crypto::CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES]; + MutableByteSpan resumeMIC{ sigma2ResumeMICBuffer }; + uint16_t responderSessionId; + const ReliableMessageProtocolConfig * responderMrpConfig; + }; + + /** + * @brief Encodes a Sigma1 message into TLV format and allocates a buffer for it, which is owned by the PacketBufferHandle + * outparam. + * + * @param outMsg PacketBufferHandle passed by reference. A new buffer will be allocated and assigned to it within the + * method. + * + * @param inParam a struct containing all the values that will be encoded into TLV format + * + **/ + static CHIP_ERROR EncodeSigma1(System::PacketBufferHandle & outMsg, EncodeSigma1Inputs & inParam); + + /** + * Parse a Sigma1 message. This function will return success only if the + * message passes schema checks. Specifically: + * * The tags come in order. + * * The required tags are present. + * * The values for the tags that are present satisfy schema requirements + * (e.g. constraints on octet string lengths) + * * Either resumptionID and initiatorResume1MICBuffer are both present or both are + * absent. + * + * On success, the members of outParam will be set to the values corresponding to the message. + * These values will be valid as long as the buffer that the passed-in tlvReader is reading from is valid. + * + * On success, either the sessionResumptionRequested field will be set to true + * and the resumptionID and initiatorResumeMIC fields will be set to + * valid values, or the sessionResumptionRequested field will be set to false. + */ + CHIP_ERROR ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, ParsedSigma1 & parsedMessage); + + /** + * @brief Encodes a Sigma2 message into TLV format and allocates a buffer for it, which is owned by the PacketBufferHandle + * outparam. + * + * @param outMsg PacketBufferHandle passed by reference. A new buffer will be allocated and assigned to it within the + * method. + * + * @param inParam a struct containing all the values that will be encoded into TLV format + * + **/ + + static CHIP_ERROR EncodeSigma2(System::PacketBufferHandle & outMsg, EncodeSigma2Inputs & inParam); + + /** + * @brief Encodes a Sigma2_Resume message into TLV format and allocates a buffer for it, which is owned by the + * PacketBufferHandle outparam. + * + * @param outMsg PacketBufferHandle passed by reference. A new buffer will be allocated and assigned to it within the + * method. + * + * @param inParam a struct containing all the values that will be encoded into TLV format + * + **/ + static CHIP_ERROR EncodeSigma2Resume(System::PacketBufferHandle & outMsg, EncodeSigma2ResumeInputs & inParam); + private: friend class TestCASESession; @@ -236,10 +327,15 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, CHIP_ERROR SendSigma1(); CHIP_ERROR HandleSigma1_and_SendSigma2(System::PacketBufferHandle && msg); - CHIP_ERROR HandleSigma1(System::PacketBufferHandle && msg); + NextStep HandleSigma1(System::PacketBufferHandle && msg); CHIP_ERROR TryResumeSession(SessionResumptionStorage::ConstResumptionIdView resumptionId, ByteSpan resume1MIC, ByteSpan initiatorRandom); - CHIP_ERROR SendSigma2(); + + CHIP_ERROR PrepareSigma2(EncodeSigma2Inputs & output); + CHIP_ERROR PrepareSigma2Resume(EncodeSigma2ResumeInputs & output); + CHIP_ERROR SendSigma2(System::PacketBufferHandle && msg_R2); + CHIP_ERROR SendSigma2Resume(System::PacketBufferHandle && msg_R2_resume); + CHIP_ERROR HandleSigma2_and_SendSigma3(System::PacketBufferHandle && msg); CHIP_ERROR HandleSigma2(System::PacketBufferHandle && msg); CHIP_ERROR HandleSigma2Resume(System::PacketBufferHandle && msg); diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 67b4cdd7d69faf..81a2142c716021 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -59,6 +59,23 @@ using namespace chip::Crypto; namespace chip { class TestCASESecurePairingDelegate; +// Exposing CASESession's Protected members in order to be able to call the protected methods, and instantiate protected structures. +// Also to be able to instantiate New CASESessions repeatedly inside a single TestCase (which is not possible if we inherit +// CASESession in the Test Fixture) +class CASESessionAccess : public CASESession +{ +public: + using CASESession::EncodeSigma1Inputs; + using CASESession::EncodeSigma2Inputs; + using CASESession::EncodeSigma2ResumeInputs; + using CASESession::ParsedSigma1; + + using CASESession::EncodeSigma1; + using CASESession::EncodeSigma2; + using CASESession::EncodeSigma2Resume; + using CASESession::ParseSigma1; +}; + class TestCASESession : public Test::LoopbackMessagingContext { public: @@ -614,24 +631,24 @@ struct Sigma1Params { // Purposefully not using constants like kSigmaParamRandomNumberSize that // the code uses, so we have a cross-check. - static constexpr size_t initiatorRandomLen = 32; - static constexpr uint16_t initiatorSessionId = 0; - static constexpr size_t destinationIdLen = 32; - static constexpr size_t initiatorEphPubKeyLen = 65; - static constexpr size_t resumptionIdLen = 0; // Nonzero means include it. - static constexpr size_t initiatorResumeMICLen = 0; // Nonzero means include it. - - static constexpr uint8_t initiatorRandomTag = 1; - static constexpr uint8_t initiatorSessionIdTag = 2; - static constexpr uint8_t destinationIdTag = 3; - static constexpr uint8_t initiatorEphPubKeyTag = 4; - static constexpr uint8_t resumptionIdTag = 6; - static constexpr uint8_t initiatorResumeMICTag = 7; + static constexpr size_t kInitiatorRandomLen = 32; + static constexpr uint16_t kInitiatorSessionId = 0; + static constexpr size_t kDestinationIdLen = 32; + static constexpr size_t kInitiatorEphPubKeyLen = 65; + static constexpr size_t kResumptionIdLen = 0; // Nonzero means include it. + static constexpr size_t kInitiatorResumeMICLen = 0; // Nonzero means include it. + + static constexpr uint8_t kInitiatorRandomTag = 1; + static constexpr uint8_t kInitiatorSessionIdTag = 2; + static constexpr uint8_t kDestinationIdTag = 3; + static constexpr uint8_t kInitiatorEphPubKeyTag = 4; + static constexpr uint8_t kResumptionIdTag = 6; + static constexpr uint8_t kInitiatorResumeMICTag = 7; static constexpr TLV::Tag NumToTag(uint8_t num) { return TLV::ContextTag(num); } - static constexpr bool includeStructEnd = true; + static constexpr bool kIncludeStructEnd = true; - static constexpr bool expectSuccess = true; + static constexpr bool kExpectSuccess = true; }; TEST_F(TestCASESession, DestinationIdTest) @@ -649,10 +666,10 @@ TEST_F(TestCASESession, DestinationIdTest) 0x9b, 0xc6, 0x1c, 0xd9, 0xc6, 0x2a, 0x2d, 0xf6, 0xd6, 0x4d, 0xfc, 0xaa, 0x9d, 0xc4, 0x72, 0xd4 }; - const uint8_t kInitiatorRandomFromSpec[Sigma1Params::initiatorRandomLen] = { 0x7e, 0x17, 0x12, 0x31, 0x56, 0x8d, 0xfa, 0x17, - 0x20, 0x6b, 0x3a, 0xcc, 0xf8, 0xfa, 0xec, 0x2f, - 0x4d, 0x21, 0xb5, 0x80, 0x11, 0x31, 0x96, 0xf4, - 0x7c, 0x7c, 0x4d, 0xeb, 0x81, 0x0a, 0x73, 0xdc }; + const uint8_t kInitiatorRandomFromSpec[Sigma1Params::kInitiatorRandomLen] = { 0x7e, 0x17, 0x12, 0x31, 0x56, 0x8d, 0xfa, 0x17, + 0x20, 0x6b, 0x3a, 0xcc, 0xf8, 0xfa, 0xec, 0x2f, + 0x4d, 0x21, 0xb5, 0x80, 0x11, 0x31, 0x96, 0xf4, + 0x7c, 0x7c, 0x4d, 0xeb, 0x81, 0x0a, 0x73, 0xdc }; const uint8_t kExpectedDestinationIdFromSpec[Crypto::kSHA256_Hash_Length] = { 0xdc, 0x35, 0xdd, 0x5f, 0xc9, 0x13, 0x4c, 0xc5, 0x54, 0x45, 0x38, 0xc9, 0xc3, 0xfc, 0x42, 0x97, @@ -686,7 +703,7 @@ TEST_F(TestCASESession, DestinationIdTest) } template <typename Params> -static CHIP_ERROR EncodeSigma1(MutableByteSpan & buf) +static CHIP_ERROR EncodeSigma1Helper(MutableByteSpan & buf) { using namespace TLV; @@ -695,39 +712,37 @@ static CHIP_ERROR EncodeSigma1(MutableByteSpan & buf) TLVType containerType; ReturnErrorOnFailure(writer.StartContainer(AnonymousTag(), kTLVType_Structure, containerType)); - uint8_t initiatorRandom[Params::initiatorRandomLen] = { 1 }; - ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::initiatorRandomTag), ByteSpan(initiatorRandom))); + uint8_t initiatorRandom[Params::kInitiatorRandomLen] = { 1 }; + ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::kInitiatorRandomTag), ByteSpan(initiatorRandom))); - ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::initiatorSessionIdTag), Params::initiatorSessionId)); + ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::kInitiatorSessionIdTag), Params::kInitiatorSessionId)); - uint8_t destinationId[Params::destinationIdLen] = { 2 }; - ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::destinationIdTag), ByteSpan(destinationId))); + uint8_t destinationId[Params::kDestinationIdLen] = { 2 }; + ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::kDestinationIdTag), ByteSpan(destinationId))); - uint8_t initiatorEphPubKey[Params::initiatorEphPubKeyLen] = { 3 }; - ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::initiatorEphPubKeyTag), ByteSpan(initiatorEphPubKey))); + uint8_t initiatorEphPubKey[Params::kInitiatorEphPubKeyLen] = { 3 }; + ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::kInitiatorEphPubKeyTag), ByteSpan(initiatorEphPubKey))); - // I wish we had "if constexpr" support here, so the compiler would know - // resumptionIdLen is nonzero inside the block.... - if (Params::resumptionIdLen != 0) + if constexpr (Params::kResumptionIdLen != 0) { - uint8_t resumptionId[Params::resumptionIdLen]; + uint8_t resumptionId[Params::kResumptionIdLen]; // to fix _FORTIFY_SOURCE issue, _FORTIFY_SOURCE=2 by default on Android - (&memset)(resumptionId, 4, Params::resumptionIdLen); + (&memset)(resumptionId, 4, Params::kResumptionIdLen); ReturnErrorOnFailure( - writer.Put(Params::NumToTag(Params::resumptionIdTag), ByteSpan(resumptionId, Params::resumptionIdLen))); + writer.Put(Params::NumToTag(Params::kResumptionIdTag), ByteSpan(resumptionId, Params::kResumptionIdLen))); } - if (Params::initiatorResumeMICLen != 0) + if constexpr (Params::kInitiatorResumeMICLen != 0) { - uint8_t initiatorResumeMIC[Params::initiatorResumeMICLen]; + uint8_t initiatorResumeMIC[Params::kInitiatorResumeMICLen]; // to fix _FORTIFY_SOURCE issue, _FORTIFY_SOURCE=2 by default on Android - (&memset)(initiatorResumeMIC, 5, Params::initiatorResumeMICLen); - ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::initiatorResumeMICTag), - ByteSpan(initiatorResumeMIC, Params::initiatorResumeMICLen))); + (&memset)(initiatorResumeMIC, 5, Params::kInitiatorResumeMICLen); + ReturnErrorOnFailure(writer.Put(Params::NumToTag(Params::kInitiatorResumeMICTag), + ByteSpan(initiatorResumeMIC, Params::kInitiatorResumeMICLen))); } - if (Params::includeStructEnd) + if constexpr (Params::kIncludeStructEnd) { ReturnErrorOnFailure(writer.EndContainer(containerType)); } @@ -741,37 +756,30 @@ static CHIP_ERROR EncodeSigma1(MutableByteSpan & buf) do \ { \ MutableByteSpan buf(mem.Get(), bufferSize); \ - EXPECT_EQ(EncodeSigma1<params>(buf), CHIP_NO_ERROR); \ + EXPECT_EQ(EncodeSigma1Helper<params>(buf), CHIP_NO_ERROR); \ \ TLV::ContiguousBufferTLVReader reader; \ reader.Init(buf); \ + CASESessionAccess session; \ + CASESessionAccess::ParsedSigma1 parsedSigma1; \ \ - ByteSpan initiatorRandom; \ - uint16_t initiatorSessionId; \ - ByteSpan destinationId; \ - ByteSpan initiatorEphPubKey; \ - bool resumptionRequested; \ - ByteSpan resumptionId; \ - ByteSpan initiatorResumeMIC; \ - CASESession session; \ - EXPECT_EQ(session.ParseSigma1(reader, initiatorRandom, initiatorSessionId, destinationId, initiatorEphPubKey, \ - resumptionRequested, resumptionId, initiatorResumeMIC) == CHIP_NO_ERROR, \ - params::expectSuccess); \ - if (params::expectSuccess) \ + EXPECT_EQ(session.ParseSigma1(reader, parsedSigma1) == CHIP_NO_ERROR, params::kExpectSuccess); \ + if (params::kExpectSuccess) \ { \ - EXPECT_EQ(resumptionRequested, params::resumptionIdLen != 0 && params::initiatorResumeMICLen != 0); \ + EXPECT_EQ(parsedSigma1.sessionResumptionRequested, \ + params::kResumptionIdLen != 0 && params::kInitiatorResumeMICLen != 0); \ /* Add other verification tests here as desired */ \ } \ } while (0) struct BadSigma1ParamsBase : public Sigma1Params { - static constexpr bool expectSuccess = false; + static constexpr bool kExpectSuccess = false; }; struct Sigma1NoStructEnd : public BadSigma1ParamsBase { - static constexpr bool includeStructEnd = false; + static constexpr bool kIncludeStructEnd = false; }; struct Sigma1WrongTags : public BadSigma1ParamsBase @@ -781,72 +789,72 @@ struct Sigma1WrongTags : public BadSigma1ParamsBase struct Sigma1TooLongRandom : public BadSigma1ParamsBase { - static constexpr size_t initiatorRandomLen = 33; + static constexpr size_t kInitiatorRandomLen = 33; }; struct Sigma1TooShortRandom : public BadSigma1ParamsBase { - static constexpr size_t initiatorRandomLen = 31; + static constexpr size_t kInitiatorRandomLen = 31; }; struct Sigma1TooLongDest : public BadSigma1ParamsBase { - static constexpr size_t destinationIdLen = 33; + static constexpr size_t kDestinationIdLen = 33; }; struct Sigma1TooShortDest : public BadSigma1ParamsBase { - static constexpr size_t destinationIdLen = 31; + static constexpr size_t kDestinationIdLen = 31; }; struct Sigma1TooLongPubkey : public BadSigma1ParamsBase { - static constexpr size_t initiatorEphPubKeyLen = 66; + static constexpr size_t kInitiatorEphPubKeyLen = 66; }; struct Sigma1TooShortPubkey : public BadSigma1ParamsBase { - static constexpr size_t initiatorEphPubKeyLen = 64; + static constexpr size_t kInitiatorEphPubKeyLen = 64; }; struct Sigma1WithResumption : public Sigma1Params { - static constexpr size_t resumptionIdLen = 16; - static constexpr size_t initiatorResumeMICLen = 16; + static constexpr size_t kResumptionIdLen = 16; + static constexpr size_t kInitiatorResumeMICLen = 16; }; struct Sigma1TooLongResumptionId : public Sigma1WithResumption { - static constexpr size_t resumptionIdLen = 17; - static constexpr bool expectSuccess = false; + static constexpr size_t kResumptionIdLen = 17; + static constexpr bool kExpectSuccess = false; }; struct Sigma1TooShortResumptionId : public BadSigma1ParamsBase { - static constexpr size_t resumptionIdLen = 15; - static constexpr bool expectSuccess = false; + static constexpr size_t kResumptionIdLen = 15; + static constexpr bool kExpectSuccess = false; }; struct Sigma1TooLongResumeMIC : public Sigma1WithResumption { - static constexpr size_t resumptionIdLen = 17; - static constexpr bool expectSuccess = false; + static constexpr size_t kResumptionIdLen = 17; + static constexpr bool kExpectSuccess = false; }; struct Sigma1TooShortResumeMIC : public Sigma1WithResumption { - static constexpr size_t initiatorResumeMICLen = 15; - static constexpr bool expectSuccess = false; + static constexpr size_t kInitiatorResumeMICLen = 15; + static constexpr bool kExpectSuccess = false; }; struct Sigma1SessionIdMax : public Sigma1Params { - static constexpr uint32_t initiatorSessionId = UINT16_MAX; + static constexpr uint32_t kInitiatorSessionId = UINT16_MAX; }; struct Sigma1SessionIdTooBig : public BadSigma1ParamsBase { - static constexpr uint32_t initiatorSessionId = UINT16_MAX + 1; + static constexpr uint32_t kInitiatorSessionId = UINT16_MAX + 1; }; TEST_F(TestCASESession, Sigma1ParsingTest) @@ -874,6 +882,231 @@ TEST_F(TestCASESession, Sigma1ParsingTest) TestSigma1Parsing(mem, bufferSize, Sigma1SessionIdTooBig); } +TEST_F(TestCASESession, EncodeSigma1Test) +{ + CASESessionAccess::EncodeSigma1Inputs encodeParams; + + uint8_t random[32]; + EXPECT_EQ(chip::Crypto::DRBG_get_bytes(&random[0], sizeof(random)), CHIP_NO_ERROR); + encodeParams.initiatorRandom = ByteSpan(random); + encodeParams.initiatorSessionId = 7315; + uint8_t destinationId[32] = { 0xDE, 0xAD }; + encodeParams.destinationId = ByteSpan(destinationId); + + ReliableMessageProtocolConfig mrpConfig = GetDefaultMRPConfig(); + encodeParams.initiatorMrpConfig = &mrpConfig; + + { + System::PacketBufferHandle msg; + // EncodeSigma1 should fail when there is no public key + EXPECT_EQ(CHIP_ERROR_INVALID_ARGUMENT, CASESessionAccess::EncodeSigma1(msg, encodeParams)); + } + + Crypto::P256Keypair * ephemeralKey = gDeviceOperationalKeystore.AllocateEphemeralKeypairForCASE(); + ASSERT_NE(ephemeralKey, nullptr); + EXPECT_EQ(CHIP_NO_ERROR, ephemeralKey->Initialize(ECPKeyTarget::ECDH)); + encodeParams.initiatorEphPubKey = &ephemeralKey->Pubkey(); + + { + System::PacketBufferHandle msg; + // EncodeSigma1 will Succeed when Public Key is provided + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma1(msg, encodeParams)); + } + + { + System::PacketBufferHandle msg; + // EncodeSigma1 should fail when MRP config is missing + encodeParams.initiatorMrpConfig = nullptr; + EXPECT_EQ(CHIP_ERROR_INVALID_ARGUMENT, CASESessionAccess::EncodeSigma1(msg, encodeParams)); + } + + { + System::PacketBufferHandle msg; + // Succeed when MRP Config is provided + encodeParams.initiatorMrpConfig = &mrpConfig; + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma1(msg, encodeParams)); + } + + { + System::PacketBufferHandle msg1; + + // Round Trip Test: Encode Sigma1, Parse it then verify parsed values + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma1(msg1, encodeParams)); + + System::PacketBufferTLVReader tlvReader; + tlvReader.Init(std::move(msg1)); + + CASESessionAccess session; + CASESessionAccess::ParsedSigma1 parsedMessage; + + EXPECT_EQ(CHIP_NO_ERROR, session.ParseSigma1(tlvReader, parsedMessage)); + + // compare parsed values with original values + EXPECT_TRUE(parsedMessage.initiatorRandom.data_equal(encodeParams.initiatorRandom)); + EXPECT_EQ(parsedMessage.initiatorSessionId, encodeParams.initiatorSessionId); + EXPECT_TRUE(parsedMessage.destinationId.data_equal(encodeParams.destinationId)); + EXPECT_TRUE(parsedMessage.initiatorEphPubKey.data_equal( + ByteSpan(encodeParams.initiatorEphPubKey->ConstBytes(), encodeParams.initiatorEphPubKey->Length()))); + } + + { + // Round Trip Test: Sigma1 with Session Resumption + // Encode Sigma1 with Resumption, parse it and and verify with original values + chip::SessionResumptionStorage::ResumptionIdStorage resumptionId; + + EXPECT_EQ(chip::Crypto::DRBG_get_bytes(resumptionId.data(), resumptionId.size()), CHIP_NO_ERROR); + EXPECT_EQ(chip::Crypto::DRBG_get_bytes(&encodeParams.initiatorResume1MICBuffer[0], + sizeof(encodeParams.initiatorResume1MICBuffer)), + CHIP_NO_ERROR); + + encodeParams.resumptionId = ByteSpan(resumptionId.data(), resumptionId.size()); + encodeParams.initiatorResumeMIC = ByteSpan(encodeParams.initiatorResume1MICBuffer); + encodeParams.sessionResumptionRequested = true; + + System::PacketBufferHandle msg2; + + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma1(msg2, encodeParams)); + + System::PacketBufferTLVReader tlvReader; + tlvReader.Init(std::move(msg2)); + + CASESessionAccess session; + CASESessionAccess::ParsedSigma1 parsedMessage; + + EXPECT_EQ(CHIP_NO_ERROR, session.ParseSigma1(tlvReader, parsedMessage)); + + // RoundTrip + EXPECT_TRUE(parsedMessage.initiatorRandom.data_equal(encodeParams.initiatorRandom)); + EXPECT_EQ(parsedMessage.initiatorSessionId, encodeParams.initiatorSessionId); + EXPECT_TRUE(parsedMessage.destinationId.data_equal(encodeParams.destinationId)); + EXPECT_TRUE(parsedMessage.initiatorEphPubKey.data_equal( + ByteSpan(encodeParams.initiatorEphPubKey->ConstBytes(), encodeParams.initiatorEphPubKey->Length()))); + + EXPECT_TRUE(parsedMessage.resumptionId.data_equal(encodeParams.resumptionId)); + EXPECT_TRUE(parsedMessage.initiatorResumeMIC.data_equal(encodeParams.initiatorResumeMIC)); + EXPECT_TRUE(parsedMessage.initiatorMrpParamsPresent); + } + // Release EphemeralKeyPair + gDeviceOperationalKeystore.ReleaseEphemeralKeypair(ephemeralKey); +} + +TEST_F(TestCASESession, EncodeSigma2Test) +{ + CASESessionAccess::EncodeSigma2Inputs encodeParams; + constexpr uint8_t kEncrypted2datalen = 100U; + + EXPECT_EQ(chip::Crypto::DRBG_get_bytes(&encodeParams.responderRandom[0], sizeof(encodeParams.responderRandom)), CHIP_NO_ERROR); + encodeParams.responderSessionId = 7315; + + // Generate Ephemeral Public Key + Crypto::P256Keypair * ephemeralKey = gDeviceOperationalKeystore.AllocateEphemeralKeypairForCASE(); + ASSERT_NE(ephemeralKey, nullptr); + EXPECT_EQ(CHIP_NO_ERROR, ephemeralKey->Initialize(ECPKeyTarget::ECDH)); + encodeParams.responderEphPubKey = &ephemeralKey->Pubkey(); + + // TBEData2Encrypted + encodeParams.encrypted2Length = kEncrypted2datalen + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES; + encodeParams.msgR2Encrypted.Alloc(encodeParams.encrypted2Length); + + // responder Session Parameters + ReliableMessageProtocolConfig mrpConfig = GetDefaultMRPConfig(); + encodeParams.responderMrpConfig = &mrpConfig; + + { + System::PacketBufferHandle msg; + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma2(msg, encodeParams)); + // EncodeSigma2 frees msgR2Encrypted after encoding it + encodeParams.msgR2Encrypted.Alloc(encodeParams.encrypted2Length); + } + + { + System::PacketBufferHandle msg; + + // EncodeSigma2 should fail when there is no public key + encodeParams.responderEphPubKey = nullptr; + EXPECT_EQ(CHIP_ERROR_INVALID_ARGUMENT, CASESessionAccess::EncodeSigma2(msg, encodeParams)); + } + + encodeParams.responderEphPubKey = &ephemeralKey->Pubkey(); + + { + System::PacketBufferHandle msg; + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma2(msg, encodeParams)); + // EncodeSigma2 frees msgR2Encrypted after encoding it + encodeParams.msgR2Encrypted.Alloc(encodeParams.encrypted2Length); + } + + { + System::PacketBufferHandle msg; + // EncodeSigma2 should fail when TBEData2Encrypted is not allocated + encodeParams.msgR2Encrypted.Free(); + EXPECT_EQ(CHIP_ERROR_INCORRECT_STATE, CASESessionAccess::EncodeSigma2(msg, encodeParams)); + } + + encodeParams.msgR2Encrypted.Alloc(encodeParams.encrypted2Length); + { + System::PacketBufferHandle msg; + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma2(msg, encodeParams)); + // EncodeSigma2 frees msgR2Encrypted after encoding it + encodeParams.msgR2Encrypted.Alloc(encodeParams.encrypted2Length); + } + + { + System::PacketBufferHandle msg; + // EncodeSigma2 should fail when the encrypted2Length is not set + encodeParams.encrypted2Length = 0; + EXPECT_EQ(CHIP_ERROR_INCORRECT_STATE, CASESessionAccess::EncodeSigma2(msg, encodeParams)); + } + // Set encrypted2Length again + encodeParams.encrypted2Length = kEncrypted2datalen + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES; + + { + System::PacketBufferHandle msg; + // EncodeSigma2 should fail when MRP config is missing + encodeParams.responderMrpConfig = nullptr; + EXPECT_EQ(CHIP_ERROR_INVALID_ARGUMENT, CASESessionAccess::EncodeSigma2(msg, encodeParams)); + } + + { + System::PacketBufferHandle msg; + // Succeed when MRP Config is provided + encodeParams.responderMrpConfig = &mrpConfig; + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma2(msg, encodeParams)); + } + // Release EphemeralKeyPair + gDeviceOperationalKeystore.ReleaseEphemeralKeypair(ephemeralKey); +} + +TEST_F(TestCASESession, EncodeSigma2ResumeTest) +{ + CASESessionAccess::EncodeSigma2ResumeInputs encodeParams; + + encodeParams.responderSessionId = 7315; + + // responder Session Parameters + ReliableMessageProtocolConfig mrpConfig = GetDefaultMRPConfig(); + encodeParams.responderMrpConfig = &mrpConfig; + + { + System::PacketBufferHandle msg; + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma2Resume(msg, encodeParams)); + } + + { + System::PacketBufferHandle msg; + // EncodeSigma2Resume should fail when MRP config is missing + encodeParams.responderMrpConfig = nullptr; + EXPECT_EQ(CHIP_ERROR_INVALID_ARGUMENT, CASESessionAccess::EncodeSigma2Resume(msg, encodeParams)); + } + + { + System::PacketBufferHandle msg; + // Succeed when MRP Config is provided + encodeParams.responderMrpConfig = &mrpConfig; + EXPECT_EQ(CHIP_NO_ERROR, CASESessionAccess::EncodeSigma2Resume(msg, encodeParams)); + } +} + struct SessionResumptionTestStorage : SessionResumptionStorage { SessionResumptionTestStorage(CHIP_ERROR findMethodReturnCode, ScopedNodeId peerNodeId, ResumptionIdStorage * resumptionId, @@ -1126,7 +1359,7 @@ TEST_F(TestCASESession, Sigma1BadDestinationIdTest) MutableByteSpan buf(data->Start(), data->AvailableDataLength()); // This uses a bogus destination id that is not going to match anything in practice. - EXPECT_EQ(EncodeSigma1<Sigma1Params>(buf), CHIP_NO_ERROR); + EXPECT_EQ(EncodeSigma1Helper<Sigma1Params>(buf), CHIP_NO_ERROR); data->SetDataLength(static_cast<uint16_t>(buf.size())); Optional<SessionHandle> session = sessionManager.CreateUnauthenticatedSession(GetAliceAddress(), GetDefaultMRPConfig());