diff --git a/library-vcpkg/CMakeLists.txt b/library-vcpkg/CMakeLists.txt index dded1b52a5..1b79ccbe1a 100644 --- a/library-vcpkg/CMakeLists.txt +++ b/library-vcpkg/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB THE_SOURCES "${DPP_ROOT_PATH}/src/dpp/events/*.cpp" "${DPP_ROOT_PATH}/src/dpp/cluster/*.cpp" "${DPP_ROOT_PATH}/src/dpp/*.cpp" "${DPP_ROOT_PATH}/src/dpp/*.rc") +file(GLOB THE_SOURCES "${DPP_ROOT_PATH}/src/dpp/events/*.cpp" "${DPP_ROOT_PATH}/dpp/dave/*.cpp" "${DPP_ROOT_PATH}/src/dpp/cluster/*.cpp" "${DPP_ROOT_PATH}/src/dpp/*.cpp" "${DPP_ROOT_PATH}/src/dpp/*.rc") set(LIB_NAME "${PROJECT_NAME}") @@ -40,8 +40,8 @@ target_compile_options( PRIVATE "$<$:$<$:/sdl;/Od;/DEBUG;/MP;/DFD_SETSIZE=1024>>" "$<$:$<$:/O2;/Oi;/Oy;/GL;/Gy;/sdl;/MP;/DFD_SETSIZE=1024>>" - "$<$:$<$:-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-g;-Og;-fPIC>>" - "$<$:$<$:-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-O3;-fPIC>>" + "$<$:$<$:-fPIC;-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-g;-Og;-fPIC>>" + "$<$:$<$:-fPIC;-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-O3;-fPIC>>" "${AVX_FLAG}" ) @@ -60,6 +60,11 @@ target_include_directories( "$" ) +add_subdirectory("${DPP_ROOT_PATH}/mlspp" "mlspp") +include_directories("${DPP_ROOT_PATH}/mlspp/include") +include_directories("${DPP_ROOT_PATH}/mlspp/lib/bytes/include") +include_directories("${DPP_ROOT_PATH}/mlspp/lib/hpke/include") + set_target_properties( "${LIB_NAME}" PROPERTIES OUTPUT_NAME "dpp" @@ -90,6 +95,16 @@ target_link_libraries( $<$:Threads::Threads> ) +# Private statically linked dependencies +target_link_libraries( + ${LIB_NAME} PRIVATE + mlspp + mls_vectors + bytes + tls_syntax + hpke +) + set(CONFIG_FILE_NAME "${PROJECT_NAME}Config.cmake") set(EXPORTED_TARGETS_NAME "${PROJECT_NAME}Targets") set(EXPORTED_TARGETS_FILE_NAME "${EXPORTED_TARGETS_NAME}.cmake") diff --git a/library/CMakeLists.txt b/library/CMakeLists.txt index 00b2a8b651..4e67d19eff 100644 --- a/library/CMakeLists.txt +++ b/library/CMakeLists.txt @@ -197,7 +197,7 @@ if(MSVC) endif() string(REGEX REPLACE "/W[1|2|3|4]" "/W3" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-unused-private-field -Wno-psabi -Wempty-body -Wignored-qualifiers -Wimplicit-fallthrough -Wmissing-field-initializers -Wsign-compare -Wtype-limits -Wuninitialized -Wshift-negative-value -pthread") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall -Wno-unused-private-field -Wno-psabi -Wempty-body -Wignored-qualifiers -Wimplicit-fallthrough -Wmissing-field-initializers -Wsign-compare -Wtype-limits -Wuninitialized -Wshift-negative-value -pthread") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Og") if (NOT MINGW) @@ -209,11 +209,16 @@ set(modules_dir "../src") file(GLOB subdirlist ${modules_dir}/dpp) +add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../mlspp" "mlspp") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../mlspp/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../mlspp/lib/bytes/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../mlspp/lib/hpke/include") + foreach (fullmodname ${subdirlist}) get_filename_component(modname ${fullmodname} NAME) set (modsrc "") - - file(GLOB modsrc "${modules_dir}/dpp/*.cpp" "${modules_dir}/dpp/events/*.cpp" "${modules_dir}/dpp/cluster/*.cpp" "${modules_dir}/dpp/*.rc") + + file(GLOB modsrc "${modules_dir}/dpp/*.cpp" "${modules_dir}/dpp/dave/*.cpp" "${modules_dir}/dpp/events/*.cpp" "${modules_dir}/dpp/cluster/*.cpp" "${modules_dir}/dpp/*.rc") if(BUILD_SHARED_LIBS) add_library(${modname} SHARED ${modsrc}) @@ -241,6 +246,7 @@ foreach (fullmodname ${subdirlist}) endif() if (WIN32 AND NOT MINGW) + if (NOT WINDOWS_32_BIT) target_link_libraries(${modname} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/../win32/lib/libssl.lib" "${CMAKE_CURRENT_SOURCE_DIR}/../win32/lib/libcrypto.lib" @@ -262,6 +268,15 @@ foreach (fullmodname ${subdirlist}) endif () endif() + # Private statically linked dependencies + target_link_libraries(${modname} PRIVATE + mlspp + mls_vectors + bytes + tls_syntax + hpke + ) + if (HAVE_VOICE) target_link_libraries(${modname} PUBLIC ${sodium_LIBRARY_RELEASE} ${OPUS_LIBRARIES}) @@ -269,6 +284,8 @@ foreach (fullmodname ${subdirlist}) endif() endforeach() + + target_compile_features(dpp PUBLIC cxx_std_17) target_compile_features(dpp PRIVATE cxx_constexpr) target_compile_features(dpp PRIVATE cxx_auto_type) diff --git a/mlspp/CMakeLists.txt b/mlspp/CMakeLists.txt new file mode 100755 index 0000000000..895e62db2e --- /dev/null +++ b/mlspp/CMakeLists.txt @@ -0,0 +1,196 @@ +cmake_minimum_required(VERSION 3.13) + +project(mlspp + VERSION 0.1 + LANGUAGES CXX +) + +option(TESTING "Build tests" OFF) +option(CLANG_TIDY "Perform linting with clang-tidy" OFF) +option(SANITIZERS "Enable sanitizers" OFF) +option(MLS_NAMESPACE_SUFFIX "Namespace Suffix for CXX and CMake Export") +option(DISABLE_GREASE "Disables the inclusion of MLS protocol recommended GREASE values" OFF) +option(REQUIRE_BORINGSSL "Require BoringSSL instead of OpenSSL" OFF) + +if(MLS_NAMESPACE_SUFFIX) + set(MLS_CXX_NAMESPACE "mls_${MLS_NAMESPACE_SUFFIX}" CACHE STRING "Top-level Namespace for CXX") + set(MLS_EXPORT_NAMESPACE "MLSPP${MLS_NAMESPACE_SUFFIX}" CACHE STRING "Namespace for CMake Export") +else() + set(MLS_CXX_NAMESPACE "../include/dpp/mlspp/mls" CACHE STRING "Top-level Namespace for CXX") + set(MLS_EXPORT_NAMESPACE "MLSPP" CACHE STRING "Namespace for CMake Export") +endif() +message(STATUS "CXX Namespace: ${MLS_CXX_NAMESPACE}") +message(STATUS "CMake Export Namespace: ${MLS_EXPORT_NAMESPACE}") + + +### +### Global Config +### +set_property(GLOBAL PROPERTY USE_FOLDERS ON) + +configure_file( + "cmake/namespace.h.in" + "${CMAKE_CURRENT_SOURCE_DIR}/include/namespace.h" + @ONLY +) + +include(CheckCXXCompilerFlag) +include(CMakePackageConfigHelpers) +include(GNUInstallDirs) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU") + add_compile_options(-Wall -pedantic -Wextra -Werror -Wmissing-declarations) +elseif(MSVC) + add_compile_options(/W4 /WX) + add_definitions(-DWINDOWS) + + # MSVC helpfully recommends safer equivalents for things like + # getenv, but they are not portable. + add_definitions(-D_CRT_SECURE_NO_WARNINGS) +endif() + +if("$ENV{MACOSX_DEPLOYMENT_TARGET}" STREQUAL "10.11") + add_compile_options(-DVARIANT_COMPAT) +endif() + +if (DISABLE_GREASE) + add_compile_options(-DDISABLE_GREASE) +endif () + +### +### Dependencies +### + +# Configure vcpkg to only build release libraries +set(VCPKG_BUILD_TYPE release) + +# External libraries +find_package(OpenSSL REQUIRED) +if ( OPENSSL_FOUND ) + find_path(BORINGSSL_INCLUDE_DIR openssl/is_boringssl.h HINTS ${OPENSSL_INCLUDE_DIR} NO_DEFAULT_PATH) + + if (BORINGSSL_INCLUDE_DIR) + message(STATUS "Found OpenSSL includes are for BoringSSL") + + add_compile_definitions(WITH_BORINGSSL) + + if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU") + add_compile_options(-Wno-gnu-anonymous-struct -Wno-nested-anon-types) + endif () + + file(STRINGS "${OPENSSL_INCLUDE_DIR}/openssl/crypto.h" boringssl_version_str + REGEX "^#[\t ]*define[\t ]+OPENSSL_VERSION_TEXT[\t ]+\"OpenSSL ([0-9])+\\.([0-9])+\\.([0-9])+ .+") + + string(REGEX REPLACE "^.*OPENSSL_VERSION_TEXT[\t ]+\"OpenSSL ([0-9]+\\.[0-9]+\\.[0-9])+ .+$" + "\\1" OPENSSL_VERSION "${boringssl_version_str}") + + elseif (REQUIRE_BORINGSSL) + message(FATAL_ERROR "BoringSSL required but not found") + endif () + + if (${OPENSSL_VERSION} VERSION_GREATER_EQUAL 3) + add_compile_definitions(WITH_OPENSSL3) + elseif(${OPENSSL_VERSION} VERSION_LESS 1.1.1) + message(FATAL_ERROR "OpenSSL 1.1.1 or greater is required") + endif() + message(STATUS "OpenSSL Found: ${OPENSSL_VERSION}") + message(STATUS "OpenSSL Include: ${OPENSSL_INCLUDE_DIR}") + message(STATUS "OpenSSL Libraries: ${OPENSSL_LIBRARIES}") +else() + message(FATAL_ERROR "No OpenSSL library found") +endif() + +# Internal libraries +add_subdirectory(lib) + +# Third-Party libraries in tree +add_subdirectory(third_party) + + +### +### Library Config +### + +set(LIB_NAME "${PROJECT_NAME}") + +file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h") +file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") + +add_library(${LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES}) +add_dependencies(${LIB_NAME} bytes tls_syntax hpke) +target_link_libraries(${LIB_NAME} bytes tls_syntax hpke) +target_include_directories(${LIB_NAME} + PUBLIC + $ + $ + PRIVATE + ${OPENSSL_INCLUDE_DIR} +) + +install(TARGETS ${LIB_NAME} EXPORT mlspp-targets) + + +### +### Exports +### +set(CMAKE_EXPORT_PACKAGE_REGISTRY ON) +export( + EXPORT + mlspp-targets + NAMESPACE + ${MLS_EXPORT_NAMESPACE}:: + FILE + ${MLS_EXPORT_NAMESPACE}Targets.cmake) +export(PACKAGE ${MLS_EXPORT_NAMESPACE}) + +configure_package_config_file(cmake/config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/${MLS_EXPORT_NAMESPACE}Config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_DATADIR}/${MLS_EXPORT_NAMESPACE} + NO_SET_AND_CHECK_MACRO) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/${MLS_EXPORT_NAMESPACE}ConfigVersion.cmake + VERSION ${PROJECT_VERSION} + COMPATIBILITY SameMajorVersion) + +### +### Install +### + +install( + EXPORT + mlspp-targets + NAMESPACE + ${MLS_EXPORT_NAMESPACE}:: + FILE + ${MLS_EXPORT_NAMESPACE}Targets.cmake + DESTINATION + ${CMAKE_INSTALL_DATADIR}/${MLS_EXPORT_NAMESPACE}) + +install( + DIRECTORY + include + DESTINATION + ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}) + +install( + FILES + ${CMAKE_CURRENT_BINARY_DIR}/${MLS_EXPORT_NAMESPACE}Config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/${MLS_EXPORT_NAMESPACE}ConfigVersion.cmake + DESTINATION + ${CMAKE_INSTALL_DATADIR}/${MLS_EXPORT_NAMESPACE}) + +install( + FILES + LICENSE + DESTINATION + ${CMAKE_INSTALL_DATADIR}/${PROJECT_NAME} + RENAME + copyright) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/bytes/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/hpke/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/mls_vectors/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/tls_syntax/include") diff --git a/mlspp/LICENSE b/mlspp/LICENSE new file mode 100755 index 0000000000..044f922aa7 --- /dev/null +++ b/mlspp/LICENSE @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2018, Cisco Systems +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/mlspp/cmake/config.cmake.in b/mlspp/cmake/config.cmake.in new file mode 100755 index 0000000000..82145af779 --- /dev/null +++ b/mlspp/cmake/config.cmake.in @@ -0,0 +1,4 @@ +@PACKAGE_INIT@ + +include(${CMAKE_CURRENT_LIST_DIR}/@MLS_EXPORT_NAMESPACE@Targets.cmake) +check_required_components(mlspp) diff --git a/mlspp/cmake/namespace.h.in b/mlspp/cmake/namespace.h.in new file mode 100755 index 0000000000..79dfad2cf5 --- /dev/null +++ b/mlspp/cmake/namespace.h.in @@ -0,0 +1,4 @@ +#pragma once + +// Configurable top-level MLS namespace +#define MLS_NAMESPACE @MLS_CXX_NAMESPACE@ diff --git a/mlspp/include/mls/common.h b/mlspp/include/mls/common.h new file mode 100755 index 0000000000..a32f9ef730 --- /dev/null +++ b/mlspp/include/mls/common.h @@ -0,0 +1,274 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +using namespace std::literals::string_literals; + +// Expose the bytes library globally +#include +using namespace mlspp::bytes_ns; + +// Expose the compatibility library globally +#include +namespace var = mlspp::tls::var; +namespace opt = mlspp::tls::opt; + +namespace mlspp { + +// Make variant equality work in the same way as optional equality, with +// automatic unwrapping. In other words +// +// v == T(x) <=> hold_alternative(v) && get(v) == x +// +// For consistency, we also define symmetric and negated version. In this +// house, we obey the symmetric law of equivalence relations! +template +bool +operator==(const var::variant& v, const T& t) +{ + return var::visit( + [&](const auto& arg) { + using U = std::decay_t; + if constexpr (std::is_same_v) { + return arg == t; + } else { + return false; + } + }, + v); +} + +template +bool +operator==(const T& t, const var::variant& v) +{ + return v == t; +} + +template +bool +operator!=(const var::variant& v, const T& t) +{ + return !(v == t); +} + +template +bool +operator!=(const T& t, const var::variant& v) +{ + return !(v == t); +} + +using epoch_t = uint64_t; + +/// +/// Get the current system clock time in the format MLS expects +/// + +uint64_t +seconds_since_epoch(); + +/// +/// Easy construction of overloaded lambdas +/// + +template +struct overloaded : Ts... +{ + using Ts::operator()...; + + // XXX(RLB) MSVC has a bug where it incorrectly computes the size of this + // type. Microsoft claims they have fixed it in the latest MSVC, and GitHub + // claims they are running a version with the fix. But in practice, we still + // hit it. Including this dummy variable is a work-around. + // + // https://developercommunity.visualstudio.com/t/runtime-stack-corruption-using-stdvisit/346200 + int dummy = 0; +}; + +// clang-format off +// XXX(RLB): For some reason, different versions of clang-format disagree on how +// this should be formatted. Probably because it's new syntax with C++17? +// Exempting it from clang-format for now. +template overloaded(Ts...) -> overloaded; +// clang-format on + +/// +/// Auto-generate equality and inequality operators for TLS-serializable things +/// + +template +inline typename std::enable_if::type +operator==(const T& lhs, const T& rhs) +{ + return lhs._tls_fields_w() == rhs._tls_fields_w(); +} + +template +inline typename std::enable_if::type +operator!=(const T& lhs, const T& rhs) +{ + return lhs._tls_fields_w() != rhs._tls_fields_w(); +} + +/// +/// Error types +/// + +// The `using parent = X` / `using parent::parent` construction here +// imports the constructors of the parent. + +class NotImplementedError : public std::exception +{ +public: + using parent = std::exception; + using parent::parent; +}; + +class ProtocolError : public std::runtime_error +{ +public: + using parent = std::runtime_error; + using parent::parent; +}; + +class IncompatibleNodesError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +class InvalidParameterError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +class InvalidPathError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +class InvalidIndexError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +class InvalidMessageTypeError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +class MissingNodeError : public std::out_of_range +{ +public: + using parent = std::out_of_range; + using parent::parent; +}; + +class MissingStateError : public std::out_of_range +{ +public: + using parent = std::out_of_range; + using parent::parent; +}; + +// A slightly more elegant way to silence -Werror=unused-variable +template +void +silence_unused(const T& val) +{ + (void)val; +} + +namespace stdx { + +// XXX(RLB) This method takes any container in, but always puts the resuls in +// std::vector. The output could be made generic with a Rust-like syntax, +// defining a PendingTransform object that caches the inputs, with a template +// `collect()` method that puts them in an output container. Which makes the +// calling syntax as follows: +// +// auto out = stdx::transform(in, f).collect(); +// +// (You always need the explicit specialization, even if assigning it to an +// explicitly typed variable, because C++ won't infer return types.) +// +// Given that the above syntax is pretty chatty, and we never need anything +// other than vectors here anyway, I have left this as-is. +template +std::vector +transform(const Container& c, const UnaryOperation& op) +{ + auto out = std::vector{}; + auto ins = std::inserter(out, out.begin()); + std::transform(c.begin(), c.end(), ins, op); + return out; +} + +template +bool +any_of(const Container& c, const UnaryPredicate& pred) +{ + return std::any_of(c.begin(), c.end(), pred); +} + +template +bool +all_of(const Container& c, const UnaryPredicate& pred) +{ + return std::all_of(c.begin(), c.end(), pred); +} + +template +auto +count_if(const Container& c, const UnaryPredicate& pred) +{ + return std::count_if(c.begin(), c.end(), pred); +} + +template +bool +contains(const Container& c, const Value& val) +{ + return std::find(c.begin(), c.end(), val) != c.end(); +} + +template +auto +find_if(Container& c, const UnaryPredicate& pred) +{ + return std::find_if(c.begin(), c.end(), pred); +} + +template +auto +find_if(const Container& c, const UnaryPredicate& pred) +{ + return std::find_if(c.begin(), c.end(), pred); +} + +template +auto +upper_bound(const Container& c, const Value& val) +{ + return std::upper_bound(c.begin(), c.end(), val); +} + +} // namespace stdx + +} // namespace mlspp diff --git a/mlspp/include/mls/core_types.h b/mlspp/include/mls/core_types.h new file mode 100755 index 0000000000..6c591c0623 --- /dev/null +++ b/mlspp/include/mls/core_types.h @@ -0,0 +1,380 @@ +#pragma once + +#include "mls/credential.h" +#include "mls/crypto.h" +#include "mls/tree_math.h" + +namespace mlspp { + +// enum { +// reserved(0), +// mls10(1), +// (255) +// } ProtocolVersion; +enum class ProtocolVersion : uint16_t +{ + mls10 = 0x01, +}; + +extern const std::array all_supported_versions; + +// struct { +// ExtensionType extension_type; +// opaque extension_data; +// } Extension; +struct Extension +{ + using Type = uint16_t; + + Type type; + bytes data; + + TLS_SERIALIZABLE(type, data) +}; + +struct ExtensionType +{ + static constexpr Extension::Type application_id = 1; + static constexpr Extension::Type ratchet_tree = 2; + static constexpr Extension::Type required_capabilities = 3; + static constexpr Extension::Type external_pub = 4; + static constexpr Extension::Type external_senders = 5; + + // XXX(RLB) There is no IANA-registered type for this extension yet, so we use + // a value from the vendor-specific space + static constexpr Extension::Type sframe_parameters = 0xff02; +}; + +struct ExtensionList +{ + std::vector extensions; + + // XXX(RLB) It would be good if this maintained extensions in order. It might + // be possible to do this automatically by changing the storage to a + // map and extending the TLS code to marshal that type. + template + inline void add(const T& obj) + { + auto data = tls::marshal(obj); + add(T::type, std::move(data)); + } + + void add(Extension::Type type, bytes data); + + template + std::optional find() const + { + for (const auto& ext : extensions) { + if (ext.type == T::type) { + return tls::get(ext.data); + } + } + + return std::nullopt; + } + + bool has(uint16_t type) const; + + TLS_SERIALIZABLE(extensions) +}; + +// enum { +// reserved(0), +// key_package(1), +// update(2), +// commit(3), +// (255) +// } LeafNodeSource; +enum struct LeafNodeSource : uint8_t +{ + key_package = 1, + update = 2, + commit = 3, +}; + +// struct { +// ProtocolVersion versions; +// CipherSuite ciphersuites; +// ExtensionType extensions; +// ProposalType proposals; +// CredentialType credentials; +// } Capabilities; +struct Capabilities +{ + std::vector versions; + std::vector cipher_suites; + std::vector extensions; + std::vector proposals; + std::vector credentials; + + static Capabilities create_default(); + bool extensions_supported(const std::vector& required) const; + bool proposals_supported(const std::vector& required) const; + bool credential_supported(const Credential& credential) const; + + template + bool credentials_supported(const Container& required) const + { + return stdx::all_of(required, [&](CredentialType type) { + return stdx::contains(credentials, type); + }); + } + + TLS_SERIALIZABLE(versions, cipher_suites, extensions, proposals, credentials) +}; + +// struct { +// uint64 not_before; +// uint64 not_after; +// } Lifetime; +struct Lifetime +{ + uint64_t not_before; + uint64_t not_after; + + static Lifetime create_default(); + + TLS_SERIALIZABLE(not_before, not_after) +}; + +// struct { +// HPKEPublicKey encryption_key; +// SignaturePublicKey signature_key; +// Credential credential; +// Capabilities capabilities; +// +// LeafNodeSource leaf_node_source; +// select (leaf_node_source) { +// case add: +// Lifetime lifetime; +// +// case update: +// struct {} +// +// case commit: +// opaque parent_hash; +// } +// +// Extension extensions; +// // SignWithLabel(., "LeafNodeTBS", LeafNodeTBS) +// opaque signature; +// } LeafNode; +struct Empty +{ + TLS_SERIALIZABLE() +}; + +struct ParentHash +{ + bytes parent_hash; + TLS_SERIALIZABLE(parent_hash); +}; + +struct LeafNodeOptions +{ + std::optional credential; + std::optional capabilities; + std::optional extensions; +}; + +// TODO Move this to treekem.h +struct LeafNode +{ + HPKEPublicKey encryption_key; + SignaturePublicKey signature_key; + Credential credential; + Capabilities capabilities; + + var::variant content; + + ExtensionList extensions; + bytes signature; + + LeafNode() = default; + LeafNode(const LeafNode&) = default; + LeafNode(LeafNode&&) = default; + LeafNode& operator=(const LeafNode&) = default; + LeafNode& operator=(LeafNode&&) = default; + + LeafNode(CipherSuite cipher_suite, + HPKEPublicKey encryption_key_in, + SignaturePublicKey signature_key_in, + Credential credential_in, + Capabilities capabilities_in, + Lifetime lifetime_in, + ExtensionList extensions_in, + const SignaturePrivateKey& sig_priv); + + LeafNode for_update(CipherSuite cipher_suite, + const bytes& group_id, + LeafIndex leaf_index, + HPKEPublicKey encryption_key, + const LeafNodeOptions& opts, + const SignaturePrivateKey& sig_priv_in) const; + + LeafNode for_commit(CipherSuite cipher_suite, + const bytes& group_id, + LeafIndex leaf_index, + HPKEPublicKey encryption_key, + const bytes& parent_hash, + const LeafNodeOptions& opts, + const SignaturePrivateKey& sig_priv_in) const; + + void set_capabilities(Capabilities capabilities_in); + + LeafNodeSource source() const; + + struct MemberBinding + { + bytes group_id; + LeafIndex leaf_index; + TLS_SERIALIZABLE(group_id, leaf_index); + }; + + void sign(CipherSuite cipher_suite, + const SignaturePrivateKey& sig_priv, + const std::optional& binding); + bool verify(CipherSuite cipher_suite, + const std::optional& binding) const; + + bool verify_expiry(uint64_t now) const; + bool verify_extension_support(const ExtensionList& ext_list) const; + + TLS_SERIALIZABLE(encryption_key, + signature_key, + credential, + capabilities, + content, + extensions, + signature) + TLS_TRAITS(tls::pass, + tls::pass, + tls::pass, + tls::pass, + tls::variant, + tls::pass, + tls::pass) + +private: + LeafNode clone_with_options(HPKEPublicKey encryption_key, + const LeafNodeOptions& opts) const; + bytes to_be_signed(const std::optional& binding) const; +}; + +// Concrete extension types +struct RequiredCapabilitiesExtension +{ + std::vector extensions; + std::vector proposals; + + static const Extension::Type type; + TLS_SERIALIZABLE(extensions, proposals) +}; + +struct ApplicationIDExtension +{ + bytes id; + + static const Extension::Type type; + TLS_SERIALIZABLE(id) +}; + +/// +/// NodeType, ParentNode, and KeyPackage +/// + +// TODO move this to treekem.h +struct ParentNode +{ + HPKEPublicKey public_key; + bytes parent_hash; + std::vector unmerged_leaves; + + bytes hash(CipherSuite suite) const; + + TLS_SERIALIZABLE(public_key, parent_hash, unmerged_leaves) +}; + +// TODO Move this to messages.h +// struct { +// ProtocolVersion version; +// CipherSuite cipher_suite; +// HPKEPublicKey init_key; +// LeafNode leaf_node; +// Extension extensions; +// // SignWithLabel(., "KeyPackageTBS", KeyPackageTBS) +// opaque signature; +// } KeyPackage; +struct KeyPackage +{ + ProtocolVersion version; + CipherSuite cipher_suite; + HPKEPublicKey init_key; + LeafNode leaf_node; + ExtensionList extensions; + bytes signature; + + KeyPackage(); + KeyPackage(CipherSuite suite_in, + HPKEPublicKey init_key_in, + LeafNode leaf_node_in, + ExtensionList extensions_in, + const SignaturePrivateKey& sig_priv_in); + + KeyPackageRef ref() const; + + void sign(const SignaturePrivateKey& sig_priv); + bool verify() const; + + TLS_SERIALIZABLE(version, + cipher_suite, + init_key, + leaf_node, + extensions, + signature) + +private: + bytes to_be_signed() const; +}; + +/// +/// UpdatePath +/// + +// struct { +// HPKEPublicKey public_key; +// HPKECiphertext encrypted_path_secret; +// } UpdatePathNode; +struct UpdatePathNode +{ + HPKEPublicKey public_key; + std::vector encrypted_path_secret; + + TLS_SERIALIZABLE(public_key, encrypted_path_secret) +}; + +// struct { +// LeafNode leaf_node; +// UpdatePathNode nodes; +// } UpdatePath; +struct UpdatePath +{ + LeafNode leaf_node; + std::vector nodes; + + TLS_SERIALIZABLE(leaf_node, nodes) +}; + +} // namespace mlspp + +namespace mlspp::tls { + +TLS_VARIANT_MAP(mlspp::LeafNodeSource, + mlspp::Lifetime, + key_package) +TLS_VARIANT_MAP(mlspp::LeafNodeSource, mlspp::Empty, update) +TLS_VARIANT_MAP(mlspp::LeafNodeSource, + mlspp::ParentHash, + commit) + +} // namespace mlspp::tls diff --git a/mlspp/include/mls/credential.h b/mlspp/include/mls/credential.h new file mode 100755 index 0000000000..bcf77a36cc --- /dev/null +++ b/mlspp/include/mls/credential.h @@ -0,0 +1,228 @@ +#pragma once + +#include +#include + +namespace mlspp { + +namespace hpke { +struct UserInfoVC; +} + +// struct { +// opaque identity<0..2^16-1>; +// SignaturePublicKey public_key; +// } BasicCredential; +struct BasicCredential +{ + BasicCredential() {} + + BasicCredential(bytes identity_in) + : identity(std::move(identity_in)) + { + } + + bytes identity; + + TLS_SERIALIZABLE(identity) +}; + +struct X509Credential +{ + struct CertData + { + bytes data; + + TLS_SERIALIZABLE(data) + }; + + X509Credential() = default; + explicit X509Credential(const std::vector& der_chain_in); + + SignatureScheme signature_scheme() const; + SignaturePublicKey public_key() const; + bool valid_for(const SignaturePublicKey& pub) const; + + // TODO(rlb) This should be const or exposed via a method + std::vector der_chain; + +private: + SignaturePublicKey _public_key; + SignatureScheme _signature_scheme; +}; + +tls::ostream& +operator<<(tls::ostream& str, const X509Credential& obj); + +tls::istream& +operator>>(tls::istream& str, X509Credential& obj); + +struct UserInfoVCCredential +{ + UserInfoVCCredential() = default; + explicit UserInfoVCCredential(std::string userinfo_vc_jwt_in); + + std::string userinfo_vc_jwt; + + bool valid_for(const SignaturePublicKey& pub) const; + bool valid_from(const PublicJWK& pub) const; + + friend tls::ostream operator<<(tls::ostream& str, + const UserInfoVCCredential& obj); + friend tls::istream operator>>(tls::istream& str, UserInfoVCCredential& obj); + friend bool operator==(const UserInfoVCCredential& lhs, + const UserInfoVCCredential& rhs); + friend bool operator!=(const UserInfoVCCredential& lhs, + const UserInfoVCCredential& rhs); + +private: + std::shared_ptr _vc; +}; + +bool +operator==(const X509Credential& lhs, const X509Credential& rhs); + +enum struct CredentialType : uint16_t +{ + reserved = 0, + basic = 1, + x509 = 2, + + userinfo_vc_draft_00 = 0xFE00, + multi_draft_00 = 0xFF00, + + // GREASE values, included here mainly so that debugger output looks nice + GREASE_0 = 0x0A0A, + GREASE_1 = 0x1A1A, + GREASE_2 = 0x2A2A, + GREASE_3 = 0x3A3A, + GREASE_4 = 0x4A4A, + GREASE_5 = 0x5A5A, + GREASE_6 = 0x6A6A, + GREASE_7 = 0x7A7A, + GREASE_8 = 0x8A8A, + GREASE_9 = 0x9A9A, + GREASE_A = 0xAAAA, + GREASE_B = 0xBABA, + GREASE_C = 0xCACA, + GREASE_D = 0xDADA, + GREASE_E = 0xEAEA, +}; + +// struct { +// Credential credential; +// SignaturePublicKey credential_key; +// opaque signature; +// } CredentialBinding +// +// struct { +// CredentialBinding bindings; +// } MultiCredential; +struct CredentialBinding; +struct CredentialBindingInput; + +struct MultiCredential +{ + MultiCredential() = default; + MultiCredential(const std::vector& binding_inputs, + const SignaturePublicKey& signature_key); + + std::vector bindings; + + bool valid_for(const SignaturePublicKey& pub) const; + + TLS_SERIALIZABLE(bindings) +}; + +// struct { +// CredentialType credential_type; +// select (credential_type) { +// case basic: +// BasicCredential; +// +// case x509: +// opaque cert_data<1..2^24-1>; +// }; +// } Credential; +struct Credential +{ + Credential() = default; + + CredentialType type() const; + + template + const T& get() const + { + return var::get(_cred); + } + + static Credential basic(const bytes& identity); + static Credential x509(const std::vector& der_chain); + static Credential userinfo_vc(const std::string& userinfo_vc_jwt); + static Credential multi( + const std::vector& binding_inputs, + const SignaturePublicKey& signature_key); + + bool valid_for(const SignaturePublicKey& pub) const; + + TLS_SERIALIZABLE(_cred) + TLS_TRAITS(tls::variant) + +private: + using SpecificCredential = var::variant; + + Credential(SpecificCredential specific); + SpecificCredential _cred; +}; + +// XXX(RLB): This struct needs to appear below Credential so that all types are +// concrete at the appropriate points. +struct CredentialBindingInput +{ + CipherSuite cipher_suite; + Credential credential; + const SignaturePrivateKey& credential_priv; +}; + +struct CredentialBinding +{ + CipherSuite cipher_suite; + Credential credential; + SignaturePublicKey credential_key; + bytes signature; + + CredentialBinding() = default; + CredentialBinding(CipherSuite suite_in, + Credential credential_in, + const SignaturePrivateKey& credential_priv, + const SignaturePublicKey& signature_key); + + bool valid_for(const SignaturePublicKey& signature_key) const; + + TLS_SERIALIZABLE(cipher_suite, credential, credential_key, signature) + +private: + bytes to_be_signed(const SignaturePublicKey& signature_key) const; +}; + +} // namespace mlspp + +namespace mlspp::tls { + +TLS_VARIANT_MAP(mlspp::CredentialType, + mlspp::BasicCredential, + basic) +TLS_VARIANT_MAP(mlspp::CredentialType, + mlspp::X509Credential, + x509) +TLS_VARIANT_MAP(mlspp::CredentialType, + mlspp::UserInfoVCCredential, + userinfo_vc_draft_00) +TLS_VARIANT_MAP(mlspp::CredentialType, + mlspp::MultiCredential, + multi_draft_00) + +} // namespace mlspp::tls diff --git a/mlspp/include/mls/crypto.h b/mlspp/include/mls/crypto.h new file mode 100755 index 0000000000..f9924de02d --- /dev/null +++ b/mlspp/include/mls/crypto.h @@ -0,0 +1,266 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace mlspp { + +/// Signature Code points, borrowed from RFC 8446 +enum struct SignatureScheme : uint16_t +{ + ecdsa_secp256r1_sha256 = 0x0403, + ecdsa_secp384r1_sha384 = 0x0805, + ecdsa_secp521r1_sha512 = 0x0603, + ed25519 = 0x0807, + ed448 = 0x0808, + rsa_pkcs1_sha256 = 0x0401, +}; + +SignatureScheme +tls_signature_scheme(hpke::Signature::ID id); + +/// Cipher suites + +struct KeyAndNonce +{ + bytes key; + bytes nonce; +}; + +// opaque HashReference; +// HashReference KeyPackageRef; +// HashReference ProposalRef; +using HashReference = bytes; +using KeyPackageRef = HashReference; +using ProposalRef = HashReference; + +struct CipherSuite +{ + enum struct ID : uint16_t + { + unknown = 0x0000, + X25519_AES128GCM_SHA256_Ed25519 = 0x0001, + P256_AES128GCM_SHA256_P256 = 0x0002, + X25519_CHACHA20POLY1305_SHA256_Ed25519 = 0x0003, + X448_AES256GCM_SHA512_Ed448 = 0x0004, + P521_AES256GCM_SHA512_P521 = 0x0005, + X448_CHACHA20POLY1305_SHA512_Ed448 = 0x0006, + P384_AES256GCM_SHA384_P384 = 0x0007, + + // GREASE values, included here mainly so that debugger output looks nice + GREASE_0 = 0x0A0A, + GREASE_1 = 0x1A1A, + GREASE_2 = 0x2A2A, + GREASE_3 = 0x3A3A, + GREASE_4 = 0x4A4A, + GREASE_5 = 0x5A5A, + GREASE_6 = 0x6A6A, + GREASE_7 = 0x7A7A, + GREASE_8 = 0x8A8A, + GREASE_9 = 0x9A9A, + GREASE_A = 0xAAAA, + GREASE_B = 0xBABA, + GREASE_C = 0xCACA, + GREASE_D = 0xDADA, + GREASE_E = 0xEAEA, + }; + + CipherSuite(); + CipherSuite(ID id_in); + + ID cipher_suite() const { return id; } + SignatureScheme signature_scheme() const; + + size_t secret_size() const { return get().digest.hash_size; } + size_t key_size() const { return get().hpke.aead.key_size; } + size_t nonce_size() const { return get().hpke.aead.nonce_size; } + + bytes zero() const { return bytes(secret_size(), 0); } + const hpke::HPKE& hpke() const { return get().hpke; } + const hpke::Digest& digest() const { return get().digest; } + const hpke::Signature& sig() const { return get().sig; } + + bytes expand_with_label(const bytes& secret, + const std::string& label, + const bytes& context, + size_t length) const; + bytes derive_secret(const bytes& secret, const std::string& label) const; + bytes derive_tree_secret(const bytes& secret, + const std::string& label, + uint32_t generation, + size_t length) const; + + template + bytes ref(const T& value) const + { + return raw_ref(reference_label(), tls::marshal(value)); + } + + bytes raw_ref(const bytes& label, const bytes& value) const + { + // RefHash(label, value) = Hash(RefHashInput) + // + // struct { + // opaque label; + // opaque value; + // } RefHashInput; + auto w = tls::ostream(); + w << label << value; + return digest().hash(w.bytes()); + } + + TLS_SERIALIZABLE(id) + +private: + ID id; + + struct Ciphers + { + hpke::HPKE hpke; + const hpke::Digest& digest; + const hpke::Signature& sig; + }; + + const Ciphers& get() const; + + template + static const bytes& reference_label(); +}; + +#if WITH_BORINGSSL +extern const std::array all_supported_suites; +#else +extern const std::array all_supported_suites; +#endif + +// Utilities +using mlspp::hpke::random_bytes; + +// HPKE Keys +namespace encrypt_label { +extern const std::string update_path_node; +extern const std::string welcome; +} // namespace encrypt_label + +struct HPKECiphertext +{ + bytes kem_output; + bytes ciphertext; + + TLS_SERIALIZABLE(kem_output, ciphertext) +}; + +struct HPKEPublicKey +{ + bytes data; + + HPKECiphertext encrypt(CipherSuite suite, + const std::string& label, + const bytes& context, + const bytes& pt) const; + + std::tuple do_export(CipherSuite suite, + const bytes& info, + const std::string& label, + size_t size) const; + + TLS_SERIALIZABLE(data) +}; + +struct HPKEPrivateKey +{ + static HPKEPrivateKey generate(CipherSuite suite); + static HPKEPrivateKey parse(CipherSuite suite, const bytes& data); + static HPKEPrivateKey derive(CipherSuite suite, const bytes& secret); + + HPKEPrivateKey() = default; + + bytes data; + HPKEPublicKey public_key; + + bytes decrypt(CipherSuite suite, + const std::string& label, + const bytes& context, + const HPKECiphertext& ct) const; + + bytes do_export(CipherSuite suite, + const bytes& info, + const bytes& kem_output, + const std::string& label, + size_t size) const; + + void set_public_key(CipherSuite suite); + + TLS_SERIALIZABLE(data) + +private: + HPKEPrivateKey(bytes priv_data, bytes pub_data); +}; + +// Signature Keys +namespace sign_label { +extern const std::string mls_content; +extern const std::string leaf_node; +extern const std::string key_package; +extern const std::string group_info; +extern const std::string multi_credential; +} // namespace sign_label + +struct SignaturePublicKey +{ + static SignaturePublicKey from_jwk(CipherSuite suite, + const std::string& json_str); + + bytes data; + + bool verify(const CipherSuite& suite, + const std::string& label, + const bytes& message, + const bytes& signature) const; + + std::string to_jwk(CipherSuite suite) const; + + TLS_SERIALIZABLE(data) +}; + +struct PublicJWK +{ + SignatureScheme signature_scheme; + std::optional key_id; + SignaturePublicKey public_key; + + static PublicJWK parse(const std::string& jwk_json); +}; + +struct SignaturePrivateKey +{ + static SignaturePrivateKey generate(CipherSuite suite); + static SignaturePrivateKey parse(CipherSuite suite, const bytes& data); + static SignaturePrivateKey derive(CipherSuite suite, const bytes& secret); + static SignaturePrivateKey from_jwk(CipherSuite suite, + const std::string& json_str); + + SignaturePrivateKey() = default; + + bytes data; + SignaturePublicKey public_key; + + bytes sign(const CipherSuite& suite, + const std::string& label, + const bytes& message) const; + + void set_public_key(CipherSuite suite); + std::string to_jwk(CipherSuite suite) const; + + TLS_SERIALIZABLE(data) + +private: + SignaturePrivateKey(bytes priv_data, bytes pub_data); +}; + +} // namespace mlspp diff --git a/mlspp/include/mls/key_schedule.h b/mlspp/include/mls/key_schedule.h new file mode 100755 index 0000000000..b30dd766e4 --- /dev/null +++ b/mlspp/include/mls/key_schedule.h @@ -0,0 +1,205 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace mlspp { + +struct HashRatchet +{ + CipherSuite suite; + bytes next_secret; + uint32_t next_generation; + std::map cache; + + size_t key_size; + size_t nonce_size; + size_t secret_size; + + // These defaults are necessary for use with containers + HashRatchet() = default; + HashRatchet(const HashRatchet& other) = default; + HashRatchet(HashRatchet&& other) = default; + HashRatchet& operator=(const HashRatchet& other) = default; + HashRatchet& operator=(HashRatchet&& other) = default; + + HashRatchet(CipherSuite suite_in, bytes base_secret_in); + + std::tuple next(); + KeyAndNonce get(uint32_t generation); + void erase(uint32_t generation); +}; + +struct SecretTree +{ + SecretTree() = default; + SecretTree(CipherSuite suite_in, + LeafCount group_size_in, + bytes encryption_secret_in); + + bool has_leaf(LeafIndex sender) { return sender < group_size; } + + bytes get(LeafIndex sender); + +private: + CipherSuite suite; + LeafCount group_size; + NodeIndex root; + std::map secrets; + size_t secret_size; +}; + +using ReuseGuard = std::array; + +struct GroupKeySource +{ + enum struct RatchetType + { + handshake, + application, + }; + + GroupKeySource() = default; + GroupKeySource(CipherSuite suite_in, + LeafCount group_size, + bytes encryption_secret); + + bool has_leaf(LeafIndex sender) { return secret_tree.has_leaf(sender); } + + std::tuple next(ContentType content_type, + LeafIndex sender); + KeyAndNonce get(ContentType content_type, + LeafIndex sender, + uint32_t generation, + ReuseGuard reuse_guard); + void erase(ContentType type, LeafIndex sender, uint32_t generation); + +private: + CipherSuite suite; + SecretTree secret_tree; + + using Key = std::tuple; + std::map chains; + + HashRatchet& chain(RatchetType type, LeafIndex sender); + HashRatchet& chain(ContentType type, LeafIndex sender); + + static const std::array all_ratchet_types; +}; + +struct KeyScheduleEpoch +{ +private: + CipherSuite suite; + +public: + bytes joiner_secret; + bytes epoch_secret; + + bytes sender_data_secret; + bytes encryption_secret; + bytes exporter_secret; + bytes epoch_authenticator; + bytes external_secret; + bytes confirmation_key; + bytes membership_key; + bytes resumption_psk; + bytes init_secret; + + HPKEPrivateKey external_priv; + + KeyScheduleEpoch() = default; + + // Full initializer, used by invited joiner + static KeyScheduleEpoch joiner(CipherSuite suite_in, + const bytes& joiner_secret, + const std::vector& psks, + const bytes& context); + + // Ciphersuite-only initializer, used by external joiner + KeyScheduleEpoch(CipherSuite suite_in); + + // Initial epoch + KeyScheduleEpoch(CipherSuite suite_in, + const bytes& init_secret, + const bytes& context); + + static std::tuple external_init( + CipherSuite suite, + const HPKEPublicKey& external_pub); + bytes receive_external_init(const bytes& kem_output) const; + + KeyScheduleEpoch next(const bytes& commit_secret, + const std::vector& psks, + const std::optional& force_init_secret, + const bytes& context) const; + + GroupKeySource encryption_keys(LeafCount size) const; + bytes confirmation_tag(const bytes& confirmed_transcript_hash) const; + bytes do_export(const std::string& label, + const bytes& context, + size_t size) const; + PSKWithSecret resumption_psk_w_secret(ResumptionPSKUsage usage, + const bytes& group_id, + epoch_t epoch); + + static bytes make_psk_secret(CipherSuite suite, + const std::vector& psks); + static bytes welcome_secret(CipherSuite suite, + const bytes& joiner_secret, + const std::vector& psks); + static KeyAndNonce sender_data_keys(CipherSuite suite, + const bytes& sender_data_secret, + const bytes& ciphertext); + + // TODO(RLB) make these methods private, but accessible to test vectors + KeyScheduleEpoch(CipherSuite suite_in, + const bytes& init_secret, + const bytes& commit_secret, + const bytes& psk_secret, + const bytes& context); + KeyScheduleEpoch next_raw(const bytes& commit_secret, + const bytes& psk_secret, + const std::optional& force_init_secret, + const bytes& context) const; + static bytes welcome_secret_raw(CipherSuite suite, + const bytes& joiner_secret, + const bytes& psk_secret); + +private: + KeyScheduleEpoch(CipherSuite suite_in, + const bytes& joiner_secret, + const bytes& psk_secret, + const bytes& context); +}; + +bool +operator==(const KeyScheduleEpoch& lhs, const KeyScheduleEpoch& rhs); + +struct TranscriptHash +{ + CipherSuite suite; + bytes confirmed; + bytes interim; + + // For a new group + TranscriptHash(CipherSuite suite_in); + + // For joining a group + TranscriptHash(CipherSuite suite_in, + bytes confirmed_in, + const bytes& confirmation_tag); + + void update(const AuthenticatedContent& content_auth); + void update_confirmed(const AuthenticatedContent& content_auth); + void update_interim(const bytes& confirmation_tag); + void update_interim(const AuthenticatedContent& content_auth); +}; + +bool +operator==(const TranscriptHash& lhs, const TranscriptHash& rhs); + +} // namespace mlspp diff --git a/mlspp/include/mls/messages.h b/mlspp/include/mls/messages.h new file mode 100755 index 0000000000..583ccc2733 --- /dev/null +++ b/mlspp/include/mls/messages.h @@ -0,0 +1,752 @@ +#pragma once + +#include "mls/common.h" +#include "mls/core_types.h" +#include "mls/credential.h" +#include "mls/crypto.h" +#include "mls/treekem.h" +#include +#include + +namespace mlspp { + +struct ExternalPubExtension +{ + HPKEPublicKey external_pub; + + static const uint16_t type; + TLS_SERIALIZABLE(external_pub) +}; + +struct RatchetTreeExtension +{ + TreeKEMPublicKey tree; + + static const uint16_t type; + TLS_SERIALIZABLE(tree) +}; + +struct ExternalSender +{ + SignaturePublicKey signature_key; + Credential credential; + + TLS_SERIALIZABLE(signature_key, credential); +}; + +struct ExternalSendersExtension +{ + std::vector senders; + + static const uint16_t type; + TLS_SERIALIZABLE(senders); +}; + +struct SFrameParameters +{ + uint16_t cipher_suite; + uint8_t epoch_bits; + + static const uint16_t type; + TLS_SERIALIZABLE(cipher_suite, epoch_bits) +}; + +struct SFrameCapabilities +{ + std::vector cipher_suites; + + bool compatible(const SFrameParameters& params) const; + + static const uint16_t type; + TLS_SERIALIZABLE(cipher_suites) +}; + +/// +/// PSKs +/// +enum struct PSKType : uint8_t +{ + reserved = 0, + external = 1, + resumption = 2, +}; + +struct ExternalPSK +{ + bytes psk_id; + TLS_SERIALIZABLE(psk_id) +}; + +enum struct ResumptionPSKUsage : uint8_t +{ + reserved = 0, + application = 1, + reinit = 2, + branch = 3, +}; + +struct ResumptionPSK +{ + ResumptionPSKUsage usage; + bytes psk_group_id; + epoch_t psk_epoch; + TLS_SERIALIZABLE(usage, psk_group_id, psk_epoch) +}; + +struct PreSharedKeyID +{ + var::variant content; + bytes psk_nonce; + TLS_SERIALIZABLE(content, psk_nonce) + TLS_TRAITS(tls::variant, tls::pass) +}; + +struct PreSharedKeys +{ + std::vector psks; + TLS_SERIALIZABLE(psks) +}; + +struct PSKWithSecret +{ + PreSharedKeyID id; + bytes secret; +}; + +// struct { +// ProtocolVersion version = mls10; +// CipherSuite cipher_suite; +// opaque group_id; +// uint64 epoch; +// opaque tree_hash; +// opaque confirmed_transcript_hash; +// Extension extensions; +// } GroupContext; +struct GroupContext +{ + ProtocolVersion version{ ProtocolVersion::mls10 }; + CipherSuite cipher_suite; + bytes group_id; + epoch_t epoch; + bytes tree_hash; + bytes confirmed_transcript_hash; + ExtensionList extensions; + + GroupContext() = default; + GroupContext(CipherSuite cipher_suite_in, + bytes group_id_in, + epoch_t epoch_in, + bytes tree_hash_in, + bytes confirmed_transcript_hash_in, + ExtensionList extensions_in); + + TLS_SERIALIZABLE(version, + cipher_suite, + group_id, + epoch, + tree_hash, + confirmed_transcript_hash, + extensions) +}; + +// struct { +// GroupContext group_context; +// Extension extensions; +// MAC confirmation_tag; +// uint32 signer; +// // SignWithLabel(., "GroupInfoTBS", GroupInfoTBS) +// opaque signature; +// } GroupInfo; +struct GroupInfo +{ + GroupContext group_context; + ExtensionList extensions; + bytes confirmation_tag; + LeafIndex signer; + bytes signature; + + GroupInfo() = default; + GroupInfo(GroupContext group_context_in, + ExtensionList extensions_in, + bytes confirmation_tag_in); + + bytes to_be_signed() const; + void sign(const TreeKEMPublicKey& tree, + LeafIndex signer_index, + const SignaturePrivateKey& priv); + bool verify(const TreeKEMPublicKey& tree) const; + + // These methods exist only to simplify unit testing + void sign(LeafIndex signer_index, const SignaturePrivateKey& priv); + bool verify(const SignaturePublicKey& pub) const; + + TLS_SERIALIZABLE(group_context, + extensions, + confirmation_tag, + signer, + signature) +}; + +// struct { +// opaque joiner_secret<1..255>; +// optional path_secret; +// PreSharedKeys psks; +// } GroupSecrets; +struct GroupSecrets +{ + struct PathSecret + { + bytes secret; + + TLS_SERIALIZABLE(secret) + }; + + bytes joiner_secret; + std::optional path_secret; + PreSharedKeys psks; + + TLS_SERIALIZABLE(joiner_secret, path_secret, psks) +}; + +// struct { +// opaque key_package_hash<1..255>; +// HPKECiphertext encrypted_group_secrets; +// } EncryptedGroupSecrets; +struct EncryptedGroupSecrets +{ + KeyPackageRef new_member; + HPKECiphertext encrypted_group_secrets; + + TLS_SERIALIZABLE(new_member, encrypted_group_secrets) +}; + +// struct { +// ProtocolVersion version = mls10; +// CipherSuite cipher_suite; +// EncryptedGroupSecrets group_secretss<1..2^32-1>; +// opaque encrypted_group_info<1..2^32-1>; +// } Welcome; +struct Welcome +{ + CipherSuite cipher_suite; + std::vector secrets; + bytes encrypted_group_info; + + Welcome(); + Welcome(CipherSuite suite, + const bytes& joiner_secret, + const std::vector& psks, + const GroupInfo& group_info); + + void encrypt(const KeyPackage& kp, const std::optional& path_secret); + std::optional find(const KeyPackage& kp) const; + GroupSecrets decrypt_secrets(int kp_index, + const HPKEPrivateKey& init_priv) const; + GroupInfo decrypt(const bytes& joiner_secret, + const std::vector& psks) const; + + TLS_SERIALIZABLE(cipher_suite, secrets, encrypted_group_info) + +private: + bytes _joiner_secret; + PreSharedKeys _psks; + static KeyAndNonce group_info_key_nonce( + CipherSuite suite, + const bytes& joiner_secret, + const std::vector& psks); +}; + +/// +/// Proposals & Commit +/// + +// Add +struct Add +{ + KeyPackage key_package; + TLS_SERIALIZABLE(key_package) +}; + +// Update +struct Update +{ + LeafNode leaf_node; + TLS_SERIALIZABLE(leaf_node) +}; + +// Remove +struct Remove +{ + LeafIndex removed; + TLS_SERIALIZABLE(removed) +}; + +// PreSharedKey +struct PreSharedKey +{ + PreSharedKeyID psk; + TLS_SERIALIZABLE(psk) +}; + +// ReInit +struct ReInit +{ + bytes group_id; + ProtocolVersion version; + CipherSuite cipher_suite; + ExtensionList extensions; + + TLS_SERIALIZABLE(group_id, version, cipher_suite, extensions) +}; + +// ExternalInit +struct ExternalInit +{ + bytes kem_output; + TLS_SERIALIZABLE(kem_output) +}; + +// GroupContextExtensions +struct GroupContextExtensions +{ + ExtensionList group_context_extensions; + TLS_SERIALIZABLE(group_context_extensions) +}; + +struct ProposalType; + +struct Proposal +{ + using Type = uint16_t; + + var::variant + content; + + Type proposal_type() const; + + TLS_SERIALIZABLE(content) + TLS_TRAITS(tls::variant) +}; + +struct ProposalType +{ + static constexpr Proposal::Type invalid = 0; + static constexpr Proposal::Type add = 1; + static constexpr Proposal::Type update = 2; + static constexpr Proposal::Type remove = 3; + static constexpr Proposal::Type psk = 4; + static constexpr Proposal::Type reinit = 5; + static constexpr Proposal::Type external_init = 6; + static constexpr Proposal::Type group_context_extensions = 7; + + constexpr ProposalType() + : val(invalid) + { + } + + constexpr ProposalType(Proposal::Type pt) + : val(pt) + { + } + + Proposal::Type val; + TLS_SERIALIZABLE(val) +}; + +enum struct ProposalOrRefType : uint8_t +{ + reserved = 0, + value = 1, + reference = 2, +}; + +struct ProposalOrRef +{ + var::variant content; + + TLS_SERIALIZABLE(content) + TLS_TRAITS(tls::variant) +}; + +// struct { +// ProposalOrRef proposals<0..2^32-1>; +// optional path; +// } Commit; +struct Commit +{ + std::vector proposals; + std::optional path; + + // Validate that the commit is acceptable as an external commit, and if so, + // produce the public key from the ExternalInit proposal + std::optional valid_external() const; + + TLS_SERIALIZABLE(proposals, path) +}; + +// struct { +// opaque group_id<0..255>; +// uint32 epoch; +// uint32 sender; +// ContentType content_type; +// +// select (PublicMessage.content_type) { +// case handshake: +// GroupOperation operation; +// opaque confirmation<0..255>; +// +// case application: +// opaque application_data<0..2^32-1>; +// } +// +// opaque signature<0..2^16-1>; +// } PublicMessage; +struct ApplicationData +{ + bytes data; + TLS_SERIALIZABLE(data) +}; + +struct GroupContext; + +enum struct WireFormat : uint16_t +{ + reserved = 0, + mls_public_message = 1, + mls_private_message = 2, + mls_welcome = 3, + mls_group_info = 4, + mls_key_package = 5, +}; + +enum struct ContentType : uint8_t +{ + invalid = 0, + application = 1, + proposal = 2, + commit = 3, +}; + +enum struct SenderType : uint8_t +{ + invalid = 0, + member = 1, + external = 2, + new_member_proposal = 3, + new_member_commit = 4, +}; + +struct MemberSender +{ + LeafIndex sender; + TLS_SERIALIZABLE(sender); +}; + +struct ExternalSenderIndex +{ + uint32_t sender_index; + TLS_SERIALIZABLE(sender_index) +}; + +struct NewMemberProposalSender +{ + TLS_SERIALIZABLE() +}; + +struct NewMemberCommitSender +{ + TLS_SERIALIZABLE() +}; + +struct Sender +{ + var::variant + sender; + + SenderType sender_type() const; + + TLS_SERIALIZABLE(sender) + TLS_TRAITS(tls::variant) +}; + +/// +/// MLSMessage and friends +/// +struct GroupKeySource; + +struct GroupContent +{ + using RawContent = var::variant; + + bytes group_id; + epoch_t epoch; + Sender sender; + bytes authenticated_data; + RawContent content; + + GroupContent() = default; + GroupContent(bytes group_id_in, + epoch_t epoch_in, + Sender sender_in, + bytes authenticated_data_in, + RawContent content_in); + GroupContent(bytes group_id_in, + epoch_t epoch_in, + Sender sender_in, + bytes authenticated_data_in, + ContentType content_type); + + ContentType content_type() const; + + TLS_SERIALIZABLE(group_id, epoch, sender, authenticated_data, content) + TLS_TRAITS(tls::pass, + tls::pass, + tls::pass, + tls::pass, + tls::variant) +}; + +struct GroupContentAuthData +{ + ContentType content_type = ContentType::invalid; + bytes signature; + std::optional confirmation_tag; + + friend tls::ostream& operator<<(tls::ostream& str, + const GroupContentAuthData& obj); + friend tls::istream& operator>>(tls::istream& str, GroupContentAuthData& obj); + friend bool operator==(const GroupContentAuthData& lhs, + const GroupContentAuthData& rhs); +}; + +struct AuthenticatedContent +{ + WireFormat wire_format; + GroupContent content; + GroupContentAuthData auth; + + AuthenticatedContent() = default; + + static AuthenticatedContent sign(WireFormat wire_format, + GroupContent content, + CipherSuite suite, + const SignaturePrivateKey& sig_priv, + const std::optional& context); + bool verify(CipherSuite suite, + const SignaturePublicKey& sig_pub, + const std::optional& context) const; + + bytes confirmed_transcript_hash_input() const; + bytes interim_transcript_hash_input() const; + + void set_confirmation_tag(const bytes& confirmation_tag); + bool check_confirmation_tag(const bytes& confirmation_tag) const; + + friend tls::ostream& operator<<(tls::ostream& str, + const AuthenticatedContent& obj); + friend tls::istream& operator>>(tls::istream& str, AuthenticatedContent& obj); + friend bool operator==(const AuthenticatedContent& lhs, + const AuthenticatedContent& rhs); + +private: + AuthenticatedContent(WireFormat wire_format_in, GroupContent content_in); + AuthenticatedContent(WireFormat wire_format_in, + GroupContent content_in, + GroupContentAuthData auth_in); + + bytes to_be_signed(const std::optional& context) const; + + friend struct PublicMessage; + friend struct PrivateMessage; +}; + +struct ValidatedContent +{ + const AuthenticatedContent& authenticated_content() const; + + friend bool operator==(const ValidatedContent& lhs, + const ValidatedContent& rhs); + +private: + AuthenticatedContent content_auth; + + ValidatedContent(AuthenticatedContent content_auth_in); + + friend struct PublicMessage; + friend struct PrivateMessage; + friend class State; +}; + +struct PublicMessage +{ + PublicMessage() = default; + + bytes get_group_id() const { return content.group_id; } + epoch_t get_epoch() const { return content.epoch; } + + static PublicMessage protect(AuthenticatedContent content_auth, + CipherSuite suite, + const std::optional& membership_key, + const std::optional& context); + std::optional unprotect( + CipherSuite suite, + const std::optional& membership_key, + const std::optional& context) const; + + bool contains(const AuthenticatedContent& content_auth) const; + + // TODO(RLB) Make this private and expose only to tests + AuthenticatedContent authenticated_content() const; + + friend tls::ostream& operator<<(tls::ostream& str, const PublicMessage& obj); + friend tls::istream& operator>>(tls::istream& str, PublicMessage& obj); + friend bool operator==(const PublicMessage& lhs, const PublicMessage& rhs); + friend bool operator!=(const PublicMessage& lhs, const PublicMessage& rhs); + +private: + GroupContent content; + GroupContentAuthData auth; + std::optional membership_tag; + + PublicMessage(AuthenticatedContent content_auth); + + bytes membership_mac(CipherSuite suite, + const bytes& membership_key, + const std::optional& context) const; +}; + +struct PrivateMessage +{ + PrivateMessage() = default; + + bytes get_group_id() const { return group_id; } + epoch_t get_epoch() const { return epoch; } + + static PrivateMessage protect(AuthenticatedContent content_auth, + CipherSuite suite, + GroupKeySource& keys, + const bytes& sender_data_secret, + size_t padding_size); + std::optional unprotect( + CipherSuite suite, + GroupKeySource& keys, + const bytes& sender_data_secret) const; + + TLS_SERIALIZABLE(group_id, + epoch, + content_type, + authenticated_data, + encrypted_sender_data, + ciphertext) + +private: + bytes group_id; + epoch_t epoch; + ContentType content_type; + bytes authenticated_data; + bytes encrypted_sender_data; + bytes ciphertext; + + PrivateMessage(GroupContent content, + bytes encrypted_sender_data_in, + bytes ciphertext_in); +}; + +struct MLSMessage +{ + ProtocolVersion version = ProtocolVersion::mls10; + var::variant + message; + + bytes group_id() const; + epoch_t epoch() const; + WireFormat wire_format() const; + + MLSMessage() = default; + MLSMessage(PublicMessage public_message); + MLSMessage(PrivateMessage private_message); + MLSMessage(Welcome welcome); + MLSMessage(GroupInfo group_info); + MLSMessage(KeyPackage key_package); + + TLS_SERIALIZABLE(version, message) + TLS_TRAITS(tls::pass, tls::variant) +}; + +MLSMessage +external_proposal(CipherSuite suite, + const bytes& group_id, + epoch_t epoch, + const Proposal& proposal, + uint32_t signer_index, + const SignaturePrivateKey& sig_priv); + +} // namespace mlspp + +namespace mlspp::tls { + +TLS_VARIANT_MAP(mlspp::PSKType, mlspp::ExternalPSK, external) +TLS_VARIANT_MAP(mlspp::PSKType, + mlspp::ResumptionPSK, + resumption) + +TLS_VARIANT_MAP(mlspp::ProposalOrRefType, + mlspp::Proposal, + value) +TLS_VARIANT_MAP(mlspp::ProposalOrRefType, + mlspp::ProposalRef, + reference) + +TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Add, add) +TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Update, update) +TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Remove, remove) +TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::PreSharedKey, psk) +TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::ReInit, reinit) +TLS_VARIANT_MAP(mlspp::ProposalType, + mlspp::ExternalInit, + external_init) +TLS_VARIANT_MAP(mlspp::ProposalType, + mlspp::GroupContextExtensions, + group_context_extensions) + +TLS_VARIANT_MAP(mlspp::ContentType, + mlspp::ApplicationData, + application) +TLS_VARIANT_MAP(mlspp::ContentType, mlspp::Proposal, proposal) +TLS_VARIANT_MAP(mlspp::ContentType, mlspp::Commit, commit) + +TLS_VARIANT_MAP(mlspp::SenderType, mlspp::MemberSender, member) +TLS_VARIANT_MAP(mlspp::SenderType, + mlspp::ExternalSenderIndex, + external) +TLS_VARIANT_MAP(mlspp::SenderType, + mlspp::NewMemberProposalSender, + new_member_proposal) +TLS_VARIANT_MAP(mlspp::SenderType, + mlspp::NewMemberCommitSender, + new_member_commit) + +TLS_VARIANT_MAP(mlspp::WireFormat, + mlspp::PublicMessage, + mls_public_message) +TLS_VARIANT_MAP(mlspp::WireFormat, + mlspp::PrivateMessage, + mls_private_message) +TLS_VARIANT_MAP(mlspp::WireFormat, mlspp::Welcome, mls_welcome) +TLS_VARIANT_MAP(mlspp::WireFormat, + mlspp::GroupInfo, + mls_group_info) +TLS_VARIANT_MAP(mlspp::WireFormat, + mlspp::KeyPackage, + mls_key_package) + +} // namespace mlspp::tls diff --git a/mlspp/include/mls/session.h b/mlspp/include/mls/session.h new file mode 100755 index 0000000000..045fbfdb7e --- /dev/null +++ b/mlspp/include/mls/session.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace mlspp { + +class PendingJoin; +class Session; + +class Client +{ +public: + Client(CipherSuite suite_in, + SignaturePrivateKey sig_priv_in, + Credential cred_in); + + Session begin_session(const bytes& group_id) const; + + PendingJoin start_join() const; + +private: + const CipherSuite suite; + const SignaturePrivateKey sig_priv; + const Credential cred; +}; + +class PendingJoin +{ +public: + PendingJoin(PendingJoin&& other) noexcept; + PendingJoin& operator=(PendingJoin&& other) noexcept; + ~PendingJoin(); + bytes key_package() const; + Session complete(const bytes& welcome) const; + +private: + struct Inner; + std::unique_ptr inner; + + PendingJoin(Inner* inner); + friend class Client; +}; + +class Session +{ +public: + Session(Session&& other) noexcept; + Session& operator=(Session&& other) noexcept; + ~Session(); + + // Settings + void encrypt_handshake(bool enabled); + + // Message producers + bytes add(const bytes& key_package_data); + bytes update(); + bytes remove(uint32_t index); + std::tuple commit(const bytes& proposal); + std::tuple commit(const std::vector& proposals); + std::tuple commit(); + + // Message consumers + bool handle(const bytes& handshake_data); + + // Information about the current state + epoch_t epoch() const; + LeafIndex index() const; + CipherSuite cipher_suite() const; + const ExtensionList& extensions() const; + const TreeKEMPublicKey& tree() const; + bytes do_export(const std::string& label, + const bytes& context, + size_t size) const; + GroupInfo group_info() const; + std::vector roster() const; + bytes epoch_authenticator() const; + + // Application message protection + bytes protect(const bytes& plaintext); + bytes unprotect(const bytes& ciphertext); + +protected: + struct Inner; + std::unique_ptr inner; + + Session(Inner* inner); + friend class Client; + friend class PendingJoin; + + friend bool operator==(const Session& lhs, const Session& rhs); + friend bool operator!=(const Session& lhs, const Session& rhs); +}; + +} // namespace mlspp diff --git a/mlspp/include/mls/state.h b/mlspp/include/mls/state.h new file mode 100755 index 0000000000..84e4926263 --- /dev/null +++ b/mlspp/include/mls/state.h @@ -0,0 +1,431 @@ +#pragma once + +#include "mls/crypto.h" +#include "mls/key_schedule.h" +#include "mls/messages.h" +#include "mls/treekem.h" +#include +#include +#include + +namespace mlspp { + +// Index into the session roster +struct RosterIndex : public UInt32 +{ + using UInt32::UInt32; +}; + +struct CommitOpts +{ + std::vector extra_proposals; + bool inline_tree; + bool force_path; + LeafNodeOptions leaf_node_opts; +}; + +struct MessageOpts +{ + bool encrypt = false; + bytes authenticated_data; + size_t padding_size = 0; +}; + +class State +{ +public: + /// + /// Constructors + /// + + // Initialize an empty group + State(bytes group_id, + CipherSuite suite, + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const LeafNode& leaf_node, + ExtensionList extensions); + + // Initialize a group from a Welcome + State(const HPKEPrivateKey& init_priv, + HPKEPrivateKey leaf_priv, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const Welcome& welcome, + const std::optional& tree, + std::map psks); + + // Join a group from outside + // XXX(RLB) To be fully general, we would need a few more options here, e.g., + // whether to include PSKs or evict our prior appearance. + static std::tuple external_join( + const bytes& leaf_secret, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const GroupInfo& group_info, + const std::optional& tree, + const MessageOpts& msg_opts, + std::optional remove_prior, + const std::map& psks); + + // Propose that a new member be added a group + static MLSMessage new_member_add(const bytes& group_id, + epoch_t epoch, + const KeyPackage& new_member, + const SignaturePrivateKey& sig_priv); + + /// + /// Message factories + /// + + Proposal add_proposal(const KeyPackage& key_package) const; + Proposal update_proposal(HPKEPrivateKey leaf_priv, + const LeafNodeOptions& opts); + Proposal remove_proposal(RosterIndex index) const; + Proposal remove_proposal(LeafIndex removed) const; + Proposal group_context_extensions_proposal(ExtensionList exts) const; + Proposal pre_shared_key_proposal(const bytes& external_psk_id) const; + Proposal pre_shared_key_proposal(const bytes& group_id, epoch_t epoch) const; + static Proposal reinit_proposal(bytes group_id, + ProtocolVersion version, + CipherSuite cipher_suite, + ExtensionList extensions); + + MLSMessage add(const KeyPackage& key_package, const MessageOpts& msg_opts); + MLSMessage update(HPKEPrivateKey leaf_priv, + const LeafNodeOptions& opts, + const MessageOpts& msg_opts); + MLSMessage remove(RosterIndex index, const MessageOpts& msg_opts); + MLSMessage remove(LeafIndex removed, const MessageOpts& msg_opts); + MLSMessage group_context_extensions(ExtensionList exts, + const MessageOpts& msg_opts); + MLSMessage pre_shared_key(const bytes& external_psk_id, + const MessageOpts& msg_opts); + MLSMessage pre_shared_key(const bytes& group_id, + epoch_t epoch, + const MessageOpts& msg_opts); + MLSMessage reinit(bytes group_id, + ProtocolVersion version, + CipherSuite cipher_suite, + ExtensionList extensions, + const MessageOpts& msg_opts); + + std::tuple commit( + const bytes& leaf_secret, + const std::optional& opts, + const MessageOpts& msg_opts); + + /// + /// Generic handshake message handlers + /// + std::optional handle(const MLSMessage& msg); + std::optional handle(const MLSMessage& msg, + std::optional cached_state); + + std::optional handle(const ValidatedContent& content_auth); + std::optional handle(const ValidatedContent& content_auth, + std::optional cached_state); + + /// + /// PSK management + /// + void add_resumption_psk(const bytes& group_id, epoch_t epoch, bytes secret); + void remove_resumption_psk(const bytes& group_id, epoch_t epoch); + void add_external_psk(const bytes& id, const bytes& secret); + void remove_external_psk(const bytes& id); + + /// + /// Accessors + /// + const bytes& group_id() const { return _group_id; } + epoch_t epoch() const { return _epoch; } + LeafIndex index() const { return _index; } + CipherSuite cipher_suite() const { return _suite; } + const ExtensionList& extensions() const { return _extensions; } + const TreeKEMPublicKey& tree() const { return _tree; } + const bytes& resumption_psk() const { return _key_schedule.resumption_psk; } + + bytes do_export(const std::string& label, + const bytes& context, + size_t size) const; + GroupInfo group_info(bool inline_tree) const; + + // Ordered list of credentials from non-blank leaves + std::vector roster() const; + + bytes epoch_authenticator() const; + + /// + /// Unwrap messages so that applications can inspect them + /// + ValidatedContent unwrap(const MLSMessage& msg); + + /// + /// Application encryption and decryption + /// + MLSMessage protect(const bytes& authenticated_data, + const bytes& pt, + size_t padding_size); + std::tuple unprotect(const MLSMessage& ct); + + // Assemble a group context for this state + GroupContext group_context() const; + + // Subgroup branching + std::tuple create_branch( + bytes group_id, + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const LeafNode& leaf_node, + ExtensionList extensions, + const std::vector& key_packages, + const bytes& leaf_secret, + const CommitOpts& commit_opts) const; + State handle_branch(const HPKEPrivateKey& init_priv, + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const Welcome& welcome, + const std::optional& tree) const; + + // Reinitialization + struct Tombstone + { + std::tuple create_welcome( + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const LeafNode& leaf_node, + const std::vector& key_packages, + const bytes& leaf_secret, + const CommitOpts& commit_opts) const; + State handle_welcome(const HPKEPrivateKey& init_priv, + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const Welcome& welcome, + const std::optional& tree) const; + + TLS_SERIALIZABLE(prior_group_id, prior_epoch, resumption_psk, reinit); + + const bytes epoch_authenticator; + const ReInit reinit; + + private: + Tombstone(const State& state_in, ReInit reinit_in); + + bytes prior_group_id; + epoch_t prior_epoch; + bytes resumption_psk; + + friend class State; + }; + + std::tuple reinit_commit( + const bytes& leaf_secret, + const std::optional& opts, + const MessageOpts& msg_opts); + Tombstone handle_reinit_commit(const MLSMessage& commit); + +protected: + // Shared confirmed state + // XXX(rlb@ipv.sx): Can these be made const? + CipherSuite _suite; + bytes _group_id; + epoch_t _epoch; + TreeKEMPublicKey _tree; + TreeKEMPrivateKey _tree_priv; + TranscriptHash _transcript_hash; + ExtensionList _extensions; + + // Shared secret state + KeyScheduleEpoch _key_schedule; + GroupKeySource _keys; + + // Per-participant state + LeafIndex _index; + SignaturePrivateKey _identity_priv; + + // Storage for PSKs + std::map _external_psks; + + using EpochRef = std::tuple; + std::map _resumption_psks; + + // Cache of Proposals and update secrets + struct CachedProposal + { + ProposalRef ref; + Proposal proposal; + std::optional sender; + }; + std::list _pending_proposals; + + struct CachedUpdate + { + HPKEPrivateKey update_priv; + Update proposal; + }; + std::optional _cached_update; + + // Assemble a preliminary, unjoined group state + State(SignaturePrivateKey sig_priv, + const GroupInfo& group_info, + const std::optional& tree); + + // Assemble a group from a Welcome, allowing for resumption PSKs + State(const HPKEPrivateKey& init_priv, + HPKEPrivateKey leaf_priv, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const Welcome& welcome, + const std::optional& tree, + std::map external_psks, + std::map resumption_psks); + + // Import a tree from an externally-provided tree or an extension + TreeKEMPublicKey import_tree(const bytes& tree_hash, + const std::optional& external, + const ExtensionList& extensions); + bool validate_tree() const; + + // Form a commit, covering all the cases with slightly different validation + // rules: + // * Normal + // * External + // * Branch + // * Reinit + struct NormalCommitParams + {}; + + struct ExternalCommitParams + { + KeyPackage joiner_key_package; + bytes force_init_secret; + }; + + struct RestartCommitParams + { + ResumptionPSKUsage allowed_usage; + }; + + struct ReInitCommitParams + {}; + + using CommitParams = var::variant; + + std::tuple commit( + const bytes& leaf_secret, + const std::optional& opts, + const MessageOpts& msg_opts, + CommitParams params); + + std::optional handle( + const MLSMessage& msg, + std::optional cached_state, + const std::optional& expected_params); + std::optional handle( + const ValidatedContent& val_content, + std::optional cached_state, + const std::optional& expected_params); + + // Create an MLSMessage encapsulating some content + template + AuthenticatedContent sign(const Sender& sender, + Inner&& content, + const bytes& authenticated_data, + bool encrypt) const; + + MLSMessage protect(AuthenticatedContent&& content_auth, size_t padding_size); + + template + MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts); + + // Apply the changes requested by various messages + LeafIndex apply(const Add& add); + void apply(LeafIndex target, const Update& update); + void apply(LeafIndex target, + const Update& update, + const HPKEPrivateKey& leaf_priv); + LeafIndex apply(const Remove& remove); + void apply(const GroupContextExtensions& gce); + std::vector apply(const std::vector& proposals, + Proposal::Type required_type); + std::tuple, std::vector> apply( + const std::vector& proposals); + + // Verify that a specific key package or all members support a given set of + // extensions + bool extensions_supported(const ExtensionList& exts) const; + + // Extract proposals and PSKs from cache + void cache_proposal(AuthenticatedContent content_auth); + std::optional resolve( + const ProposalOrRef& id, + std::optional sender_index) const; + std::vector must_resolve( + const std::vector& ids, + std::optional sender_index) const; + + std::vector resolve( + const std::vector& psks) const; + + // Check properties of proposals + bool valid(const LeafNode& leaf_node, + LeafNodeSource required_source, + std::optional index) const; + bool valid(const KeyPackage& key_package) const; + bool valid(const Add& add) const; + bool valid(LeafIndex sender, const Update& update) const; + bool valid(const Remove& remove) const; + bool valid(const PreSharedKey& psk) const; + static bool valid(const ReInit& reinit); + bool valid(const ExternalInit& external_init) const; + bool valid(const GroupContextExtensions& gce) const; + bool valid(std::optional sender, const Proposal& proposal) const; + + bool valid(const std::vector& proposals, + LeafIndex commit_sender, + const CommitParams& params) const; + bool valid_normal(const std::vector& proposals, + LeafIndex commit_sender) const; + bool valid_external(const std::vector& proposals) const; + static bool valid_reinit(const std::vector& proposals); + static bool valid_restart(const std::vector& proposals, + ResumptionPSKUsage allowed_usage); + + static bool valid_external_proposal_type(const Proposal::Type proposal_type); + + CommitParams infer_commit_type( + const std::optional& sender, + const std::vector& proposals, + const std::optional& expected_params) const; + static bool path_required(const std::vector& proposals); + + // Compare the **shared** attributes of the states + friend bool operator==(const State& lhs, const State& rhs); + friend bool operator!=(const State& lhs, const State& rhs); + + // Derive and set the secrets for an epoch, given some new entropy + void update_epoch_secrets(const bytes& commit_secret, + const std::vector& psks, + const std::optional& force_init_secret); + + // Signature verification over a handshake message + bool verify_internal(const AuthenticatedContent& content_auth) const; + bool verify_external(const AuthenticatedContent& content_auth) const; + bool verify_new_member_proposal( + const AuthenticatedContent& content_auth) const; + bool verify_new_member_commit(const AuthenticatedContent& content_auth) const; + bool verify(const AuthenticatedContent& content_auth) const; + + // Convert a Roster entry into LeafIndex + LeafIndex leaf_for_roster_entry(RosterIndex index) const; + + // Create a draft successor state + State successor() const; +}; + +} // namespace mlspp diff --git a/mlspp/include/mls/tree_math.h b/mlspp/include/mls/tree_math.h new file mode 100755 index 0000000000..6a11f47a6f --- /dev/null +++ b/mlspp/include/mls/tree_math.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include + +// The below functions provide the index calculus for the tree +// structures used in MLS. They are premised on a "flat" +// representation of a balanced binary tree. Leaf nodes are +// even-numbered nodes, with the n-th leaf at 2*n. Intermediate +// nodes are held in odd-numbered nodes. For example, a 11-element +// tree has the following structure: +// +// X +// X +// X X X +// X X X X X +// X X X X X X X X X X X +// 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 14 +// +// This allows us to compute relationships between tree nodes simply +// by manipulating indices, rather than having to maintain +// complicated structures in memory, even for partial trees. (The +// storage for a tree can just be a map[int]Node dictionary or an +// array.) The basic rule is that the high-order bits of parent and +// child nodes have the following relation: +// +// 01x = <00x, 10x> + +namespace mlspp { + +// Index types go in the overall namespace +// XXX(rlb@ipv.sx): Seems like this stuff can probably get +// simplified down a fair bit. +struct UInt32 +{ + uint32_t val; + + UInt32() + : val(0) + { + } + + explicit UInt32(uint32_t val_in) + : val(val_in) + { + } + + TLS_SERIALIZABLE(val) +}; + +struct NodeCount; + +struct LeafCount : public UInt32 +{ + using UInt32::UInt32; + explicit LeafCount(const NodeCount w); + + static LeafCount full(const LeafCount n); +}; + +struct NodeCount : public UInt32 +{ + using UInt32::UInt32; + explicit NodeCount(const LeafCount n); +}; + +struct NodeIndex; + +struct LeafIndex : public UInt32 +{ + using UInt32::UInt32; + explicit LeafIndex(const NodeIndex x); + bool operator<(const LeafIndex other) const { return val < other.val; } + bool operator<(const LeafCount other) const { return val < other.val; } + + NodeIndex ancestor(LeafIndex other) const; +}; + +struct NodeIndex : public UInt32 +{ + using UInt32::UInt32; + explicit NodeIndex(const LeafIndex x); + bool operator<(const NodeIndex other) const { return val < other.val; } + bool operator<(const NodeCount other) const { return val < other.val; } + + static NodeIndex root(LeafCount n); + + bool is_leaf() const; + bool is_below(NodeIndex other) const; + + NodeIndex left() const; + NodeIndex right() const; + NodeIndex parent() const; + NodeIndex sibling() const; + + // Returns the sibling of this node "relative to this ancestor" -- the child + // of `ancestor` that is not in the direct path of this node. + NodeIndex sibling(NodeIndex ancestor) const; + + std::vector dirpath(LeafCount n); + std::vector copath(LeafCount n); + + uint32_t level() const; +}; + +} // namespace mlspp diff --git a/mlspp/include/mls/treekem.h b/mlspp/include/mls/treekem.h new file mode 100755 index 0000000000..e3e1028d3d --- /dev/null +++ b/mlspp/include/mls/treekem.h @@ -0,0 +1,255 @@ +#pragma once + +#include "mls/common.h" +#include "mls/core_types.h" +#include "mls/crypto.h" +#include "mls/tree_math.h" +#include + +#define ENABLE_TREE_DUMP 1 + +namespace mlspp { + +enum struct NodeType : uint8_t +{ + reserved = 0x00, + leaf = 0x01, + parent = 0x02, +}; + +struct Node +{ + var::variant node; + + const HPKEPublicKey& public_key() const; + std::optional parent_hash() const; + + TLS_SERIALIZABLE(node) + TLS_TRAITS(tls::variant) +}; + +struct OptionalNode +{ + std::optional node; + + bool blank() const { return !node.has_value(); } + bool leaf() const + { + return !blank() && var::holds_alternative(opt::get(node).node); + } + + LeafNode& leaf_node() { return var::get(opt::get(node).node); } + + const LeafNode& leaf_node() const + { + return var::get(opt::get(node).node); + } + + ParentNode& parent_node() + { + return var::get(opt::get(node).node); + } + + const ParentNode& parent_node() const + { + return var::get(opt::get(node).node); + } + + TLS_SERIALIZABLE(node) +}; + +struct TreeKEMPublicKey; + +struct TreeKEMPrivateKey +{ + CipherSuite suite; + LeafIndex index; + bytes update_secret; + std::map path_secrets; + std::map private_key_cache; + + static TreeKEMPrivateKey solo(CipherSuite suite, + LeafIndex index, + HPKEPrivateKey leaf_priv); + static TreeKEMPrivateKey create(const TreeKEMPublicKey& pub, + LeafIndex from, + const bytes& leaf_secret); + static TreeKEMPrivateKey joiner(const TreeKEMPublicKey& pub, + LeafIndex index, + HPKEPrivateKey leaf_priv, + NodeIndex intersect, + const std::optional& path_secret); + + void set_leaf_priv(HPKEPrivateKey priv); + std::tuple shared_path_secret(LeafIndex to) const; + + bool have_private_key(NodeIndex n) const; + std::optional private_key(NodeIndex n); + std::optional private_key(NodeIndex n) const; + + void decap(LeafIndex from, + const TreeKEMPublicKey& pub, + const bytes& context, + const UpdatePath& path, + const std::vector& except); + + void truncate(LeafCount size); + + bool consistent(const TreeKEMPrivateKey& other) const; + bool consistent(const TreeKEMPublicKey& other) const; + +#if ENABLE_TREE_DUMP + void dump() const; +#endif + + // TODO(RLB) Make this private but exposed to test vectors + void implant(const TreeKEMPublicKey& pub, + NodeIndex start, + const bytes& path_secret); +}; + +struct TreeKEMPublicKey +{ + CipherSuite suite; + LeafCount size{ 0 }; + std::vector nodes; + + explicit TreeKEMPublicKey(CipherSuite suite); + + TreeKEMPublicKey() = default; + TreeKEMPublicKey(const TreeKEMPublicKey& other) = default; + TreeKEMPublicKey(TreeKEMPublicKey&& other) = default; + TreeKEMPublicKey& operator=(const TreeKEMPublicKey& other) = default; + TreeKEMPublicKey& operator=(TreeKEMPublicKey&& other) = default; + + LeafIndex allocate_leaf(); + LeafIndex add_leaf(const LeafNode& leaf); + void update_leaf(LeafIndex index, const LeafNode& leaf); + void blank_path(LeafIndex index); + + TreeKEMPrivateKey update(LeafIndex from, + const bytes& leaf_secret, + const bytes& group_id, + const SignaturePrivateKey& sig_priv, + const LeafNodeOptions& opts); + UpdatePath encap(const TreeKEMPrivateKey& priv, + const bytes& context, + const std::vector& except) const; + + void merge(LeafIndex from, const UpdatePath& path); + void set_hash_all(); + const bytes& get_hash(NodeIndex index); + bytes root_hash() const; + + bool parent_hash_valid(LeafIndex from, const UpdatePath& path) const; + bool parent_hash_valid() const; + + bool has_leaf(LeafIndex index) const; + std::optional find(const LeafNode& leaf) const; + std::optional leaf_node(LeafIndex index) const; + std::vector resolve(NodeIndex index) const; + + template + bool all_leaves(const UnaryPredicate& pred) const + { + for (LeafIndex i{ 0 }; i < size; i.val++) { + const auto& node = node_at(i); + if (node.blank()) { + continue; + } + + if (!pred(i, node.leaf_node())) { + return false; + } + } + + return true; + } + + template + bool any_leaf(const UnaryPredicate& pred) const + { + for (LeafIndex i{ 0 }; i < size; i.val++) { + const auto& node = node_at(i); + if (node.blank()) { + continue; + } + + if (pred(i, node.leaf_node())) { + return true; + } + } + + return false; + } + + using FilteredDirectPath = + std::vector>>; + FilteredDirectPath filtered_direct_path(NodeIndex index) const; + + void truncate(); + + OptionalNode& node_at(NodeIndex n); + const OptionalNode& node_at(NodeIndex n) const; + OptionalNode& node_at(LeafIndex n); + const OptionalNode& node_at(LeafIndex n) const; + + TLS_SERIALIZABLE(nodes) + +#if ENABLE_TREE_DUMP + void dump() const; +#endif + +private: + std::map hashes; + + void clear_hash_all(); + void clear_hash_path(LeafIndex index); + + bool has_parent_hash(NodeIndex child, const bytes& target_ph) const; + + bytes parent_hash(const ParentNode& parent, NodeIndex copath_child) const; + std::vector parent_hashes( + LeafIndex from, + const FilteredDirectPath& fdp, + const std::vector& path_nodes) const; + + using TreeHashCache = std::map>; + const bytes& original_tree_hash(TreeHashCache& cache, + NodeIndex index, + std::vector parent_except) const; + bytes original_parent_hash(TreeHashCache& cache, + NodeIndex parent, + NodeIndex sibling) const; + + bool exists_in_tree(const HPKEPublicKey& key, + std::optional except) const; + bool exists_in_tree(const SignaturePublicKey& key, + std::optional except) const; + + OptionalNode blank_node; + + friend struct TreeKEMPrivateKey; +}; + +tls::ostream& +operator<<(tls::ostream& str, const TreeKEMPublicKey& obj); +tls::istream& +operator>>(tls::istream& str, TreeKEMPublicKey& obj); + +struct LeafNodeHashInput; +struct ParentNodeHashInput; + +} // namespace mlspp + +namespace mlspp::tls { + +TLS_VARIANT_MAP(mlspp::NodeType, mlspp::LeafNodeHashInput, leaf) +TLS_VARIANT_MAP(mlspp::NodeType, + mlspp::ParentNodeHashInput, + parent) + +TLS_VARIANT_MAP(mlspp::NodeType, mlspp::LeafNode, leaf) +TLS_VARIANT_MAP(mlspp::NodeType, mlspp::ParentNode, parent) + +} // namespace mlspp::tls diff --git a/mlspp/include/namespace.h b/mlspp/include/namespace.h new file mode 100755 index 0000000000..d07ba5ee94 --- /dev/null +++ b/mlspp/include/namespace.h @@ -0,0 +1,4 @@ +#pragma once + +// Configurable top-level MLS namespace +#define MLS_NAMESPACE ../include/dpp/mlspp/mls diff --git a/mlspp/include/version.h b/mlspp/include/version.h new file mode 100755 index 0000000000..0cead31c4e --- /dev/null +++ b/mlspp/include/version.h @@ -0,0 +1,5 @@ +#pragma once + +/* Global version strings */ +extern const char VERSION[]; +extern const char HASHVAR[]; diff --git a/mlspp/lib/CMakeLists.txt b/mlspp/lib/CMakeLists.txt new file mode 100755 index 0000000000..31f9546e84 --- /dev/null +++ b/mlspp/lib/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(bytes) +add_subdirectory(hpke) +add_subdirectory(tls_syntax) +add_subdirectory(mls_vectors) diff --git a/mlspp/lib/bytes/CMakeLists.txt b/mlspp/lib/bytes/CMakeLists.txt new file mode 100755 index 0000000000..30d450fd35 --- /dev/null +++ b/mlspp/lib/bytes/CMakeLists.txt @@ -0,0 +1,37 @@ +set(CURRENT_LIB_NAME bytes) + +### +### Library Config +### + +file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h") +file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") + +add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES}) +add_dependencies(${CURRENT_LIB_NAME} tls_syntax) +include_directories("${PROJECT_SOURCE_DIR}/../bytes/include") +target_link_libraries(${CURRENT_LIB_NAME} tls_syntax) +target_include_directories(${CURRENT_LIB_NAME} + PUBLIC + $ + $ + $ +) + +### +### Install +### + +install(TARGETS ${CURRENT_LIB_NAME} EXPORT mlspp-targets) +install( + DIRECTORY + include + DESTINATION + ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME} +) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include") + diff --git a/mlspp/lib/bytes/include/bytes/bytes.h b/mlspp/lib/bytes/include/bytes/bytes.h new file mode 100755 index 0000000000..d5ad4d1c05 --- /dev/null +++ b/mlspp/lib/bytes/include/bytes/bytes.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include "../../../tls_syntax/include/tls/tls_syntax.h" +#include + +namespace mlspp::bytes_ns { + +struct bytes +{ + // Ensure defaults + bytes() = default; + bytes(const bytes&) = default; + bytes& operator=(const bytes&) = default; + bytes(bytes&&) = default; + bytes& operator=(bytes&&) = default; + + // Zeroize on drop + ~bytes() + { + auto ptr = static_cast(_data.data()); + std::fill(ptr, ptr + _data.size(), uint8_t(0)); + } + + // Mimic std::vector ctors + bytes(size_t count, const uint8_t& value = 0) + : _data(count, value) + { + } + + bytes(std::initializer_list init) + : _data(init) + { + } + + template + bytes(const std::array& data) + : _data(data.begin(), data.end()) + { + } + + // Slice out sub-vectors (to avoid an iterator ctor) + bytes slice(size_t begin_index, size_t end_index) const + { + const auto begin_it = _data.begin() + begin_index; + const auto end_it = _data.begin() + end_index; + return std::vector(begin_it, end_it); + } + + // Freely convert to/from std::vector + bytes(const std::vector& vec) + : _data(vec) + { + } + + bytes(std::vector&& vec) + : _data(vec) + { + } + + operator const std::vector&() const { return _data; } + operator std::vector&() { return _data; } + operator std::vector&&() && { return std::move(_data); } + + const std::vector& as_vec() const { return _data; } + std::vector& as_vec() { return _data; } + + // Pass through methods + auto data() const { return _data.data(); } + auto data() { return _data.data(); } + + auto size() const { return _data.size(); } + auto empty() const { return _data.empty(); } + + auto begin() const { return _data.begin(); } + auto begin() { return _data.begin(); } + + auto end() const { return _data.end(); } + auto end() { return _data.end(); } + + const auto& at(size_t pos) const { return _data.at(pos); } + auto& at(size_t pos) { return _data.at(pos); } + + void resize(size_t count) { _data.resize(count); } + void reserve(size_t len) { _data.reserve(len); } + void push_back(uint8_t byte) { _data.push_back(byte); } + + // Equality operators + bool operator==(const bytes& other) const; + bool operator!=(const bytes& other) const; + + bool operator==(const std::vector& other) const; + bool operator!=(const std::vector& other) const; + + // Arithmetic operators + bytes& operator+=(const bytes& other); + bytes operator+(const bytes& rhs) const; + bytes operator^(const bytes& rhs) const; + + // Sorting operators (to allow usage as map keys) + bool operator<(const bytes& rhs) const; + + // Other, external operators + friend std::ostream& operator<<(std::ostream& out, const bytes& data); + friend bool operator==(const std::vector& lhs, const bytes& rhs); + friend bool operator!=(const std::vector& lhs, const bytes& rhs); + + // TLS syntax serialization + TLS_SERIALIZABLE(_data); + +private: + std::vector _data; +}; + +std::string +to_ascii(const bytes& data); + +bytes +from_ascii(const std::string& ascii); + +std::string +to_hex(const bytes& data); + +bytes +from_hex(const std::string& hex); + +} // namespace mlspp::bytes_ns diff --git a/mlspp/lib/bytes/include/tls/compat.h b/mlspp/lib/bytes/include/tls/compat.h new file mode 100755 index 0000000000..c226e44ecd --- /dev/null +++ b/mlspp/lib/bytes/include/tls/compat.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include + +#ifdef VARIANT_COMPAT +#include +#else +#include +#endif // VARIANT_COMPAT + +namespace mlspp::tls { + +// To balance backward-compatibility with macOS 10.11 with forward-compatibility +// with future versions of C++, we use `mpark::variant` or `std::variant` as +// needed, using `var::variant` to refer to whichever one is in use. +#ifdef VARIANT_COMPAT +namespace var = mpark; +#else +namespace var = std; +#endif // VARIANT_COMPAT + +// In a similar vein, we provide our own safe accessors for std::optional, since +// std::optional::value() is not available on macOS 10.11. +namespace opt { + +template +T& +get(std::optional& opt) +{ + if (!opt) { + throw std::runtime_error("bad_optional_access"); + } + return *opt; +} + +template +const T& +get(const std::optional& opt) +{ + if (!opt) { + throw std::runtime_error("bad_optional_access"); + } + return *opt; +} + +template +T&& +get(std::optional&& opt) +{ + if (!opt) { + throw std::runtime_error("bad_optional_access"); + } + return std::move(*opt); +} + +template +const T&& +get(const std::optional&& opt) +{ + if (!opt) { + throw std::runtime_error("bad_optional_access"); + } + return std::move(*opt); +} + +} // namespace opt +} // namespace mlspp::tls diff --git a/mlspp/lib/bytes/include/tls/tls_syntax.h b/mlspp/lib/bytes/include/tls/tls_syntax.h new file mode 100755 index 0000000000..09d5940d9d --- /dev/null +++ b/mlspp/lib/bytes/include/tls/tls_syntax.h @@ -0,0 +1,569 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace mlspp::tls { + +// For indicating no min or max in vector definitions +const size_t none = std::numeric_limits::max(); + +class WriteError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +class ReadError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +/// +/// Declarations of Streams and Traits +/// + +class ostream +{ +public: + static const size_t none = std::numeric_limits::max(); + + void write_raw(const std::vector& bytes); + + const std::vector& bytes() const { return _buffer; } + size_t size() const { return _buffer.size(); } + bool empty() const { return _buffer.empty(); } + +private: + std::vector _buffer; + ostream& write_uint(uint64_t value, int length); + + friend ostream& operator<<(ostream& out, bool data); + friend ostream& operator<<(ostream& out, uint8_t data); + friend ostream& operator<<(ostream& out, uint16_t data); + friend ostream& operator<<(ostream& out, uint32_t data); + friend ostream& operator<<(ostream& out, uint64_t data); + + template + friend ostream& operator<<(ostream& out, const std::vector& data); + + friend struct varint; +}; + +class istream +{ +public: + istream(const std::vector& data) + : _buffer(data) + { + // So that we can use the constant-time pop_back + std::reverse(_buffer.begin(), _buffer.end()); + } + + size_t size() const { return _buffer.size(); } + bool empty() const { return _buffer.empty(); } + + std::vector bytes() + { + auto bytes = _buffer; + std::reverse(bytes.begin(), bytes.end()); + return bytes; + } + +private: + istream() {} + std::vector _buffer; + uint8_t next(); + + template + istream& read_uint(T& data, size_t length) + { + uint64_t value = 0; + for (size_t i = 0; i < length; i += 1) { + value = (value << unsigned(8)) + next(); + } + data = static_cast(value); + return *this; + } + + friend istream& operator>>(istream& in, bool& data); + friend istream& operator>>(istream& in, uint8_t& data); + friend istream& operator>>(istream& in, uint16_t& data); + friend istream& operator>>(istream& in, uint32_t& data); + friend istream& operator>>(istream& in, uint64_t& data); + + template + friend istream& operator>>(istream& in, std::vector& data); + + friend struct varint; +}; + +// Traits must have static encode and decode methods, of the following form: +// +// static ostream& encode(ostream& str, const T& val); +// static istream& decode(istream& str, T& val); +// +// Trait types will never be constructed; only these static methods are used. +// The value arguments to encode and decode can be as strict or as loose as +// desired. +// +// Ultimately, all interesting encoding should be done through traits. +// +// * vectors +// * variants +// * varints + +struct pass +{ + template + static ostream& encode(ostream& str, const T& val); + + template + static istream& decode(istream& str, T& val); +}; + +template +struct variant +{ + template + static inline Ts type(const var::variant& data); + + template + static ostream& encode(ostream& str, const var::variant& data); + + template + static inline typename std::enable_if::type + read_variant(istream&, Te, var::variant&); + + template + static inline typename std::enable_if < + I::type read_variant(istream& str, + Te target_type, + var::variant& v); + + template + static istream& decode(istream& str, var::variant& data); +}; + +struct varint +{ + static ostream& encode(ostream& str, const uint64_t& val); + static istream& decode(istream& str, uint64_t& val); +}; + +/// +/// Writer implementations +/// + +// Primitive writers defined in .cpp file + +// Array writer +template +ostream& +operator<<(ostream& out, const std::array& data) +{ + for (const auto& item : data) { + out << item; + } + return out; +} + +// Optional writer +template +ostream& +operator<<(ostream& out, const std::optional& opt) +{ + if (!opt) { + return out << uint8_t(0); + } + + return out << uint8_t(1) << opt::get(opt); +} + +// Enum writer +template::value, int> = 0> +ostream& +operator<<(ostream& str, const T& val) +{ + auto u = static_cast>(val); + return str << u; +} + +// Vector writer +template +ostream& +operator<<(ostream& str, const std::vector& vec) +{ + // Pre-encode contents + ostream temp; + for (const auto& item : vec) { + temp << item; + } + + // Write the encoded length, then the pre-encoded data + varint::encode(str, temp._buffer.size()); + str.write_raw(temp.bytes()); + + return str; +} + +/// +/// Reader implementations +/// + +// Primitive type readers defined in .cpp file + +// Array reader +template +istream& +operator>>(istream& in, std::array& data) +{ + for (auto& item : data) { + in >> item; + } + return in; +} + +// Optional reader +template +istream& +operator>>(istream& in, std::optional& opt) +{ + uint8_t present = 0; + in >> present; + + switch (present) { + case 0: + opt.reset(); + return in; + + case 1: + opt.emplace(); + return in >> opt::get(opt); + + default: + throw std::invalid_argument("Malformed optional"); + } +} + +// Enum reader +// XXX(rlb): It would be nice if this could enforce that the values are valid, +// but C++ doesn't seem to have that ability. When used as a tag for variants, +// the variant reader will enforce, at least. +template::value, int> = 0> +istream& +operator>>(istream& str, T& val) +{ + std::underlying_type_t u; + str >> u; + val = static_cast(u); + return str; +} + +// Vector reader +template +istream& +operator>>(istream& str, std::vector& vec) +{ + // Read the encoded data size + auto size = uint64_t(0); + varint::decode(str, size); + if (size > str._buffer.size()) { + throw ReadError("Vector is longer than remaining data"); + } + + // Read the elements of the vector + // NB: Remember that we store the vector in reverse order + // NB: This requires that T be default-constructible + istream r; + r._buffer = + std::vector{ str._buffer.end() - size, str._buffer.end() }; + + vec.clear(); + while (r._buffer.size() > 0) { + vec.emplace_back(); + r >> vec.back(); + } + + // Truncate the primary buffer + str._buffer.erase(str._buffer.end() - size, str._buffer.end()); + + return str; +} + +// Abbreviations +template +std::vector +marshal(const T& value) +{ + ostream w; + w << value; + return w.bytes(); +} + +template +void +unmarshal(const std::vector& data, T& value) +{ + istream r(data); + r >> value; +} + +template +T +get(const std::vector& data, Tp... args) +{ + T value(args...); + unmarshal(data, value); + return value; +} + +// Use this macro to define struct serialization with minimal boilerplate +#define TLS_SERIALIZABLE(...) \ + static const bool _tls_serializable = true; \ + auto _tls_fields_r() \ + { \ + return std::forward_as_tuple(__VA_ARGS__); \ + } \ + auto _tls_fields_w() const \ + { \ + return std::forward_as_tuple(__VA_ARGS__); \ + } + +// If your struct contains nontrivial members (e.g., vectors), use this to +// define traits for them. +#define TLS_TRAITS(...) \ + static const bool _tls_has_traits = true; \ + using _tls_traits = std::tuple<__VA_ARGS__>; + +template +struct is_serializable +{ + template + static std::true_type test(decltype(U::_tls_serializable)); + + template + static std::false_type test(...); + + static const bool value = decltype(test(true))::value; +}; + +template +struct has_traits +{ + template + static std::true_type test(decltype(U::_tls_has_traits)); + + template + static std::false_type test(...); + + static const bool value = decltype(test(true))::value; +}; + +/// +/// Trait implementations +/// + +// Pass-through (normal encoding/decoding) +template +ostream& +pass::encode(ostream& str, const T& val) +{ + return str << val; +} + +template +istream& +pass::decode(istream& str, T& val) +{ + return str >> val; +} + +// Variant encoding +template +constexpr Ts +variant_map(); + +#define TLS_VARIANT_MAP(EnumType, MappedType, enum_value) \ + template<> \ + constexpr EnumType variant_map() \ + { \ + return EnumType::enum_value; \ + } + +template +template +inline Ts +variant::type(const var::variant& data) +{ + const auto get_type = [](const auto& v) { + return variant_map>(); + }; + return var::visit(get_type, data); +} + +template +template +ostream& +variant::encode(ostream& str, const var::variant& data) +{ + const auto write_variant = [&str](auto&& v) { + using Tv = std::decay_t; + str << variant_map() << v; + }; + var::visit(write_variant, data); + return str; +} + +template +template +inline typename std::enable_if::type +variant::read_variant(istream&, Te, var::variant&) +{ + throw ReadError("Invalid variant type label"); +} + +template + template + inline + typename std::enable_if < I::type + variant::read_variant(istream& str, + Te target_type, + var::variant& v) +{ + using Tc = var::variant_alternative_t>; + if (variant_map() == target_type) { + str >> v.template emplace(); + return; + } + + read_variant(str, target_type, v); +} + +template +template +istream& +variant::decode(istream& str, var::variant& data) +{ + Ts target_type; + str >> target_type; + read_variant(str, target_type, data); + return str; +} + +// Struct writer without traits (enabled by macro) +template +inline typename std::enable_if::type +write_tuple(ostream&, const std::tuple&) +{ +} + +template + inline typename std::enable_if < + I::type + write_tuple(ostream& str, const std::tuple& t) +{ + str << std::get(t); + write_tuple(str, t); +} + +template +inline + typename std::enable_if::value && !has_traits::value, + ostream&>::type + operator<<(ostream& str, const T& obj) +{ + write_tuple(str, obj._tls_fields_w()); + return str; +} + +// Struct writer with traits (enabled by macro) +template +inline typename std::enable_if::type +write_tuple_traits(ostream&, const std::tuple&) +{ +} + +template + inline typename std::enable_if < + I::type + write_tuple_traits(ostream& str, const std::tuple& t) +{ + std::tuple_element_t::encode(str, std::get(t)); + write_tuple_traits(str, t); +} + +template +inline + typename std::enable_if::value && has_traits::value, + ostream&>::type + operator<<(ostream& str, const T& obj) +{ + write_tuple_traits(str, obj._tls_fields_w()); + return str; +} + +// Struct reader without traits (enabled by macro) +template +inline typename std::enable_if::type +read_tuple(istream&, const std::tuple&) +{ +} + +template + inline + typename std::enable_if < I::type + read_tuple(istream& str, const std::tuple& t) +{ + str >> std::get(t); + read_tuple(str, t); +} + +template +inline + typename std::enable_if::value && !has_traits::value, + istream&>::type + operator>>(istream& str, T& obj) +{ + read_tuple(str, obj._tls_fields_r()); + return str; +} + +// Struct reader with traits (enabled by macro) +template +inline typename std::enable_if::type +read_tuple_traits(istream&, const std::tuple&) +{ +} + +template + inline typename std::enable_if < + I::type + read_tuple_traits(istream& str, const std::tuple& t) +{ + std::tuple_element_t::decode(str, std::get(t)); + read_tuple_traits(str, t); +} + +template +inline + typename std::enable_if::value && has_traits::value, + istream&>::type + operator>>(istream& str, T& obj) +{ + read_tuple_traits(str, obj._tls_fields_r()); + return str; +} + +} // namespace mlspp::tls diff --git a/mlspp/lib/bytes/src/bytes.cpp b/mlspp/lib/bytes/src/bytes.cpp new file mode 100755 index 0000000000..634b9f9aad --- /dev/null +++ b/mlspp/lib/bytes/src/bytes.cpp @@ -0,0 +1,146 @@ +#include + +#include +#include +#include +#include + +namespace mlspp::bytes_ns { + +bool +bytes::operator==(const bytes& other) const +{ + return *this == other._data; +} + +bool +bytes::operator!=(const bytes& other) const +{ + return !(*this == other._data); +} + +bool +bytes::operator==(const std::vector& other) const +{ + const size_t size = other.size(); + if (_data.size() != size) { + return false; + } + + unsigned char diff = 0; + for (size_t i = 0; i < size; ++i) { + // Not sure why the linter thinks `diff` is signed + // NOLINTNEXTLINE(hicpp-signed-bitwise) + diff |= (_data.at(i) ^ other.at(i)); + } + return (diff == 0); +} + +bool +bytes::operator!=(const std::vector& other) const +{ + return !(*this == other); +} + +bytes& +bytes::operator+=(const bytes& other) +{ + // Not sure what the default argument is here + // NOLINTNEXTLINE(fuchsia-default-arguments) + _data.insert(end(), other.begin(), other.end()); + return *this; +} + +bytes +bytes::operator+(const bytes& rhs) const +{ + bytes out = *this; + out += rhs; + return out; +} + +bool +bytes::operator<(const bytes& rhs) const +{ + return _data < rhs._data; +} + +bytes +bytes::operator^(const bytes& rhs) const +{ + if (size() != rhs.size()) { + throw std::invalid_argument("XOR with unequal size"); + } + + bytes out = *this; + for (size_t i = 0; i < size(); ++i) { + out.at(i) ^= rhs.at(i); + } + return out; +} + +std::string +to_ascii(const bytes& data) +{ + return { data.begin(), data.end() }; +} + +bytes +from_ascii(const std::string& ascii) +{ + return std::vector(ascii.begin(), ascii.end()); +} + +std::string +to_hex(const bytes& data) +{ + std::stringstream hex(std::ios_base::out); + hex.flags(std::ios::hex); + for (const auto& byte : data) { + hex << std::setw(2) << std::setfill('0') << int(byte); + } + return hex.str(); +} + +bytes +from_hex(const std::string& hex) +{ + if (hex.length() % 2 == 1) { + throw std::invalid_argument("Odd-length hex string"); + } + + auto len = hex.length() / 2; + auto out = bytes(len); + for (size_t i = 0; i < len; i += 1) { + const std::string byte = hex.substr(2 * i, 2); + out.at(i) = static_cast(strtol(byte.c_str(), nullptr, 16)); + } + + return out; +} + +std::ostream& +operator<<(std::ostream& out, const bytes& data) +{ + // Adjust this threshold to make output more compact + const size_t threshold = 0xffff; + if (data.size() < threshold) { + return out << to_hex(data); + } + + return out << to_hex(data.slice(0, threshold)) << "..."; +} + +bool +operator==(const std::vector& lhs, const bytes_ns::bytes& rhs) +{ + return rhs == lhs; +} + +bool +operator!=(const std::vector& lhs, const bytes_ns::bytes& rhs) +{ + return rhs != lhs; +} + +} // namespace mlspp::bytes_ns diff --git a/mlspp/lib/hpke/CMakeLists.txt b/mlspp/lib/hpke/CMakeLists.txt new file mode 100755 index 0000000000..f27690f8ad --- /dev/null +++ b/mlspp/lib/hpke/CMakeLists.txt @@ -0,0 +1,62 @@ +set(CURRENT_LIB_NAME hpke) + +### +### Library Config +### + +file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h") +file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") + +# -Werror=dangling-reference + +add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES}) +add_dependencies(${CURRENT_LIB_NAME} bytes tls_syntax) +target_include_directories(${CURRENT_LIB_NAME} + PRIVATE + "${JSON_INCLUDE_INTERFACE}") + +target_link_libraries(${CURRENT_LIB_NAME} + PUBLIC + bytes tls_syntax + PRIVATE + OpenSSL::Crypto +) + +target_include_directories(${CURRENT_LIB_NAME} + PUBLIC + $ + $ + $ + PRIVATE + ${OPENSSL_INCLUDE_DIR} +) + +### +### Install +### + +install(TARGETS ${CURRENT_LIB_NAME} EXPORT mlspp-targets) +install( + DIRECTORY + include + DESTINATION + ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME} +) + +target_compile_options( + "${CURRENT_LIB_NAME}" + PUBLIC + "$<$:/bigobj;/Zc:preprocessor>" + PRIVATE + "$<$:$<$:/sdl;/Od;/DEBUG;/MP;/DFD_SETSIZE=1024>>" + "$<$:$<$:/O2;/Oi;/Oy;/GL;/Gy;/sdl;/MP;/DFD_SETSIZE=1024>>" + "$<$:$<$:-Wall;-Wno-dangling-reference;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-g;-Og;-fPIC>>" + "$<$:$<$:-Wall;-Wno-dangling-reference;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-O3;-fPIC>>" + "${AVX_FLAG}" +) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include") + diff --git a/mlspp/lib/hpke/include/hpke/base64.h b/mlspp/lib/hpke/include/hpke/base64.h new file mode 100755 index 0000000000..44ab7a5a09 --- /dev/null +++ b/mlspp/lib/hpke/include/hpke/base64.h @@ -0,0 +1,20 @@ +#pragma once + +#include +using namespace mlspp::bytes_ns; + +namespace mlspp::hpke { + +std::string +to_base64(const bytes& data); + +std::string +to_base64url(const bytes& data); + +bytes +from_base64(const std::string& enc); + +bytes +from_base64url(const std::string& enc); + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/include/hpke/certificate.h b/mlspp/lib/hpke/include/hpke/certificate.h new file mode 100755 index 0000000000..752c5f417b --- /dev/null +++ b/mlspp/lib/hpke/include/hpke/certificate.h @@ -0,0 +1,75 @@ +#pragma once +#include +#include + +#include +#include +#include +#include + +using namespace mlspp::bytes_ns; + +namespace mlspp::hpke { + +struct Certificate +{ +private: + struct ParsedCertificate; + std::unique_ptr parsed_cert; + +public: + struct NameType + { + static const int organization; + static const int common_name; + static const int organizational_unit; + static const int country; + static const int serial_number; + static const int state_or_province_name; + }; + + using ParsedName = std::map; + + // Certificate Expiration Status + enum struct ExpirationStatus + { + inactive, // now < notBefore + active, // notBefore < now < notAfter + expired, // notAfter < now + }; + + explicit Certificate(const bytes& der); + explicit Certificate(std::unique_ptr&& parsed_cert_in); + Certificate() = delete; + Certificate(const Certificate& other); + ~Certificate(); + + static std::vector parse_pem(const bytes& pem); + + bool valid_from(const Certificate& parent) const; + + // Accessors for parsed certificate elements + uint64_t issuer_hash() const; + uint64_t subject_hash() const; + ParsedName issuer() const; + ParsedName subject() const; + bool is_ca() const; + ExpirationStatus expiration_status() const; + std::optional subject_key_id() const; + std::optional authority_key_id() const; + std::vector email_addresses() const; + std::vector dns_names() const; + bytes hash() const; + std::chrono::system_clock::time_point not_before() const; + std::chrono::system_clock::time_point not_after() const; + Signature::ID public_key_algorithm() const; + Signature::ID signature_algorithm() const; + + const std::unique_ptr public_key; + const bytes raw; +}; + +bool +operator==(const Certificate& lhs, const Certificate& rhs); + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/include/hpke/digest.h b/mlspp/lib/hpke/include/hpke/digest.h new file mode 100755 index 0000000000..025669133e --- /dev/null +++ b/mlspp/lib/hpke/include/hpke/digest.h @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include + +using namespace mlspp::bytes_ns; + +namespace mlspp::hpke { + +struct Digest +{ + enum struct ID + { + SHA256, + SHA384, + SHA512, + }; + + template + static const Digest& get(); + + const ID id; + + bytes hash(const bytes& data) const; + bytes hmac(const bytes& key, const bytes& data) const; + + const size_t hash_size; + +private: + explicit Digest(ID id); + + bytes hmac_for_hkdf_extract(const bytes& key, const bytes& data) const; + friend struct HKDF; +}; + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/include/hpke/hpke.h b/mlspp/lib/hpke/include/hpke/hpke.h new file mode 100755 index 0000000000..4cf8c00a28 --- /dev/null +++ b/mlspp/lib/hpke/include/hpke/hpke.h @@ -0,0 +1,253 @@ +#pragma once + +#include +#include + +#include +using namespace mlspp::bytes_ns; + +namespace mlspp::hpke { + +struct KEM +{ + enum struct ID : uint16_t + { + DHKEM_P256_SHA256 = 0x0010, + DHKEM_P384_SHA384 = 0x0011, + DHKEM_P521_SHA512 = 0x0012, + DHKEM_X25519_SHA256 = 0x0020, +#if !defined(WITH_BORINGSSL) + DHKEM_X448_SHA512 = 0x0021, +#endif + }; + + template + static const KEM& get(); + + virtual ~KEM() = default; + + struct PublicKey + { + virtual ~PublicKey() = default; + }; + + struct PrivateKey + { + virtual ~PrivateKey() = default; + virtual std::unique_ptr public_key() const = 0; + }; + + const ID id; + const size_t secret_size; + const size_t enc_size; + const size_t pk_size; + const size_t sk_size; + + virtual std::unique_ptr generate_key_pair() const = 0; + virtual std::unique_ptr derive_key_pair( + const bytes& ikm) const = 0; + + virtual bytes serialize(const PublicKey& pk) const = 0; + virtual std::unique_ptr deserialize(const bytes& enc) const = 0; + + virtual bytes serialize_private(const PrivateKey& sk) const; + virtual std::unique_ptr deserialize_private( + const bytes& skm) const; + + // (shared_secret, enc) + virtual std::pair encap(const PublicKey& pkR) const = 0; + virtual bytes decap(const bytes& enc, const PrivateKey& skR) const = 0; + + // (shared_secret, enc) + virtual std::pair auth_encap(const PublicKey& pkR, + const PrivateKey& skS) const; + virtual bytes auth_decap(const bytes& enc, + const PublicKey& pkS, + const PrivateKey& skR) const; + +protected: + KEM(ID id_in, + size_t secret_size_in, + size_t enc_size_in, + size_t pk_size_in, + size_t sk_size_in); +}; + +struct KDF +{ + enum struct ID : uint16_t + { + HKDF_SHA256 = 0x0001, + HKDF_SHA384 = 0x0002, + HKDF_SHA512 = 0x0003, + }; + + template + static const KDF& get(); + + virtual ~KDF() = default; + + const ID id; + const size_t hash_size; + + virtual bytes extract(const bytes& salt, const bytes& ikm) const = 0; + virtual bytes expand(const bytes& prk, + const bytes& info, + size_t size) const = 0; + + bytes labeled_extract(const bytes& suite_id, + const bytes& salt, + const bytes& label, + const bytes& ikm) const; + bytes labeled_expand(const bytes& suite_id, + const bytes& prk, + const bytes& label, + const bytes& info, + size_t size) const; + +protected: + KDF(ID id_in, size_t hash_size_in); +}; + +struct AEAD +{ + enum struct ID : uint16_t + { + AES_128_GCM = 0x0001, + AES_256_GCM = 0x0002, + CHACHA20_POLY1305 = 0x0003, + + // Reserved identifier for pseudo-AEAD on contexts that only allow export + export_only = 0xffff, + }; + + template + static const AEAD& get(); + + virtual ~AEAD() = default; + + const ID id; + const size_t key_size; + const size_t nonce_size; + + virtual bytes seal(const bytes& key, + const bytes& nonce, + const bytes& aad, + const bytes& pt) const = 0; + virtual std::optional open(const bytes& key, + const bytes& nonce, + const bytes& aad, + const bytes& ct) const = 0; + +protected: + AEAD(ID id_in, size_t key_size_in, size_t nonce_size_in); +}; + +struct Context +{ + bytes do_export(const bytes& exporter_context, size_t size) const; + +protected: + bytes suite; + bytes key; + bytes nonce; + bytes exporter_secret; + const KDF& kdf; + const AEAD& aead; + + bytes current_nonce() const; + void increment_seq(); + +private: + uint64_t seq; + + Context(bytes suite_in, + bytes key_in, + bytes nonce_in, + bytes exporter_secret_in, + const KDF& kdf_in, + const AEAD& aead_in); + + friend struct HPKE; + friend struct HPKETest; + friend bool operator==(const Context& lhs, const Context& rhs); +}; + +struct SenderContext : public Context +{ + SenderContext(Context&& c); + bytes seal(const bytes& aad, const bytes& pt); +}; + +struct ReceiverContext : public Context +{ + ReceiverContext(Context&& c); + std::optional open(const bytes& aad, const bytes& ct); +}; + +struct HPKE +{ + enum struct Mode : uint8_t + { + base = 0, + psk = 1, + auth = 2, + auth_psk = 3, + }; + + HPKE(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id); + + using SenderInfo = std::pair; + + SenderInfo setup_base_s(const KEM::PublicKey& pkR, const bytes& info) const; + ReceiverContext setup_base_r(const bytes& enc, + const KEM::PrivateKey& skR, + const bytes& info) const; + + SenderInfo setup_psk_s(const KEM::PublicKey& pkR, + const bytes& info, + const bytes& psk, + const bytes& psk_id) const; + ReceiverContext setup_psk_r(const bytes& enc, + const KEM::PrivateKey& skR, + const bytes& info, + const bytes& psk, + const bytes& psk_id) const; + + SenderInfo setup_auth_s(const KEM::PublicKey& pkR, + const bytes& info, + const KEM::PrivateKey& skS) const; + ReceiverContext setup_auth_r(const bytes& enc, + const KEM::PrivateKey& skR, + const bytes& info, + const KEM::PublicKey& pkS) const; + + SenderInfo setup_auth_psk_s(const KEM::PublicKey& pkR, + const bytes& info, + const bytes& psk, + const bytes& psk_id, + const KEM::PrivateKey& skS) const; + ReceiverContext setup_auth_psk_r(const bytes& enc, + const KEM::PrivateKey& skR, + const bytes& info, + const bytes& psk, + const bytes& psk_id, + const KEM::PublicKey& pkS) const; + + bytes suite; + const KEM& kem; + const KDF& kdf; + const AEAD& aead; + +private: + static bool verify_psk_inputs(Mode mode, + const bytes& psk, + const bytes& psk_id); + Context key_schedule(Mode mode, + const bytes& shared_secret, + const bytes& info, + const bytes& psk, + const bytes& psk_id) const; +}; + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/include/hpke/random.h b/mlspp/lib/hpke/include/hpke/random.h new file mode 100755 index 0000000000..ad344e659c --- /dev/null +++ b/mlspp/lib/hpke/include/hpke/random.h @@ -0,0 +1,11 @@ +#pragma once + +#include +using namespace mlspp::bytes_ns; + +namespace mlspp::hpke { + +bytes +random_bytes(size_t size); + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/include/hpke/signature.h b/mlspp/lib/hpke/include/hpke/signature.h new file mode 100755 index 0000000000..8e5985a840 --- /dev/null +++ b/mlspp/lib/hpke/include/hpke/signature.h @@ -0,0 +1,89 @@ +#pragma once + +#include + +#include +using namespace mlspp::bytes_ns; + +namespace mlspp::hpke { + +struct Signature +{ + enum struct ID + { + P256_SHA256, + P384_SHA384, + P521_SHA512, + Ed25519, +#if !defined(WITH_BORINGSSL) + Ed448, +#endif + RSA_SHA256, + RSA_SHA384, + RSA_SHA512, + }; + + template + static const Signature& get(); + + virtual ~Signature() = default; + + struct PublicKey + { + virtual ~PublicKey() = default; + }; + + struct PrivateKey + { + virtual ~PrivateKey() = default; + virtual std::unique_ptr public_key() const = 0; + }; + + const ID id; + + virtual std::unique_ptr generate_key_pair() const = 0; + virtual std::unique_ptr derive_key_pair( + const bytes& ikm) const = 0; + + virtual bytes serialize(const PublicKey& pk) const = 0; + virtual std::unique_ptr deserialize(const bytes& enc) const = 0; + + virtual bytes serialize_private(const PrivateKey& sk) const = 0; + virtual std::unique_ptr deserialize_private( + const bytes& skm) const = 0; + + struct PrivateJWK + { + const Signature& sig; + std::optional key_id; + std::unique_ptr key; + }; + static PrivateJWK parse_jwk_private(const std::string& jwk_json); + + struct PublicJWK + { + const Signature& sig; + std::optional key_id; + std::unique_ptr key; + }; + static PublicJWK parse_jwk(const std::string& jwk_json); + + virtual std::unique_ptr import_jwk_private( + const std::string& jwk_json) const = 0; + virtual std::unique_ptr import_jwk( + const std::string& jwk_json) const = 0; + virtual std::string export_jwk_private(const PrivateKey& env) const = 0; + virtual std::string export_jwk(const PublicKey& env) const = 0; + + virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0; + virtual bool verify(const bytes& data, + const bytes& sig, + const PublicKey& pk) const = 0; + + static std::unique_ptr generate_rsa(size_t bits); + +protected: + Signature(ID id_in); +}; + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/include/hpke/userinfo_vc.h b/mlspp/lib/hpke/include/hpke/userinfo_vc.h new file mode 100755 index 0000000000..9f3cae27f4 --- /dev/null +++ b/mlspp/lib/hpke/include/hpke/userinfo_vc.h @@ -0,0 +1,82 @@ +#pragma once +#include +#include + +#include +#include +#include +#include + +using namespace mlspp::bytes_ns; + +namespace mlspp::hpke { + +struct UserInfoClaimsAddress +{ + std::optional formatted; + std::optional street_address; + std::optional locality; + std::optional region; + std::optional postal_code; + std::optional country; +}; + +struct UserInfoClaims +{ + + std::optional sub; + std::optional name; + std::optional given_name; + std::optional family_name; + std::optional middle_name; + std::optional nickname; + std::optional preferred_username; + std::optional profile; + std::optional picture; + std::optional website; + std::optional email; + std::optional email_verified; + std::optional gender; + std::optional birthdate; + std::optional zoneinfo; + std::optional locale; + std::optional phone_number; + std::optional phone_number_verified; + std::optional address; + std::optional updated_at; + + static UserInfoClaims from_json(const std::string& cred_subject); +}; + +struct UserInfoVC +{ +private: + struct ParsedCredential; + std::shared_ptr parsed_cred; + +public: + explicit UserInfoVC(std::string jwt); + UserInfoVC() = default; + UserInfoVC(const UserInfoVC& other) = default; + ~UserInfoVC() = default; + UserInfoVC& operator=(const UserInfoVC& other) = default; + UserInfoVC& operator=(UserInfoVC&& other) = default; + + const Signature& signature_algorithm() const; + std::string issuer() const; + std::optional key_id() const; + std::chrono::system_clock::time_point not_before() const; + std::chrono::system_clock::time_point not_after() const; + const std::string& raw_credential() const; + const UserInfoClaims& subject() const; + const Signature::PublicJWK& public_key() const; + + bool valid_from(const Signature::PublicKey& issuer_key) const; + + std::string raw; +}; + +bool +operator==(const UserInfoVC& lhs, const UserInfoVC& rhs); + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/aead_cipher.cpp b/mlspp/lib/hpke/src/aead_cipher.cpp new file mode 100755 index 0000000000..b200647789 --- /dev/null +++ b/mlspp/lib/hpke/src/aead_cipher.cpp @@ -0,0 +1,321 @@ +#include "aead_cipher.h" +#include "openssl_common.h" + +#include + +#if WITH_BORINGSSL +#include +#endif + +namespace mlspp::hpke { + +/// +/// ExportOnlyCipher +/// +bytes +ExportOnlyCipher::seal(const bytes& /* key */, + const bytes& /* nonce */, + const bytes& /* aad */, + const bytes& /* pt */) const +{ + throw std::runtime_error("seal() on export-only context"); +} + +std::optional +ExportOnlyCipher::open(const bytes& /* key */, + const bytes& /* nonce */, + const bytes& /* aad */, + const bytes& /* ct */) const +{ + throw std::runtime_error("open() on export-only context"); +} + +ExportOnlyCipher::ExportOnlyCipher() + : AEAD(AEAD::ID::export_only, 0, 0) +{ +} + +/// +/// AEADCipher +/// +AEADCipher +make_aead(AEAD::ID cipher_in) +{ + return { cipher_in }; +} + +template<> +const AEADCipher& +AEADCipher::get() +{ + static const auto instance = make_aead(AEAD::ID::AES_128_GCM); + return instance; +} + +template<> +const AEADCipher& +AEADCipher::get() +{ + static const auto instance = make_aead(AEAD::ID::AES_256_GCM); + return instance; +} + +template<> +const AEADCipher& +AEADCipher::get() +{ + static const auto instance = make_aead(AEAD::ID::CHACHA20_POLY1305); + return instance; +} + +static size_t +cipher_key_size(AEAD::ID cipher) +{ + switch (cipher) { + case AEAD::ID::AES_128_GCM: + return 16; + + case AEAD::ID::AES_256_GCM: + case AEAD::ID::CHACHA20_POLY1305: + return 32; + + default: + throw std::runtime_error("Unsupported algorithm"); + } +} + +static size_t +cipher_nonce_size(AEAD::ID cipher) +{ + switch (cipher) { + case AEAD::ID::AES_128_GCM: + case AEAD::ID::AES_256_GCM: + case AEAD::ID::CHACHA20_POLY1305: + return 12; + + default: + throw std::runtime_error("Unsupported algorithm"); + } +} + +static size_t +cipher_tag_size(AEAD::ID cipher) +{ + switch (cipher) { + case AEAD::ID::AES_128_GCM: + case AEAD::ID::AES_256_GCM: + case AEAD::ID::CHACHA20_POLY1305: + return 16; + + default: + throw std::runtime_error("Unsupported algorithm"); + } +} + +#if WITH_BORINGSSL +static const EVP_AEAD* +boringssl_cipher(AEAD::ID cipher) +{ + switch (cipher) { + case AEAD::ID::AES_128_GCM: + return EVP_aead_aes_128_gcm(); + + case AEAD::ID::AES_256_GCM: + return EVP_aead_aes_256_gcm(); + + case AEAD::ID::CHACHA20_POLY1305: + return EVP_aead_chacha20_poly1305(); + + default: + throw std::runtime_error("Unsupported algorithm"); + } +} +#else +static const EVP_CIPHER* +openssl_cipher(AEAD::ID cipher) +{ + switch (cipher) { + case AEAD::ID::AES_128_GCM: + return EVP_aes_128_gcm(); + + case AEAD::ID::AES_256_GCM: + return EVP_aes_256_gcm(); + + case AEAD::ID::CHACHA20_POLY1305: + return EVP_chacha20_poly1305(); + + default: + throw std::runtime_error("Unsupported algorithm"); + } +} +#endif // WITH_BORINGSSL + +AEADCipher::AEADCipher(AEAD::ID id_in) + : AEAD(id_in, cipher_key_size(id_in), cipher_nonce_size(id_in)) + , tag_size(cipher_tag_size(id)) +{ +} + +bytes +AEADCipher::seal(const bytes& key, + const bytes& nonce, + const bytes& aad, + const bytes& pt) const +{ +#if WITH_BORINGSSL + auto ctx = make_typed_unique( + EVP_AEAD_CTX_new(boringssl_cipher(id), key.data(), key.size(), tag_size)); + if (ctx == nullptr) { + throw openssl_error(); + } + + auto ct = bytes(pt.size() + tag_size); + auto out_len = ct.size(); + if (1 != EVP_AEAD_CTX_seal(ctx.get(), + ct.data(), + &out_len, + ct.size(), + nonce.data(), + nonce.size(), + pt.data(), + pt.size(), + aad.data(), + aad.size())) { + throw openssl_error(); + } + + return ct; +#else + auto ctx = make_typed_unique(EVP_CIPHER_CTX_new()); + if (ctx == nullptr) { + throw openssl_error(); + } + + const auto* cipher = openssl_cipher(id); + if (1 != EVP_EncryptInit(ctx.get(), cipher, key.data(), nonce.data())) { + throw openssl_error(); + } + + int outlen = 0; + if (!aad.empty()) { + if (1 != EVP_EncryptUpdate(ctx.get(), + nullptr, + &outlen, + aad.data(), + static_cast(aad.size()))) { + throw openssl_error(); + } + } + + bytes ct(pt.size()); + if (1 != EVP_EncryptUpdate(ctx.get(), + ct.data(), + &outlen, + pt.data(), + static_cast(pt.size()))) { + throw openssl_error(); + } + + // Providing nullptr as an argument is safe here because this + // function never writes with GCM; it only computes the tag + if (1 != EVP_EncryptFinal(ctx.get(), nullptr, &outlen)) { + throw openssl_error(); + } + + bytes tag(tag_size); + if (1 != EVP_CIPHER_CTX_ctrl(ctx.get(), + EVP_CTRL_GCM_GET_TAG, + static_cast(tag_size), + tag.data())) { + throw openssl_error(); + } + + ct += tag; + return ct; +#endif // WITH_BORINGSSL +} + +std::optional +AEADCipher::open(const bytes& key, + const bytes& nonce, + const bytes& aad, + const bytes& ct) const +{ + if (ct.size() < tag_size) { + throw std::runtime_error("AEAD ciphertext smaller than tag size"); + } + +#if WITH_BORINGSSL + auto ctx = make_typed_unique(EVP_AEAD_CTX_new( + boringssl_cipher(id), key.data(), key.size(), cipher_tag_size(id))); + if (ctx == nullptr) { + throw openssl_error(); + } + + auto pt = bytes(ct.size() - tag_size); + auto out_len = pt.size(); + if (1 != EVP_AEAD_CTX_open(ctx.get(), + pt.data(), + &out_len, + pt.size(), + nonce.data(), + nonce.size(), + ct.data(), + ct.size(), + aad.data(), + aad.size())) { + throw openssl_error(); + } + + return pt; +#else + auto ctx = make_typed_unique(EVP_CIPHER_CTX_new()); + if (ctx == nullptr) { + throw openssl_error(); + } + + const auto* cipher = openssl_cipher(id); + if (1 != EVP_DecryptInit(ctx.get(), cipher, key.data(), nonce.data())) { + throw openssl_error(); + } + + auto inner_ct_size = ct.size() - tag_size; + auto tag = ct.slice(inner_ct_size, ct.size()); + if (1 != EVP_CIPHER_CTX_ctrl(ctx.get(), + EVP_CTRL_GCM_SET_TAG, + static_cast(tag_size), + tag.data())) { + throw openssl_error(); + } + + int out_size = 0; + if (!aad.empty()) { + if (1 != EVP_DecryptUpdate(ctx.get(), + nullptr, + &out_size, + aad.data(), + static_cast(aad.size()))) { + throw openssl_error(); + } + } + + bytes pt(inner_ct_size); + if (1 != EVP_DecryptUpdate(ctx.get(), + pt.data(), + &out_size, + ct.data(), + static_cast(inner_ct_size))) { + throw openssl_error(); + } + + // Providing nullptr as an argument is safe here because this + // function never writes with GCM; it only verifies the tag + if (1 != EVP_DecryptFinal(ctx.get(), nullptr, &out_size)) { + throw std::runtime_error("AEAD authentication failure"); + } + + return pt; +#endif // WITH_BORINGSSL +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/aead_cipher.h b/mlspp/lib/hpke/src/aead_cipher.h new file mode 100755 index 0000000000..32a36964fc --- /dev/null +++ b/mlspp/lib/hpke/src/aead_cipher.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +namespace mlspp::hpke { + +struct ExportOnlyCipher : public AEAD +{ + ExportOnlyCipher(); + ~ExportOnlyCipher() override = default; + + bytes seal(const bytes& key, + const bytes& nonce, + const bytes& aad, + const bytes& pt) const override; + std::optional open(const bytes& key, + const bytes& nonce, + const bytes& aad, + const bytes& ct) const override; +}; + +struct AEADCipher : public AEAD +{ + template + static const AEADCipher& get(); + + ~AEADCipher() override = default; + + bytes seal(const bytes& key, + const bytes& nonce, + const bytes& aad, + const bytes& pt) const override; + std::optional open(const bytes& key, + const bytes& nonce, + const bytes& aad, + const bytes& ct) const override; + +private: + const size_t tag_size; + + AEADCipher(AEAD::ID id_in); + friend AEADCipher make_aead(AEAD::ID cipher_in); +}; + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/base64.cpp b/mlspp/lib/hpke/src/base64.cpp new file mode 100755 index 0000000000..7db48df75b --- /dev/null +++ b/mlspp/lib/hpke/src/base64.cpp @@ -0,0 +1,105 @@ +#include + +#include "openssl_common.h" + +#include +#include +#include + +namespace mlspp::hpke { + +std::string +to_base64(const bytes& data) +{ + if (data.empty()) { + return ""; + } + +#if WITH_BORINGSSL + const auto data_size = data.size(); +#else + const auto data_size = static_cast(data.size()); +#endif + + // base64 encoding produces 4 characters for every 3 input bytes (rounded up) + const auto out_size = (data_size + 2) / 3 * 4; + auto out = bytes(out_size + 1); // NUL terminator + + const auto result = EVP_EncodeBlock(out.data(), data.data(), data_size); + if (result != out_size) { + throw openssl_error(); + } + + out.resize(out.size() - 1); // strip NUL terminator + return to_ascii(out); +} + +std::string +to_base64url(const bytes& data) +{ + if (data.empty()) { + return ""; + } + + auto encoded = to_base64(data); + + auto pad_start = encoded.find_first_of('='); + if (pad_start != std::string::npos) { + encoded = encoded.substr(0, pad_start); + } + + std::replace(encoded.begin(), encoded.end(), '+', '-'); + std::replace(encoded.begin(), encoded.end(), '/', '_'); + + return encoded; +} + +bytes +from_base64(const std::string& enc) +{ + if (enc.length() == 0) { + return {}; + } + + if (enc.length() % 4 != 0) { + throw std::runtime_error("Base64 length is not divisible by 4"); + } + + const auto in = from_ascii(enc); + const auto in_size = static_cast(in.size()); + const auto out_size = in_size / 4 * 3; + auto out = bytes(out_size); + + const auto result = EVP_DecodeBlock(out.data(), in.data(), in_size); + if (result != out_size) { + throw openssl_error(); + } + + if (enc.substr(enc.length() - 2, enc.length()) == "==") { + out.resize(out.size() - 2); + } else if (enc.substr(enc.length() - 1, enc.length()) == "=") { + out.resize(out.size() - 1); + } + + return out; +} + +bytes +from_base64url(const std::string& enc) +{ + if (enc.empty()) { + return {}; + } + + auto enc_copy = enc; + std::replace(enc_copy.begin(), enc_copy.end(), '-', '+'); + std::replace(enc_copy.begin(), enc_copy.end(), '_', '/'); + + while (enc_copy.length() % 4 != 0) { + enc_copy += "="; + } + + return from_base64(enc_copy); +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/certificate.cpp b/mlspp/lib/hpke/src/certificate.cpp new file mode 100755 index 0000000000..f4b72dee37 --- /dev/null +++ b/mlspp/lib/hpke/src/certificate.cpp @@ -0,0 +1,539 @@ +#include "group.h" +#include "openssl_common.h" +#include "rsa.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlspp::hpke { +/// +/// Utility functions +/// + +static std::optional +asn1_octet_string_to_bytes(const ASN1_OCTET_STRING* octets) +{ + if (octets == nullptr) { + return std::nullopt; + } + const auto* ptr = ASN1_STRING_get0_data(octets); + const auto len = ASN1_STRING_length(octets); + // NOLINTNEXTLINE (cppcoreguidelines-pro-bounds-pointer-arithmetic) + return std::vector(ptr, ptr + len); +} + +static std::string +asn1_string_to_std_string(const ASN1_STRING* asn1_string) +{ + const auto* data = ASN1_STRING_get0_data(asn1_string); + const auto data_size = static_cast(ASN1_STRING_length(asn1_string)); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto str = std::string(reinterpret_cast(data)); + if (str.size() != data_size) { + throw std::runtime_error("Malformed ASN.1 string"); + } + return str; +} + +static std::chrono::system_clock::time_point +asn1_time_to_chrono(const ASN1_TIME* asn1_time) +{ + auto epoch_chrono = std::chrono::system_clock::time_point(); + auto epoch_time_t = std::chrono::system_clock::to_time_t(epoch_chrono); + auto epoch_asn1 = make_typed_unique(ASN1_TIME_set(nullptr, epoch_time_t)); + if (!epoch_asn1) { + throw openssl_error(); + } + + auto secs = int(0); + auto days = int(0); + if (ASN1_TIME_diff(&days, &secs, epoch_asn1.get(), asn1_time) != 1) { + throw openssl_error(); + } + + auto delta = std::chrono::seconds(secs) + std::chrono::hours(24 * days); + return std::chrono::system_clock::time_point(delta); +} + +/// +/// ParsedCertificate +/// + +const int Certificate::NameType::organization = NID_organizationName; +const int Certificate::NameType::common_name = NID_commonName; +const int Certificate::NameType::organizational_unit = + NID_organizationalUnitName; +const int Certificate::NameType::country = NID_countryName; +const int Certificate::NameType::serial_number = NID_serialNumber; +const int Certificate::NameType::state_or_province_name = + NID_stateOrProvinceName; + +struct RFC822Name +{ + std::string value; +}; + +struct DNSName +{ + std::string value; +}; + +using GeneralName = tls::var::variant; + +struct Certificate::ParsedCertificate +{ + + static std::unique_ptr parse(const bytes& der) + { + const auto* buf = der.data(); + auto cert = + make_typed_unique(d2i_X509(nullptr, &buf, static_cast(der.size()))); + if (cert == nullptr) { + throw openssl_error(); + } + + return std::make_unique(cert.release()); + } + + static bytes compute_digest(const X509* cert) + { + const auto* md = EVP_sha256(); + auto digest = bytes(EVP_MD_size(md)); + unsigned int out_size = 0; + if (1 != X509_digest(cert, md, digest.data(), &out_size)) { + throw openssl_error(); + } + return digest; + } + + // Note: This method does not implement total general name parsing. + // Duplicate entries are not supported; if they are present, the last one + // presented by OpenSSL is chosen. + static ParsedName parse_names(const X509_NAME* x509_name) + { + if (x509_name == nullptr) { + throw openssl_error(); + } + + ParsedName parsed_name; + + for (int i = X509_NAME_entry_count(x509_name) - 1; i >= 0; i--) { + auto* entry = X509_NAME_get_entry(x509_name, i); + if (entry == nullptr) { + continue; + } + + auto* oid = X509_NAME_ENTRY_get_object(entry); + auto* asn_str = X509_NAME_ENTRY_get_data(entry); + if (oid == nullptr || asn_str == nullptr) { + continue; + } + + const int nid = OBJ_obj2nid(oid); + const std::string parsed_value = asn1_string_to_std_string(asn_str); + parsed_name[nid] = parsed_value; + } + + return parsed_name; + } + + // Parse Subject Key Identifier Extension + static std::optional parse_skid(X509* cert) + { + return asn1_octet_string_to_bytes(X509_get0_subject_key_id(cert)); + } + + // Parse Authority Key Identifier + static std::optional parse_akid(X509* cert) + { + return asn1_octet_string_to_bytes(X509_get0_authority_key_id(cert)); + } + + static std::vector parse_san(X509* cert) + { + std::vector names; + +#ifdef WITH_BORINGSSL + using san_names_nb_t = size_t; +#else + using san_names_nb_t = int; +#endif + + san_names_nb_t san_names_nb = 0; + + auto* ext_ptr = + X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto* san_ptr = reinterpret_cast(ext_ptr); + const auto san_names = make_typed_unique(san_ptr); + san_names_nb = sk_GENERAL_NAME_num(san_names.get()); + + // Check each name within the extension + for (san_names_nb_t i = 0; i < san_names_nb; i++) { + auto* current_name = sk_GENERAL_NAME_value(san_names.get(), i); + if (current_name->type == GEN_DNS) { + const auto dns_name = + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access) + asn1_string_to_std_string(current_name->d.dNSName); + names.emplace_back(DNSName{ dns_name }); + } else if (current_name->type == GEN_EMAIL) { + const auto email = + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access + asn1_string_to_std_string(current_name->d.rfc822Name); + names.emplace_back(RFC822Name{ email }); + } + } + + return names; + } + + explicit ParsedCertificate(X509* x509_in) + : x509(x509_in, typed_delete) + , pub_key_id(public_key_algorithm(x509.get())) + , sig_algo(signature_algorithm(x509.get())) + , issuer_hash(X509_issuer_name_hash(x509.get())) + , subject_hash(X509_subject_name_hash(x509.get())) + , issuer(parse_names(X509_get_issuer_name(x509.get()))) + , subject(parse_names(X509_get_subject_name(x509.get()))) + , subject_key_id(parse_skid(x509.get())) + , authority_key_id(parse_akid(x509.get())) + , sub_alt_names(parse_san(x509.get())) + , is_ca(X509_check_ca(x509.get()) != 0) + , hash(compute_digest(x509.get())) + , not_before(asn1_time_to_chrono(X509_get0_notBefore(x509.get()))) + , not_after(asn1_time_to_chrono(X509_get0_notAfter(x509.get()))) + { + } + + ParsedCertificate(const ParsedCertificate& other) + : x509(nullptr, typed_delete) + , pub_key_id(public_key_algorithm(other.x509.get())) + , sig_algo(signature_algorithm(other.x509.get())) + , issuer_hash(other.issuer_hash) + , subject_hash(other.subject_hash) + , issuer(other.issuer) + , subject(other.subject) + , subject_key_id(other.subject_key_id) + , authority_key_id(other.authority_key_id) + , sub_alt_names(other.sub_alt_names) + , is_ca(other.is_ca) + , hash(other.hash) + , not_before(other.not_before) + , not_after(other.not_after) + { + if (1 != X509_up_ref(other.x509.get())) { + throw openssl_error(); + } + x509.reset(other.x509.get()); + } + + static Signature::ID public_key_algorithm(X509* x509) + { +#if WITH_BORINGSSL + const auto pub = make_typed_unique(X509_get_pubkey(x509)); + const auto* pub_ptr = pub.get(); +#else + const auto* pub_ptr = X509_get0_pubkey(x509); +#endif + + switch (EVP_PKEY_base_id(pub_ptr)) { + case EVP_PKEY_ED25519: + return Signature::ID::Ed25519; +#if !defined(WITH_BORINGSSL) + case EVP_PKEY_ED448: + return Signature::ID::Ed448; +#endif + case EVP_PKEY_EC: { + auto key_size = EVP_PKEY_bits(pub_ptr); + switch (key_size) { + case 256: + return Signature::ID::P256_SHA256; + case 384: + return Signature::ID::P384_SHA384; + case 521: + return Signature::ID::P521_SHA512; + default: + throw std::runtime_error("Unknown curve"); + } + } + case EVP_PKEY_RSA: + // RSA public keys are not specific to an algorithm + return Signature::ID::RSA_SHA256; + default: + break; + } + throw std::runtime_error("Unsupported public key algorithm"); + } + + static Signature::ID signature_algorithm(X509* cert) + { + auto nid = X509_get_signature_nid(cert); + switch (nid) { + case EVP_PKEY_ED25519: + return Signature::ID::Ed25519; +#if !defined(WITH_BORINGSSL) + case EVP_PKEY_ED448: + return Signature::ID::Ed448; +#endif + case NID_ecdsa_with_SHA256: + return Signature::ID::P256_SHA256; + case NID_ecdsa_with_SHA384: + return Signature::ID::P384_SHA384; + case NID_ecdsa_with_SHA512: + return Signature::ID::P521_SHA512; + case NID_sha1WithRSAEncryption: + // We fall through to SHA256 for SHA1 because we do not implement SHA-1. + case NID_sha256WithRSAEncryption: + return Signature::ID::RSA_SHA256; + case NID_sha384WithRSAEncryption: + return Signature::ID::RSA_SHA384; + case NID_sha512WithRSAEncryption: + return Signature::ID::RSA_SHA512; + default: + break; + } + + throw std::runtime_error("Unsupported signature algorithm"); + } + + typed_unique_ptr public_key() const + { + return make_typed_unique(X509_get_pubkey(x509.get())); + } + + Certificate::ExpirationStatus expiration_status() const + { + auto now = std::chrono::system_clock::now(); + + if (now < not_before) { + return Certificate::ExpirationStatus::inactive; + } + + if (now > not_after) { + return Certificate::ExpirationStatus::expired; + } + + return Certificate::ExpirationStatus::active; + } + + bytes raw() const + { + auto out = bytes(i2d_X509(x509.get(), nullptr)); + auto* ptr = out.data(); + i2d_X509(x509.get(), &ptr); + return out; + } + + typed_unique_ptr x509; + const Signature::ID pub_key_id; + const Signature::ID sig_algo; + const uint64_t issuer_hash; + const uint64_t subject_hash; + const ParsedName issuer; + const ParsedName subject; + const std::optional subject_key_id; + const std::optional authority_key_id; + const std::vector sub_alt_names; + const bool is_ca; + const bytes hash; + const std::chrono::system_clock::time_point not_before; + const std::chrono::system_clock::time_point not_after; +}; + +/// +/// Certificate +/// + +static std::unique_ptr +signature_key(EVP_PKEY* pkey) +{ + switch (EVP_PKEY_base_id(pkey)) { + case EVP_PKEY_RSA: + return std::make_unique(pkey); + + case EVP_PKEY_ED448: + case EVP_PKEY_ED25519: + case EVP_PKEY_EC: + return std::make_unique(pkey); + + default: + throw std::runtime_error("Unsupported algorithm"); + } +} + +Certificate::Certificate(std::unique_ptr&& parsed_cert_in) + : parsed_cert(std::move(parsed_cert_in)) + , public_key(signature_key(parsed_cert->public_key().release())) + , raw(parsed_cert->raw()) +{ +} + +Certificate::Certificate(const bytes& der) + : parsed_cert(ParsedCertificate::parse(der)) + , public_key(signature_key(parsed_cert->public_key().release())) + , raw(der) +{ +} + +Certificate::Certificate(const Certificate& other) + : parsed_cert(std::make_unique(*other.parsed_cert)) + , public_key(signature_key(parsed_cert->public_key().release())) + , raw(other.raw) +{ +} + +Certificate::~Certificate() = default; + +std::vector +Certificate::parse_pem(const bytes& pem) +{ + auto size_int = static_cast(pem.size()); + auto bio = make_typed_unique(BIO_new_mem_buf(pem.data(), size_int)); + if (!bio) { + throw openssl_error(); + } + + auto certs = std::vector(); + while (true) { + auto x509 = make_typed_unique( + PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); + if (!x509) { + // NOLINTNEXTLINE(hicpp-signed-bitwise) + auto err = ERR_GET_REASON(ERR_peek_last_error()); + if (err == PEM_R_NO_START_LINE) { + // No more objects to read + break; + } + + throw openssl_error(); + } + + auto parsed = std::make_unique(x509.release()); + certs.emplace_back(std::move(parsed)); + } + + return certs; +} + +bool +Certificate::valid_from(const Certificate& parent) const +{ + auto pub = parent.parsed_cert->public_key(); + return (1 == X509_verify(parsed_cert->x509.get(), pub.get())); +} + +uint64_t +Certificate::issuer_hash() const +{ + return parsed_cert->issuer_hash; +} + +uint64_t +Certificate::subject_hash() const +{ + return parsed_cert->subject_hash; +} + +Certificate::ParsedName +Certificate::subject() const +{ + return parsed_cert->subject; +} + +Certificate::ParsedName +Certificate::issuer() const +{ + return parsed_cert->issuer; +} + +bool +Certificate::is_ca() const +{ + return parsed_cert->is_ca; +} + +Certificate::ExpirationStatus +Certificate::expiration_status() const +{ + return parsed_cert->expiration_status(); +} + +std::optional +Certificate::subject_key_id() const +{ + return parsed_cert->subject_key_id; +} + +std::optional +Certificate::authority_key_id() const +{ + return parsed_cert->authority_key_id; +} + +std::vector +Certificate::email_addresses() const +{ + std::vector emails; + for (const auto& name : parsed_cert->sub_alt_names) { + if (tls::var::holds_alternative(name)) { + emails.emplace_back(tls::var::get(name).value); + } + } + return emails; +} + +std::vector +Certificate::dns_names() const +{ + std::vector domains; + for (const auto& name : parsed_cert->sub_alt_names) { + if (tls::var::holds_alternative(name)) { + domains.emplace_back(tls::var::get(name).value); + } + } + + return domains; +} + +bytes +Certificate::hash() const +{ + return parsed_cert->hash; +} + +std::chrono::system_clock::time_point +Certificate::not_before() const +{ + return parsed_cert->not_before; +} + +std::chrono::system_clock::time_point +Certificate::not_after() const +{ + return parsed_cert->not_after; +} + +Signature::ID +Certificate::public_key_algorithm() const +{ + return parsed_cert->pub_key_id; +} + +Signature::ID +Certificate::signature_algorithm() const +{ + return parsed_cert->sig_algo; +} + +bool +operator==(const Certificate& lhs, const Certificate& rhs) +{ + return lhs.raw == rhs.raw; +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/common.cpp b/mlspp/lib/hpke/src/common.cpp new file mode 100755 index 0000000000..dabbfd6286 --- /dev/null +++ b/mlspp/lib/hpke/src/common.cpp @@ -0,0 +1,20 @@ +#include "common.h" + +namespace mlspp::hpke { + +bytes +i2osp(uint64_t val, size_t size) +{ + auto out = bytes(size, 0); + auto max = size; + if (size > 8) { + max = 8; + } + + for (size_t i = 0; i < max; i++) { + out.at(size - i - 1) = static_cast(val >> (8 * i)); + } + return out; +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/common.h b/mlspp/lib/hpke/src/common.h new file mode 100755 index 0000000000..d6938f6a1f --- /dev/null +++ b/mlspp/lib/hpke/src/common.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace mlspp::hpke { + +bytes +i2osp(uint64_t val, size_t size); + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/dhkem.cpp b/mlspp/lib/hpke/src/dhkem.cpp new file mode 100755 index 0000000000..fb23c568cc --- /dev/null +++ b/mlspp/lib/hpke/src/dhkem.cpp @@ -0,0 +1,216 @@ +#include "dhkem.h" + +#include "common.h" + +namespace mlspp::hpke { + +DHKEM::PrivateKey::PrivateKey(Group::PrivateKey* group_priv_in) + : group_priv(group_priv_in) +{ +} + +std::unique_ptr +DHKEM::PrivateKey::public_key() const +{ + return group_priv->public_key(); +} + +DHKEM +make_dhkem(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in) +{ + return { kem_id_in, group_in, kdf_in }; +} + +template<> +const DHKEM& +DHKEM::get() +{ + static const auto instance = make_dhkem(KEM::ID::DHKEM_P256_SHA256, + Group::get(), + KDF::get()); + return instance; +} + +template<> +const DHKEM& +DHKEM::get() +{ + static const auto instance = make_dhkem(KEM::ID::DHKEM_P384_SHA384, + Group::get(), + KDF::get()); + return instance; +} + +template<> +const DHKEM& +DHKEM::get() +{ + static const auto instance = make_dhkem(KEM::ID::DHKEM_P521_SHA512, + Group::get(), + KDF::get()); + return instance; +} + +template<> +const DHKEM& +DHKEM::get() +{ + static const auto instance = make_dhkem(KEM::ID::DHKEM_X25519_SHA256, + Group::get(), + KDF::get()); + return instance; +} + +#if !defined(WITH_BORINGSSL) +template<> +const DHKEM& +DHKEM::get() +{ + static const auto instance = make_dhkem(KEM::ID::DHKEM_X448_SHA512, + Group::get(), + KDF::get()); + return instance; +} +#endif + +DHKEM::DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in) + : KEM(kem_id_in, + kdf_in.hash_size, + group_in.pk_size, + group_in.pk_size, + group_in.sk_size) + , group(group_in) + , kdf(kdf_in) +{ + static const auto label_kem = from_ascii("KEM"); + suite_id = label_kem + i2osp(uint16_t(kem_id_in), 2); +} + +std::unique_ptr +DHKEM::generate_key_pair() const +{ + return std::make_unique( + group.generate_key_pair().release()); +} + +std::unique_ptr +DHKEM::derive_key_pair(const bytes& ikm) const +{ + return std::make_unique( + group.derive_key_pair(suite_id, ikm).release()); +} + +bytes +DHKEM::serialize(const KEM::PublicKey& pk) const +{ + const auto& gpk = dynamic_cast(pk); + return group.serialize(gpk); +} + +std::unique_ptr +DHKEM::deserialize(const bytes& enc) const +{ + return group.deserialize(enc); +} + +bytes +DHKEM::serialize_private(const KEM::PrivateKey& sk) const +{ + const auto& gsk = dynamic_cast(sk); + return group.serialize_private(*gsk.group_priv); +} + +std::unique_ptr +DHKEM::deserialize_private(const bytes& skm) const +{ + return std::make_unique(group.deserialize_private(skm).release()); +} + +std::pair +DHKEM::encap(const KEM::PublicKey& pkR) const +{ + const auto& gpkR = dynamic_cast(pkR); + + auto skE = group.generate_key_pair(); + auto pkE = skE->public_key(); + + auto zz = group.dh(*skE, gpkR); + auto enc = group.serialize(*pkE); + + auto pkRm = group.serialize(gpkR); + auto kem_context = enc + pkRm; + + auto shared_secret = extract_and_expand(zz, kem_context); + return std::make_pair(shared_secret, enc); +} + +bytes +DHKEM::decap(const bytes& enc, const KEM::PrivateKey& skR) const +{ + const auto& gskR = dynamic_cast(skR); + auto pkR = gskR.group_priv->public_key(); + auto pkE = group.deserialize(enc); + auto zz = group.dh(*gskR.group_priv, *pkE); + + auto pkRm = group.serialize(*pkR); + auto kem_context = enc + pkRm; + return extract_and_expand(zz, kem_context); +} + +std::pair +DHKEM::auth_encap(const KEM::PublicKey& pkR, const KEM::PrivateKey& skS) const +{ + const auto& gpkR = dynamic_cast(pkR); + const auto& gskS = dynamic_cast(skS); + + auto skE = group.generate_key_pair(); + auto pkE = skE->public_key(); + auto pkS = gskS.group_priv->public_key(); + + auto zzER = group.dh(*skE, gpkR); + auto zzSR = group.dh(*gskS.group_priv, gpkR); + auto zz = zzER + zzSR; + auto enc = group.serialize(*pkE); + + auto pkRm = group.serialize(gpkR); + auto pkSm = group.serialize(*pkS); + auto kem_context = enc + pkRm + pkSm; + + auto shared_secret = extract_and_expand(zz, kem_context); + return std::make_pair(shared_secret, enc); +} + +bytes +DHKEM::auth_decap(const bytes& enc, + const KEM::PublicKey& pkS, + const KEM::PrivateKey& skR) const +{ + const auto& gpkS = dynamic_cast(pkS); + const auto& gskR = dynamic_cast(skR); + + auto pkE = group.deserialize(enc); + auto pkR = gskR.group_priv->public_key(); + + auto zzER = group.dh(*gskR.group_priv, *pkE); + auto zzSR = group.dh(*gskR.group_priv, gpkS); + auto zz = zzER + zzSR; + + auto pkRm = group.serialize(*pkR); + auto pkSm = group.serialize(gpkS); + auto kem_context = enc + pkRm + pkSm; + + return extract_and_expand(zz, kem_context); +} + +bytes +DHKEM::extract_and_expand(const bytes& dh, const bytes& kem_context) const +{ + static const auto label_eae_prk = from_ascii("eae_prk"); + static const auto label_shared_secret = from_ascii("shared_secret"); + + auto eae_prk = kdf.labeled_extract(suite_id, {}, label_eae_prk, dh); + return kdf.labeled_expand( + suite_id, eae_prk, label_shared_secret, kem_context, secret_size); +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/dhkem.h b/mlspp/lib/hpke/src/dhkem.h new file mode 100755 index 0000000000..5a0feb9fca --- /dev/null +++ b/mlspp/lib/hpke/src/dhkem.h @@ -0,0 +1,57 @@ +#pragma once + +#include + +#include "group.h" + +namespace mlspp::hpke { + +struct DHKEM : public KEM +{ + struct PrivateKey : public KEM::PrivateKey + { + PrivateKey(Group::PrivateKey* group_priv_in); + std::unique_ptr public_key() const override; + + std::unique_ptr group_priv; + }; + + template + static const DHKEM& get(); + + ~DHKEM() override = default; + + std::unique_ptr generate_key_pair() const override; + std::unique_ptr derive_key_pair( + const bytes& ikm) const override; + + bytes serialize(const KEM::PublicKey& pk) const override; + std::unique_ptr deserialize(const bytes& enc) const override; + + bytes serialize_private(const KEM::PrivateKey& sk) const override; + std::unique_ptr deserialize_private( + const bytes& skm) const override; + + std::pair encap(const KEM::PublicKey& pk) const override; + bytes decap(const bytes& enc, const KEM::PrivateKey& sk) const override; + + std::pair auth_encap(const KEM::PublicKey& pkR, + const KEM::PrivateKey& skS) const override; + bytes auth_decap(const bytes& enc, + const KEM::PublicKey& pkS, + const KEM::PrivateKey& skR) const override; + +private: + const Group& group; + const KDF& kdf; + bytes suite_id; + + bytes extract_and_expand(const bytes& dh, const bytes& kem_context) const; + + DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in); + friend DHKEM make_dhkem(KEM::ID kem_id_in, + const Group& group_in, + const KDF& kdf_in); +}; + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/digest.cpp b/mlspp/lib/hpke/src/digest.cpp new file mode 100755 index 0000000000..28b301bc1d --- /dev/null +++ b/mlspp/lib/hpke/src/digest.cpp @@ -0,0 +1,187 @@ +#include + +#include +#include +#if defined(WITH_OPENSSL3) +#include +#endif + +#include "openssl_common.h" + +namespace mlspp::hpke { + +static const EVP_MD* +openssl_digest_type(Digest::ID digest) +{ + switch (digest) { + case Digest::ID::SHA256: + return EVP_sha256(); + + case Digest::ID::SHA384: + return EVP_sha384(); + + case Digest::ID::SHA512: + return EVP_sha512(); + + default: + throw std::runtime_error("Unsupported ciphersuite"); + } +} + +#if defined(WITH_OPENSSL3) +static std::string +openssl_digest_name(Digest::ID digest) +{ + switch (digest) { + case Digest::ID::SHA256: + return OSSL_DIGEST_NAME_SHA2_256; + + case Digest::ID::SHA384: + return OSSL_DIGEST_NAME_SHA2_384; + + case Digest::ID::SHA512: + return OSSL_DIGEST_NAME_SHA2_512; + + default: + throw std::runtime_error("Unsupported digest algorithm"); + } +} +#endif + +template<> +const Digest& +Digest::get() +{ + static const Digest instance(Digest::ID::SHA256); + return instance; +} + +template<> +const Digest& +Digest::get() +{ + static const Digest instance(Digest::ID::SHA384); + return instance; +} + +template<> +const Digest& +Digest::get() +{ + static const Digest instance(Digest::ID::SHA512); + return instance; +} + +Digest::Digest(Digest::ID id_in) + : id(id_in) + , hash_size(EVP_MD_size(openssl_digest_type(id_in))) +{ +} + +bytes +Digest::hash(const bytes& data) const +{ + auto md = bytes(hash_size); + unsigned int size = 0; + const auto* type = openssl_digest_type(id); + if (1 != + EVP_Digest(data.data(), data.size(), md.data(), &size, type, nullptr)) { + throw openssl_error(); + } + + return md; +} + +bytes +Digest::hmac(const bytes& key, const bytes& data) const +{ + auto md = bytes(hash_size); + unsigned int size = 0; + const auto* type = openssl_digest_type(id); + if (nullptr == HMAC(type, + key.data(), + static_cast(key.size()), + data.data(), + static_cast(data.size()), + md.data(), + &size)) { + throw openssl_error(); + } + + return md; +} + +bytes +Digest::hmac_for_hkdf_extract(const bytes& key, const bytes& data) const +{ +#if defined(WITH_OPENSSL3) + auto digest_name = openssl_digest_name(id); + std::array params = { + OSSL_PARAM_construct_utf8_string( + OSSL_ALG_PARAM_DIGEST, digest_name.data(), 0), + OSSL_PARAM_construct_end() + }; + const auto mac = + make_typed_unique(EVP_MAC_fetch(nullptr, OSSL_MAC_NAME_HMAC, nullptr)); + const auto ctx = make_typed_unique(EVP_MAC_CTX_new(mac.get())); +#else + const auto* type = openssl_digest_type(id); + auto ctx = make_typed_unique(HMAC_CTX_new()); +#endif + if (ctx == nullptr) { + throw openssl_error(); + } + + // Some FIPS-enabled libraries are overly conservative in their interpretation + // of NIST SP 800-131A, which requires HMAC keys to be at least 112 bits long. + // That document does not impose that requirement on HKDF, so we disable FIPS + // enforcement for purposes of HKDF. + // + // https://doi.org/10.6028/NIST.SP.800-131Ar2 + auto key_size = static_cast(key.size()); + // OpenSSL 3 does not support the flag EVP_MD_CTX_FLAG_NON_FIPS_ALLOW anymore. + // However, OpenSSL 3 in FIPS mode doesn't seem to check the HMAC key size + // constraint. +#if !defined(WITH_OPENSSL3) && !defined(WITH_BORINGSSL) + static const auto fips_min_hmac_key_len = 14; + if (FIPS_mode() != 0 && key_size < fips_min_hmac_key_len) { + HMAC_CTX_set_flags(ctx.get(), EVP_MD_CTX_FLAG_NON_FIPS_ALLOW); + } +#endif + + // Guard against sending nullptr to HMAC_Init_ex + const auto* key_data = key.data(); + const auto non_null_zero_length_key = uint8_t(0); + if (key_data == nullptr) { + key_data = &non_null_zero_length_key; + } + + auto md = bytes(hash_size); +#if defined(WITH_OPENSSL3) + if (1 != EVP_MAC_init(ctx.get(), key_data, key_size, params.data())) { + throw openssl_error(); + } + if (1 != EVP_MAC_update(ctx.get(), data.data(), data.size())) { + throw openssl_error(); + } + size_t size = 0; + if (1 != EVP_MAC_final(ctx.get(), md.data(), &size, hash_size)) { + throw openssl_error(); + } +#else + if (1 != HMAC_Init_ex(ctx.get(), key_data, key_size, type, nullptr)) { + throw openssl_error(); + } + if (1 != HMAC_Update(ctx.get(), data.data(), data.size())) { + throw openssl_error(); + } + unsigned int size = 0; + if (1 != HMAC_Final(ctx.get(), md.data(), &size)) { + throw openssl_error(); + } +#endif + + return md; +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/group.cpp b/mlspp/lib/hpke/src/group.cpp new file mode 100755 index 0000000000..f9f9d7744c --- /dev/null +++ b/mlspp/lib/hpke/src/group.cpp @@ -0,0 +1,1077 @@ +#include "group.h" + +#include + +#include "common.h" +#include "openssl_common.h" + +#include "openssl/bn.h" +#include "openssl/ec.h" +#include "openssl/evp.h" +#include "openssl/obj_mac.h" +#if defined(WITH_OPENSSL3) +#include "openssl/core_names.h" +#include "openssl/param_build.h" +#endif + +namespace mlspp::hpke { + +static inline size_t +group_dh_size(Group::ID group_id); + +static inline const EVP_MD* +group_sig_digest(Group::ID group_id) +{ + switch (group_id) { + case Group::ID::P256: + return EVP_sha256(); + case Group::ID::P384: + return EVP_sha384(); + case Group::ID::P521: + return EVP_sha512(); + + // EdDSA does its own hashing internally + case Group::ID::Ed25519: + case Group::ID::Ed448: + return nullptr; + + // Groups not used for signature + case Group::ID::X25519: + case Group::ID::X448: + throw std::runtime_error("Signature not supported for group"); + + default: + throw std::runtime_error("Unknown group"); + } +} + +/// +/// General implementation with OpenSSL EVP_PKEY +/// + +EVPGroup::EVPGroup(Group::ID group_id, const KDF& kdf) + : Group(group_id, kdf) +{ +} + +EVPGroup::PublicKey::PublicKey(EVP_PKEY* pkey_in) + : pkey(pkey_in, typed_delete) +{ +} + +EVPGroup::PrivateKey::PrivateKey(EVP_PKEY* pkey_in) + : pkey(pkey_in, typed_delete) +{ +} + +std::unique_ptr +EVPGroup::PrivateKey::public_key() const +{ + if (1 != EVP_PKEY_up_ref(pkey.get())) { + throw openssl_error(); + } + return std::make_unique(pkey.get()); +} + +std::unique_ptr +EVPGroup::generate_key_pair() const +{ + return derive_key_pair({}, random_bytes(sk_size)); +} + +bytes +EVPGroup::dh(const Group::PrivateKey& sk, const Group::PublicKey& pk) const +{ + const auto& rsk = dynamic_cast(sk); + const auto& rpk = dynamic_cast(pk); + + // This and the next line are acceptable because the OpenSSL + // functions fail to mark the required EVP_PKEYs as const, even + // though they are not modified. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto* priv_pkey = const_cast(rsk.pkey.get()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto* pub_pkey = const_cast(rpk.pkey.get()); + + auto ctx = make_typed_unique(EVP_PKEY_CTX_new(priv_pkey, nullptr)); + if (ctx == nullptr) { + throw openssl_error(); + } + + if (1 != EVP_PKEY_derive_init(ctx.get())) { + throw openssl_error(); + } + + if (1 != EVP_PKEY_derive_set_peer(ctx.get(), pub_pkey)) { + throw openssl_error(); + } + + size_t out_len = 0; + if (1 != EVP_PKEY_derive(ctx.get(), nullptr, &out_len)) { + throw openssl_error(); + } + + bytes out(out_len); + uint8_t* ptr = out.data(); + if (1 != (EVP_PKEY_derive(ctx.get(), ptr, &out_len))) { + throw openssl_error(); + } + + return out; +} + +bytes +EVPGroup::sign(const bytes& data, const Group::PrivateKey& sk) const +{ + const auto& rsk = dynamic_cast(sk); + + auto ctx = make_typed_unique(EVP_MD_CTX_create()); + if (ctx == nullptr) { + throw openssl_error(); + } + + const auto* digest = group_sig_digest(id); + if (1 != + EVP_DigestSignInit(ctx.get(), nullptr, digest, nullptr, rsk.pkey.get())) { + throw openssl_error(); + } + + size_t siglen = EVP_PKEY_size(rsk.pkey.get()); + bytes sig(siglen); + if (1 != EVP_DigestSign( + ctx.get(), sig.data(), &siglen, data.data(), data.size())) { + throw openssl_error(); + } + + sig.resize(siglen); + return sig; +} + +bool +EVPGroup::verify(const bytes& data, + const bytes& sig, + const Group::PublicKey& pk) const +{ + const auto& rpk = dynamic_cast(pk); + + auto ctx = make_typed_unique(EVP_MD_CTX_create()); + if (ctx == nullptr) { + throw openssl_error(); + } + + const auto* digest = group_sig_digest(id); + if (1 != EVP_DigestVerifyInit( + ctx.get(), nullptr, digest, nullptr, rpk.pkey.get())) { + throw openssl_error(); + } + + auto rv = EVP_DigestVerify( + ctx.get(), sig.data(), sig.size(), data.data(), data.size()); + + return rv == 1; +} + +/// +/// DH over "normal" curves +/// + +struct ECKeyGroup : public EVPGroup +{ + ECKeyGroup(Group::ID group_id, const KDF& kdf) + : EVPGroup(group_id, kdf) + , curve_nid(group_to_nid(group_id)) + { + } + +#if defined(WITH_OPENSSL3) + typed_unique_ptr keypair_evp_key( + const typed_unique_ptr& priv) const + { + const auto* name = OBJ_nid2sn(curve_nid); + if (name == nullptr) { + throw std::runtime_error("Unsupported algorithm"); + } + + auto group = make_typed_unique( + EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid)); + if (group == nullptr) { + throw openssl_error(); + } + + auto pt = make_typed_unique(EC_POINT_new(group.get())); + if (pt == nullptr) { + throw openssl_error(); + } + + if (1 != EC_POINT_mul( + group.get(), pt.get(), priv.get(), nullptr, nullptr, nullptr)) { + throw openssl_error(); + } + + const auto pt_size = EC_POINT_point2oct(group.get(), + pt.get(), + POINT_CONVERSION_UNCOMPRESSED, + nullptr, + 0, + nullptr); + if (0 == pt_size) { + throw openssl_error(); + } + + bytes pub(pt_size); + if (EC_POINT_point2oct(group.get(), + pt.get(), + POINT_CONVERSION_UNCOMPRESSED, + pub.data(), + pt_size, + nullptr) != pt_size) { + throw openssl_error(); + } + + auto builder = make_typed_unique(OSSL_PARAM_BLD_new()); + if (builder == nullptr || + 1 != OSSL_PARAM_BLD_push_utf8_string( + builder.get(), OSSL_PKEY_PARAM_GROUP_NAME, name, 0) || + 1 != OSSL_PARAM_BLD_push_BN( + builder.get(), OSSL_PKEY_PARAM_PRIV_KEY, priv.get()) || + 1 != + OSSL_PARAM_BLD_push_octet_string( + builder.get(), OSSL_PKEY_PARAM_PUB_KEY, pub.data(), pub.size())) { + throw openssl_error(); + } + + auto params = make_typed_unique(OSSL_PARAM_BLD_to_param(builder.get())); + auto ctx = + make_typed_unique(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr)); + auto key = make_typed_unique(EVP_PKEY_new()); + auto* key_ptr = key.get(); + if (params == nullptr || ctx == nullptr || key == nullptr || + EVP_PKEY_fromdata_init(ctx.get()) <= 0 || + EVP_PKEY_fromdata( + ctx.get(), &key_ptr, EVP_PKEY_KEYPAIR, params.get()) <= 0) { + throw openssl_error(); + } + ctx.reset(); + + ctx = make_typed_unique( + EVP_PKEY_CTX_new_from_pkey(nullptr, key.get(), nullptr)); + if (EVP_PKEY_check(ctx.get()) <= 0) { + throw openssl_error(); + } + + return key; + } + + typed_unique_ptr public_evp_key(const bytes& pub) const + { + const auto* name = OBJ_nid2sn(curve_nid); + if (name == nullptr) { + throw std::runtime_error("Unsupported algorithm"); + } + + auto group = make_typed_unique( + EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid)); + if (group == nullptr) { + throw openssl_error(); + } + + auto builder = make_typed_unique(OSSL_PARAM_BLD_new()); + if (builder == nullptr || + 1 != OSSL_PARAM_BLD_push_utf8_string( + builder.get(), OSSL_PKEY_PARAM_GROUP_NAME, name, 0) || + 1 != + OSSL_PARAM_BLD_push_octet_string( + builder.get(), OSSL_PKEY_PARAM_PUB_KEY, pub.data(), pub.size())) { + throw openssl_error(); + } + + auto params = make_typed_unique(OSSL_PARAM_BLD_to_param(builder.get())); + auto ctx = + make_typed_unique(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr)); + auto key = make_typed_unique(EVP_PKEY_new()); + auto* key_ptr = key.get(); + if (params == nullptr || ctx == nullptr || key == nullptr || + EVP_PKEY_fromdata_init(ctx.get()) <= 0 || + EVP_PKEY_fromdata( + ctx.get(), &key_ptr, EVP_PKEY_KEYPAIR, params.get()) <= 0) { + throw openssl_error(); + } + ctx.reset(); + + ctx = make_typed_unique( + EVP_PKEY_CTX_new_from_pkey(nullptr, key.get(), nullptr)); + if (EVP_PKEY_public_check(ctx.get()) <= 0) { + throw openssl_error(); + } + + return key; + } +#endif + + std::unique_ptr derive_key_pair( + const bytes& suite_id, + const bytes& ikm) const override + { + static const int retry_limit = 255; + static const auto label_dkp_prk = from_ascii("dkp_prk"); + static const auto label_candidate = from_ascii("candidate"); + + auto dkp_prk = kdf.labeled_extract(suite_id, {}, label_dkp_prk, ikm); + +#if defined(WITH_OPENSSL3) + auto* group = EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid); + auto group_ptr = make_typed_unique(group); +#else + auto eckey = new_ec_key(); + const auto* group = EC_KEY_get0_group(eckey.get()); +#endif + + auto order = make_typed_unique(BN_new()); + if (1 != EC_GROUP_get_order(group, order.get(), nullptr)) { + throw openssl_error(); + } + + auto sk = make_typed_unique(BN_new()); + BN_zero(sk.get()); + + auto counter = int(0); + while (BN_is_zero(sk.get()) != 0 || BN_cmp(sk.get(), order.get()) != -1) { + auto ctr = i2osp(counter, 1); + auto candidate = + kdf.labeled_expand(suite_id, dkp_prk, label_candidate, ctr, sk_size); + candidate.at(0) &= bitmask(); + sk.reset(BN_bin2bn( + candidate.data(), static_cast(candidate.size()), nullptr)); + + counter += 1; + if (counter > retry_limit) { + throw std::runtime_error("DeriveKeyPair iteration limit exceeded"); + } + } + +#if defined(WITH_OPENSSL3) + auto key = keypair_evp_key(sk); + return std::make_unique(key.release()); +#else + auto pt = make_typed_unique(EC_POINT_new(group)); + EC_POINT_mul(group, pt.get(), sk.get(), nullptr, nullptr, nullptr); + + EC_KEY_set_private_key(eckey.get(), sk.get()); + EC_KEY_set_public_key(eckey.get(), pt.get()); + + auto pkey = to_pkey(eckey.release()); + return std::make_unique(pkey.release()); +#endif + } + + bytes serialize(const Group::PublicKey& pk) const override + { + const auto& rpk = dynamic_cast(pk); +#if defined(WITH_OPENSSL3) + OSSL_PARAM* param = nullptr; + if (1 != EVP_PKEY_todata(rpk.pkey.get(), EVP_PKEY_PUBLIC_KEY, ¶m)) { + throw openssl_error(); + } + auto param_ptr = make_typed_unique(param); + + const OSSL_PARAM* pk_param = + OSSL_PARAM_locate_const(param_ptr.get(), OSSL_PKEY_PARAM_PUB_KEY); + if (pk_param == nullptr) { + return bytes({}, 0); + } + + size_t len = 0; + if (1 != OSSL_PARAM_get_octet_string(pk_param, nullptr, 0, &len)) { + return bytes({}, 0); + } + + bytes buf(len); + void* data_ptr = buf.data(); + if (1 != OSSL_PARAM_get_octet_string(pk_param, &data_ptr, len, nullptr)) { + return bytes({}, 0); + } + + // Prior to OpenSSL 3.0.8, we will always get compressed point from + // OSSL_PKEY_PARAM_PUB_KEY, so we will have to do the following conversion. + // From OpenSSL 3.0.8, we can obtain the uncompressed point value by setting + // OSSL_PKEY_PARAM_EC_POINT_CONVERSION_FORMAT appropriately. + auto group = make_typed_unique( + EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid)); + if (group == nullptr) { + return bytes({}, 0); + } + auto point = make_typed_unique(EC_POINT_new(group.get())); + const auto* oct_ptr = static_cast(data_ptr); + if (1 != + EC_POINT_oct2point(group.get(), point.get(), oct_ptr, len, nullptr)) { + return bytes({}, 0); + } + len = EC_POINT_point2oct(group.get(), + point.get(), + POINT_CONVERSION_UNCOMPRESSED, + nullptr, + 0, + nullptr); + if (0 == len) { + return bytes({}, 0); + } + bytes out(len); + auto* data = out.data(); + if (EC_POINT_point2oct(group.get(), + point.get(), + POINT_CONVERSION_UNCOMPRESSED, + data, + len, + nullptr) != len) { + return bytes({}, 0); + } +#else + auto* pub = EVP_PKEY_get0_EC_KEY(rpk.pkey.get()); + + auto len = i2o_ECPublicKey(pub, nullptr); + if (len != static_cast(pk_size)) { + throw openssl_error(); + } + + bytes out(len); + auto* data = out.data(); + if (i2o_ECPublicKey(pub, &data) == 0) { + throw openssl_error(); + } +#endif + return out; + } + + std::unique_ptr deserialize(const bytes& enc) const override + { +#if defined(WITH_OPENSSL3) + auto key = public_evp_key(enc); + if (key == nullptr) { + throw std::runtime_error("Unable to deserialize the public key"); + } + return std::make_unique(key.release()); +#else + auto eckey = new_ec_key(); + auto* eckey_ptr = eckey.get(); + const auto* data_ptr = enc.data(); + if (nullptr == + o2i_ECPublicKey(&eckey_ptr, + &data_ptr, + static_cast( // NOLINT(google-runtime-int) + enc.size()))) { + throw openssl_error(); + } + + auto pkey = to_pkey(eckey.release()); + return std::make_unique(pkey.release()); +#endif + } + + bytes serialize_private(const Group::PrivateKey& sk) const override + { + const auto& rsk = dynamic_cast(sk); +#if defined(WITH_OPENSSL3) + OSSL_PARAM* param = nullptr; + if (1 != EVP_PKEY_todata(rsk.pkey.get(), EVP_PKEY_KEYPAIR, ¶m)) { + throw openssl_error(); + } + auto param_ptr = make_typed_unique(param); + + const OSSL_PARAM* sk_param = + OSSL_PARAM_locate_const(param_ptr.get(), OSSL_PKEY_PARAM_PRIV_KEY); + if (sk_param == nullptr) { + return bytes({}, 0); + } + + BIGNUM* d = nullptr; + if (1 != OSSL_PARAM_get_BN(sk_param, &d)) { + return bytes({}, 0); + } + auto d_ptr = make_typed_unique(d); +#else + auto* eckey = EVP_PKEY_get0_EC_KEY(rsk.pkey.get()); + const auto* d = EC_KEY_get0_private_key(eckey); +#endif + + auto out = bytes(BN_num_bytes(d)); +#if WITH_BORINGSSL + // In BoringSSL, BN_bn2bin returns size_t + const auto out_size = out.size(); +#else + // In OpenSSL, BN_bn2bin returns int + const auto out_size = static_cast(out.size()); +#endif + if (BN_bn2bin(d, out.data()) != out_size) { + throw openssl_error(); + } + + const auto zeros_needed = group_dh_size(id) - out.size(); + auto leading_zeros = bytes(zeros_needed, 0); + return leading_zeros + out; + } + + std::unique_ptr deserialize_private( + const bytes& skm) const override + { +#if defined(WITH_OPENSSL3) + auto priv = make_typed_unique( + BN_bin2bn(skm.data(), static_cast(skm.size()), nullptr)); + if (priv == nullptr) { + throw std::runtime_error("Unable to deserialize the private key"); + } + auto key = keypair_evp_key(priv); + return std::make_unique(key.release()); +#else + auto eckey = new_ec_key(); + const auto* group = EC_KEY_get0_group(eckey.get()); + const auto d = make_typed_unique( + BN_bin2bn(skm.data(), static_cast(skm.size()), nullptr)); + auto pt = make_typed_unique(EC_POINT_new(group)); + + EC_POINT_mul(group, pt.get(), d.get(), nullptr, nullptr, nullptr); + EC_KEY_set_private_key(eckey.get(), d.get()); + EC_KEY_set_public_key(eckey.get(), pt.get()); + + auto pkey = to_pkey(eckey.release()); + return std::make_unique(pkey.release()); +#endif + } + + // EC Key + std::tuple coordinates( + const Group::PublicKey& pk) const override + { + auto bn_x = make_typed_unique(BN_new()); + auto bn_y = make_typed_unique(BN_new()); + const auto& rpk = dynamic_cast(pk); + +#if defined(WITH_OPENSSL3) + // Raw pointer OK here because it becomes managed as soon as possible + OSSL_PARAM* param_ptr = nullptr; + if (1 != EVP_PKEY_todata(rpk.pkey.get(), EVP_PKEY_PUBLIC_KEY, ¶m_ptr)) { + throw openssl_error(); + } + + auto param = make_typed_unique(param_ptr); + + // Raw pointer OK here because it is non-owning + const auto* pk_param = + OSSL_PARAM_locate_const(param.get(), OSSL_PKEY_PARAM_PUB_KEY); + if (pk_param == nullptr) { + throw std::runtime_error("Failed to locate OSSL_PKEY_PARAM_PUB_KEY"); + } + + // Copy the octet string representation of the key into a buffer + auto len = size_t(0); + if (1 != OSSL_PARAM_get_octet_string(pk_param, nullptr, 0, &len)) { + throw std::runtime_error("Failed to get OSSL_PKEY_PARAM_PUB_KEY len"); + } + + auto buf = bytes(len); + auto* buf_ptr = buf.data(); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto* buf_ptr_void = reinterpret_cast(buf_ptr); + if (1 != + OSSL_PARAM_get_octet_string(pk_param, &buf_ptr_void, len, nullptr)) { + throw std::runtime_error("Failed to get OSSL_PKEY_PARAM_PUB_KEY data"); + } + + // Parse the octet string representation into an EC_POINT + auto group = make_typed_unique( + EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid)); + if (group == nullptr) { + throw openssl_error(); + } + + auto point = make_typed_unique(EC_POINT_new(group.get())); + if (point == nullptr) { + throw openssl_error(); + } + + if (1 != + EC_POINT_oct2point(group.get(), point.get(), buf_ptr, len, nullptr)) { + throw openssl_error(); + } + + // Retrieve the affine coordinates of the point + if (1 != EC_POINT_get_affine_coordinates( + group.get(), point.get(), bn_x.get(), bn_y.get(), nullptr)) { + throw openssl_error(); + } +#else + // Raw pointers are non-owning + auto* pub = EVP_PKEY_get0_EC_KEY(rpk.pkey.get()); + const auto* point = EC_KEY_get0_public_key(pub); + const auto* group = EC_KEY_get0_group(pub); + + if (1 != EC_POINT_get_affine_coordinates_GFp( + group, point, bn_x.get(), bn_y.get(), nullptr)) { + throw openssl_error(); + } +#endif + const auto x_size = BN_num_bytes(bn_x.get()); + auto x = bytes(x_size); + if (BN_bn2bin(bn_x.get(), x.data()) != x_size) { + throw openssl_error(); + } + + const auto y_size = BN_num_bytes(bn_y.get()); + auto y = bytes(y_size); + if (BN_bn2bin(bn_y.get(), y.data()) != y_size) { + throw openssl_error(); + } + + const auto zeros_needed_x = dh_size - x.size(); + const auto zeros_needed_y = dh_size - y.size(); + auto leading_zeros_x = bytes(zeros_needed_x, 0); + auto leading_zeros_y = bytes(zeros_needed_y, 0); + + return { leading_zeros_x + x, leading_zeros_y + y }; + } + + // EC Key + std::unique_ptr public_key_from_coordinates( + const bytes& x, + const bytes& y) const override + { + auto bn_x = make_typed_unique( + BN_bin2bn(x.data(), static_cast(x.size()), nullptr)); + auto bn_y = make_typed_unique( + BN_bin2bn(y.data(), static_cast(y.size()), nullptr)); + + if (bn_x == nullptr || bn_y == nullptr) { + throw std::runtime_error("Failed to convert bn_x or bn_y"); + } + +#if defined(WITH_OPENSSL3) + const auto group = make_typed_unique( + EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid)); + if (group == nullptr) { + throw std::runtime_error("Failed to create EC_GROUP"); + } + + // Construct a point with the given coordinates + auto point = make_typed_unique(EC_POINT_new(group.get())); + if (group == nullptr) { + throw std::runtime_error("Failed to create EC_POINT"); + } + + if (1 != EC_POINT_set_affine_coordinates( + group.get(), point.get(), bn_x.get(), bn_y.get(), nullptr)) { + throw openssl_error(); + } + + // Serialize the point + const auto point_size = EC_POINT_point2oct(group.get(), + point.get(), + POINT_CONVERSION_UNCOMPRESSED, + nullptr, + 0, + nullptr); + if (0 == point_size) { + throw openssl_error(); + } + + auto pub = bytes(point_size); + if (EC_POINT_point2oct(group.get(), + point.get(), + POINT_CONVERSION_UNCOMPRESSED, + pub.data(), + point_size, + nullptr) != point_size) { + throw openssl_error(); + } + + // Initialize a public key from the serialized point + auto key = public_evp_key(pub); + return std::make_unique(key.release()); +#else + auto eckey = new_ec_key(); + if (eckey == nullptr) { + throw std::runtime_error("Failed to create EC_KEY"); + } + + // Group pointer is non-owning + const auto* group = EC_KEY_get0_group(eckey.get()); + auto point = make_typed_unique(EC_POINT_new(group)); + + if (1 != EC_POINT_set_affine_coordinates_GFp( + group, point.get(), bn_x.get(), bn_y.get(), nullptr)) { + throw openssl_error(); + } + + if (1 != EC_KEY_set_public_key(eckey.get(), point.get())) { + throw openssl_error(); + } + + auto pkey = to_pkey(eckey.release()); + return std::make_unique(pkey.release()); +#endif + } + +private: + int curve_nid; + +#if !defined(WITH_OPENSSL3) + typed_unique_ptr new_ec_key() const + { + return make_typed_unique(EC_KEY_new_by_curve_name(curve_nid)); + } + + static typed_unique_ptr to_pkey(EC_KEY* eckey) + { + auto* pkey = EVP_PKEY_new(); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) + EVP_PKEY_assign_EC_KEY(pkey, eckey); + return make_typed_unique(pkey); + } +#endif + + static inline int group_to_nid(Group::ID group_id) + { + switch (group_id) { + case Group::ID::P256: + return NID_X9_62_prime256v1; + case Group::ID::P384: + return NID_secp384r1; + case Group::ID::P521: + return NID_secp521r1; + default: + throw std::runtime_error("Unsupported algorithm"); + } + } + + uint8_t bitmask() const + { + switch (id) { + case Group::ID::P256: + case Group::ID::P384: + return 0xff; + + case Group::ID::P521: + return 0x01; + + default: + throw std::runtime_error("Unsupported algorithm"); + } + } +}; + +/// +/// DH over "raw" curves +/// + +struct RawKeyGroup : public EVPGroup +{ + RawKeyGroup(Group::ID group_id, const KDF& kdf) + : EVPGroup(group_id, kdf) + , evp_type(group_to_evp(group_id)) + { + } + + template + static const RawKeyGroup instance; + + std::unique_ptr derive_key_pair( + const bytes& suite_id, + const bytes& ikm) const override + { + static const auto label_dkp_prk = from_ascii("dkp_prk"); + static const auto label_sk = from_ascii("sk"); + + auto dkp_prk = kdf.labeled_extract(suite_id, {}, label_dkp_prk, ikm); + auto skm = kdf.labeled_expand(suite_id, dkp_prk, label_sk, {}, sk_size); + return deserialize_private(skm); + } + + bytes serialize(const Group::PublicKey& pk) const override + { + const auto& rpk = dynamic_cast(pk); + auto raw = bytes(pk_size); + auto* data_ptr = raw.data(); + auto data_len = raw.size(); + if (1 != EVP_PKEY_get_raw_public_key(rpk.pkey.get(), data_ptr, &data_len)) { + throw openssl_error(); + } + + return raw; + } + + std::unique_ptr deserialize(const bytes& enc) const override + { + auto* pkey = + EVP_PKEY_new_raw_public_key(evp_type, nullptr, enc.data(), enc.size()); + if (pkey == nullptr) { + throw openssl_error(); + } + + return std::make_unique(pkey); + } + + bytes serialize_private(const Group::PrivateKey& sk) const override + { + const auto& rsk = dynamic_cast(sk); + auto raw = bytes(sk_size); + auto* data_ptr = raw.data(); + auto data_len = raw.size(); + if (1 != + EVP_PKEY_get_raw_private_key(rsk.pkey.get(), data_ptr, &data_len)) { + throw openssl_error(); + } + + return raw; + } + + std::unique_ptr deserialize_private( + const bytes& skm) const override + { + auto* pkey = + EVP_PKEY_new_raw_private_key(evp_type, nullptr, skm.data(), skm.size()); + if (pkey == nullptr) { + throw openssl_error(); + } + + return std::make_unique(pkey); + } + + // Raw Key + std::tuple coordinates( + const Group::PublicKey& pk) const override + { + const auto& rpk = dynamic_cast(pk); + auto raw = bytes(pk_size); + auto* data_ptr = raw.data(); + auto data_len = raw.size(); + + if (1 != EVP_PKEY_get_raw_public_key(rpk.pkey.get(), data_ptr, &data_len)) { + throw openssl_error(); + } + + return { raw, {} }; + } + + // Raw Key + std::unique_ptr public_key_from_coordinates( + const bytes& x, + const bytes& /* y */) const override + { + return deserialize(x); + } + +private: + const int evp_type; + + static inline int group_to_evp(Group::ID group_id) + { + switch (group_id) { + case Group::ID::X25519: + return EVP_PKEY_X25519; + case Group::ID::X448: + return EVP_PKEY_X448; + case Group::ID::Ed25519: + return EVP_PKEY_ED25519; + case Group::ID::Ed448: + return EVP_PKEY_ED448; + default: + throw std::runtime_error("Unsupported algorithm"); + } + } +}; + +/// +/// General DH group +/// + +template<> +const Group& +Group::get() +{ + static const ECKeyGroup instance(Group::ID::P256, + KDF::get()); + + return instance; +} + +template<> +const Group& +Group::get() +{ + static const ECKeyGroup instance(Group::ID::P384, + KDF::get()); + + return instance; +} + +template<> +const Group& +Group::get() +{ + static const ECKeyGroup instance(Group::ID::P521, + KDF::get()); + + return instance; +} + +template<> +const Group& +Group::get() +{ + static const RawKeyGroup instance(Group::ID::X25519, + KDF::get()); + return instance; +} + +template<> +const Group& +Group::get() +{ + static const RawKeyGroup instance(Group::ID::Ed25519, + KDF::get()); + return instance; +} + +// BoringSSL doesn't support X448 / Ed448 +#if !defined(WITH_BORINGSSL) +template<> +const Group& +Group::get() +{ + static const RawKeyGroup instance(Group::ID::X448, + KDF::get()); + return instance; +} +#endif + +template<> +const Group& +Group::get() +{ + static const RawKeyGroup instance(Group::ID::Ed448, + KDF::get()); + return instance; +} + +static inline size_t +group_dh_size(Group::ID group_id) +{ + switch (group_id) { + case Group::ID::P256: + return 32; + case Group::ID::P384: + return 48; + case Group::ID::P521: + return 66; + case Group::ID::X25519: + return 32; + case Group::ID::X448: + return 56; + + // Non-DH groups + case Group::ID::Ed25519: + case Group::ID::Ed448: + return 0; + + default: + throw std::runtime_error("Unknown group"); + } +} + +static inline size_t +group_pk_size(Group::ID group_id) +{ + switch (group_id) { + case Group::ID::P256: + return 65; + case Group::ID::P384: + return 97; + case Group::ID::P521: + return 133; + case Group::ID::X25519: + case Group::ID::Ed25519: + return 32; + case Group::ID::X448: + return 56; + case Group::ID::Ed448: + return 57; + + default: + throw std::runtime_error("Unknown group"); + } +} + +static inline size_t +group_sk_size(Group::ID group_id) +{ + switch (group_id) { + case Group::ID::P256: + return 32; + case Group::ID::P384: + return 48; + case Group::ID::P521: + return 66; + case Group::ID::X25519: + case Group::ID::Ed25519: + return 32; + case Group::ID::X448: + return 56; + case Group::ID::Ed448: + return 57; + + default: + throw std::runtime_error("Unknown group"); + } +} + +static inline std::string +group_jwk_curve_name(Group::ID group_id) +{ + switch (group_id) { + case Group::ID::P256: + return "P-256"; + case Group::ID::P384: + return "P-384"; + case Group::ID::P521: + return "P-521"; + case Group::ID::Ed25519: + return "Ed25519"; + case Group::ID::Ed448: + return "Ed448"; + case Group::ID::X25519: + return "X25519"; + case Group::ID::X448: + return "X448"; + default: + throw std::runtime_error("Unknown group"); + } +} + +static inline std::string +group_jwk_key_type(Group::ID group_id) +{ + switch (group_id) { + case Group::ID::P256: + case Group::ID::P384: + case Group::ID::P521: + return "EC"; + case Group::ID::Ed25519: + case Group::ID::Ed448: + case Group::ID::X25519: + case Group::ID::X448: + return "OKP"; + default: + throw std::runtime_error("Unknown group"); + } +} + +Group::Group(ID group_id_in, const KDF& kdf_in) + : id(group_id_in) + , dh_size(group_dh_size(group_id_in)) + , pk_size(group_pk_size(group_id_in)) + , sk_size(group_sk_size(group_id_in)) + , jwk_key_type(group_jwk_key_type(group_id_in)) + , jwk_curve_name(group_jwk_curve_name(group_id_in)) + , kdf(kdf_in) +{ +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/group.h b/mlspp/lib/hpke/src/group.h new file mode 100755 index 0000000000..73b52c715f --- /dev/null +++ b/mlspp/lib/hpke/src/group.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include + +#include "openssl_common.h" +#include + +namespace mlspp::hpke { + +struct Group +{ + enum struct ID : uint8_t + { + P256, + P384, + P521, + X25519, + X448, + Ed25519, + Ed448, + }; + + struct PublicKey + : public KEM::PublicKey + , public Signature::PublicKey + { + virtual ~PublicKey() = default; + }; + + struct PrivateKey + { + virtual ~PrivateKey() = default; + virtual std::unique_ptr public_key() const = 0; + }; + + template + static const Group& get(); + + virtual ~Group() = default; + + const ID id; + const size_t dh_size; + const size_t pk_size; + const size_t sk_size; + const std::string jwk_key_type; + const std::string jwk_curve_name; + + virtual std::unique_ptr generate_key_pair() const = 0; + virtual std::unique_ptr derive_key_pair( + const bytes& suite_id, + const bytes& ikm) const = 0; + + virtual bytes serialize(const PublicKey& pk) const = 0; + virtual std::unique_ptr deserialize(const bytes& enc) const = 0; + + virtual bytes serialize_private(const PrivateKey& sk) const = 0; + virtual std::unique_ptr deserialize_private( + const bytes& skm) const = 0; + + virtual bytes dh(const PrivateKey& sk, const PublicKey& pk) const = 0; + + virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0; + virtual bool verify(const bytes& data, + const bytes& sig, + const PublicKey& pk) const = 0; + + virtual std::tuple coordinates(const PublicKey& pk) const = 0; + virtual std::unique_ptr public_key_from_coordinates( + const bytes& x, + const bytes& y) const = 0; + +protected: + const KDF& kdf; + + friend struct DHKEM; + + Group(ID group_id_in, const KDF& kdf_in); +}; + +struct EVPGroup : public Group +{ + EVPGroup(Group::ID group_id, const KDF& kdf); + + struct PublicKey : public Group::PublicKey + { + explicit PublicKey(EVP_PKEY* pkey_in); + ~PublicKey() override = default; + + // NOLINTNEXTLINE(misc-non-private-member-variables-in-classes) + typed_unique_ptr pkey; + }; + + struct PrivateKey : public Group::PrivateKey + { + explicit PrivateKey(EVP_PKEY* pkey_in); + ~PrivateKey() override = default; + + std::unique_ptr public_key() const override; + + // NOLINTNEXTLINE(misc-non-private-member-variables-in-classes) + typed_unique_ptr pkey; + }; + + std::unique_ptr generate_key_pair() const override; + + bytes dh(const Group::PrivateKey& sk, + const Group::PublicKey& pk) const override; + + bytes sign(const bytes& data, const Group::PrivateKey& sk) const override; + bool verify(const bytes& data, + const bytes& sig, + const Group::PublicKey& pk) const override; +}; + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/hkdf.cpp b/mlspp/lib/hpke/src/hkdf.cpp new file mode 100755 index 0000000000..274b4ad25a --- /dev/null +++ b/mlspp/lib/hpke/src/hkdf.cpp @@ -0,0 +1,79 @@ +#include "hkdf.h" +#include "openssl_common.h" + +#include +#include +#include + +namespace mlspp::hpke { + +template<> +const HKDF& +HKDF::get() +{ + static const HKDF instance(Digest::get()); + return instance; +} + +template<> +const HKDF& +HKDF::get() +{ + static const HKDF instance(Digest::get()); + return instance; +} + +template<> +const HKDF& +HKDF::get() +{ + static const HKDF instance(Digest::get()); + return instance; +} + +static KDF::ID +digest_to_kdf(Digest::ID digest_id) +{ + switch (digest_id) { + case Digest::ID::SHA256: + return KDF::ID::HKDF_SHA256; + case Digest::ID::SHA384: + return KDF::ID::HKDF_SHA384; + case Digest::ID::SHA512: + return KDF::ID::HKDF_SHA512; + } + + throw std::runtime_error("Unsupported algorithm"); +} + +HKDF::HKDF(const Digest& digest_in) + : KDF(digest_to_kdf(digest_in.id), digest_in.hash_size) + , digest(digest_in) +{ +} + +bytes +HKDF::extract(const bytes& salt, const bytes& ikm) const +{ + return digest.hmac_for_hkdf_extract(salt, ikm); +} + +bytes +HKDF::expand(const bytes& prk, const bytes& info, size_t size) const +{ + auto okm = bytes{}; + auto i = uint8_t(0x00); + auto Ti = bytes{}; + while (okm.size() < size) { + i += 1; + auto block = Ti + info + bytes{ i }; + + Ti = digest.hmac(prk, block); + okm += Ti; + } + + okm.resize(size); + return okm; +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/hkdf.h b/mlspp/lib/hpke/src/hkdf.h new file mode 100755 index 0000000000..e6951be6b5 --- /dev/null +++ b/mlspp/lib/hpke/src/hkdf.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +namespace mlspp::hpke { + +struct HKDF : public KDF +{ + template + static const HKDF& get(); + + ~HKDF() override = default; + + bytes extract(const bytes& salt, const bytes& ikm) const override; + bytes expand(const bytes& prk, const bytes& info, size_t size) const override; + +private: + const Digest& digest; + + explicit HKDF(const Digest& digest_in); +}; + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/hpke.cpp b/mlspp/lib/hpke/src/hpke.cpp new file mode 100755 index 0000000000..7ee2faf670 --- /dev/null +++ b/mlspp/lib/hpke/src/hpke.cpp @@ -0,0 +1,540 @@ +#include +#include + +#include "aead_cipher.h" +#include "common.h" +#include "dhkem.h" +#include "hkdf.h" + +#include +#include +#include + +namespace mlspp::hpke { + +/// +/// Helper functions and constants +/// + +static const bytes& +label_exp() +{ + static const bytes val = from_ascii("exp"); + return val; +} + +static const bytes& +label_hpke() +{ + static const bytes val = from_ascii("HPKE"); + return val; +} + +static const bytes& +label_hpke_version() +{ + static const bytes val = from_ascii("HPKE-v1"); + return val; +} + +static const bytes& +label_info_hash() +{ + static const bytes val = from_ascii("info_hash"); + return val; +} + +static const bytes& +label_key() +{ + static const bytes val = from_ascii("key"); + return val; +} + +static const bytes& +label_base_nonce() +{ + static const bytes val = from_ascii("base_nonce"); + return val; +} + +static const bytes& +label_psk_id_hash() +{ + static const bytes val = from_ascii("psk_id_hash"); + return val; +} + +static const bytes& +label_sec() +{ + static const bytes val = from_ascii("sec"); + return val; +} + +static const bytes& +label_secret() +{ + static const bytes val = from_ascii("secret"); + return val; +} + +/// +/// Factory methods for primitives +/// + +KEM::KEM(ID id_in, + size_t secret_size_in, + size_t enc_size_in, + size_t pk_size_in, + size_t sk_size_in) + : id(id_in) + , secret_size(secret_size_in) + , enc_size(enc_size_in) + , pk_size(pk_size_in) + , sk_size(sk_size_in) +{ +} + +template<> +const KEM& +KEM::get() +{ + return DHKEM::get(); +} + +template<> +const KEM& +KEM::get() +{ + return DHKEM::get(); +} + +template<> +const KEM& +KEM::get() +{ + return DHKEM::get(); +} + +template<> +const KEM& +KEM::get() +{ + return DHKEM::get(); +} + +#if !defined(WITH_BORINGSSL) +template<> +const KEM& +KEM::get() +{ + return DHKEM::get(); +} +#endif + +bytes +KEM::serialize_private(const KEM::PrivateKey& /* unused */) const +{ + throw std::runtime_error("Not implemented"); +} + +std::unique_ptr +KEM::deserialize_private(const bytes& /* unused */) const +{ + throw std::runtime_error("Not implemented"); +} + +std::pair +KEM::auth_encap(const PublicKey& /* unused */, + const PrivateKey& /* unused */) const +{ + throw std::runtime_error("Not implemented"); +} + +bytes +KEM::auth_decap(const bytes& /* unused */, + const PublicKey& /* unused */, + const PrivateKey& /* unused */) const +{ + throw std::runtime_error("Not implemented"); +} + +template<> +const KDF& +KDF::get() +{ + return HKDF::get(); +} + +template<> +const KDF& +KDF::get() +{ + return HKDF::get(); +} + +template<> +const KDF& +KDF::get() +{ + return HKDF::get(); +} + +KDF::KDF(ID id_in, size_t hash_size_in) + : id(id_in) + , hash_size(hash_size_in) +{ +} + +bytes +KDF::labeled_extract(const bytes& suite_id, + const bytes& salt, + const bytes& label, + const bytes& ikm) const +{ + auto labeled_ikm = label_hpke_version() + suite_id + label + ikm; + return extract(salt, labeled_ikm); +} + +bytes +KDF::labeled_expand(const bytes& suite_id, + const bytes& prk, + const bytes& label, + const bytes& info, + size_t size) const +{ + auto labeled_info = + i2osp(size, 2) + label_hpke_version() + suite_id + label + info; + return expand(prk, labeled_info, size); +} + +template<> +const AEAD& +AEAD::get() +{ + return AEADCipher::get(); +} + +template<> +const AEAD& +AEAD::get() +{ + return AEADCipher::get(); +} + +template<> +const AEAD& +AEAD::get() +{ + return AEADCipher::get(); +} + +template<> +const AEAD& +AEAD::get() +{ + static const auto export_only = ExportOnlyCipher{}; + return export_only; +} + +AEAD::AEAD(ID id_in, size_t key_size_in, size_t nonce_size_in) + : id(id_in) + , key_size(key_size_in) + , nonce_size(nonce_size_in) +{ +} + +/// +/// Encryption Contexts +/// + +bytes +Context::do_export(const bytes& exporter_context, size_t size) const +{ + return kdf.labeled_expand( + suite, exporter_secret, label_sec(), exporter_context, size); +} + +bytes +Context::current_nonce() const +{ + auto curr = i2osp(seq, aead.nonce_size); + return curr ^ nonce; +} + +void +Context::increment_seq() +{ + if (seq == std::numeric_limits::max()) { + throw std::runtime_error("Sequence number overflow"); + } + + seq += 1; +} + +Context::Context(bytes suite_in, + bytes key_in, + bytes nonce_in, + bytes exporter_secret_in, + const KDF& kdf_in, + const AEAD& aead_in) + : suite(std::move(suite_in)) + , key(std::move(key_in)) + , nonce(std::move(nonce_in)) + , exporter_secret(std::move(exporter_secret_in)) + , kdf(kdf_in) + , aead(aead_in) + , seq(0) +{ +} + +bool +operator==(const Context& lhs, const Context& rhs) +{ + // TODO(RLB) Compare KDF and AEAD algorithms + auto suite = (lhs.suite == rhs.suite); + auto key = (lhs.key == rhs.key); + auto nonce = (lhs.nonce == rhs.nonce); + auto exporter_secret = (lhs.exporter_secret == rhs.exporter_secret); + auto seq = (lhs.seq == rhs.seq); + return suite && key && nonce && exporter_secret && seq; +} + +SenderContext::SenderContext(Context&& c) + : Context(std::move(c)) +{ +} + +bytes +SenderContext::seal(const bytes& aad, const bytes& pt) +{ + auto ct = aead.seal(key, current_nonce(), aad, pt); + increment_seq(); + return ct; +} + +ReceiverContext::ReceiverContext(Context&& c) + : Context(std::move(c)) +{ +} + +std::optional +ReceiverContext::open(const bytes& aad, const bytes& ct) +{ + auto maybe_pt = aead.open(key, current_nonce(), aad, ct); + increment_seq(); + return maybe_pt; +} + +/// +/// HPKE +/// + +static const bytes default_psk = {}; +static const bytes default_psk_id = {}; + +static bytes +suite_id(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id) +{ + return label_hpke() + i2osp(static_cast(kem_id), 2) + + i2osp(static_cast(kdf_id), 2) + + i2osp(static_cast(aead_id), 2); +} + +static const KEM& +select_kem(KEM::ID id) +{ + switch (id) { + case KEM::ID::DHKEM_P256_SHA256: + return KEM::get(); + case KEM::ID::DHKEM_P384_SHA384: + return KEM::get(); + case KEM::ID::DHKEM_P521_SHA512: + return KEM::get(); + case KEM::ID::DHKEM_X25519_SHA256: + return KEM::get(); +#if !defined(WITH_BORINGSSL) + case KEM::ID::DHKEM_X448_SHA512: + return KEM::get(); +#endif + default: + throw std::runtime_error("Unsupported algorithm"); + } +} + +static const KDF& +select_kdf(KDF::ID id) +{ + switch (id) { + case KDF::ID::HKDF_SHA256: + return KDF::get(); + case KDF::ID::HKDF_SHA384: + return KDF::get(); + case KDF::ID::HKDF_SHA512: + return KDF::get(); + default: + throw std::runtime_error("Unsupported algorithm"); + } +} + +static const AEAD& +select_aead(AEAD::ID id) +{ + switch (id) { + case AEAD::ID::AES_128_GCM: + return AEAD::get(); + case AEAD::ID::AES_256_GCM: + return AEAD::get(); + case AEAD::ID::CHACHA20_POLY1305: + return AEAD::get(); + case AEAD::ID::export_only: + return AEAD::get(); + default: + throw std::runtime_error("Unsupported algorithm"); + } +} + +HPKE::HPKE(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id) + : suite(suite_id(kem_id, kdf_id, aead_id)) + , kem(select_kem(kem_id)) + , kdf(select_kdf(kdf_id)) + , aead(select_aead(aead_id)) +{ +} + +HPKE::SenderInfo +HPKE::setup_base_s(const KEM::PublicKey& pkR, const bytes& info) const +{ + auto [shared_secret, enc] = kem.encap(pkR); + auto ctx = + key_schedule(Mode::base, shared_secret, info, default_psk, default_psk_id); + return std::make_pair(enc, SenderContext(std::move(ctx))); +} + +ReceiverContext +HPKE::setup_base_r(const bytes& enc, + const KEM::PrivateKey& skR, + const bytes& info) const +{ + auto pkRm = kem.serialize(*skR.public_key()); + auto shared_secret = kem.decap(enc, skR); + auto ctx = + key_schedule(Mode::base, shared_secret, info, default_psk, default_psk_id); + return { std::move(ctx) }; +} + +HPKE::SenderInfo +HPKE::setup_psk_s(const KEM::PublicKey& pkR, + const bytes& info, + const bytes& psk, + const bytes& psk_id) const +{ + auto [shared_secret, enc] = kem.encap(pkR); + auto ctx = key_schedule(Mode::psk, shared_secret, info, psk, psk_id); + return std::make_pair(enc, SenderContext(std::move(ctx))); +} + +ReceiverContext +HPKE::setup_psk_r(const bytes& enc, + const KEM::PrivateKey& skR, + const bytes& info, + const bytes& psk, + const bytes& psk_id) const +{ + auto shared_secret = kem.decap(enc, skR); + auto ctx = key_schedule(Mode::psk, shared_secret, info, psk, psk_id); + return { std::move(ctx) }; +} + +HPKE::SenderInfo +HPKE::setup_auth_s(const KEM::PublicKey& pkR, + const bytes& info, + const KEM::PrivateKey& skS) const +{ + auto [shared_secret, enc] = kem.auth_encap(pkR, skS); + auto ctx = + key_schedule(Mode::auth, shared_secret, info, default_psk, default_psk_id); + return std::make_pair(enc, SenderContext(std::move(ctx))); +} + +ReceiverContext +HPKE::setup_auth_r(const bytes& enc, + const KEM::PrivateKey& skR, + const bytes& info, + const KEM::PublicKey& pkS) const +{ + auto shared_secret = kem.auth_decap(enc, pkS, skR); + auto ctx = + key_schedule(Mode::auth, shared_secret, info, default_psk, default_psk_id); + return { std::move(ctx) }; +} + +HPKE::SenderInfo +HPKE::setup_auth_psk_s(const KEM::PublicKey& pkR, + const bytes& info, + const bytes& psk, + const bytes& psk_id, + const KEM::PrivateKey& skS) const +{ + auto [shared_secret, enc] = kem.auth_encap(pkR, skS); + auto ctx = key_schedule(Mode::auth_psk, shared_secret, info, psk, psk_id); + return std::make_pair(enc, SenderContext(std::move(ctx))); +} + +ReceiverContext +HPKE::setup_auth_psk_r(const bytes& enc, + const KEM::PrivateKey& skR, + const bytes& info, + const bytes& psk, + const bytes& psk_id, + const KEM::PublicKey& pkS) const +{ + auto shared_secret = kem.auth_decap(enc, pkS, skR); + auto ctx = key_schedule(Mode::auth_psk, shared_secret, info, psk, psk_id); + return { std::move(ctx) }; +} + +bool +HPKE::verify_psk_inputs(Mode mode, const bytes& psk, const bytes& psk_id) +{ + auto got_psk = (psk != default_psk); + auto got_psk_id = (psk_id != default_psk_id); + if (got_psk != got_psk_id) { + return false; + } + + return (!got_psk && (mode == Mode::base || mode == Mode::auth)) || + (got_psk && (mode == Mode::psk || mode == Mode::auth_psk)); +} + +Context +HPKE::key_schedule(Mode mode, + const bytes& shared_secret, + const bytes& info, + const bytes& psk, + const bytes& psk_id) const +{ + if (!verify_psk_inputs(mode, psk, psk_id)) { + throw std::runtime_error("Invalid PSK inputs"); + } + + auto psk_id_hash = + kdf.labeled_extract(suite, {}, label_psk_id_hash(), psk_id); + auto info_hash = kdf.labeled_extract(suite, {}, label_info_hash(), info); + auto mode_bytes = bytes{ uint8_t(mode) }; + auto key_schedule_context = mode_bytes + psk_id_hash + info_hash; + + auto secret = kdf.labeled_extract(suite, shared_secret, label_secret(), psk); + + auto key = kdf.labeled_expand( + suite, secret, label_key(), key_schedule_context, aead.key_size); + auto nonce = kdf.labeled_expand( + suite, secret, label_base_nonce(), key_schedule_context, aead.nonce_size); + auto exporter_secret = kdf.labeled_expand( + suite, secret, label_exp(), key_schedule_context, kdf.hash_size); + + return { suite, key, nonce, exporter_secret, kdf, aead }; +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/openssl_common.cpp b/mlspp/lib/hpke/src/openssl_common.cpp new file mode 100755 index 0000000000..e520059d1c --- /dev/null +++ b/mlspp/lib/hpke/src/openssl_common.cpp @@ -0,0 +1,160 @@ +#include "openssl_common.h" + +#include +#include +#include +#include +#include +#include +#if defined(WITH_OPENSSL3) +#include +#endif + +namespace mlspp::hpke { + +template<> +void +typed_delete(EVP_CIPHER_CTX* ptr) +{ + EVP_CIPHER_CTX_free(ptr); +} + +#if WITH_BORINGSSL +template<> +void +typed_delete(EVP_AEAD_CTX* ptr) +{ + EVP_AEAD_CTX_free(ptr); +} +#endif + +template<> +void +typed_delete(EVP_PKEY_CTX* ptr) +{ + EVP_PKEY_CTX_free(ptr); +} + +template<> +void +typed_delete(EVP_MD_CTX* ptr) +{ + EVP_MD_CTX_free(ptr); +} + +#if !defined(WITH_OPENSSL3) +template<> +void +typed_delete(HMAC_CTX* ptr) +{ + HMAC_CTX_free(ptr); +} +#endif + +template<> +void +typed_delete(EVP_PKEY* ptr) +{ + EVP_PKEY_free(ptr); +} + +template<> +void +typed_delete(BIGNUM* ptr) +{ + BN_free(ptr); +} + +template<> +void +typed_delete(EC_POINT* ptr) +{ + EC_POINT_free(ptr); +} + +#if !defined(WITH_OPENSSL3) +template<> +void +typed_delete(EC_KEY* ptr) +{ + EC_KEY_free(ptr); +} +#endif + +#if defined(WITH_OPENSSL3) +template<> +void +typed_delete(EVP_MAC* ptr) +{ + EVP_MAC_free(ptr); +} + +template<> +void +typed_delete(EVP_MAC_CTX* ptr) +{ + EVP_MAC_CTX_free(ptr); +} + +template<> +void +typed_delete(EC_GROUP* ptr) +{ + EC_GROUP_free(ptr); +} + +template<> +void +typed_delete(OSSL_PARAM_BLD* ptr) +{ + OSSL_PARAM_BLD_free(ptr); +} + +template<> +void +typed_delete(OSSL_PARAM* ptr) +{ + OSSL_PARAM_free(ptr); +} +#endif + +template<> +void +typed_delete(X509* ptr) +{ + X509_free(ptr); +} + +template<> +void +typed_delete(STACK_OF(GENERAL_NAME) * ptr) +{ + sk_GENERAL_NAME_pop_free(ptr, GENERAL_NAME_free); +} + +template<> +void +typed_delete(BIO* ptr) +{ + BIO_vfree(ptr); +} + +template<> +void +typed_delete(ASN1_TIME* ptr) +{ + ASN1_TIME_free(ptr); +} + +/// +/// Map OpenSSL errors to C++ exceptions +/// + +std::runtime_error +openssl_error() +{ + auto code = ERR_get_error(); + return std::runtime_error(ERR_error_string(code, nullptr)); +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/openssl_common.h b/mlspp/lib/hpke/src/openssl_common.h new file mode 100755 index 0000000000..d9282bd83b --- /dev/null +++ b/mlspp/lib/hpke/src/openssl_common.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace mlspp::hpke { + +template +void +typed_delete(T* ptr); + +template +using typed_unique_ptr = std::unique_ptr)>; + +template +typed_unique_ptr +make_typed_unique(T* ptr) +{ + return typed_unique_ptr(ptr, typed_delete); +} + +std::runtime_error +openssl_error(); + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/random.cpp b/mlspp/lib/hpke/src/random.cpp new file mode 100755 index 0000000000..6f00c6bae5 --- /dev/null +++ b/mlspp/lib/hpke/src/random.cpp @@ -0,0 +1,19 @@ +#include + +#include "openssl_common.h" + +#include + +namespace mlspp::hpke { + +bytes +random_bytes(size_t size) +{ + auto rand = bytes(size); + if (1 != RAND_bytes(rand.data(), static_cast(size))) { + throw openssl_error(); + } + return rand; +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/rsa.cpp b/mlspp/lib/hpke/src/rsa.cpp new file mode 100755 index 0000000000..de6aba172f --- /dev/null +++ b/mlspp/lib/hpke/src/rsa.cpp @@ -0,0 +1,207 @@ +#include "rsa.h" + +#include "common.h" +#include "openssl/rsa.h" +#include "openssl_common.h" + +namespace mlspp::hpke { + +std::unique_ptr +RSASignature::generate_key_pair() const +{ + throw std::runtime_error("Not implemented"); +} + +std::unique_ptr +RSASignature::derive_key_pair(const bytes& /*ikm*/) const +{ + throw std::runtime_error("Not implemented"); +} + +std::unique_ptr +RSASignature::generate_key_pair(size_t bits) +{ + auto ctx = make_typed_unique(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr)); + if (ctx == nullptr) { + throw openssl_error(); + } + + if (EVP_PKEY_keygen_init(ctx.get()) <= 0) { + throw openssl_error(); + } + + // NOLINTNEXTLINE(hicpp-signed-bitwise) + if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx.get(), static_cast(bits)) <= + 0) { + throw openssl_error(); + } + + auto* pkey = static_cast(nullptr); + if (EVP_PKEY_keygen(ctx.get(), &pkey) <= 0) { + throw openssl_error(); + } + + return std::make_unique(pkey); +} + +// TODO(rlb): Implement derive() with sizes + +bytes +RSASignature::serialize(const Signature::PublicKey& pk) const +{ + const auto& rpk = dynamic_cast(pk); + const int len = i2d_PublicKey(rpk.pkey.get(), nullptr); + auto raw = bytes(len); + auto* data_ptr = raw.data(); + if (len != i2d_PublicKey(rpk.pkey.get(), &data_ptr)) { + throw openssl_error(); + } + return raw; +} + +std::unique_ptr +RSASignature::deserialize(const bytes& enc) const +{ + const auto* data_ptr = enc.data(); + auto* pkey = d2i_PublicKey( + EVP_PKEY_RSA, nullptr, &data_ptr, static_cast(enc.size())); + if (pkey == nullptr) { + throw openssl_error(); + } + return std::make_unique(pkey); +} + +bytes +RSASignature::serialize_private(const Signature::PrivateKey& sk) const +{ + const auto& rsk = dynamic_cast(sk); + const int len = i2d_PrivateKey(rsk.pkey.get(), nullptr); + auto raw = bytes(len); + auto* data_ptr = raw.data(); + if (len != i2d_PrivateKey(rsk.pkey.get(), &data_ptr)) { + throw openssl_error(); + } + + return raw; +} + +std::unique_ptr +RSASignature::deserialize_private(const bytes& skm) const +{ + const auto* data_ptr = skm.data(); + auto* pkey = d2i_PrivateKey( + EVP_PKEY_RSA, nullptr, &data_ptr, static_cast(skm.size())); + if (pkey == nullptr) { + throw openssl_error(); + } + return std::make_unique(pkey); +} + +bytes +RSASignature::sign(const bytes& data, const Signature::PrivateKey& sk) const +{ + const auto& rsk = dynamic_cast(sk); + + auto ctx = make_typed_unique(EVP_MD_CTX_create()); + if (ctx == nullptr) { + throw openssl_error(); + } + + if (1 != + EVP_DigestSignInit(ctx.get(), nullptr, md, nullptr, rsk.pkey.get())) { + throw openssl_error(); + } + + size_t siglen = EVP_PKEY_size(rsk.pkey.get()); + bytes sig(siglen); + if (1 != EVP_DigestSign( + ctx.get(), sig.data(), &siglen, data.data(), data.size())) { + throw openssl_error(); + } + + sig.resize(siglen); + return sig; +} + +bool +RSASignature::verify(const bytes& data, + const bytes& sig, + const Signature::PublicKey& pk) const +{ + const auto& rpk = dynamic_cast(pk); + + auto ctx = make_typed_unique(EVP_MD_CTX_create()); + if (ctx == nullptr) { + throw openssl_error(); + } + + if (1 != + EVP_DigestVerifyInit(ctx.get(), nullptr, md, nullptr, rpk.pkey.get())) { + throw openssl_error(); + } + + auto rv = EVP_DigestVerify( + ctx.get(), sig.data(), sig.size(), data.data(), data.size()); + + return rv == 1; +} + +// TODO(RLB) Implement these methods. No concrete need, but might be nice for +// completeness. +std::unique_ptr +RSASignature::import_jwk_private(const std::string& /* json_str */) const +{ + throw std::runtime_error("not implemented"); +} + +std::unique_ptr +RSASignature::import_jwk(const std::string& /* json_str */) const +{ + throw std::runtime_error("not implemented"); +} + +std::string +RSASignature::export_jwk_private(const Signature::PrivateKey& /* sk */) const +{ + throw std::runtime_error("not implemented"); +} + +std::string +RSASignature::export_jwk(const Signature::PublicKey& /* pk */) const +{ + throw std::runtime_error("not implemented"); +} + +const EVP_MD* +RSASignature::digest_to_md(Digest::ID digest) +{ + // NOLINTNEXTLINE(hicpp-multiway-paths-covered) + switch (digest) { + case Digest::ID::SHA256: + return EVP_sha256(); + case Digest::ID::SHA384: + return EVP_sha384(); + case Digest::ID::SHA512: + return EVP_sha512(); + default: + throw std::runtime_error("Unsupported digest"); + } +} + +Signature::ID +RSASignature::digest_to_sig(Digest::ID digest) +{ + // NOLINTNEXTLINE(hicpp-multiway-paths-covered) + switch (digest) { + case Digest::ID::SHA256: + return Signature::ID::RSA_SHA256; + case Digest::ID::SHA384: + return Signature::ID::RSA_SHA384; + case Digest::ID::SHA512: + return Signature::ID::RSA_SHA512; + default: + throw std::runtime_error("Unsupported digest"); + } +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/rsa.h b/mlspp/lib/hpke/src/rsa.h new file mode 100755 index 0000000000..dde38bde7e --- /dev/null +++ b/mlspp/lib/hpke/src/rsa.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include +#include + +#include "openssl_common.h" +#include +#include + +namespace mlspp::hpke { + +// XXX(RLB): There is a lot of code in RSASignature that is duplicated in +// EVPGroup. I have allowed this duplication rather than factoring it out +// because I would like to be able to cleanly remove RSA later. +struct RSASignature : public Signature +{ + struct PublicKey : public Signature::PublicKey + { + explicit PublicKey(EVP_PKEY* pkey_in) + : pkey(pkey_in, typed_delete) + { + } + + ~PublicKey() override = default; + + typed_unique_ptr pkey; + }; + + struct PrivateKey : public Signature::PrivateKey + { + explicit PrivateKey(EVP_PKEY* pkey_in) + : pkey(pkey_in, typed_delete) + { + } + + ~PrivateKey() override = default; + + std::unique_ptr public_key() const override + { + if (1 != EVP_PKEY_up_ref(pkey.get())) { + throw openssl_error(); + } + return std::make_unique(pkey.get()); + } + + typed_unique_ptr pkey; + }; + + explicit RSASignature(Digest::ID digest) + : Signature(digest_to_sig(digest)) + , md(digest_to_md(digest)) + { + } + + std::unique_ptr generate_key_pair() const override; + + std::unique_ptr derive_key_pair( + const bytes& /*ikm*/) const override; + + static std::unique_ptr generate_key_pair(size_t bits); + + // TODO(rlb): Implement derive() with sizes + + bytes serialize(const Signature::PublicKey& pk) const override; + + std::unique_ptr deserialize( + const bytes& enc) const override; + + bytes serialize_private(const Signature::PrivateKey& sk) const override; + + std::unique_ptr deserialize_private( + const bytes& skm) const override; + + bytes sign(const bytes& data, const Signature::PrivateKey& sk) const override; + + bool verify(const bytes& data, + const bytes& sig, + const Signature::PublicKey& pk) const override; + + std::unique_ptr import_jwk_private( + const std::string& json_str) const override; + std::unique_ptr import_jwk( + const std::string& json_str) const override; + std::string export_jwk_private( + const Signature::PrivateKey& sk) const override; + std::string export_jwk(const Signature::PublicKey& pk) const override; + +private: + const EVP_MD* md; + + static const EVP_MD* digest_to_md(Digest::ID digest); + + static Signature::ID digest_to_sig(Digest::ID digest); +}; + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/signature.cpp b/mlspp/lib/hpke/src/signature.cpp new file mode 100755 index 0000000000..a866b2843c --- /dev/null +++ b/mlspp/lib/hpke/src/signature.cpp @@ -0,0 +1,344 @@ +#include +#include +#include +#include + +#include "dhkem.h" +#include "rsa.h" + +#include "dpp/nlohmann/json.hpp" +#include +#include +#include +#include +#include + +using nlohmann::json; + +namespace mlspp::hpke { + +struct GroupSignature : public Signature +{ + struct PrivateKey : public Signature::PrivateKey + { + explicit PrivateKey(Group::PrivateKey* group_priv_in) + : group_priv(group_priv_in) + { + } + + std::unique_ptr public_key() const override + { + return group_priv->public_key(); + } + + std::unique_ptr group_priv; + }; + + static Signature::ID group_to_sig(Group::ID group_id) + { + switch (group_id) { + case Group::ID::P256: + return Signature::ID::P256_SHA256; + case Group::ID::P384: + return Signature::ID::P384_SHA384; + case Group::ID::P521: + return Signature::ID::P521_SHA512; + case Group::ID::Ed25519: + return Signature::ID::Ed25519; +#if !defined(WITH_BORINGSSL) + case Group::ID::Ed448: + return Signature::ID::Ed448; +#endif + default: + throw std::runtime_error("Unsupported group"); + } + } + + explicit GroupSignature(const Group& group_in) + : Signature(group_to_sig(group_in.id)) + , group(group_in) + { + } + + std::unique_ptr generate_key_pair() const override + { + return std::make_unique(group.generate_key_pair().release()); + } + + std::unique_ptr derive_key_pair( + const bytes& ikm) const override + { + return std::make_unique( + group.derive_key_pair({}, ikm).release()); + } + + bytes serialize(const Signature::PublicKey& pk) const override + { + const auto& rpk = dynamic_cast(pk); + return group.serialize(rpk); + } + + std::unique_ptr deserialize( + const bytes& enc) const override + { + return group.deserialize(enc); + } + + bytes serialize_private(const Signature::PrivateKey& sk) const override + { + const auto& rsk = dynamic_cast(sk); + return group.serialize_private(*rsk.group_priv); + } + + std::unique_ptr deserialize_private( + const bytes& skm) const override + { + return std::make_unique( + group.deserialize_private(skm).release()); + } + + bytes sign(const bytes& data, const Signature::PrivateKey& sk) const override + { + const auto& rsk = dynamic_cast(sk); + return group.sign(data, *rsk.group_priv); + } + + bool verify(const bytes& data, + const bytes& sig, + const Signature::PublicKey& pk) const override + { + const auto& rpk = dynamic_cast(pk); + return group.verify(data, sig, rpk); + } + + std::unique_ptr import_jwk_private( + const std::string& jwk_json) const override + { + const auto jwk = validate_jwk_json(jwk_json, true); + + const auto d = from_base64url(jwk.at("d")); + auto gsk = group.deserialize_private(d); + + return std::make_unique(gsk.release()); + } + + std::unique_ptr import_jwk( + const std::string& jwk_json) const override + { + const auto jwk = validate_jwk_json(jwk_json, false); + + const auto x = from_base64url(jwk.at("x")); + auto y = bytes{}; + if (jwk.contains("y")) { + y = from_base64url(jwk.at("y")); + } + + return group.public_key_from_coordinates(x, y); + } + + std::string export_jwk(const Signature::PublicKey& pk) const override + { + const auto& gpk = dynamic_cast(pk); + const auto jwk_json = export_jwk_json(gpk); + return jwk_json.dump(); + } + + std::string export_jwk_private(const Signature::PrivateKey& sk) const override + { + const auto& gssk = dynamic_cast(sk); + const auto& gsk = gssk.group_priv; + const auto gpk = gsk->public_key(); + + auto jwk_json = export_jwk_json(*gpk); + + // encode the private key + const auto enc = serialize_private(sk); + jwk_json.emplace("d", to_base64url(enc)); + + return jwk_json.dump(); + } + +private: + const Group& group; + + json validate_jwk_json(const std::string& jwk_json, bool private_key) const + { + json jwk = json::parse(jwk_json); + + if (jwk.empty() || !jwk.contains("kty") || !jwk.contains("crv") || + !jwk.contains("x") || (private_key && !jwk.contains("d"))) { + throw std::runtime_error("malformed JWK"); + } + + if (jwk.at("kty") != group.jwk_key_type) { + throw std::runtime_error("invalid JWK key type"); + } + + if (jwk.at("crv") != group.jwk_curve_name) { + throw std::runtime_error("invalid JWK curve"); + } + + return jwk; + } + + json export_jwk_json(const Group::PublicKey& pk) const + { + const auto [x, y] = group.coordinates(pk); + + json jwk = json::object({ + { "crv", group.jwk_curve_name }, + { "kty", group.jwk_key_type }, + }); + + if (group.jwk_key_type == "EC") { + jwk.emplace("x", to_base64url(x)); + jwk.emplace("y", to_base64url(y)); + } else if (group.jwk_key_type == "OKP") { + jwk.emplace("x", to_base64url(x)); + } else { + throw std::runtime_error("unknown key type"); + } + + return jwk; + } +}; + +template<> +const Signature& +Signature::get() +{ + static const auto instance = GroupSignature(Group::get()); + return instance; +} + +template<> +const Signature& +Signature::get() +{ + static const auto instance = GroupSignature(Group::get()); + return instance; +} + +template<> +const Signature& +Signature::get() +{ + static const auto instance = GroupSignature(Group::get()); + return instance; +} + +template<> +const Signature& +Signature::get() +{ + static const auto instance = GroupSignature(Group::get()); + return instance; +} + +#if !defined(WITH_BORINGSSL) +template<> +const Signature& +Signature::get() +{ + static const auto instance = GroupSignature(Group::get()); + return instance; +} +#endif + +template<> +const Signature& +Signature::get() +{ + static const auto instance = RSASignature(Digest::ID::SHA256); + return instance; +} + +template<> +const Signature& +Signature::get() +{ + static const auto instance = RSASignature(Digest::ID::SHA384); + return instance; +} + +template<> +const Signature& +Signature::get() +{ + static const auto instance = RSASignature(Digest::ID::SHA512); + return instance; +} + +Signature::Signature(Signature::ID id_in) + : id(id_in) +{ +} + +std::unique_ptr +Signature::generate_rsa(size_t bits) +{ + return RSASignature::generate_key_pair(bits); +} + +static const Signature& +sig_from_jwk(const std::string& jwk_json) +{ + using KeyTypeAndCurve = std::tuple; + static const auto alg_sig_map = std::map + { + { { "EC", "P-256" }, Signature::get() }, + { { "EC", "P-384" }, Signature::get() }, + { { "EC", "P-512" }, Signature::get() }, + { { "OKP", "Ed25519" }, Signature::get() }, +#if !defined(WITH_BORINGSSL) + { { "OKP", "Ed448" }, Signature::get() }, +#endif + // TODO(RLB): RSA + }; + + const auto jwk = json::parse(jwk_json); + const auto& kty = jwk.at("kty"); + + auto crv = std::string(""); + if (jwk.contains("crv")) { + crv = jwk.at("crv"); + } + + const auto key = KeyTypeAndCurve{ kty, crv }; + return alg_sig_map.at(key); +} + +Signature::PrivateJWK +Signature::parse_jwk_private(const std::string& jwk_json) +{ + // XXX(RLB): This JSON-parses the JWK twice. I'm assuming that this is a less + // bad cost than changing the import_jwk method signature to take `json`. + const auto& sig = sig_from_jwk(jwk_json); + const auto jwk = json::parse(jwk_json); + auto priv = sig.import_jwk_private(jwk_json); + + auto kid = std::optional{}; + if (jwk.contains("kid")) { + kid = jwk.at("kid").get(); + } + + return { sig, kid, std::move(priv) }; +} + +Signature::PublicJWK +Signature::parse_jwk(const std::string& jwk_json) +{ + // XXX(RLB): Same double-parsing comment as with `parse_jwk_private` + const auto& sig = sig_from_jwk(jwk_json); + const auto jwk = json::parse(jwk_json); + auto pub = sig.import_jwk(jwk_json); + + auto kid = std::optional{}; + if (jwk.contains("kid")) { + kid = jwk.at("kid").get(); + } + + return { sig, kid, std::move(pub) }; +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/hpke/src/userinfo_vc.cpp b/mlspp/lib/hpke/src/userinfo_vc.cpp new file mode 100755 index 0000000000..e211a821ca --- /dev/null +++ b/mlspp/lib/hpke/src/userinfo_vc.cpp @@ -0,0 +1,401 @@ +#include +#include +#include +#include "dpp/nlohmann/json.hpp" +#include + +using nlohmann::json; + +namespace mlspp::hpke { + +static const std::string name_attr = "name"; +static const std::string sub_attr = "sub"; +static const std::string given_name_attr = "given_name"; +static const std::string family_name_attr = "family_name"; +static const std::string middle_name_attr = "middle_name"; +static const std::string nickname_attr = "nickname"; +static const std::string preferred_username_attr = "preferred_username"; +static const std::string profile_attr = "profile"; +static const std::string picture_attr = "picture"; +static const std::string website_attr = "website"; +static const std::string email_attr = "email"; +static const std::string email_verified_attr = "email_verified"; +static const std::string gender_attr = "gender"; +static const std::string birthdate_attr = "birthdate"; +static const std::string zoneinfo_attr = "zoneinfo"; +static const std::string locale_attr = "locale"; +static const std::string phone_number_attr = "phone_number"; +static const std::string phone_number_verified_attr = "phone_number_verified"; +static const std::string address_attr = "address"; +static const std::string address_formatted_attr = "formatted"; +static const std::string address_street_address_attr = "street_address"; +static const std::string address_locality_attr = "locality"; +static const std::string address_region_attr = "region"; +static const std::string address_postal_code_attr = "postal_code"; +static const std::string address_country_attr = "country"; +static const std::string updated_at_attr = "updated_at"; + +template +static std::optional +get_optional(const json& json_object, const std::string& field_name) +{ + if (!json_object.contains(field_name)) { + return std::nullopt; + } + + return { json_object.at(field_name).get() }; +} + +/// +/// ParsedCredential +/// +static const Signature& +signature_from_alg(const std::string& alg) +{ + static const auto alg_sig_map = std::map + { + { "ES256", Signature::get() }, + { "ES384", Signature::get() }, + { "ES512", Signature::get() }, + { "Ed25519", Signature::get() }, +#if !defined(WITH_BORINGSSL) + { "Ed448", Signature::get() }, +#endif + { "RS256", Signature::get() }, + { "RS384", Signature::get() }, + { "RS512", Signature::get() }, + }; + + return alg_sig_map.at(alg); +} + +static std::chrono::system_clock::time_point +epoch_time(int64_t seconds_since_epoch) +{ + const auto delta = std::chrono::seconds(seconds_since_epoch); + return std::chrono::system_clock::time_point(delta); +} + +static bool +is_ecdsa(const Signature& sig) +{ + return sig.id == Signature::ID::P256_SHA256 || + sig.id == Signature::ID::P384_SHA384 || + sig.id == Signature::ID::P521_SHA512; +} + +// OpenSSL expects ECDSA signatures to be in DER form. JWS provides the +// signature in raw R||S form. So we need to do some manual DER encoding. +static bytes +jws_to_der_sig(const bytes& jws_sig) +{ + // Inputs that are too large will result in invalid DER encodings with this + // code. At this size, the combination of the DER integer headers and the + // integer data will overflow the one-byte DER struct length. + static const auto max_sig_size = size_t(250); + if (jws_sig.size() > max_sig_size) { + throw std::runtime_error("JWS signature too large"); + } + + if (jws_sig.size() % 2 != 0) { + throw std::runtime_error("Malformed JWS signature"); + } + + const auto int_size = jws_sig.size() / 2; + const auto jws_sig_cut = + jws_sig.begin() + static_cast(int_size); + + // Compute the encoded size of R and S integer data, adding a zero byte if + // needed to clear the sign bit + const auto r_big = (jws_sig.at(0) >= 0x80); + const auto s_big = (jws_sig.at(int_size) >= 0x80); + + const auto r_size = int_size + (r_big ? 1 : 0); + const auto s_size = int_size + (s_big ? 1 : 0); + + // Compute the size of the DER-encoded signature + static const auto int_header_size = 2; + const auto r_int_size = int_header_size + r_size; + const auto s_int_size = int_header_size + s_size; + + const auto content_size = r_int_size + s_int_size; + const auto content_big = (content_size > 0x80); + + auto der_header_size = 2 + (content_big ? 1 : 0); + const auto der_size = der_header_size + content_size; + + // Allocate the DER buffer + auto der = bytes(der_size, 0); + + // Write the header + der.at(0) = 0x30; + if (content_big) { + der.at(1) = 0x81; + der.at(2) = static_cast(content_size); + } else { + der.at(1) = static_cast(content_size); + } + + // Write R, virtually padding with a zero byte if needed + const auto r_start = der_header_size; + const auto r_data_start = r_start + int_header_size + (r_big ? 1 : 0); + const auto r_data_begin = + der.begin() + static_cast(r_data_start); + + der.at(r_start) = 0x02; + der.at(r_start + 1) = static_cast(r_size); + std::copy(jws_sig.begin(), jws_sig_cut, r_data_begin); + + // Write S, virtually padding with a zero byte if needed + const auto s_start = der_header_size + r_int_size; + const auto s_data_start = s_start + int_header_size + (s_big ? 1 : 0); + const auto s_data_begin = + der.begin() + static_cast(s_data_start); + + der.at(s_start) = 0x02; + der.at(s_start + 1) = static_cast(s_size); + std::copy(jws_sig_cut, jws_sig.end(), s_data_begin); + + return der; +} + +struct UserInfoVC::ParsedCredential +{ + // Header fields + const Signature& signature_algorithm; // `alg` + std::optional key_id; // `kid` + + // Top-level Payload fields + std::string issuer; // `iss` + std::chrono::system_clock::time_point not_before; // `nbf` + std::chrono::system_clock::time_point not_after; // `exp` + + // Credential subject fields + UserInfoClaims credential_subject; + Signature::PublicJWK public_key; + + // Signature verification information + bytes to_be_signed; + bytes signature; + + ParsedCredential(const Signature& signature_algorithm_in, + std::optional key_id_in, + std::string issuer_in, + std::chrono::system_clock::time_point not_before_in, + std::chrono::system_clock::time_point not_after_in, + UserInfoClaims credential_subject_in, + Signature::PublicJWK&& public_key_in, + bytes to_be_signed_in, + bytes signature_in) + : signature_algorithm(signature_algorithm_in) + , key_id(std::move(key_id_in)) + , issuer(std::move(issuer_in)) + , not_before(not_before_in) + , not_after(not_after_in) + , credential_subject(std::move(credential_subject_in)) + , public_key(std::move(public_key_in)) + , to_be_signed(std::move(to_be_signed_in)) + , signature(std::move(signature_in)) + { + } + + static std::shared_ptr parse(const std::string& jwt) + { + // Split the JWT into its header, payload, and signature + const auto first_dot = jwt.find_first_of('.'); + const auto last_dot = jwt.find_last_of('.'); + if (first_dot == std::string::npos || last_dot == std::string::npos || + first_dot == last_dot || last_dot > jwt.length() - 2) { + throw std::runtime_error("malformed JWT; not enough '.' characters"); + } + + const auto header_b64 = jwt.substr(0, first_dot); + const auto payload_b64 = + jwt.substr(first_dot + 1, last_dot - first_dot - 1); + const auto signature_b64 = jwt.substr(last_dot + 1); + + // Parse the components + const auto header = json::parse(to_ascii(from_base64url(header_b64))); + const auto payload = json::parse(to_ascii(from_base64url(payload_b64))); + + // Prepare the validation inputs + const auto hdr = header.at("alg"); + const auto& sig = signature_from_alg(hdr); + const auto to_be_signed = from_ascii(header_b64 + "." + payload_b64); + auto signature = from_base64url(signature_b64); + if (is_ecdsa(sig)) { + signature = jws_to_der_sig(signature); + } + + auto kid = std::optional{}; + if (header.contains("kid")) { + kid = header.at("kid").get(); + } + + // Verify the VC parts + const auto& vc = payload.at("vc"); + + static const auto context = + std::vector{ { "https://www.w3.org/2018/credentials/v1" } }; + const auto vc_context = vc.at("@context").get>(); + if (vc_context != context) { + throw std::runtime_error("malformed VC: incorrect context value"); + } + + static const auto type = std::vector{ + "VerifiableCredential", + "UserInfoCredential", + }; + if (vc.at("type") != type) { + throw std::runtime_error("malformed VC: incorrect type value"); + } + + // Parse the subject public key + static const std::string did_jwk_prefix = "did:jwk:"; + const auto id = vc.at("credentialSubject").at("id").get(); + if (id.find(did_jwk_prefix) != 0) { + throw std::runtime_error("malformed UserInfo VC: ID is not did:jwk"); + } + + const auto jwk = to_ascii(from_base64url(id.substr(did_jwk_prefix.size()))); + auto public_key = Signature::parse_jwk(jwk); + + // Extract the salient parts + return std::make_shared( + sig, + kid, + + payload.at("iss"), + epoch_time(payload.at("nbf").get()), + epoch_time(payload.at("exp").get()), + + UserInfoClaims::from_json(vc.at("credentialSubject").dump()), + std::move(public_key), + + to_be_signed, + signature); + } + + bool verify(const Signature::PublicKey& issuer_key) + { + return signature_algorithm.verify(to_be_signed, signature, issuer_key); + } +}; + +/// +/// UserInfoClaims +/// +UserInfoClaims +UserInfoClaims::from_json(const std::string& cred_subject) +{ + const auto& cred_subject_json = nlohmann::json::parse(cred_subject); + + std::optional address_opt = {}; + + if (cred_subject_json.contains(address_attr)) { + auto address_json = cred_subject_json.at(address_attr); + address_opt = { + get_optional(address_json, address_formatted_attr), + get_optional(address_json, address_street_address_attr), + get_optional(address_json, address_locality_attr), + get_optional(address_json, address_region_attr), + get_optional(address_json, address_postal_code_attr), + get_optional(address_json, address_country_attr) + }; + } + + return { + get_optional(cred_subject_json, sub_attr), + get_optional(cred_subject_json, name_attr), + get_optional(cred_subject_json, given_name_attr), + get_optional(cred_subject_json, family_name_attr), + get_optional(cred_subject_json, middle_name_attr), + get_optional(cred_subject_json, nickname_attr), + get_optional(cred_subject_json, preferred_username_attr), + get_optional(cred_subject_json, profile_attr), + get_optional(cred_subject_json, picture_attr), + get_optional(cred_subject_json, website_attr), + get_optional(cred_subject_json, email_attr), + get_optional(cred_subject_json, email_verified_attr), + get_optional(cred_subject_json, gender_attr), + get_optional(cred_subject_json, birthdate_attr), + get_optional(cred_subject_json, zoneinfo_attr), + get_optional(cred_subject_json, locale_attr), + get_optional(cred_subject_json, phone_number_attr), + get_optional(cred_subject_json, phone_number_verified_attr), + address_opt, + get_optional(cred_subject_json, updated_at_attr), + }; +} + +/// +/// UserInfoVC +/// + +UserInfoVC::UserInfoVC(std::string jwt) + : parsed_cred(ParsedCredential::parse(jwt)) + , raw(std::move(jwt)) +{ +} + +const Signature& +UserInfoVC::signature_algorithm() const +{ + return parsed_cred->signature_algorithm; +} + +std::string +UserInfoVC::issuer() const +{ + return parsed_cred->issuer; +} + +std::optional +UserInfoVC::key_id() const +{ + return parsed_cred->key_id; +} + +bool +UserInfoVC::valid_from(const Signature::PublicKey& issuer_key) const +{ + return parsed_cred->verify(issuer_key); +} + +const std::string& +UserInfoVC::raw_credential() const +{ + return raw; +} + +const UserInfoClaims& +UserInfoVC::subject() const +{ + return parsed_cred->credential_subject; +} + +std::chrono::system_clock::time_point +UserInfoVC::not_before() const +{ + return parsed_cred->not_before; +} + +std::chrono::system_clock::time_point +UserInfoVC::not_after() const +{ + return parsed_cred->not_after; +} + +const Signature::PublicJWK& +UserInfoVC::public_key() const +{ + return parsed_cred->public_key; +} + +bool +operator==(const UserInfoVC& lhs, const UserInfoVC& rhs) +{ + return lhs.raw == rhs.raw; +} + +} // namespace mlspp::hpke diff --git a/mlspp/lib/mls_vectors/CMakeLists.txt b/mlspp/lib/mls_vectors/CMakeLists.txt new file mode 100755 index 0000000000..e16f4994f1 --- /dev/null +++ b/mlspp/lib/mls_vectors/CMakeLists.txt @@ -0,0 +1,36 @@ +set(CURRENT_LIB_NAME mls_vectors) + +### +### Library Config +### + +file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h") +file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") + +add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES}) +add_dependencies(${CURRENT_LIB_NAME} mlspp) +target_link_libraries(${CURRENT_LIB_NAME} mlspp bytes tls_syntax) +target_include_directories(${CURRENT_LIB_NAME} + PUBLIC + $ + $ + $ +) + +### +### Install +### + +install(TARGETS ${CURRENT_LIB_NAME} EXPORT mlspp-targets) +install( + DIRECTORY + include + DESTINATION + ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME} +) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include") + diff --git a/mlspp/lib/mls_vectors/include/mls_vectors/mls_vectors.h b/mlspp/lib/mls_vectors/include/mls_vectors/mls_vectors.h new file mode 100755 index 0000000000..25c148a94d --- /dev/null +++ b/mlspp/lib/mls_vectors/include/mls_vectors/mls_vectors.h @@ -0,0 +1,577 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mls_vectors { + +struct PseudoRandom +{ + struct Generator + { + Generator() = default; + Generator(mlspp::CipherSuite suite_in, const std::string& label); + Generator sub(const std::string& label) const; + + bytes secret(const std::string& label) const; + bytes generate(const std::string& label, size_t size) const; + + uint16_t uint16(const std::string& label) const; + uint32_t uint32(const std::string& label) const; + uint64_t uint64(const std::string& label) const; + + mlspp::SignaturePrivateKey signature_key( + const std::string& label) const; + mlspp::HPKEPrivateKey hpke_key(const std::string& label) const; + + size_t output_length() const; + + private: + mlspp::CipherSuite suite; + bytes seed; + + Generator(mlspp::CipherSuite suite_in, bytes seed_in); + }; + + PseudoRandom() = default; + PseudoRandom(mlspp::CipherSuite suite, const std::string& label); + + Generator prg; +}; + +struct TreeMathTestVector +{ + using OptionalNode = std::optional; + + mlspp::LeafCount n_leaves; + mlspp::NodeCount n_nodes; + mlspp::NodeIndex root; + std::vector left; + std::vector right; + std::vector parent; + std::vector sibling; + + std::optional null_if_invalid( + mlspp::NodeIndex input, + mlspp::NodeIndex answer) const; + + TreeMathTestVector() = default; + TreeMathTestVector(uint32_t n_leaves); + std::optional verify() const; +}; + +struct CryptoBasicsTestVector : PseudoRandom +{ + struct RefHash + { + std::string label; + bytes value; + bytes out; + + RefHash() = default; + RefHash(mlspp::CipherSuite suite, + const PseudoRandom::Generator& prg); + std::optional verify(mlspp::CipherSuite suite) const; + }; + + struct ExpandWithLabel + { + bytes secret; + std::string label; + bytes context; + uint16_t length; + bytes out; + + ExpandWithLabel() = default; + ExpandWithLabel(mlspp::CipherSuite suite, + const PseudoRandom::Generator& prg); + std::optional verify(mlspp::CipherSuite suite) const; + }; + + struct DeriveSecret + { + bytes secret; + std::string label; + bytes out; + + DeriveSecret() = default; + DeriveSecret(mlspp::CipherSuite suite, + const PseudoRandom::Generator& prg); + std::optional verify(mlspp::CipherSuite suite) const; + }; + + struct DeriveTreeSecret + { + bytes secret; + std::string label; + uint32_t generation; + uint16_t length; + bytes out; + + DeriveTreeSecret() = default; + DeriveTreeSecret(mlspp::CipherSuite suite, + const PseudoRandom::Generator& prg); + std::optional verify(mlspp::CipherSuite suite) const; + }; + + struct SignWithLabel + { + mlspp::SignaturePrivateKey priv; + mlspp::SignaturePublicKey pub; + bytes content; + std::string label; + bytes signature; + + SignWithLabel() = default; + SignWithLabel(mlspp::CipherSuite suite, + const PseudoRandom::Generator& prg); + std::optional verify(mlspp::CipherSuite suite) const; + }; + + struct EncryptWithLabel + { + mlspp::HPKEPrivateKey priv; + mlspp::HPKEPublicKey pub; + std::string label; + bytes context; + bytes plaintext; + bytes kem_output; + bytes ciphertext; + + EncryptWithLabel() = default; + EncryptWithLabel(mlspp::CipherSuite suite, + const PseudoRandom::Generator& prg); + std::optional verify(mlspp::CipherSuite suite) const; + }; + + mlspp::CipherSuite cipher_suite; + + RefHash ref_hash; + ExpandWithLabel expand_with_label; + DeriveSecret derive_secret; + DeriveTreeSecret derive_tree_secret; + SignWithLabel sign_with_label; + EncryptWithLabel encrypt_with_label; + + CryptoBasicsTestVector() = default; + CryptoBasicsTestVector(mlspp::CipherSuite suite); + std::optional verify() const; +}; + +struct SecretTreeTestVector : PseudoRandom +{ + struct SenderData + { + bytes sender_data_secret; + bytes ciphertext; + bytes key; + bytes nonce; + + SenderData() = default; + SenderData(mlspp::CipherSuite suite, + const PseudoRandom::Generator& prg); + std::optional verify(mlspp::CipherSuite suite) const; + }; + + struct RatchetStep + { + uint32_t generation; + bytes handshake_key; + bytes handshake_nonce; + bytes application_key; + bytes application_nonce; + }; + + mlspp::CipherSuite cipher_suite; + + SenderData sender_data; + + bytes encryption_secret; + std::vector> leaves; + + SecretTreeTestVector() = default; + SecretTreeTestVector(mlspp::CipherSuite suite, + uint32_t n_leaves, + const std::vector& generations); + std::optional verify() const; +}; + +struct KeyScheduleTestVector : PseudoRandom +{ + struct Export + { + std::string label; + bytes context; + size_t length; + bytes secret; + }; + + struct Epoch + { + // Chosen by the generator + bytes tree_hash; + bytes commit_secret; + bytes psk_secret; + bytes confirmed_transcript_hash; + + // Computed values + bytes group_context; + + bytes joiner_secret; + bytes welcome_secret; + bytes init_secret; + + bytes sender_data_secret; + bytes encryption_secret; + bytes exporter_secret; + bytes epoch_authenticator; + bytes external_secret; + bytes confirmation_key; + bytes membership_key; + bytes resumption_psk; + + mlspp::HPKEPublicKey external_pub; + Export exporter; + }; + + mlspp::CipherSuite cipher_suite; + + bytes group_id; + bytes initial_init_secret; + + std::vector epochs; + + KeyScheduleTestVector() = default; + KeyScheduleTestVector(mlspp::CipherSuite suite, uint32_t n_epochs); + std::optional verify() const; +}; + +struct MessageProtectionTestVector : PseudoRandom +{ + mlspp::CipherSuite cipher_suite; + + bytes group_id; + mlspp::epoch_t epoch; + bytes tree_hash; + bytes confirmed_transcript_hash; + + mlspp::SignaturePrivateKey signature_priv; + mlspp::SignaturePublicKey signature_pub; + + bytes encryption_secret; + bytes sender_data_secret; + bytes membership_key; + + mlspp::Proposal proposal; + mlspp::MLSMessage proposal_pub; + mlspp::MLSMessage proposal_priv; + + mlspp::Commit commit; + mlspp::MLSMessage commit_pub; + mlspp::MLSMessage commit_priv; + + bytes application; + mlspp::MLSMessage application_priv; + + MessageProtectionTestVector() = default; + MessageProtectionTestVector(mlspp::CipherSuite suite); + std::optional verify(); + +private: + mlspp::GroupKeySource group_keys() const; + mlspp::GroupContext group_context() const; + + mlspp::MLSMessage protect_pub( + const mlspp::GroupContent::RawContent& raw_content) const; + mlspp::MLSMessage protect_priv( + const mlspp::GroupContent::RawContent& raw_content); + std::optional unprotect( + const mlspp::MLSMessage& message); +}; + +struct PSKSecretTestVector : PseudoRandom +{ + struct PSK + { + bytes psk_id; + bytes psk_nonce; + bytes psk; + }; + + mlspp::CipherSuite cipher_suite; + std::vector psks; + bytes psk_secret; + + PSKSecretTestVector() = default; + PSKSecretTestVector(mlspp::CipherSuite suite, size_t n_psks); + std::optional verify() const; +}; + +struct TranscriptTestVector : PseudoRandom +{ + mlspp::CipherSuite cipher_suite; + + bytes confirmation_key; + bytes interim_transcript_hash_before; + + mlspp::AuthenticatedContent authenticated_content; + + bytes confirmed_transcript_hash_after; + bytes interim_transcript_hash_after; + + TranscriptTestVector() = default; + TranscriptTestVector(mlspp::CipherSuite suite); + std::optional verify() const; +}; + +struct WelcomeTestVector : PseudoRandom +{ + mlspp::CipherSuite cipher_suite; + + mlspp::HPKEPrivateKey init_priv; + mlspp::SignaturePublicKey signer_pub; + + mlspp::MLSMessage key_package; + mlspp::MLSMessage welcome; + + WelcomeTestVector() = default; + WelcomeTestVector(mlspp::CipherSuite suite); + std::optional verify() const; +}; + +// XXX(RLB): The |structure| of the example trees below is to avoid compile +// errors from gcc's -Werror=comment when there is a '\' character at the end of +// a line. Inspired by a similar bug in Chromium: +// https://codereview.chromium.org/874663003/patch/1/10001 +enum struct TreeStructure +{ + // Full trees on N leaves, created by member k adding member k+1 + full_tree_2, + full_tree_3, + full_tree_4, + full_tree_5, + full_tree_6, + full_tree_7, + full_tree_8, + full_tree_32, + full_tree_33, + full_tree_34, + + // | W | + // | ______|______ | + // | / \ | + // | U Y | + // | __|__ __|__ | + // | / \ / \ | + // | T _ X Z | + // | / \ / \ / \ / \ | + // | A B C _ E F G H | + // + // * Start with full tree on 8 members + // * 0 commits removeing 2 and 3, and adding a new member + internal_blanks_no_skipping, + + // | W | + // | ______|______ | + // | / \ | + // | _ Y | + // | __|__ __|__ | + // | / \ / \ | + // | _ _ X Z | + // | / \ / \ / \ / \ | + // | A _ _ _ E F G H | + // + // * Start with full tree on 8 members + // * 0 commitsremoveing 1, 2, and 3 + internal_blanks_with_skipping, + + // | W[H] | + // | ______|______ | + // | / \ | + // | U Y[H] | + // | __|__ __|__ | + // | / \ / \ | + // | T V X _ | + // | / \ / \ / \ / \ | + // | A B C D E F G H | + // + // * Start with full tree on 7 members + // * 0 commits adding a member in a partial Commit (no path) + unmerged_leaves_no_skipping, + + // | W [F] | + // | ______|______ | + // | / \ | + // | U Y [F] | + // | __|__ __|__ | + // | / \ / \ | + // | T _ _ _ | + // | / \ / \ / \ / \ | + // | A B C D E F G _ | + // + // == Fig. 20 / {{parent-hash-tree}} + // * 0 creates group + // * 0 adds 1, ..., 6 in a partial Commit + // * O commits removing 5 + // * 4 commits without any proposals + // * 0 commits adding a new member in a partial Commit + unmerged_leaves_with_skipping, +}; + +extern std::array all_tree_structures; +extern std::array treekem_test_tree_structures; + +struct TreeHashTestVector : PseudoRandom +{ + mlspp::CipherSuite cipher_suite; + bytes group_id; + + mlspp::TreeKEMPublicKey tree; + std::vector tree_hashes; + std::vector> resolutions; + + TreeHashTestVector() = default; + TreeHashTestVector(mlspp::CipherSuite suite, + TreeStructure tree_structure); + std::optional verify(); +}; + +struct TreeOperationsTestVector : PseudoRandom +{ + enum struct Scenario + { + add_right_edge, + add_internal, + update, + remove_right_edge, + remove_internal, + }; + + static const std::vector all_scenarios; + + mlspp::CipherSuite cipher_suite; + + mlspp::TreeKEMPublicKey tree_before; + bytes tree_hash_before; + + mlspp::Proposal proposal; + mlspp::LeafIndex proposal_sender; + + mlspp::TreeKEMPublicKey tree_after; + bytes tree_hash_after; + + TreeOperationsTestVector() = default; + TreeOperationsTestVector(mlspp::CipherSuite suite, Scenario scenario); + std::optional verify(); +}; + +struct TreeKEMTestVector : PseudoRandom +{ + struct PathSecret + { + mlspp::NodeIndex node; + bytes path_secret; + }; + + struct LeafPrivateInfo + { + mlspp::LeafIndex index; + mlspp::HPKEPrivateKey encryption_priv; + mlspp::SignaturePrivateKey signature_priv; + std::vector path_secrets; + }; + + struct UpdatePathInfo + { + mlspp::LeafIndex sender; + mlspp::UpdatePath update_path; + std::vector> path_secrets; + bytes commit_secret; + bytes tree_hash_after; + }; + + mlspp::CipherSuite cipher_suite; + + bytes group_id; + mlspp::epoch_t epoch; + bytes confirmed_transcript_hash; + + mlspp::TreeKEMPublicKey ratchet_tree; + + std::vector leaves_private; + std::vector update_paths; + + TreeKEMTestVector() = default; + TreeKEMTestVector(mlspp::CipherSuite suite, + TreeStructure tree_structure); + std::optional verify(); +}; + +struct MessagesTestVector : PseudoRandom +{ + bytes mls_welcome; + bytes mls_group_info; + bytes mls_key_package; + + bytes ratchet_tree; + bytes group_secrets; + + bytes add_proposal; + bytes update_proposal; + bytes remove_proposal; + bytes pre_shared_key_proposal; + bytes re_init_proposal; + bytes external_init_proposal; + bytes group_context_extensions_proposal; + + bytes commit; + + bytes public_message_proposal; + bytes public_message_commit; + bytes private_message; + + MessagesTestVector(); + std::optional verify() const; +}; + +struct PassiveClientTestVector : PseudoRandom +{ + struct PSK + { + bytes psk_id; + bytes psk; + }; + + struct Epoch + { + std::vector proposals; + mlspp::MLSMessage commit; + bytes epoch_authenticator; + }; + + mlspp::CipherSuite cipher_suite; + + mlspp::MLSMessage key_package; + mlspp::SignaturePrivateKey signature_priv; + mlspp::HPKEPrivateKey encryption_priv; + mlspp::HPKEPrivateKey init_priv; + + std::vector external_psks; + + mlspp::MLSMessage welcome; + std::optional ratchet_tree; + bytes initial_epoch_authenticator; + + std::vector epochs; + + PassiveClientTestVector() = default; + std::optional verify(); +}; + +} // namespace mls_vectors diff --git a/mlspp/lib/mls_vectors/src/mls_vectors.cpp b/mlspp/lib/mls_vectors/src/mls_vectors.cpp new file mode 100755 index 0000000000..a6176efa55 --- /dev/null +++ b/mlspp/lib/mls_vectors/src/mls_vectors.cpp @@ -0,0 +1,2052 @@ +#include +#include +#include +#include + +#include // XXX + +namespace mls_vectors { + +using namespace mlspp; + +/// +/// Assertions for verifying test vectors +/// + +template, int> = 0> +std::ostream& +operator<<(std::ostream& str, const T& obj) +{ + auto u = static_cast>(obj); + return str << u; +} + +static std::ostream& +operator<<(std::ostream& str, const NodeIndex& obj) +{ + return str << obj.val; +} + +static std::ostream& +operator<<(std::ostream& str, const NodeCount& obj) +{ + return str << obj.val; +} + +template +static std::ostream& +operator<<(std::ostream& str, const std::optional& obj) +{ + if (!obj) { + return str << "(nullopt)"; + } + + return str << opt::get(obj); +} + +static std::ostream& +operator<<(std::ostream& str, const std::vector& obj) +{ + return str << to_hex(obj); +} + +template +static std::ostream& +operator<<(std::ostream& str, const std::vector& obj) +{ + for (const auto& val : obj) { + str << val << " "; + } + return str; +} + +static std::ostream& +operator<<(std::ostream& str, const GroupContent::RawContent& obj) +{ + return var::visit( + overloaded{ + [&](const Proposal&) -> std::ostream& { return str << "[Proposal]"; }, + [&](const Commit&) -> std::ostream& { return str << "[Commit]"; }, + [&](const ApplicationData&) -> std::ostream& { + return str << "[ApplicationData]"; + }, + }, + obj); +} + +template +inline std::enable_if_t +operator<<(std::ostream& str, const T& obj) +{ + return str << to_hex(tls::marshal(obj)); +} + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define VERIFY(label, test) \ + if (auto err = verify_bool(label, test)) { \ + return err; \ + } + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define VERIFY_EQUAL(label, actual, expected) \ + if (auto err = verify_equal(label, actual, expected)) { \ + return err; \ + } + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define VERIFY_TLS_RTT(label, Type, expected) \ + if (auto err = verify_round_trip(label, expected)) { \ + return err; \ + } + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define VERIFY_TLS_RTT_VAL(label, Type, expected, val) \ + if (auto err = verify_round_trip(label, expected, val)) { \ + return err; \ + } + +template +static std::optional +verify_bool(const std::string& label, const T& test) +{ + if (test) { + return std::nullopt; + } + + return label; +} + +template +static std::optional +verify_equal(const std::string& label, const T& actual, const U& expected) +{ + if (actual == expected) { + return std::nullopt; + } + + auto ss = std::stringstream(); + ss << "Error: " << label << " " << actual << " != " << expected; + return ss.str(); +} + +template +static std::optional +verify_round_trip(const std::string& label, const bytes& expected) +{ + auto noop = [](const auto& /* unused */) { return true; }; + return verify_round_trip(label, expected, noop); +} + +template +static std::optional +verify_round_trip(const std::string& label, const bytes& expected, const F& val) +{ + auto obj = T{}; + try { + obj = tls::get(expected); + } catch (const std::exception& e) { + auto ss = std::stringstream(); + ss << "Decode error: " << label << " " << e.what(); + return ss.str(); + } + + if (!val(obj)) { + auto ss = std::stringstream(); + ss << "Validation error: " << label; + return ss.str(); + } + + auto actual = tls::marshal(obj); + VERIFY_EQUAL(label, actual, expected); + return std::nullopt; +} + +/// +/// PseudoRandom +/// + +PseudoRandom::Generator::Generator(CipherSuite suite_in, + const std::string& label) + : suite(suite_in) + , seed(suite.hpke().kdf.extract({}, from_ascii(label))) +{ +} + +PseudoRandom::Generator::Generator(CipherSuite suite_in, bytes seed_in) + : suite(suite_in) + , seed(std::move(seed_in)) +{ +} + +PseudoRandom::Generator +PseudoRandom::Generator::sub(const std::string& label) const +{ + return { suite, suite.derive_secret(seed, label) }; +} + +bytes +PseudoRandom::Generator::secret(const std::string& label) const +{ + return suite.derive_secret(seed, label); +} + +bytes +PseudoRandom::Generator::generate(const std::string& label, size_t size) const +{ + return suite.expand_with_label(seed, label, {}, size); +} + +uint16_t +PseudoRandom::Generator::uint16(const std::string& label) const +{ + auto data = generate(label, 2); + return tls::get(data); +} + +uint32_t +PseudoRandom::Generator::uint32(const std::string& label) const +{ + auto data = generate(label, 4); + return tls::get(data); +} + +uint64_t +PseudoRandom::Generator::uint64(const std::string& label) const +{ + auto data = generate(label, 8); + return tls::get(data); +} + +SignaturePrivateKey +PseudoRandom::Generator::signature_key(const std::string& label) const +{ + auto data = generate(label, suite.secret_size()); + return SignaturePrivateKey::derive(suite, data); +} + +HPKEPrivateKey +PseudoRandom::Generator::hpke_key(const std::string& label) const +{ + auto data = generate(label, suite.secret_size()); + return HPKEPrivateKey::derive(suite, data); +} + +size_t +PseudoRandom::Generator::output_length() const +{ + return suite.secret_size(); +} + +PseudoRandom::PseudoRandom(CipherSuite suite, const std::string& label) + : prg(suite, label) +{ +} + +/// +/// TreeMathTestVector +/// + +// XXX(RLB): This is a hack to get the tests working in the right format. In +// reality, the tree math functions should be updated to be fallible. +std::optional +TreeMathTestVector::null_if_invalid(NodeIndex input, NodeIndex answer) const +{ + // For some invalid cases (e.g., leaf.left()), we currently return the node + // itself instead of null + if (input == answer) { + return std::nullopt; + } + + // NodeIndex::parent is irrespective of tree size, so we might step out of the + // tree under consideration. + if (answer.val >= n_nodes.val) { + return std::nullopt; + } + + return answer; +} + +TreeMathTestVector::TreeMathTestVector(uint32_t n_leaves_in) + : n_leaves(n_leaves_in) + , n_nodes(n_leaves) + , root(NodeIndex::root(n_leaves)) + , left(n_nodes.val) + , right(n_nodes.val) + , parent(n_nodes.val) + , sibling(n_nodes.val) +{ + for (NodeIndex x{ 0 }; x.val < n_nodes.val; x.val++) { + left[x.val] = null_if_invalid(x, x.left()); + right[x.val] = null_if_invalid(x, x.right()); + parent[x.val] = null_if_invalid(x, x.parent()); + sibling[x.val] = null_if_invalid(x, x.sibling()); + } +} + +std::optional +TreeMathTestVector::verify() const +{ + VERIFY_EQUAL("n_nodes", n_nodes, NodeCount(n_leaves)); + VERIFY_EQUAL("root", root, NodeIndex::root(n_leaves)); + + for (NodeIndex x{ 0 }; x.val < n_nodes.val; x.val++) { + VERIFY_EQUAL("left", null_if_invalid(x, x.left()), left[x.val]); + VERIFY_EQUAL("right", null_if_invalid(x, x.right()), right[x.val]); + VERIFY_EQUAL("parent", null_if_invalid(x, x.parent()), parent[x.val]); + VERIFY_EQUAL("sibling", null_if_invalid(x, x.sibling()), sibling[x.val]); + } + + return std::nullopt; +} + +/// +/// TreeMathTestVector +/// + +CryptoBasicsTestVector::RefHash::RefHash(CipherSuite suite, + const PseudoRandom::Generator& prg) + : label("RefHash") + , value(prg.secret("value")) + , out(suite.raw_ref(from_ascii(label), value)) +{ +} + +std::optional +CryptoBasicsTestVector::RefHash::verify(CipherSuite suite) const +{ + VERIFY_EQUAL("ref hash", out, suite.raw_ref(from_ascii(label), value)); + return std::nullopt; +} + +CryptoBasicsTestVector::ExpandWithLabel::ExpandWithLabel( + CipherSuite suite, + const PseudoRandom::Generator& prg) + : secret(prg.secret("secret")) + , label("ExpandWithLabel") + , context(prg.secret("context")) + , length(static_cast(prg.output_length())) + , out(suite.expand_with_label(secret, label, context, length)) +{ +} + +std::optional +CryptoBasicsTestVector::ExpandWithLabel::verify(CipherSuite suite) const +{ + VERIFY_EQUAL("expand with label", + out, + suite.expand_with_label(secret, label, context, length)); + return std::nullopt; +} + +CryptoBasicsTestVector::DeriveSecret::DeriveSecret( + CipherSuite suite, + const PseudoRandom::Generator& prg) + : secret(prg.secret("secret")) + , label("DeriveSecret") + , out(suite.derive_secret(secret, label)) +{ +} + +std::optional +CryptoBasicsTestVector::DeriveSecret::verify(CipherSuite suite) const +{ + VERIFY_EQUAL("derive secret", out, suite.derive_secret(secret, label)); + return std::nullopt; +} + +CryptoBasicsTestVector::DeriveTreeSecret::DeriveTreeSecret( + CipherSuite suite, + const PseudoRandom::Generator& prg) + : secret(prg.secret("secret")) + , label("DeriveTreeSecret") + , generation(prg.uint32("generation")) + , length(static_cast(prg.output_length())) + , out(suite.derive_tree_secret(secret, label, generation, length)) +{ +} + +std::optional +CryptoBasicsTestVector::DeriveTreeSecret::verify(CipherSuite suite) const +{ + VERIFY_EQUAL("derive tree secret", + out, + suite.derive_tree_secret(secret, label, generation, length)); + return std::nullopt; +} + +CryptoBasicsTestVector::SignWithLabel::SignWithLabel( + CipherSuite suite, + const PseudoRandom::Generator& prg) + : priv(prg.signature_key("priv")) + , pub(priv.public_key) + , content(prg.secret("content")) + , label("SignWithLabel") + , signature(priv.sign(suite, label, content)) +{ +} + +std::optional +CryptoBasicsTestVector::SignWithLabel::verify(CipherSuite suite) const +{ + VERIFY("verify with label", pub.verify(suite, label, content, signature)); + + auto new_signature = priv.sign(suite, label, content); + VERIFY("sign with label", pub.verify(suite, label, content, new_signature)); + + return std::nullopt; +} + +CryptoBasicsTestVector::EncryptWithLabel::EncryptWithLabel( + CipherSuite suite, + const PseudoRandom::Generator& prg) + : priv(prg.hpke_key("priv")) + , pub(priv.public_key) + , label("EncryptWithLabel") + , context(prg.secret("context")) + , plaintext(prg.secret("plaintext")) +{ + auto ct = pub.encrypt(suite, label, context, plaintext); + kem_output = ct.kem_output; + ciphertext = ct.ciphertext; +} + +std::optional +CryptoBasicsTestVector::EncryptWithLabel::verify(CipherSuite suite) const +{ + auto ct = HPKECiphertext{ kem_output, ciphertext }; + auto pt = priv.decrypt(suite, label, context, ct); + VERIFY_EQUAL("decrypt with label", pt, plaintext); + + auto new_ct = pub.encrypt(suite, label, context, plaintext); + auto new_pt = priv.decrypt(suite, label, context, new_ct); + VERIFY_EQUAL("encrypt with label", new_pt, plaintext); + + return std::nullopt; +} + +CryptoBasicsTestVector::CryptoBasicsTestVector(CipherSuite suite) + : PseudoRandom(suite, "crypto-basics") + , cipher_suite(suite) + , ref_hash(suite, prg.sub("ref_hash")) + , expand_with_label(suite, prg.sub("expand_with_label")) + , derive_secret(suite, prg.sub("derive_secret")) + , derive_tree_secret(suite, prg.sub("derive_tree_secret")) + , sign_with_label(suite, prg.sub("sign_with_label")) + , encrypt_with_label(suite, prg.sub("encrypt_with_label")) +{ +} + +std::optional +CryptoBasicsTestVector::verify() const +{ + auto result = ref_hash.verify(cipher_suite); + if (result) { + return result; + } + + result = expand_with_label.verify(cipher_suite); + if (result) { + return result; + } + + result = derive_secret.verify(cipher_suite); + if (result) { + return result; + } + + result = derive_tree_secret.verify(cipher_suite); + if (result) { + return result; + } + + result = sign_with_label.verify(cipher_suite); + if (result) { + return result; + } + + result = encrypt_with_label.verify(cipher_suite); + if (result) { + return result; + } + + return std::nullopt; +} + +/// +/// SecretTreeTestVector +/// + +SecretTreeTestVector::SenderData::SenderData(mlspp::CipherSuite suite, + const PseudoRandom::Generator& prg) + : sender_data_secret(prg.secret("sender_data_secret")) + , ciphertext(prg.secret("ciphertext")) +{ + auto key_and_nonce = + KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, ciphertext); + key = key_and_nonce.key; + nonce = key_and_nonce.nonce; +} + +std::optional +SecretTreeTestVector::SenderData::verify(mlspp::CipherSuite suite) const +{ + auto key_and_nonce = + KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, ciphertext); + VERIFY_EQUAL("sender data key", key, key_and_nonce.key); + VERIFY_EQUAL("sender data nonce", nonce, key_and_nonce.nonce); + return std::nullopt; +} + +SecretTreeTestVector::SecretTreeTestVector( + mlspp::CipherSuite suite, + uint32_t n_leaves, + const std::vector& generations) + : PseudoRandom(suite, "secret-tree") + , cipher_suite(suite) + , sender_data(suite, prg.sub("sender_data")) + , encryption_secret(prg.secret("encryption_secret")) +{ + auto src = + GroupKeySource(cipher_suite, LeafCount{ n_leaves }, encryption_secret); + leaves.resize(n_leaves); + auto zero_reuse_guard = ReuseGuard{ 0, 0, 0, 0 }; + for (uint32_t i = 0; i < n_leaves; i++) { + auto leaf = LeafIndex{ i }; + + for (const auto generation : generations) { + auto hs = + src.get(ContentType::proposal, leaf, generation, zero_reuse_guard); + auto app = + src.get(ContentType::application, leaf, generation, zero_reuse_guard); + + leaves.at(i).push_back( + RatchetStep{ generation, hs.key, hs.nonce, app.key, app.nonce }); + + src.erase(ContentType::proposal, leaf, generation); + src.erase(ContentType::application, leaf, generation); + } + } +} + +std::optional +SecretTreeTestVector::verify() const +{ + auto sender_data_error = sender_data.verify(cipher_suite); + if (sender_data_error) { + return sender_data_error; + } + + auto n_leaves = static_cast(leaves.size()); + auto src = + GroupKeySource(cipher_suite, LeafCount{ n_leaves }, encryption_secret); + auto zero_reuse_guard = ReuseGuard{ 0, 0, 0, 0 }; + for (uint32_t i = 0; i < n_leaves; i++) { + auto leaf = LeafIndex{ i }; + + for (const auto& step : leaves[i]) { + auto generation = step.generation; + + auto hs = + src.get(ContentType::proposal, leaf, generation, zero_reuse_guard); + VERIFY_EQUAL("hs key", hs.key, step.handshake_key); + VERIFY_EQUAL("hs nonce", hs.nonce, step.handshake_nonce); + + auto app = + src.get(ContentType::application, leaf, generation, zero_reuse_guard); + VERIFY_EQUAL("app key", app.key, step.application_key); + VERIFY_EQUAL("app nonce", app.nonce, step.application_nonce); + } + } + + return std::nullopt; +} + +/// +/// KeyScheduleTestVector +/// + +KeyScheduleTestVector::KeyScheduleTestVector(CipherSuite suite, + uint32_t n_epochs) + : PseudoRandom(suite, "key-schedule") + , cipher_suite(suite) + , group_id(prg.secret("group_id")) + , initial_init_secret(prg.secret("group_id")) +{ + auto group_context = GroupContext{ suite, group_id, 0, {}, {}, {} }; + auto epoch = KeyScheduleEpoch(cipher_suite); + epoch.init_secret = initial_init_secret; + + for (uint64_t i = 0; i < n_epochs; i++) { + auto epoch_prg = prg.sub(to_hex(tls::marshal(i))); + + group_context.tree_hash = epoch_prg.secret("tree_hash"); + group_context.confirmed_transcript_hash = + epoch_prg.secret("confirmed_transcript_hash"); + auto ctx = tls::marshal(group_context); + + // TODO(RLB) Add Test case for externally-driven epoch change + auto commit_secret = epoch_prg.secret("commit_secret"); + auto psk_secret = epoch_prg.secret("psk_secret"); + epoch = epoch.next_raw(commit_secret, psk_secret, std::nullopt, ctx); + + auto welcome_secret = KeyScheduleEpoch::welcome_secret_raw( + cipher_suite, epoch.joiner_secret, psk_secret); + + auto exporter_prg = epoch_prg.sub("exporter"); + auto exporter_label = to_hex(exporter_prg.secret("label")); + auto exporter_context = exporter_prg.secret("context"); + auto exporter_length = cipher_suite.secret_size(); + auto exported = + epoch.do_export(exporter_label, exporter_context, exporter_length); + + epochs.push_back({ group_context.tree_hash, + commit_secret, + psk_secret, + group_context.confirmed_transcript_hash, + + ctx, + + epoch.joiner_secret, + welcome_secret, + epoch.init_secret, + + epoch.sender_data_secret, + epoch.encryption_secret, + epoch.exporter_secret, + epoch.epoch_authenticator, + epoch.external_secret, + epoch.confirmation_key, + epoch.membership_key, + epoch.resumption_psk, + + epoch.external_priv.public_key, + + { + exporter_label, + exporter_context, + exporter_length, + exported, + } }); + + group_context.epoch += 1; + } +} + +std::optional +KeyScheduleTestVector::verify() const +{ + auto group_context = GroupContext{ cipher_suite, group_id, 0, {}, {}, {} }; + auto epoch = KeyScheduleEpoch(cipher_suite); + epoch.init_secret = initial_init_secret; + + for (const auto& tve : epochs) { + group_context.tree_hash = tve.tree_hash; + group_context.confirmed_transcript_hash = tve.confirmed_transcript_hash; + auto ctx = tls::marshal(group_context); + VERIFY_EQUAL("group context", ctx, tve.group_context); + + epoch = + epoch.next_raw(tve.commit_secret, tve.psk_secret, std::nullopt, ctx); + + // Verify the rest of the epoch + VERIFY_EQUAL("joiner secret", epoch.joiner_secret, tve.joiner_secret); + + auto welcome_secret = KeyScheduleEpoch::welcome_secret_raw( + cipher_suite, tve.joiner_secret, tve.psk_secret); + VERIFY_EQUAL("welcome secret", welcome_secret, tve.welcome_secret); + + VERIFY_EQUAL( + "sender data secret", epoch.sender_data_secret, tve.sender_data_secret); + VERIFY_EQUAL( + "encryption secret", epoch.encryption_secret, tve.encryption_secret); + VERIFY_EQUAL("exporter secret", epoch.exporter_secret, tve.exporter_secret); + VERIFY_EQUAL("epoch authenticator", + epoch.epoch_authenticator, + tve.epoch_authenticator); + VERIFY_EQUAL("external secret", epoch.external_secret, tve.external_secret); + VERIFY_EQUAL( + "confirmation key", epoch.confirmation_key, tve.confirmation_key); + VERIFY_EQUAL("membership key", epoch.membership_key, tve.membership_key); + VERIFY_EQUAL("resumption psk", epoch.resumption_psk, tve.resumption_psk); + VERIFY_EQUAL("init secret", epoch.init_secret, tve.init_secret); + + VERIFY_EQUAL( + "external pub", epoch.external_priv.public_key, tve.external_pub); + + auto exported = epoch.do_export( + tve.exporter.label, tve.exporter.context, tve.exporter.length); + VERIFY_EQUAL("exported", exported, tve.exporter.secret); + + group_context.epoch += 1; + } + + return std::nullopt; +} + +/// +/// MessageProtectionTestVector +/// + +MessageProtectionTestVector::MessageProtectionTestVector(CipherSuite suite) + : PseudoRandom(suite, "message-protection") + , cipher_suite(suite) + , group_id(prg.secret("group_id")) + , epoch(prg.uint64("epoch")) + , tree_hash(prg.secret("tree_hash")) + , confirmed_transcript_hash(prg.secret("confirmed_transcript_hash")) + , signature_priv(prg.signature_key("signature_priv")) + , signature_pub(signature_priv.public_key) + , encryption_secret(prg.secret("encryption_secret")) + , sender_data_secret(prg.secret("sender_data_secret")) + , membership_key(prg.secret("membership_key")) + , proposal{ GroupContextExtensions{} } + , commit{ /* XXX(RLB) this is technically invalid, empty w/o path */ } + , application{ prg.secret("application") } +{ + proposal_pub = protect_pub(proposal); + proposal_priv = protect_priv(proposal); + + commit_pub = protect_pub(commit); + commit_priv = protect_priv(commit); + + application_priv = protect_priv(ApplicationData{ application }); +} + +std::optional +MessageProtectionTestVector::verify() +{ + // Initialize fields that don't get set from JSON + prg = PseudoRandom::Generator(cipher_suite, "message-protection"); + signature_priv.set_public_key(cipher_suite); + + // Sanity check the key pairs + VERIFY_EQUAL("sig kp", signature_priv.public_key, signature_pub); + + // Verify proposal unprotect as PublicMessage + auto proposal_pub_unprotected = unprotect(proposal_pub); + VERIFY("proposal pub unprotect auth", proposal_pub_unprotected); + VERIFY_EQUAL("proposal pub unprotect", + opt::get(proposal_pub_unprotected).content, + proposal); + + // Verify proposal unprotect as PrivateMessage + auto proposal_priv_unprotected = unprotect(proposal_priv); + VERIFY("proposal priv unprotect auth", proposal_priv_unprotected); + VERIFY_EQUAL("proposal priv unprotect", + opt::get(proposal_priv_unprotected).content, + proposal); + + // Verify commit unprotect as PublicMessage + auto commit_pub_unprotected = unprotect(commit_pub); + VERIFY("commit pub unprotect auth", commit_pub_unprotected); + VERIFY_EQUAL( + "commit pub unprotect", opt::get(commit_pub_unprotected).content, commit); + + // Verify commit unprotect as PrivateMessage + auto commit_priv_unprotected = unprotect(commit_priv); + VERIFY("commit priv unprotect auth", commit_priv_unprotected); + VERIFY_EQUAL( + "commit priv unprotect", opt::get(commit_priv_unprotected).content, commit); + + // Verify application data unprotect as PrivateMessage + auto app_unprotected = unprotect(application_priv); + VERIFY("app priv unprotect auth", app_unprotected); + VERIFY_EQUAL("app priv unprotect", + opt::get(app_unprotected).content, + ApplicationData{ application }); + + // Verify protect/unprotect round-trips + // XXX(RLB): Note that because (a) unprotect() deletes keys from the ratchet + // and (b) we are using the same ratchet to send and receive, we need to do + // these round-trip tests after all the unprotect tests are done. Otherwise + // the protect() calls here will re-use generations used the test vector, and + // then unprotect() will delete the keys, then when you go to decrypt the test + // vector object, you'll get "expired key". It might be good to have better + // safeguards around such reuse. + auto proposal_pub_protected = protect_pub(proposal); + auto proposal_pub_protected_unprotected = unprotect(proposal_pub_protected); + VERIFY("proposal pub protect/unprotect auth", + proposal_pub_protected_unprotected); + VERIFY_EQUAL("proposal pub protect/unprotect", + opt::get(proposal_pub_protected_unprotected).content, + proposal); + + auto proposal_priv_protected = protect_priv(proposal); + auto proposal_priv_protected_unprotected = unprotect(proposal_priv_protected); + VERIFY("proposal priv protect/unprotect auth", + proposal_priv_protected_unprotected); + VERIFY_EQUAL("proposal priv protect/unprotect", + opt::get(proposal_priv_protected_unprotected).content, + proposal); + + auto commit_pub_protected = protect_pub(commit); + auto commit_pub_protected_unprotected = unprotect(commit_pub_protected); + VERIFY("commit pub protect/unprotect auth", commit_pub_protected_unprotected); + VERIFY_EQUAL("commit pub protect/unprotect", + opt::get(commit_pub_protected_unprotected).content, + commit); + + auto commit_priv_protected = protect_priv(commit); + auto commit_priv_protected_unprotected = unprotect(commit_priv_protected); + VERIFY("commit priv protect/unprotect auth", + commit_priv_protected_unprotected); + VERIFY_EQUAL("commit priv protect/unprotect", + opt::get(commit_priv_protected_unprotected).content, + commit); + + auto app_protected = protect_priv(ApplicationData{ application }); + auto app_protected_unprotected = unprotect(app_protected); + VERIFY("app priv protect/unprotect auth", app_protected_unprotected); + VERIFY_EQUAL("app priv protect/unprotect", + opt::get(app_protected_unprotected).content, + ApplicationData{ application }); + + return std::nullopt; +} + +GroupKeySource +MessageProtectionTestVector::group_keys() const +{ + return { cipher_suite, LeafCount{ 2 }, encryption_secret }; +} + +GroupContext +MessageProtectionTestVector::group_context() const +{ + return GroupContext{ + cipher_suite, group_id, epoch, tree_hash, confirmed_transcript_hash, {} + }; +} + +MLSMessage +MessageProtectionTestVector::protect_pub( + const mlspp::GroupContent::RawContent& raw_content) const +{ + auto sender = Sender{ MemberSender{ LeafIndex{ 1 } } }; + auto authenticated_data = bytes{}; + + auto content = + GroupContent{ group_id, epoch, sender, authenticated_data, raw_content }; + + auto auth_content = AuthenticatedContent::sign(WireFormat::mls_public_message, + content, + cipher_suite, + signature_priv, + group_context()); + if (content.content_type() == ContentType::commit) { + auto confirmation_tag = prg.secret("confirmation_tag"); + auth_content.set_confirmation_tag(confirmation_tag); + } + + return PublicMessage::protect( + auth_content, cipher_suite, membership_key, group_context()); +} + +MLSMessage +MessageProtectionTestVector::protect_priv( + const mlspp::GroupContent::RawContent& raw_content) +{ + auto sender = Sender{ MemberSender{ LeafIndex{ 1 } } }; + auto authenticated_data = bytes{}; + auto padding_size = size_t(0); + + auto content = + GroupContent{ group_id, epoch, sender, authenticated_data, raw_content }; + + auto auth_content = + AuthenticatedContent::sign(WireFormat::mls_private_message, + content, + cipher_suite, + signature_priv, + group_context()); + if (content.content_type() == ContentType::commit) { + auto confirmation_tag = prg.secret("confirmation_tag"); + auth_content.set_confirmation_tag(confirmation_tag); + } + + auto keys = group_keys(); + return PrivateMessage::protect( + auth_content, cipher_suite, keys, sender_data_secret, padding_size); +} + +std::optional +MessageProtectionTestVector::unprotect(const MLSMessage& message) +{ + auto do_unprotect = + overloaded{ [&](const PublicMessage& pt) { + return pt.unprotect( + cipher_suite, membership_key, group_context()); + }, + [&](const PrivateMessage& ct) { + auto keys = group_keys(); + return ct.unprotect(cipher_suite, keys, sender_data_secret); + }, + [](const auto& /* other */) -> std::optional { + return std::nullopt; + } }; + + auto maybe_auth_content = var::visit(do_unprotect, message.message); + if (!maybe_auth_content) { + return std::nullopt; + } + + auto val_content = opt::get(maybe_auth_content); + const auto& auth_content = val_content.authenticated_content(); + if (!auth_content.verify(cipher_suite, signature_pub, group_context())) { + return std::nullopt; + } + + return auth_content.content; +} + +/// +/// PSKTestVector +/// +static std::vector +to_psk_w_secret(const std::vector& psks) +{ + auto pskws = std::vector(psks.size()); + std::transform( + std::begin(psks), std::end(psks), std::begin(pskws), [](const auto& psk) { + auto ext_id = ExternalPSK{ psk.psk_id }; + auto id = PreSharedKeyID{ ext_id, psk.psk_nonce }; + return PSKWithSecret{ id, psk.psk }; + }); + + return pskws; +} + +PSKSecretTestVector::PSKSecretTestVector(mlspp::CipherSuite suite, + size_t n_psks) + : PseudoRandom(suite, "psk_secret") + , cipher_suite(suite) + , psks(n_psks) +{ + uint32_t i = 0; + for (auto& psk : psks) { + auto ix = to_hex(tls::marshal(i)); + i += 1; + + psk.psk_id = prg.secret("psk_id" + ix); + psk.psk_nonce = prg.secret("psk_nonce" + ix); + psk.psk = prg.secret("psk" + ix); + } + + psk_secret = + KeyScheduleEpoch::make_psk_secret(cipher_suite, to_psk_w_secret(psks)); +} + +std::optional +PSKSecretTestVector::verify() const +{ + auto actual = + KeyScheduleEpoch::make_psk_secret(cipher_suite, to_psk_w_secret(psks)); + VERIFY_EQUAL("psk secret", actual, psk_secret); + + return std::nullopt; +} + +/// +/// TranscriptTestVector +/// +TranscriptTestVector::TranscriptTestVector(CipherSuite suite) + : PseudoRandom(suite, "transcript") + , cipher_suite(suite) + , interim_transcript_hash_before(prg.secret("interim_transcript_hash_before")) +{ + auto transcript = TranscriptHash(suite); + transcript.interim = interim_transcript_hash_before; + + auto group_id = prg.secret("group_id"); + auto epoch = prg.uint64("epoch"); + auto group_context_obj = + GroupContext{ suite, + group_id, + epoch, + prg.secret("tree_hash_before"), + prg.secret("confirmed_transcript_hash_before"), + {} }; + auto group_context = tls::marshal(group_context_obj); + + auto init_secret = prg.secret("init_secret"); + auto ks_epoch = KeyScheduleEpoch(suite, init_secret, group_context); + + auto sig_priv = prg.signature_key("sig_priv"); + auto leaf_index = LeafIndex{ 0 }; + + authenticated_content = AuthenticatedContent::sign( + WireFormat::mls_public_message, + GroupContent{ + group_id, epoch, { MemberSender{ leaf_index } }, {}, Commit{} }, + suite, + sig_priv, + group_context_obj); + + transcript.update_confirmed(authenticated_content); + + const auto confirmation_tag = ks_epoch.confirmation_tag(transcript.confirmed); + authenticated_content.set_confirmation_tag(confirmation_tag); + + transcript.update_interim(authenticated_content); + + // Store the required data + confirmation_key = ks_epoch.confirmation_key; + confirmed_transcript_hash_after = transcript.confirmed; + interim_transcript_hash_after = transcript.interim; +} + +std::optional +TranscriptTestVector::verify() const +{ + auto transcript = TranscriptHash(cipher_suite); + transcript.interim = interim_transcript_hash_before; + + transcript.update(authenticated_content); + VERIFY_EQUAL( + "confirmed", transcript.confirmed, confirmed_transcript_hash_after); + VERIFY_EQUAL("interim", transcript.interim, interim_transcript_hash_after); + + auto confirmation_tag = + cipher_suite.digest().hmac(confirmation_key, transcript.confirmed); + VERIFY_EQUAL("confirmation tag", + confirmation_tag, + authenticated_content.auth.confirmation_tag); + + return std::nullopt; +} + +/// +/// WelcomeTestVector +/// +WelcomeTestVector::WelcomeTestVector(CipherSuite suite) + : PseudoRandom(suite, "welcome") + , cipher_suite(suite) + , init_priv(prg.hpke_key("init_priv")) +{ + auto joiner_secret = prg.secret("joiner_secret"); + auto group_id = prg.secret("group_id"); + auto epoch = epoch_t(prg.uint64("epoch")); + auto tree_hash = prg.secret("tree_hash"); + auto confirmed_transcript_hash = prg.secret("confirmed_transcript_hash"); + auto enc_priv = prg.hpke_key("enc_priv"); + auto sig_priv = prg.signature_key("sig_priv"); + auto cred = Credential::basic(prg.secret("identity")); + + auto signer_index = LeafIndex{ prg.uint32("signer") }; + auto signer_priv = prg.signature_key("signer_priv"); + signer_pub = signer_priv.public_key; + + auto leaf_node = LeafNode{ + cipher_suite, + enc_priv.public_key, + sig_priv.public_key, + cred, + Capabilities::create_default(), + Lifetime::create_default(), + {}, + sig_priv, + }; + auto key_package_obj = KeyPackage{ + cipher_suite, init_priv.public_key, leaf_node, {}, sig_priv, + }; + key_package = key_package_obj; + + auto group_context = GroupContext{ + cipher_suite, group_id, epoch, tree_hash, confirmed_transcript_hash, {} + }; + + auto key_schedule = KeyScheduleEpoch::joiner( + cipher_suite, joiner_secret, {}, tls::marshal(group_context)); + auto confirmation_tag = + key_schedule.confirmation_tag(confirmed_transcript_hash); + + auto group_info = GroupInfo{ + group_context, + {}, + confirmation_tag, + }; + group_info.sign(signer_index, signer_priv); + + auto welcome_obj = Welcome(cipher_suite, joiner_secret, {}, group_info); + welcome_obj.encrypt(key_package_obj, std::nullopt); + welcome = welcome_obj; +} + +std::optional +WelcomeTestVector::verify() const +{ + VERIFY_EQUAL( + "kp format", key_package.wire_format(), WireFormat::mls_key_package); + VERIFY_EQUAL( + "welcome format", welcome.wire_format(), WireFormat::mls_welcome); + + const auto& key_package_obj = var::get(key_package.message); + const auto& welcome_obj = var::get(welcome.message); + + VERIFY_EQUAL("kp suite", key_package_obj.cipher_suite, cipher_suite); + VERIFY_EQUAL("welcome suite", welcome_obj.cipher_suite, cipher_suite); + + auto maybe_kpi = welcome_obj.find(key_package_obj); + VERIFY("found key package", maybe_kpi); + + auto kpi = opt::get(maybe_kpi); + auto group_secrets = welcome_obj.decrypt_secrets(kpi, init_priv); + auto group_info = welcome_obj.decrypt(group_secrets.joiner_secret, {}); + + // Verify signature on GroupInfo + VERIFY("group info verify", group_info.verify(signer_pub)); + + // Verify confirmation tag + const auto& group_context = group_info.group_context; + auto key_schedule = KeyScheduleEpoch::joiner( + cipher_suite, group_secrets.joiner_secret, {}, tls::marshal(group_context)); + auto confirmation_tag = + key_schedule.confirmation_tag(group_context.confirmed_transcript_hash); + + return std::nullopt; +} + +/// +/// TreeTestCase +/// + +std::array all_tree_structures{ + TreeStructure::full_tree_2, + TreeStructure::full_tree_3, + TreeStructure::full_tree_4, + TreeStructure::full_tree_5, + TreeStructure::full_tree_6, + TreeStructure::full_tree_7, + TreeStructure::full_tree_8, + TreeStructure::full_tree_32, + TreeStructure::full_tree_33, + TreeStructure::full_tree_34, + TreeStructure::internal_blanks_no_skipping, + TreeStructure::internal_blanks_with_skipping, + TreeStructure::unmerged_leaves_no_skipping, + TreeStructure::unmerged_leaves_with_skipping, +}; + +std::array treekem_test_tree_structures{ + // All cases except the big ones + TreeStructure::full_tree_2, + TreeStructure::full_tree_3, + TreeStructure::full_tree_4, + TreeStructure::full_tree_5, + TreeStructure::full_tree_6, + TreeStructure::full_tree_7, + TreeStructure::full_tree_8, + TreeStructure::internal_blanks_no_skipping, + TreeStructure::internal_blanks_with_skipping, + TreeStructure::unmerged_leaves_no_skipping, + TreeStructure::unmerged_leaves_with_skipping, +}; + +struct TreeTestCase +{ + CipherSuite suite; + PseudoRandom::Generator prg; + + bytes group_id; + uint32_t leaf_counter = 0; + uint32_t path_counter = 0; + + struct PrivateState + { + SignaturePrivateKey sig_priv; + TreeKEMPrivateKey priv; + std::vector senders; + }; + + std::map privs; + TreeKEMPublicKey pub; + + TreeTestCase(CipherSuite suite_in, PseudoRandom::Generator prg_in) + : suite(suite_in) + , prg(std::move(prg_in)) + , group_id(prg.secret("group_id")) + , pub(suite) + { + auto [where, enc_priv, sig_priv] = add_leaf(); + auto tree_priv = TreeKEMPrivateKey::solo(suite, where, enc_priv); + auto priv_state = PrivateState{ sig_priv, tree_priv, { LeafIndex{ 0 } } }; + privs.insert_or_assign(where, priv_state); + } + + std::tuple add_leaf() + { + leaf_counter += 1; + auto ix = to_hex(tls::marshal(leaf_counter)); + auto enc_priv = prg.hpke_key("encryption_key" + ix); + auto sig_priv = prg.signature_key("signature_key" + ix); + auto identity = prg.secret("identity" + ix); + + auto credential = Credential::basic(identity); + auto leaf_node = LeafNode{ suite, + enc_priv.public_key, + sig_priv.public_key, + credential, + Capabilities::create_default(), + Lifetime::create_default(), + {}, + sig_priv }; + auto where = pub.add_leaf(leaf_node); + pub.set_hash_all(); + return { where, enc_priv, sig_priv }; + } + + void commit(LeafIndex from, + const std::vector& remove, + bool add, + std::optional maybe_context) + { + // Remove members from the tree + for (auto i : remove) { + pub.blank_path(i); + privs.erase(i); + } + pub.set_hash_all(); + + auto joiner = std::vector{}; + auto maybe_enc_priv = std::optional{}; + auto maybe_sig_priv = std::optional{}; + if (add) { + auto [where, enc_priv, sig_priv] = add_leaf(); + joiner.push_back(where); + maybe_enc_priv = enc_priv; + maybe_sig_priv = sig_priv; + } + + auto path_secret = std::optional{}; + if (maybe_context) { + // Create an UpdatePath + path_counter += 1; + auto ix = to_hex(tls::marshal(path_counter)); + auto leaf_secret = prg.secret("leaf_secret" + ix); + auto priv = privs.at(from); + + auto context = opt::get(maybe_context); + auto pub_before = pub; + auto sender_priv = + pub.update(from, leaf_secret, group_id, priv.sig_priv, {}); + auto path = pub.encap(sender_priv, context, joiner); + + // Process the UpdatePath at all the members + for (auto& pair : privs) { + // XXX(RLB): It might seem like this could be done with a simple + // destructuring assignment, either here or in the `for` clause above. + // However, either of these options cause clang-tidy to segfault when + // evaulating the "bugprone-unchecked-optional-access" lint. + const auto& leaf = pair.first; + auto& priv_state = pair.second; + if (leaf == from) { + priv_state = + PrivateState{ priv_state.sig_priv, sender_priv, { from } }; + continue; + } + + priv_state.priv.decap(from, pub_before, context, path, joiner); + priv_state.senders.push_back(from); + } + + // Look up the path secret for the joiner + if (!joiner.empty()) { + auto index = joiner.front(); + auto [overlap, shared_path_secret, ok] = + sender_priv.shared_path_secret(index); + silence_unused(overlap); + silence_unused(ok); + + path_secret = shared_path_secret; + } + } + + // Add a private entry for the joiner if we added someone + if (!joiner.empty()) { + auto index = joiner.front(); + auto ancestor = index.ancestor(from); + auto enc_priv = opt::get(maybe_enc_priv); + auto sig_priv = opt::get(maybe_sig_priv); + auto tree_priv = + TreeKEMPrivateKey::joiner(pub, index, enc_priv, ancestor, path_secret); + privs.insert_or_assign(index, + PrivateState{ sig_priv, tree_priv, { from } }); + } + } + + static TreeTestCase full(CipherSuite suite, + const PseudoRandom::Generator& prg, + LeafCount leaves, + const std::string& label) + { + auto tc = TreeTestCase{ suite, prg.sub(label) }; + + for (LeafIndex i{ 0 }; i.val < leaves.val - 1; i.val++) { + tc.commit( + i, {}, true, tc.prg.secret("context" + to_hex(tls::marshal(i)))); + } + + return tc; + } + + static TreeTestCase with_structure(CipherSuite suite, + const PseudoRandom::Generator& prg, + TreeStructure tree_structure) + { + switch (tree_structure) { + case TreeStructure::full_tree_2: + return full(suite, prg, LeafCount{ 2 }, "full_tree_2"); + + case TreeStructure::full_tree_3: + return full(suite, prg, LeafCount{ 3 }, "full_tree_3"); + + case TreeStructure::full_tree_4: + return full(suite, prg, LeafCount{ 4 }, "full_tree_4"); + + case TreeStructure::full_tree_5: + return full(suite, prg, LeafCount{ 5 }, "full_tree_5"); + + case TreeStructure::full_tree_6: + return full(suite, prg, LeafCount{ 6 }, "full_tree_6"); + + case TreeStructure::full_tree_7: + return full(suite, prg, LeafCount{ 7 }, "full_tree_7"); + + case TreeStructure::full_tree_8: + return full(suite, prg, LeafCount{ 8 }, "full_tree_8"); + + case TreeStructure::full_tree_32: + return full(suite, prg, LeafCount{ 32 }, "full_tree_32"); + + case TreeStructure::full_tree_33: + return full(suite, prg, LeafCount{ 33 }, "full_tree_33"); + + case TreeStructure::full_tree_34: + return full(suite, prg, LeafCount{ 34 }, "full_tree_34"); + + case TreeStructure::internal_blanks_no_skipping: { + auto tc = TreeTestCase::full( + suite, prg, LeafCount{ 8 }, "internal_blanks_no_skipping"); + auto context = tc.prg.secret("context"); + tc.commit( + LeafIndex{ 0 }, { LeafIndex{ 2 }, LeafIndex{ 3 } }, true, context); + return tc; + } + + case TreeStructure::internal_blanks_with_skipping: { + auto tc = TreeTestCase::full( + suite, prg, LeafCount{ 8 }, "internal_blanks_with_skipping"); + auto context = tc.prg.secret("context"); + tc.commit(LeafIndex{ 0 }, + { LeafIndex{ 1 }, LeafIndex{ 2 }, LeafIndex{ 3 } }, + false, + context); + return tc; + } + + case TreeStructure::unmerged_leaves_no_skipping: { + auto tc = TreeTestCase::full( + suite, prg, LeafCount{ 7 }, "unmerged_leaves_no_skipping"); + auto context = tc.prg.secret("context"); + tc.commit(LeafIndex{ 0 }, {}, true, std::nullopt); + return tc; + } + + case TreeStructure::unmerged_leaves_with_skipping: { + auto tc = TreeTestCase::full( + suite, prg, LeafCount{ 1 }, "unmerged_leaves_with_skipping"); + + // 0 adds 1..6 + tc.commit(LeafIndex{ 0 }, {}, true, std::nullopt); + tc.commit(LeafIndex{ 0 }, {}, true, std::nullopt); + tc.commit(LeafIndex{ 0 }, {}, true, std::nullopt); + tc.commit(LeafIndex{ 0 }, {}, true, std::nullopt); + tc.commit(LeafIndex{ 0 }, {}, true, std::nullopt); + tc.commit(LeafIndex{ 0 }, {}, true, std::nullopt); + + // 0 reemoves 5 + tc.commit(LeafIndex{ 0 }, + { LeafIndex{ 5 } }, + false, + tc.prg.secret("context_remove5")); + + // 4 commits without any proupposals + tc.commit(LeafIndex{ 4 }, {}, false, tc.prg.secret("context_update4")); + + // 0 adds a new member + tc.commit(LeafIndex{ 0 }, {}, true, std::nullopt); + + return tc; + } + + default: + throw InvalidParameterError("Unsupported tree structure"); + } + } +}; + +/// +/// TreeHashTestVector +/// +TreeHashTestVector::TreeHashTestVector(mlspp::CipherSuite suite, + TreeStructure tree_structure) + : PseudoRandom(suite, "tree-hashes") + , cipher_suite(suite) +{ + auto tc = TreeTestCase::with_structure(suite, prg, tree_structure); + tree = tc.pub; + group_id = tc.group_id; + + auto width = NodeCount(tree.size); + for (NodeIndex i{ 0 }; i < width; i.val++) { + tree_hashes.push_back(tree.get_hash(i)); + resolutions.push_back(tree.resolve(i)); + } +} + +std::optional +TreeHashTestVector::verify() +{ + // Finish setting up the tree + tree.suite = cipher_suite; + tree.set_hash_all(); + + // Verify that each leaf node is properly signed + for (LeafIndex i{ 0 }; i < tree.size; i.val++) { + auto maybe_leaf = tree.leaf_node(i); + if (!maybe_leaf) { + continue; + } + + auto leaf = opt::get(maybe_leaf); + auto leaf_valid = leaf.verify(cipher_suite, { { group_id, i } }); + VERIFY("leaf sig valid", leaf_valid); + } + + // Verify the tree hashes + auto width = NodeCount{ tree.size }; + for (NodeIndex i{ 0 }; i < width; i.val++) { + VERIFY_EQUAL("tree hash", tree.get_hash(i), tree_hashes.at(i.val)); + VERIFY_EQUAL("resolution", tree.resolve(i), resolutions.at(i.val)); + } + + // Verify parent hashes + VERIFY("parent hash valid", tree.parent_hash_valid()); + + // Verify the resolutions + for (NodeIndex i{ 0 }; i < width; i.val++) { + VERIFY_EQUAL("resolution", tree.resolve(i), resolutions[i.val]); + } + + return std::nullopt; +} + +/// +/// TreeOperationsTestVector +/// + +const std::vector + TreeOperationsTestVector::all_scenarios{ + Scenario::add_right_edge, Scenario::add_internal, Scenario::update, + Scenario::remove_right_edge, Scenario::remove_internal, + }; + +TreeOperationsTestVector::TreeOperationsTestVector( + mlspp::CipherSuite suite, + Scenario scenario) + : PseudoRandom(suite, "tree-operations") + , cipher_suite(suite) + , proposal_sender(0) +{ + auto init_priv = prg.hpke_key("init_key"); + auto enc_priv = prg.hpke_key("encryption_key"); + auto sig_priv = prg.signature_key("signature_key"); + auto identity = prg.secret("identity"); + auto credential = Credential::basic(identity); + auto key_package = KeyPackage{ + suite, + init_priv.public_key, + { suite, + enc_priv.public_key, + sig_priv.public_key, + credential, + Capabilities::create_default(), + Lifetime::create_default(), + {}, + sig_priv }, + {}, + sig_priv, + }; + + switch (scenario) { + case Scenario::add_right_edge: { + auto tc = TreeTestCase::full(suite, prg, LeafCount{ 8 }, "tc"); + + proposal = Proposal{ Add{ key_package } }; + + tree_before = tc.pub; + tree_hash_before = tree_before.root_hash(); + + tree_after = tree_before; + tree_after.add_leaf(key_package.leaf_node); + break; + } + + case Scenario::add_internal: { + auto tc = TreeTestCase::full(suite, prg, LeafCount{ 8 }, "tc"); + + proposal = Proposal{ Add{ key_package } }; + + tree_before = tc.pub; + tree_before.blank_path(LeafIndex{ 4 }); + tree_before.set_hash_all(); + tree_hash_before = tree_before.root_hash(); + + tree_after = tree_before; + tree_after.add_leaf(key_package.leaf_node); + break; + } + + case Scenario::update: { + auto tc = TreeTestCase::full(suite, prg, LeafCount{ 8 }, "tc"); + + proposal_sender = LeafIndex{ 3 }; + proposal = Proposal{ Update{ key_package.leaf_node } }; + + tree_before = tc.pub; + tree_hash_before = tree_before.root_hash(); + + tree_after = tree_before; + tree_after.update_leaf(proposal_sender, key_package.leaf_node); + break; + } + + case Scenario::remove_right_edge: { + auto tc = TreeTestCase::full(suite, prg, LeafCount{ 9 }, "tc"); + + auto removed = LeafIndex{ 8 }; + proposal = Proposal{ Remove{ removed } }; + + tree_before = tc.pub; + tree_hash_before = tree_before.root_hash(); + + tree_after = tree_before; + tree_after.blank_path(removed); + tree_after.truncate(); + break; + } + + case Scenario::remove_internal: { + auto tc = TreeTestCase::full(suite, prg, LeafCount{ 8 }, "tc"); + + auto removed = LeafIndex{ 4 }; + proposal = Proposal{ Remove{ removed } }; + + tree_before = tc.pub; + tree_hash_before = tree_before.root_hash(); + + tree_after = tree_before; + tree_after.blank_path(removed); + tree_after.truncate(); + break; + } + } + + tree_after.set_hash_all(); + tree_hash_after = tree_after.root_hash(); +} + +std::optional +TreeOperationsTestVector::verify() +{ + tree_before.suite = cipher_suite; + tree_before.set_hash_all(); + + auto tree = tree_before; + VERIFY_EQUAL("tree hash before", tree.root_hash(), tree_hash_before); + + auto apply = overloaded{ + [&](const Add& add) { tree.add_leaf(add.key_package.leaf_node); }, + + [&](const Update& update) { + tree.update_leaf(proposal_sender, update.leaf_node); + }, + + [&](const Remove& remove) { + tree.blank_path(remove.removed); + tree.truncate(); + }, + + [](const auto& /* other */) { + throw InvalidParameterError("invalid proposal type"); + }, + }; + + var::visit(apply, proposal.content); + VERIFY_EQUAL("tree after", tree, tree_after); + + tree.set_hash_all(); + VERIFY_EQUAL("tree hash after", tree.root_hash(), tree_hash_after); + + return std::nullopt; +} + +/// +/// TreeKEMTestVector +/// + +TreeKEMTestVector::TreeKEMTestVector(mlspp::CipherSuite suite, + TreeStructure tree_structure) + : PseudoRandom(suite, "treekem") + , cipher_suite(suite) +{ + auto tc = TreeTestCase::with_structure(cipher_suite, prg, tree_structure); + + group_id = tc.group_id; + epoch = prg.uint64("epoch"); + confirmed_transcript_hash = prg.secret("confirmed_transcript_hash"); + + ratchet_tree = tc.pub; + + // Serialize out the private states + for (LeafIndex index{ 0 }; index < ratchet_tree.size; index.val++) { + if (tc.privs.count(index) == 0) { + continue; + } + + auto priv_state = tc.privs.at(index); + auto enc_priv = priv_state.priv.private_key_cache.at(NodeIndex(index)); + auto path_secrets = std::vector{}; + for (const auto& [node, path_secret] : priv_state.priv.path_secrets) { + if (node == NodeIndex(index)) { + // No need to serialize a secret for the leaf node + continue; + } + + path_secrets.push_back(PathSecret{ node, path_secret }); + } + + leaves_private.push_back(LeafPrivateInfo{ + index, + enc_priv, + priv_state.sig_priv, + path_secrets, + }); + } + + // Create test update paths + for (LeafIndex sender{ 0 }; sender < ratchet_tree.size; sender.val++) { + if (!tc.pub.has_leaf(sender)) { + continue; + } + + auto leaf_secret = prg.secret("update_path" + to_hex(tls::marshal(sender))); + const auto& sig_priv = tc.privs.at(sender).sig_priv; + + auto pub = tc.pub; + auto new_sender_priv = + pub.update(sender, leaf_secret, group_id, sig_priv, {}); + + auto group_context = GroupContext{ cipher_suite, + group_id, + epoch, + pub.root_hash(), + confirmed_transcript_hash, + {} }; + auto ctx = tls::marshal(group_context); + + auto path = pub.encap(new_sender_priv, ctx, {}); + + auto path_secrets = std::vector>{}; + for (LeafIndex to{ 0 }; to < ratchet_tree.size; to.val++) { + if (to == sender || !pub.has_leaf(to)) { + path_secrets.emplace_back(std::nullopt); + continue; + } + + auto [overlap, path_secret, ok] = new_sender_priv.shared_path_secret(to); + silence_unused(overlap); + silence_unused(ok); + + path_secrets.emplace_back(path_secret); + } + + update_paths.push_back(UpdatePathInfo{ + sender, + path, + path_secrets, + new_sender_priv.update_secret, + pub.root_hash(), + }); + } +} + +std::optional +TreeKEMTestVector::verify() +{ + // Finish initializing the ratchet tree + ratchet_tree.suite = cipher_suite; + ratchet_tree.set_hash_all(); + + // Validate public state + VERIFY("parent hash valid", ratchet_tree.parent_hash_valid()); + + for (LeafIndex i{ 0 }; i < ratchet_tree.size; i.val++) { + auto maybe_leaf = ratchet_tree.leaf_node(i); + if (!maybe_leaf) { + continue; + } + + auto leaf = opt::get(maybe_leaf); + VERIFY("leaf sig", leaf.verify(cipher_suite, { { group_id, i } })); + } + + // Import private keys + std::map tree_privs; + std::map sig_privs; + for (const auto& info : leaves_private) { + auto enc_priv = info.encryption_priv; + auto sig_priv = info.signature_priv; + enc_priv.set_public_key(cipher_suite); + sig_priv.set_public_key(cipher_suite); + + auto priv = TreeKEMPrivateKey{}; + priv.suite = cipher_suite; + priv.index = info.index; + priv.private_key_cache.insert_or_assign(NodeIndex(info.index), enc_priv); + + for (const auto& entry : info.path_secrets) { + priv.path_secrets.insert_or_assign(entry.node, entry.path_secret); + } + + VERIFY("priv consistent", priv.consistent(ratchet_tree)); + + tree_privs.insert_or_assign(info.index, priv); + sig_privs.insert_or_assign(info.index, sig_priv); + } + + for (const auto& info : update_paths) { + // Test decap of the existing group secrets + const auto& from = info.sender; + const auto& path = info.update_path; + VERIFY("path parent hash valid", + ratchet_tree.parent_hash_valid(from, path)); + + auto ratchet_tree_after = ratchet_tree; + ratchet_tree_after.merge(from, path); + ratchet_tree_after.set_hash_all(); + VERIFY_EQUAL( + "tree hash after", ratchet_tree_after.root_hash(), info.tree_hash_after); + + auto group_context = GroupContext{ cipher_suite, + group_id, + epoch, + ratchet_tree_after.root_hash(), + confirmed_transcript_hash, + {} }; + auto ctx = tls::marshal(group_context); + + for (LeafIndex to{ 0 }; to < ratchet_tree_after.size; to.val++) { + if (to == from || !ratchet_tree_after.has_leaf(to)) { + continue; + } + + auto priv = tree_privs.at(to); + priv.decap(from, ratchet_tree_after, ctx, path, {}); + VERIFY_EQUAL("commit secret", priv.update_secret, info.commit_secret); + + auto [overlap, path_secret, ok] = priv.shared_path_secret(from); + silence_unused(overlap); + silence_unused(ok); + VERIFY_EQUAL("path secret", path_secret, info.path_secrets[to.val]); + } + + // Test encap/decap + auto ratchet_tree_encap = ratchet_tree; + auto leaf_secret = random_bytes(cipher_suite.secret_size()); + const auto& sig_priv = sig_privs.at(from); + auto new_sender_priv = + ratchet_tree_encap.update(from, leaf_secret, group_id, sig_priv, {}); + auto new_path = ratchet_tree_encap.encap(new_sender_priv, ctx, {}); + VERIFY("new path parent hash valid", + ratchet_tree.parent_hash_valid(from, path)); + + for (LeafIndex to{ 0 }; to < ratchet_tree_encap.size; to.val++) { + if (to == from || !ratchet_tree_encap.has_leaf(to)) { + continue; + } + + auto priv = tree_privs.at(to); + priv.decap(from, ratchet_tree_encap, ctx, new_path, {}); + VERIFY_EQUAL( + "commit secret", priv.update_secret, new_sender_priv.update_secret); + } + } + + return std::nullopt; +} + +/// +/// MessagesTestVector +/// + +MessagesTestVector::MessagesTestVector() + : PseudoRandom(CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519, "messages") +{ + auto suite = CipherSuite{ CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519 }; + auto epoch = epoch_t(prg.uint64("epoch")); + auto index = LeafIndex{ prg.uint32("index") }; + auto user_id = prg.secret("user_id"); + auto group_id = prg.secret("group_id"); + // auto opaque = bytes(32, 0xD3); + // auto mac = bytes(32, 0xD5); + + auto app_id_ext = ApplicationIDExtension{ prg.secret("app_id") }; + auto ext_list = ExtensionList{}; + ext_list.add(app_id_ext); + + auto group_context = GroupContext{ suite, + group_id, + epoch, + prg.secret("tree_hash"), + prg.secret("confirmed_trasncript_hash"), + ext_list }; + + auto version = ProtocolVersion::mls10; + auto hpke_priv = prg.hpke_key("hpke_priv"); + auto hpke_priv_2 = prg.hpke_key("hpke_priv_2"); + auto hpke_pub = hpke_priv.public_key; + auto hpke_pub_2 = hpke_priv_2.public_key; + auto hpke_ct = + HPKECiphertext{ prg.secret("kem_output"), prg.secret("ciphertext") }; + auto sig_priv = prg.signature_key("signature_priv"); + auto sig_priv_2 = prg.signature_key("signature_priv_2"); + auto sig_pub = sig_priv.public_key; + auto sig_pub_2 = sig_priv_2.public_key; + + // KeyPackage and extensions + auto cred = Credential::basic(user_id); + auto leaf_node = LeafNode{ suite, + hpke_pub, + sig_pub, + cred, + Capabilities::create_default(), + Lifetime::create_default(), + ext_list, + sig_priv }; + auto leaf_node_2 = LeafNode{ suite, + hpke_pub_2, + sig_pub_2, + cred, + Capabilities::create_default(), + Lifetime::create_default(), + ext_list, + sig_priv_2 }; + auto key_package_obj = KeyPackage{ suite, hpke_pub, leaf_node, {}, sig_priv }; + + auto leaf_node_update = + leaf_node.for_update(suite, group_id, index, hpke_pub, {}, sig_priv); + auto leaf_node_commit = leaf_node.for_commit( + suite, group_id, index, hpke_pub, prg.secret("parent_hash"), {}, sig_priv); + + auto sender = Sender{ MemberSender{ index } }; + + auto tree = TreeKEMPublicKey{ suite }; + tree.add_leaf(leaf_node); + tree.add_leaf(leaf_node_2); + auto ratchet_tree_obj = RatchetTreeExtension{ tree }; + + // Welcome and its substituents + auto group_info_obj = + GroupInfo{ group_context, ext_list, prg.secret("confirmation_tag") }; + auto joiner_secret = prg.secret("joiner_secret"); + auto path_secret = prg.secret("path_secret"); + auto psk_id = ExternalPSK{ prg.secret("psk_id") }; + auto psk_nonce = prg.secret("psk_nonce"); + auto group_secrets_obj = GroupSecrets{ joiner_secret, + { { path_secret } }, + PreSharedKeys{ { + { psk_id, psk_nonce }, + } } }; + auto welcome_obj = Welcome{ suite, joiner_secret, {}, group_info_obj }; + welcome_obj.encrypt(key_package_obj, path_secret); + + // Proposals + auto add = Add{ key_package_obj }; + auto update = Update{ leaf_node_update }; + auto remove = Remove{ index }; + auto pre_shared_key = PreSharedKey{ psk_id, psk_nonce }; + auto reinit = ReInit{ group_id, version, suite, {} }; + auto external_init = ExternalInit{ prg.secret("external_init") }; + + // Commit + auto proposal_ref = ProposalRef{ 32, 0xa0 }; + + auto commit_obj = Commit{ { + { proposal_ref }, + { Proposal{ add } }, + }, + UpdatePath{ + leaf_node_commit, + { + { hpke_pub, { hpke_ct, hpke_ct } }, + { hpke_pub, { hpke_ct, hpke_ct, hpke_ct } }, + }, + } }; + + // AuthenticatedContent with Application / Proposal / Commit + + // PublicMessage + auto membership_key = prg.secret("membership_key"); + + auto content_auth_proposal = AuthenticatedContent::sign( + WireFormat::mls_public_message, + { group_id, epoch, sender, {}, Proposal{ remove } }, + suite, + sig_priv, + group_context); + auto public_message_proposal_obj = PublicMessage::protect( + content_auth_proposal, suite, membership_key, group_context); + + auto content_auth_commit = + AuthenticatedContent::sign(WireFormat::mls_public_message, + { group_id, epoch, sender, {}, commit_obj }, + suite, + sig_priv, + group_context); + content_auth_commit.set_confirmation_tag(prg.secret("confirmation_tag")); + auto public_message_commit_obj = PublicMessage::protect( + content_auth_commit, suite, membership_key, group_context); + + // PrivateMessage + auto content_auth_application_obj = AuthenticatedContent::sign( + WireFormat::mls_private_message, + { group_id, epoch, sender, {}, ApplicationData{} }, + suite, + sig_priv, + group_context); + + auto keys = GroupKeySource( + suite, LeafCount{ index.val + 1 }, prg.secret("encryption_secret")); + auto private_message_obj = + PrivateMessage::protect(content_auth_application_obj, + suite, + keys, + prg.secret("sender_data_secret"), + 10); + + // Serialize out all the objects + mls_welcome = tls::marshal(MLSMessage{ welcome_obj }); + mls_group_info = tls::marshal(MLSMessage{ group_info_obj }); + mls_key_package = tls::marshal(MLSMessage{ key_package_obj }); + + ratchet_tree = tls::marshal(ratchet_tree_obj); + group_secrets = tls::marshal(group_secrets_obj); + + add_proposal = tls::marshal(add); + update_proposal = tls::marshal(update); + remove_proposal = tls::marshal(remove); + pre_shared_key_proposal = tls::marshal(pre_shared_key); + re_init_proposal = tls::marshal(reinit); + external_init_proposal = tls::marshal(external_init); + + commit = tls::marshal(commit_obj); + + public_message_proposal = + tls::marshal(MLSMessage{ public_message_proposal_obj }); + public_message_commit = tls::marshal(MLSMessage{ public_message_commit_obj }); + private_message = tls::marshal(MLSMessage{ private_message_obj }); +} + +std::optional +MessagesTestVector::verify() const +{ + // TODO(RLB) Verify signatures + // TODO(RLB) Verify content types in PublicMessage objects + auto require_format = [](WireFormat format) { + return + [format](const MLSMessage& msg) { return msg.wire_format() == format; }; + }; + + VERIFY_TLS_RTT_VAL("Welcome", + MLSMessage, + mls_welcome, + require_format(WireFormat::mls_welcome)); + VERIFY_TLS_RTT_VAL("GroupInfo", + MLSMessage, + mls_group_info, + require_format(WireFormat::mls_group_info)); + VERIFY_TLS_RTT_VAL("KeyPackage", + MLSMessage, + mls_key_package, + require_format(WireFormat::mls_key_package)); + + VERIFY_TLS_RTT("RatchetTree", RatchetTreeExtension, ratchet_tree); + VERIFY_TLS_RTT("GroupSecrets", GroupSecrets, group_secrets); + + VERIFY_TLS_RTT("Add", Add, add_proposal); + VERIFY_TLS_RTT("Update", Update, update_proposal); + VERIFY_TLS_RTT("Remove", Remove, remove_proposal); + VERIFY_TLS_RTT("PreSharedKey", PreSharedKey, pre_shared_key_proposal); + VERIFY_TLS_RTT("ReInit", ReInit, re_init_proposal); + VERIFY_TLS_RTT("ExternalInit", ExternalInit, external_init_proposal); + + VERIFY_TLS_RTT("Commit", Commit, commit); + + VERIFY_TLS_RTT_VAL("Public(Proposal)", + MLSMessage, + public_message_proposal, + require_format(WireFormat::mls_public_message)); + VERIFY_TLS_RTT_VAL("Public(Commit)", + MLSMessage, + public_message_commit, + require_format(WireFormat::mls_public_message)); + VERIFY_TLS_RTT_VAL("PrivateMessage", + MLSMessage, + private_message, + require_format(WireFormat::mls_private_message)); + + return std::nullopt; +} + +std::optional +PassiveClientTestVector::verify() +{ + // Import everything + signature_priv.set_public_key(cipher_suite); + encryption_priv.set_public_key(cipher_suite); + init_priv.set_public_key(cipher_suite); + + const auto& key_package_raw = var::get(key_package.message); + const auto& welcome_raw = var::get(welcome.message); + + auto ext_psks = std::map{}; + for (const auto& [id, psk] : external_psks) { + ext_psks.insert_or_assign(id, psk); + } + + // Join the group and follow along + auto state = State(init_priv, + encryption_priv, + signature_priv, + key_package_raw, + welcome_raw, + ratchet_tree, + ext_psks); + VERIFY_EQUAL( + "initial epoch", state.epoch_authenticator(), initial_epoch_authenticator); + + for (const auto& tve : epochs) { + for (const auto& proposal : tve.proposals) { + state.handle(proposal); + } + + state = opt::get(state.handle(tve.commit)); + VERIFY_EQUAL( + "epoch auth", state.epoch_authenticator(), tve.epoch_authenticator) + } + + return std::nullopt; +} + +} // namespace mls_vectors diff --git a/mlspp/lib/tls_syntax/CMakeLists.txt b/mlspp/lib/tls_syntax/CMakeLists.txt new file mode 100755 index 0000000000..f929c53e93 --- /dev/null +++ b/mlspp/lib/tls_syntax/CMakeLists.txt @@ -0,0 +1,35 @@ +set(CURRENT_LIB_NAME tls_syntax) + +### +### Library Config +### + +file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h") +file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") + +add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES}) +add_dependencies(${CURRENT_LIB_NAME} third_party) +target_link_libraries(${CURRENT_LIB_NAME} PUBLIC third_party) +target_include_directories(${CURRENT_LIB_NAME} + PUBLIC + $ + $ + $ +) + +### +### Install +### + +install(TARGETS ${CURRENT_LIB_NAME} EXPORT mlspp-targets) +install( + DIRECTORY + include + DESTINATION + ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME} +) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include") diff --git a/mlspp/lib/tls_syntax/include/tls/compat.h b/mlspp/lib/tls_syntax/include/tls/compat.h new file mode 100755 index 0000000000..c226e44ecd --- /dev/null +++ b/mlspp/lib/tls_syntax/include/tls/compat.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include + +#ifdef VARIANT_COMPAT +#include +#else +#include +#endif // VARIANT_COMPAT + +namespace mlspp::tls { + +// To balance backward-compatibility with macOS 10.11 with forward-compatibility +// with future versions of C++, we use `mpark::variant` or `std::variant` as +// needed, using `var::variant` to refer to whichever one is in use. +#ifdef VARIANT_COMPAT +namespace var = mpark; +#else +namespace var = std; +#endif // VARIANT_COMPAT + +// In a similar vein, we provide our own safe accessors for std::optional, since +// std::optional::value() is not available on macOS 10.11. +namespace opt { + +template +T& +get(std::optional& opt) +{ + if (!opt) { + throw std::runtime_error("bad_optional_access"); + } + return *opt; +} + +template +const T& +get(const std::optional& opt) +{ + if (!opt) { + throw std::runtime_error("bad_optional_access"); + } + return *opt; +} + +template +T&& +get(std::optional&& opt) +{ + if (!opt) { + throw std::runtime_error("bad_optional_access"); + } + return std::move(*opt); +} + +template +const T&& +get(const std::optional&& opt) +{ + if (!opt) { + throw std::runtime_error("bad_optional_access"); + } + return std::move(*opt); +} + +} // namespace opt +} // namespace mlspp::tls diff --git a/mlspp/lib/tls_syntax/include/tls/tls_syntax.h b/mlspp/lib/tls_syntax/include/tls/tls_syntax.h new file mode 100755 index 0000000000..09d5940d9d --- /dev/null +++ b/mlspp/lib/tls_syntax/include/tls/tls_syntax.h @@ -0,0 +1,569 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace mlspp::tls { + +// For indicating no min or max in vector definitions +const size_t none = std::numeric_limits::max(); + +class WriteError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +class ReadError : public std::invalid_argument +{ +public: + using parent = std::invalid_argument; + using parent::parent; +}; + +/// +/// Declarations of Streams and Traits +/// + +class ostream +{ +public: + static const size_t none = std::numeric_limits::max(); + + void write_raw(const std::vector& bytes); + + const std::vector& bytes() const { return _buffer; } + size_t size() const { return _buffer.size(); } + bool empty() const { return _buffer.empty(); } + +private: + std::vector _buffer; + ostream& write_uint(uint64_t value, int length); + + friend ostream& operator<<(ostream& out, bool data); + friend ostream& operator<<(ostream& out, uint8_t data); + friend ostream& operator<<(ostream& out, uint16_t data); + friend ostream& operator<<(ostream& out, uint32_t data); + friend ostream& operator<<(ostream& out, uint64_t data); + + template + friend ostream& operator<<(ostream& out, const std::vector& data); + + friend struct varint; +}; + +class istream +{ +public: + istream(const std::vector& data) + : _buffer(data) + { + // So that we can use the constant-time pop_back + std::reverse(_buffer.begin(), _buffer.end()); + } + + size_t size() const { return _buffer.size(); } + bool empty() const { return _buffer.empty(); } + + std::vector bytes() + { + auto bytes = _buffer; + std::reverse(bytes.begin(), bytes.end()); + return bytes; + } + +private: + istream() {} + std::vector _buffer; + uint8_t next(); + + template + istream& read_uint(T& data, size_t length) + { + uint64_t value = 0; + for (size_t i = 0; i < length; i += 1) { + value = (value << unsigned(8)) + next(); + } + data = static_cast(value); + return *this; + } + + friend istream& operator>>(istream& in, bool& data); + friend istream& operator>>(istream& in, uint8_t& data); + friend istream& operator>>(istream& in, uint16_t& data); + friend istream& operator>>(istream& in, uint32_t& data); + friend istream& operator>>(istream& in, uint64_t& data); + + template + friend istream& operator>>(istream& in, std::vector& data); + + friend struct varint; +}; + +// Traits must have static encode and decode methods, of the following form: +// +// static ostream& encode(ostream& str, const T& val); +// static istream& decode(istream& str, T& val); +// +// Trait types will never be constructed; only these static methods are used. +// The value arguments to encode and decode can be as strict or as loose as +// desired. +// +// Ultimately, all interesting encoding should be done through traits. +// +// * vectors +// * variants +// * varints + +struct pass +{ + template + static ostream& encode(ostream& str, const T& val); + + template + static istream& decode(istream& str, T& val); +}; + +template +struct variant +{ + template + static inline Ts type(const var::variant& data); + + template + static ostream& encode(ostream& str, const var::variant& data); + + template + static inline typename std::enable_if::type + read_variant(istream&, Te, var::variant&); + + template + static inline typename std::enable_if < + I::type read_variant(istream& str, + Te target_type, + var::variant& v); + + template + static istream& decode(istream& str, var::variant& data); +}; + +struct varint +{ + static ostream& encode(ostream& str, const uint64_t& val); + static istream& decode(istream& str, uint64_t& val); +}; + +/// +/// Writer implementations +/// + +// Primitive writers defined in .cpp file + +// Array writer +template +ostream& +operator<<(ostream& out, const std::array& data) +{ + for (const auto& item : data) { + out << item; + } + return out; +} + +// Optional writer +template +ostream& +operator<<(ostream& out, const std::optional& opt) +{ + if (!opt) { + return out << uint8_t(0); + } + + return out << uint8_t(1) << opt::get(opt); +} + +// Enum writer +template::value, int> = 0> +ostream& +operator<<(ostream& str, const T& val) +{ + auto u = static_cast>(val); + return str << u; +} + +// Vector writer +template +ostream& +operator<<(ostream& str, const std::vector& vec) +{ + // Pre-encode contents + ostream temp; + for (const auto& item : vec) { + temp << item; + } + + // Write the encoded length, then the pre-encoded data + varint::encode(str, temp._buffer.size()); + str.write_raw(temp.bytes()); + + return str; +} + +/// +/// Reader implementations +/// + +// Primitive type readers defined in .cpp file + +// Array reader +template +istream& +operator>>(istream& in, std::array& data) +{ + for (auto& item : data) { + in >> item; + } + return in; +} + +// Optional reader +template +istream& +operator>>(istream& in, std::optional& opt) +{ + uint8_t present = 0; + in >> present; + + switch (present) { + case 0: + opt.reset(); + return in; + + case 1: + opt.emplace(); + return in >> opt::get(opt); + + default: + throw std::invalid_argument("Malformed optional"); + } +} + +// Enum reader +// XXX(rlb): It would be nice if this could enforce that the values are valid, +// but C++ doesn't seem to have that ability. When used as a tag for variants, +// the variant reader will enforce, at least. +template::value, int> = 0> +istream& +operator>>(istream& str, T& val) +{ + std::underlying_type_t u; + str >> u; + val = static_cast(u); + return str; +} + +// Vector reader +template +istream& +operator>>(istream& str, std::vector& vec) +{ + // Read the encoded data size + auto size = uint64_t(0); + varint::decode(str, size); + if (size > str._buffer.size()) { + throw ReadError("Vector is longer than remaining data"); + } + + // Read the elements of the vector + // NB: Remember that we store the vector in reverse order + // NB: This requires that T be default-constructible + istream r; + r._buffer = + std::vector{ str._buffer.end() - size, str._buffer.end() }; + + vec.clear(); + while (r._buffer.size() > 0) { + vec.emplace_back(); + r >> vec.back(); + } + + // Truncate the primary buffer + str._buffer.erase(str._buffer.end() - size, str._buffer.end()); + + return str; +} + +// Abbreviations +template +std::vector +marshal(const T& value) +{ + ostream w; + w << value; + return w.bytes(); +} + +template +void +unmarshal(const std::vector& data, T& value) +{ + istream r(data); + r >> value; +} + +template +T +get(const std::vector& data, Tp... args) +{ + T value(args...); + unmarshal(data, value); + return value; +} + +// Use this macro to define struct serialization with minimal boilerplate +#define TLS_SERIALIZABLE(...) \ + static const bool _tls_serializable = true; \ + auto _tls_fields_r() \ + { \ + return std::forward_as_tuple(__VA_ARGS__); \ + } \ + auto _tls_fields_w() const \ + { \ + return std::forward_as_tuple(__VA_ARGS__); \ + } + +// If your struct contains nontrivial members (e.g., vectors), use this to +// define traits for them. +#define TLS_TRAITS(...) \ + static const bool _tls_has_traits = true; \ + using _tls_traits = std::tuple<__VA_ARGS__>; + +template +struct is_serializable +{ + template + static std::true_type test(decltype(U::_tls_serializable)); + + template + static std::false_type test(...); + + static const bool value = decltype(test(true))::value; +}; + +template +struct has_traits +{ + template + static std::true_type test(decltype(U::_tls_has_traits)); + + template + static std::false_type test(...); + + static const bool value = decltype(test(true))::value; +}; + +/// +/// Trait implementations +/// + +// Pass-through (normal encoding/decoding) +template +ostream& +pass::encode(ostream& str, const T& val) +{ + return str << val; +} + +template +istream& +pass::decode(istream& str, T& val) +{ + return str >> val; +} + +// Variant encoding +template +constexpr Ts +variant_map(); + +#define TLS_VARIANT_MAP(EnumType, MappedType, enum_value) \ + template<> \ + constexpr EnumType variant_map() \ + { \ + return EnumType::enum_value; \ + } + +template +template +inline Ts +variant::type(const var::variant& data) +{ + const auto get_type = [](const auto& v) { + return variant_map>(); + }; + return var::visit(get_type, data); +} + +template +template +ostream& +variant::encode(ostream& str, const var::variant& data) +{ + const auto write_variant = [&str](auto&& v) { + using Tv = std::decay_t; + str << variant_map() << v; + }; + var::visit(write_variant, data); + return str; +} + +template +template +inline typename std::enable_if::type +variant::read_variant(istream&, Te, var::variant&) +{ + throw ReadError("Invalid variant type label"); +} + +template + template + inline + typename std::enable_if < I::type + variant::read_variant(istream& str, + Te target_type, + var::variant& v) +{ + using Tc = var::variant_alternative_t>; + if (variant_map() == target_type) { + str >> v.template emplace(); + return; + } + + read_variant(str, target_type, v); +} + +template +template +istream& +variant::decode(istream& str, var::variant& data) +{ + Ts target_type; + str >> target_type; + read_variant(str, target_type, data); + return str; +} + +// Struct writer without traits (enabled by macro) +template +inline typename std::enable_if::type +write_tuple(ostream&, const std::tuple&) +{ +} + +template + inline typename std::enable_if < + I::type + write_tuple(ostream& str, const std::tuple& t) +{ + str << std::get(t); + write_tuple(str, t); +} + +template +inline + typename std::enable_if::value && !has_traits::value, + ostream&>::type + operator<<(ostream& str, const T& obj) +{ + write_tuple(str, obj._tls_fields_w()); + return str; +} + +// Struct writer with traits (enabled by macro) +template +inline typename std::enable_if::type +write_tuple_traits(ostream&, const std::tuple&) +{ +} + +template + inline typename std::enable_if < + I::type + write_tuple_traits(ostream& str, const std::tuple& t) +{ + std::tuple_element_t::encode(str, std::get(t)); + write_tuple_traits(str, t); +} + +template +inline + typename std::enable_if::value && has_traits::value, + ostream&>::type + operator<<(ostream& str, const T& obj) +{ + write_tuple_traits(str, obj._tls_fields_w()); + return str; +} + +// Struct reader without traits (enabled by macro) +template +inline typename std::enable_if::type +read_tuple(istream&, const std::tuple&) +{ +} + +template + inline + typename std::enable_if < I::type + read_tuple(istream& str, const std::tuple& t) +{ + str >> std::get(t); + read_tuple(str, t); +} + +template +inline + typename std::enable_if::value && !has_traits::value, + istream&>::type + operator>>(istream& str, T& obj) +{ + read_tuple(str, obj._tls_fields_r()); + return str; +} + +// Struct reader with traits (enabled by macro) +template +inline typename std::enable_if::type +read_tuple_traits(istream&, const std::tuple&) +{ +} + +template + inline typename std::enable_if < + I::type + read_tuple_traits(istream& str, const std::tuple& t) +{ + std::tuple_element_t::decode(str, std::get(t)); + read_tuple_traits(str, t); +} + +template +inline + typename std::enable_if::value && has_traits::value, + istream&>::type + operator>>(istream& str, T& obj) +{ + read_tuple_traits(str, obj._tls_fields_r()); + return str; +} + +} // namespace mlspp::tls diff --git a/mlspp/lib/tls_syntax/src/tls_syntax.cpp b/mlspp/lib/tls_syntax/src/tls_syntax.cpp new file mode 100755 index 0000000000..efbea150b5 --- /dev/null +++ b/mlspp/lib/tls_syntax/src/tls_syntax.cpp @@ -0,0 +1,178 @@ +#include + +// NOLINTNEXTLINE(llvmlibc-implementation-in-namespace) +namespace mlspp::tls { + +void +ostream::write_raw(const std::vector& bytes) +{ + // Not sure what the default argument is here + _buffer.insert(_buffer.end(), bytes.begin(), bytes.end()); +} + +// Primitive type writers +ostream& +ostream::write_uint(uint64_t value, int length) +{ + for (int i = length - 1; i >= 0; --i) { + _buffer.push_back(static_cast(value >> unsigned(8 * i))); + } + return *this; +} + +ostream& +operator<<(ostream& out, bool data) +{ + if (data) { + return out << uint8_t(1); + } + + return out << uint8_t(0); +} + +ostream& +operator<<(ostream& out, uint8_t data) // NOLINT(llvmlibc-callee-namespace) +{ + return out.write_uint(data, 1); +} + +ostream& +operator<<(ostream& out, uint16_t data) +{ + return out.write_uint(data, 2); +} + +ostream& +operator<<(ostream& out, uint32_t data) +{ + return out.write_uint(data, 4); +} + +ostream& +operator<<(ostream& out, uint64_t data) +{ + return out.write_uint(data, 8); +} + +// Because pop_back() on an empty vector is undefined +uint8_t +istream::next() +{ + if (_buffer.empty()) { + throw ReadError("Attempt to read from empty buffer"); + } + + const uint8_t value = _buffer.back(); + _buffer.pop_back(); + return value; +} + +// Primitive type readers + +istream& +operator>>(istream& in, bool& data) +{ + uint8_t val = 0; + in >> val; + + // Linter thinks uint8_t is signed (?) + // NOLINTNEXTLINE(hicpp-signed-bitwise) + if ((val & 0xFE) != 0) { + throw ReadError("Malformed boolean"); + } + + data = (val == 1); + return in; +} + +istream& +operator>>(istream& in, uint8_t& data) // NOLINT(llvmlibc-callee-namespace) +{ + return in.read_uint(data, 1); +} + +istream& +operator>>(istream& in, uint16_t& data) +{ + return in.read_uint(data, 2); +} + +istream& +operator>>(istream& in, uint32_t& data) +{ + return in.read_uint(data, 4); +} + +istream& +operator>>(istream& in, uint64_t& data) +{ + return in.read_uint(data, 8); +} + +// Varint encoding +static constexpr size_t VARINT_HEADER_OFFSET = 6; +static constexpr uint64_t VARINT_1_HEADER = 0x00; // 0 << V1_OFFSET +static constexpr uint64_t VARINT_2_HEADER = 0x4000; // 1 << V2_OFFSET +static constexpr uint64_t VARINT_4_HEADER = 0x80000000; // 2 << V4_OFFSET +static constexpr uint64_t VARINT_1_MAX = 0x3f; +static constexpr uint64_t VARINT_2_MAX = 0x3fff; +static constexpr uint64_t VARINT_4_MAX = 0x3fffffff; + +ostream& +varint::encode(ostream& str, const uint64_t& val) +{ + if (val <= VARINT_1_MAX) { + return str.write_uint(VARINT_1_HEADER | val, 1); + } + + if (val <= VARINT_2_MAX) { + return str.write_uint(VARINT_2_HEADER | val, 2); + } + + if (val <= VARINT_4_MAX) { + return str.write_uint(VARINT_4_HEADER | val, 4); + } + + throw WriteError("Varint value exceeds maximum size"); +} + +istream& +varint::decode(istream& str, uint64_t& val) +{ + auto log_size = size_t(str._buffer.back() >> VARINT_HEADER_OFFSET); + if (log_size > 2) { + throw ReadError("Malformed varint header"); + } + + auto read = uint64_t(0); + auto read_bytes = size_t(size_t(1) << log_size); + str.read_uint(read, read_bytes); + + switch (log_size) { + case 0: + read ^= VARINT_1_HEADER; + break; + + case 1: + read ^= VARINT_2_HEADER; + if (read <= VARINT_1_MAX) { + throw ReadError("Non-minimal varint"); + } + break; + + case 2: + read ^= VARINT_4_HEADER; + if (read <= VARINT_2_MAX) { + throw ReadError("Non-minimal varint"); + } + break; + + default: + throw ReadError("Malformed varint header"); + } + + val = read; + return str; +} + +} // namespace mlspp::tls diff --git a/mlspp/src/common.cpp b/mlspp/src/common.cpp new file mode 100755 index 0000000000..6895bfb4d7 --- /dev/null +++ b/mlspp/src/common.cpp @@ -0,0 +1,13 @@ +#include "mls/common.h" + +namespace mlspp { + +uint64_t +seconds_since_epoch() +{ + // TODO(RLB) This should use std::chrono, but that seems not to be available + // on some platforms. + return std::time(nullptr); +} + +} // namespace mlspp diff --git a/mlspp/src/core_types.cpp b/mlspp/src/core_types.cpp new file mode 100755 index 0000000000..19cee0d4c2 --- /dev/null +++ b/mlspp/src/core_types.cpp @@ -0,0 +1,443 @@ +#include "mls/core_types.h" +#include "mls/messages.h" + +#include "grease.h" + +#include + +namespace mlspp { + +/// +/// Extensions +/// + +const Extension::Type RequiredCapabilitiesExtension::type = + ExtensionType::required_capabilities; +const Extension::Type ApplicationIDExtension::type = + ExtensionType::application_id; + +const std::array default_extensions = { + ExtensionType::application_id, ExtensionType::ratchet_tree, + ExtensionType::required_capabilities, ExtensionType::external_pub, + ExtensionType::external_senders, +}; + +const std::array default_proposals = { + ProposalType::add, + ProposalType::update, + ProposalType::remove, + ProposalType::psk, + ProposalType::reinit, + ProposalType::external_init, + ProposalType::group_context_extensions, +}; + +const std::array all_supported_versions = { + ProtocolVersion::mls10 +}; + +const std::array all_supported_ciphersuites = { + CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519, + CipherSuite::ID::P256_AES128GCM_SHA256_P256, + CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519, + CipherSuite::ID::X448_AES256GCM_SHA512_Ed448, + CipherSuite::ID::P521_AES256GCM_SHA512_P521, + CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448, +}; + +const std::array all_supported_credentials = { + CredentialType::basic, + CredentialType::x509, + CredentialType::userinfo_vc_draft_00, + CredentialType::multi_draft_00 +}; + +Capabilities +Capabilities::create_default() +{ + return { + { all_supported_versions.begin(), all_supported_versions.end() }, + { all_supported_ciphersuites.begin(), all_supported_ciphersuites.end() }, + { /* No non-default extensions */ }, + { /* No non-default proposals */ }, + { all_supported_credentials.begin(), all_supported_credentials.end() }, + }; +} + +bool +Capabilities::extensions_supported( + const std::vector& required) const +{ + return stdx::all_of(required, [&](Extension::Type type) { + if (stdx::contains(default_extensions, type)) { + return true; + } + + return stdx::contains(extensions, type); + }); +} + +bool +Capabilities::proposals_supported( + const std::vector& required) const +{ + return stdx::all_of(required, [&](Proposal::Type type) { + if (stdx::contains(default_proposals, type)) { + return true; + } + + return stdx::contains(proposals, type); + }); +} + +bool +Capabilities::credential_supported(const Credential& credential) const +{ + return stdx::contains(credentials, credential.type()); +} + +Lifetime +Lifetime::create_default() +{ + return Lifetime{ 0x0000000000000000, 0xffffffffffffffff }; +} + +void +ExtensionList::add(uint16_t type, bytes data) +{ + auto curr = stdx::find_if( + extensions, [&](const Extension& ext) -> bool { return ext.type == type; }); + if (curr != extensions.end()) { + curr->data = std::move(data); + return; + } + + extensions.push_back({ type, std::move(data) }); +} + +bool +ExtensionList::has(uint16_t type) const +{ + return stdx::any_of(extensions, + [&](const Extension& ext) { return ext.type == type; }); +} + +/// +/// LeafNode +/// +LeafNode::LeafNode(CipherSuite cipher_suite, + HPKEPublicKey encryption_key_in, + SignaturePublicKey signature_key_in, + Credential credential_in, + Capabilities capabilities_in, + Lifetime lifetime_in, + ExtensionList extensions_in, + const SignaturePrivateKey& sig_priv) + : encryption_key(std::move(encryption_key_in)) + , signature_key(std::move(signature_key_in)) + , credential(std::move(credential_in)) + , capabilities(std::move(capabilities_in)) + , content(lifetime_in) + , extensions(std::move(extensions_in)) +{ + grease(extensions); + grease(capabilities, extensions); + sign(cipher_suite, sig_priv, std::nullopt); +} + +void +LeafNode::set_capabilities(Capabilities capabilities_in) +{ + capabilities = std::move(capabilities_in); + grease(capabilities, extensions); +} + +LeafNode +LeafNode::for_update(CipherSuite cipher_suite, + const bytes& group_id, + LeafIndex leaf_index, + HPKEPublicKey encryption_key_in, + const LeafNodeOptions& opts, + const SignaturePrivateKey& sig_priv) const +{ + auto clone = clone_with_options(std::move(encryption_key_in), opts); + + clone.content = Empty{}; + clone.sign(cipher_suite, sig_priv, { { group_id, leaf_index } }); + + return clone; +} + +LeafNode +LeafNode::for_commit(CipherSuite cipher_suite, + const bytes& group_id, + LeafIndex leaf_index, + HPKEPublicKey encryption_key_in, + const bytes& parent_hash, + const LeafNodeOptions& opts, + const SignaturePrivateKey& sig_priv) const +{ + auto clone = clone_with_options(std::move(encryption_key_in), opts); + + clone.content = ParentHash{ parent_hash }; + clone.sign(cipher_suite, sig_priv, { { group_id, leaf_index } }); + + return clone; +} + +LeafNodeSource +LeafNode::source() const +{ + return tls::variant::type(content); +} + +void +LeafNode::sign(CipherSuite cipher_suite, + const SignaturePrivateKey& sig_priv, + const std::optional& binding) +{ + const auto tbs = to_be_signed(binding); + + if (sig_priv.public_key != signature_key) { + throw InvalidParameterError("Signature key mismatch"); + } + + if (!credential.valid_for(signature_key)) { + throw InvalidParameterError("Credential not valid for signature key"); + } + + signature = sig_priv.sign(cipher_suite, sign_label::leaf_node, tbs); +} + +bool +LeafNode::verify(CipherSuite cipher_suite, + const std::optional& binding) const +{ + const auto tbs = to_be_signed(binding); + + if (CredentialType::x509 == credential.type()) { + const auto& cred = credential.get(); + if (cred.signature_scheme() != + tls_signature_scheme(cipher_suite.sig().id)) { + throw std::runtime_error("Signature algorithm invalid"); + } + } + + return signature_key.verify( + cipher_suite, sign_label::leaf_node, tbs, signature); +} + +bool +LeafNode::verify_expiry(uint64_t now) const +{ + const auto valid = overloaded{ + [now](const Lifetime& lt) { + return lt.not_before <= now && now <= lt.not_after; + }, + [](const auto& /* other */) { return false; }, + }; + return var::visit(valid, content); +} + +bool +LeafNode::verify_extension_support(const ExtensionList& ext_list) const +{ + // Verify that extensions in the list are supported + auto ext_types = stdx::transform( + ext_list.extensions, [](const auto& ext) { return ext.type; }); + + if (!capabilities.extensions_supported(ext_types)) { + return false; + } + + // If there's a RequiredCapabilities extension, verify support + const auto maybe_req_capas = ext_list.find(); + if (!maybe_req_capas) { + return true; + } + + const auto& req_capas = opt::get(maybe_req_capas); + return capabilities.extensions_supported(req_capas.extensions) && + capabilities.proposals_supported(req_capas.proposals); +} + +LeafNode +LeafNode::clone_with_options(HPKEPublicKey encryption_key_in, + const LeafNodeOptions& opts) const +{ + auto clone = *this; + + clone.encryption_key = std::move(encryption_key_in); + + if (opts.credential) { + clone.credential = opt::get(opts.credential); + } + + if (opts.capabilities) { + clone.capabilities = opt::get(opts.capabilities); + } + + if (opts.extensions) { + clone.extensions = opt::get(opts.extensions); + } + + return clone; +} + +// struct { +// HPKEPublicKey encryption_key; +// SignaturePublicKey signature_key; +// Credential credential; +// Capabilities capabilities; +// +// LeafNodeSource leaf_node_source; +// select (leaf_node_source) { +// case key_package: +// Lifetime lifetime; +// +// case update: +// struct{}; +// +// case commit: +// opaque parent_hash; +// } +// +// Extension extensions; +// +// select (leaf_node_source) { +// case key_package: +// struct{}; +// +// case update: +// opaque group_id; +// +// case commit: +// opaque group_id; +// } +// } LeafNodeTBS; +struct LeafNodeTBS +{ + const HPKEPublicKey& encryption_key; + const SignaturePublicKey& signature_key; + const Credential& credential; + const Capabilities& capabilities; + const var::variant& content; + const ExtensionList& extensions; + + TLS_SERIALIZABLE(encryption_key, + signature_key, + credential, + capabilities, + content, + extensions) + TLS_TRAITS(tls::pass, + tls::pass, + tls::pass, + tls::pass, + tls::variant, + tls::pass) +}; + +bytes +LeafNode::to_be_signed(const std::optional& binding) const +{ + tls::ostream w; + + w << LeafNodeTBS{ + encryption_key, signature_key, credential, + capabilities, content, extensions, + }; + + switch (source()) { + case LeafNodeSource::key_package: + break; + + case LeafNodeSource::update: + case LeafNodeSource::commit: + w << opt::get(binding); + } + + return w.bytes(); +} + +/// +/// NodeType, ParentNode, and KeyPackage +/// + +bytes +ParentNode::hash(CipherSuite suite) const +{ + return suite.digest().hash(tls::marshal(this)); +} + +KeyPackage::KeyPackage() + : version(ProtocolVersion::mls10) + , cipher_suite(CipherSuite::ID::unknown) +{ +} + +KeyPackage::KeyPackage(CipherSuite suite_in, + HPKEPublicKey init_key_in, + LeafNode leaf_node_in, + ExtensionList extensions_in, + const SignaturePrivateKey& sig_priv_in) + : version(ProtocolVersion::mls10) + , cipher_suite(suite_in) + , init_key(std::move(init_key_in)) + , leaf_node(std::move(leaf_node_in)) + , extensions(std::move(extensions_in)) +{ + grease(extensions); + sign(sig_priv_in); +} + +KeyPackageRef +KeyPackage::ref() const +{ + return cipher_suite.ref(*this); +} + +void +KeyPackage::sign(const SignaturePrivateKey& sig_priv) +{ + auto tbs = to_be_signed(); + signature = sig_priv.sign(cipher_suite, sign_label::key_package, tbs); +} + +bool +KeyPackage::verify() const +{ + // Verify the inner leaf node + if (!leaf_node.verify(cipher_suite, std::nullopt)) { + return false; + } + + // Check that the inner leaf node is intended for use in a KeyPackage + if (leaf_node.source() != LeafNodeSource::key_package) { + return false; + } + + // Verify the KeyPackage + const auto tbs = to_be_signed(); + + if (CredentialType::x509 == leaf_node.credential.type()) { + const auto& cred = leaf_node.credential.get(); + if (cred.signature_scheme() != + tls_signature_scheme(cipher_suite.sig().id)) { + throw std::runtime_error("Signature algorithm invalid"); + } + } + + return leaf_node.signature_key.verify( + cipher_suite, sign_label::key_package, tbs, signature); +} + +bytes +KeyPackage::to_be_signed() const +{ + tls::ostream out; + out << version << cipher_suite << init_key << leaf_node << extensions; + return out.bytes(); +} + +} // namespace mlspp diff --git a/mlspp/src/credential.cpp b/mlspp/src/credential.cpp new file mode 100755 index 0000000000..4edb3a9b82 --- /dev/null +++ b/mlspp/src/credential.cpp @@ -0,0 +1,298 @@ +#include +#include +#include +#include + +namespace mlspp { + +/// +/// X509Credential +/// + +using mlspp::hpke::Certificate; // NOLINT(misc-unused-using-decls) +using mlspp::hpke::Signature; // NOLINT(misc-unused-using-decls) +using mlspp::hpke::UserInfoVC; // NOLINT(misc-unused-using-decls) + +static const Signature& +find_signature(Signature::ID id) +{ + switch (id) { + case Signature::ID::P256_SHA256: + return Signature::get(); + case Signature::ID::P384_SHA384: + return Signature::get(); + case Signature::ID::P521_SHA512: + return Signature::get(); + case Signature::ID::Ed25519: + return Signature::get(); +#if !defined(WITH_BORINGSSL) + case Signature::ID::Ed448: + return Signature::get(); +#endif + case Signature::ID::RSA_SHA256: + return Signature::get(); + default: + throw InvalidParameterError("Unsupported algorithm"); + } +} + +static std::vector +bytes_to_x509_credential_data(const std::vector& data_in) +{ + return stdx::transform( + data_in, [](const bytes& der) { return X509Credential::CertData{ der }; }); +} + +X509Credential::X509Credential(const std::vector& der_chain_in) + : der_chain(bytes_to_x509_credential_data(der_chain_in)) +{ + if (der_chain.empty()) { + throw std::invalid_argument("empty certificate chain"); + } + + // Parse the chain + auto parsed = std::vector(); + for (const auto& cert : der_chain) { + parsed.emplace_back(cert.data); + } + + // first element represents leaf cert + const auto& sig = find_signature(parsed[0].public_key_algorithm()); + const auto pub_data = sig.serialize(*parsed[0].public_key); + _signature_scheme = tls_signature_scheme(parsed[0].public_key_algorithm()); + _public_key = SignaturePublicKey{ pub_data }; + + // verify chain for valid signatures + for (size_t i = 0; i < der_chain.size() - 1; i++) { + if (!parsed[i].valid_from(parsed[i + 1])) { + throw std::runtime_error("Certificate Chain validation failure"); + } + } +} + +SignatureScheme +X509Credential::signature_scheme() const +{ + return _signature_scheme; +} + +SignaturePublicKey +X509Credential::public_key() const +{ + return _public_key; +} + +bool +X509Credential::valid_for(const SignaturePublicKey& pub) const +{ + return pub == public_key(); +} + +tls::ostream& +operator<<(tls::ostream& str, const X509Credential& obj) +{ + return str << obj.der_chain; +} + +tls::istream& +operator>>(tls::istream& str, X509Credential& obj) +{ + auto der_chain = std::vector{}; + str >> der_chain; + + auto der_in = stdx::transform( + der_chain, [](const auto& cert_data) { return cert_data.data; }); + obj = X509Credential(der_in); + + return str; +} + +bool +operator==(const X509Credential& lhs, const X509Credential& rhs) +{ + return lhs.der_chain == rhs.der_chain; +} + +/// +/// UserInfoVCCredential +/// +UserInfoVCCredential::UserInfoVCCredential(std::string userinfo_vc_jwt_in) + : userinfo_vc_jwt(std::move(userinfo_vc_jwt_in)) + , _vc(std::make_shared(userinfo_vc_jwt)) +{ +} + +bool +// NOLINTNEXTLINE(readability-convert-member-functions-to-static) +UserInfoVCCredential::valid_for(const SignaturePublicKey& pub) const +{ + const auto& vc_pub = _vc->public_key(); + return pub.data == vc_pub.sig.serialize(*vc_pub.key); +} + +bool +UserInfoVCCredential::valid_from(const PublicJWK& pub) const +{ + const auto& sig = _vc->signature_algorithm(); + if (pub.signature_scheme != tls_signature_scheme(sig.id)) { + return false; + } + + const auto sig_pub = sig.deserialize(pub.public_key.data); + return _vc->valid_from(*sig_pub); +} + +tls::ostream +operator<<(tls::ostream& str, const UserInfoVCCredential& obj) +{ + return str << from_ascii(obj.userinfo_vc_jwt); +} + +tls::istream +operator>>(tls::istream& str, UserInfoVCCredential& obj) +{ + auto jwt = bytes{}; + str >> jwt; + obj = UserInfoVCCredential(to_ascii(jwt)); + return str; +} + +bool +operator==(const UserInfoVCCredential& lhs, const UserInfoVCCredential& rhs) +{ + return lhs.userinfo_vc_jwt == rhs.userinfo_vc_jwt; +} + +bool +operator!=(const UserInfoVCCredential& lhs, const UserInfoVCCredential& rhs) +{ + return !(lhs == rhs); +} + +/// +/// CredentialBinding and MultiCredential +/// + +struct CredentialBindingTBS +{ + const CipherSuite& cipher_suite; + const Credential& credential; + const SignaturePublicKey& credential_key; + const SignaturePublicKey& signature_key; + + TLS_SERIALIZABLE(cipher_suite, credential, credential_key, signature_key) +}; + +CredentialBinding::CredentialBinding(CipherSuite cipher_suite_in, + Credential credential_in, + const SignaturePrivateKey& credential_priv, + const SignaturePublicKey& signature_key) + : cipher_suite(cipher_suite_in) + , credential(std::move(credential_in)) + , credential_key(credential_priv.public_key) +{ + if (credential.type() == CredentialType::multi_draft_00) { + throw InvalidParameterError("Multi-credentials cannot be nested"); + } + + if (!credential.valid_for(credential_key)) { + throw InvalidParameterError("Credential key does not match credential"); + } + + signature = credential_priv.sign( + cipher_suite, sign_label::multi_credential, to_be_signed(signature_key)); +} + +bytes +CredentialBinding::to_be_signed(const SignaturePublicKey& signature_key) const +{ + return tls::marshal(CredentialBindingTBS{ + cipher_suite, credential, credential_key, signature_key }); +} + +bool +CredentialBinding::valid_for(const SignaturePublicKey& signature_key) const +{ + auto valid_self = credential.valid_for(credential_key); + auto valid_other = credential_key.verify(cipher_suite, + sign_label::multi_credential, + to_be_signed(signature_key), + signature); + + return valid_self && valid_other; +} + +MultiCredential::MultiCredential( + const std::vector& binding_inputs, + const SignaturePublicKey& signature_key) +{ + bindings = + stdx::transform(binding_inputs, [&](auto&& input) { + return CredentialBinding(input.cipher_suite, + input.credential, + input.credential_priv, + signature_key); + }); +} + +bool +MultiCredential::valid_for(const SignaturePublicKey& pub) const +{ + return stdx::all_of( + bindings, [&](const auto& binding) { return binding.valid_for(pub); }); +} + +/// +/// Credential +/// + +CredentialType +Credential::type() const +{ + return tls::variant::type(_cred); +} + +Credential +Credential::basic(const bytes& identity) +{ + return { BasicCredential{ identity } }; +} + +Credential +Credential::x509(const std::vector& der_chain) +{ + return { X509Credential{ der_chain } }; +} + +Credential +Credential::multi(const std::vector& binding_inputs, + const SignaturePublicKey& signature_key) +{ + return { MultiCredential{ binding_inputs, signature_key } }; +} + +Credential +Credential::userinfo_vc(const std::string& userinfo_vc_jwt) +{ + return { UserInfoVCCredential{ userinfo_vc_jwt } }; +} + +bool +Credential::valid_for(const SignaturePublicKey& pub) const +{ + const auto pub_key_match = overloaded{ + [&](const X509Credential& x509) { return x509.valid_for(pub); }, + [](const BasicCredential& /* basic */) { return true; }, + [&](const UserInfoVCCredential& vc) { return vc.valid_for(pub); }, + [&](const MultiCredential& multi) { return multi.valid_for(pub); }, + }; + + return var::visit(pub_key_match, _cred); +} + +Credential::Credential(SpecificCredential specific) + : _cred(std::move(specific)) +{ +} + +} // namespace mlspp diff --git a/mlspp/src/crypto.cpp b/mlspp/src/crypto.cpp new file mode 100755 index 0000000000..5126036acd --- /dev/null +++ b/mlspp/src/crypto.cpp @@ -0,0 +1,498 @@ +#include +#include +#include + +#include + +using mlspp::hpke::AEAD; // NOLINT(misc-unused-using-decls) +using mlspp::hpke::Digest; // NOLINT(misc-unused-using-decls) +using mlspp::hpke::HPKE; // NOLINT(misc-unused-using-decls) +using mlspp::hpke::KDF; // NOLINT(misc-unused-using-decls) +using mlspp::hpke::KEM; // NOLINT(misc-unused-using-decls) +using mlspp::hpke::Signature; // NOLINT(misc-unused-using-decls) + +namespace mlspp { + +SignatureScheme +tls_signature_scheme(Signature::ID id) +{ + switch (id) { + case Signature::ID::P256_SHA256: + return SignatureScheme::ecdsa_secp256r1_sha256; + case Signature::ID::P384_SHA384: + return SignatureScheme::ecdsa_secp384r1_sha384; + case Signature::ID::P521_SHA512: + return SignatureScheme::ecdsa_secp521r1_sha512; + case Signature::ID::Ed25519: + return SignatureScheme::ed25519; +#if !defined(WITH_BORINGSSL) + case Signature::ID::Ed448: + return SignatureScheme::ed448; +#endif + case Signature::ID::RSA_SHA256: + return SignatureScheme::rsa_pkcs1_sha256; + default: + throw InvalidParameterError("Unsupported algorithm"); + } +} + +/// +/// CipherSuites and details +/// + +CipherSuite::CipherSuite() + : id(ID::unknown) +{ +} + +CipherSuite::CipherSuite(ID id_in) + : id(id_in) +{ +} + +SignatureScheme +CipherSuite::signature_scheme() const +{ + switch (id) { + case ID::X25519_AES128GCM_SHA256_Ed25519: + case ID::X25519_CHACHA20POLY1305_SHA256_Ed25519: + return SignatureScheme::ed25519; + case ID::P256_AES128GCM_SHA256_P256: + return SignatureScheme::ecdsa_secp256r1_sha256; + case ID::X448_AES256GCM_SHA512_Ed448: + case ID::X448_CHACHA20POLY1305_SHA512_Ed448: + return SignatureScheme::ed448; + case ID::P521_AES256GCM_SHA512_P521: + return SignatureScheme::ecdsa_secp521r1_sha512; + case ID::P384_AES256GCM_SHA384_P384: + return SignatureScheme::ecdsa_secp384r1_sha384; + default: + throw InvalidParameterError("Unsupported algorithm"); + } +} + +const CipherSuite::Ciphers& +CipherSuite::get() const +{ + static const auto ciphers_X25519_AES128GCM_SHA256_Ed25519 = + CipherSuite::Ciphers{ + HPKE(KEM::ID::DHKEM_X25519_SHA256, + KDF::ID::HKDF_SHA256, + AEAD::ID::AES_128_GCM), + Digest::get(), + Signature::get(), + }; + + static const auto ciphers_P256_AES128GCM_SHA256_P256 = CipherSuite::Ciphers{ + HPKE( + KEM::ID::DHKEM_P256_SHA256, KDF::ID::HKDF_SHA256, AEAD::ID::AES_128_GCM), + Digest::get(), + Signature::get(), + }; + + static const auto ciphers_X25519_CHACHA20POLY1305_SHA256_Ed25519 = + CipherSuite::Ciphers{ + HPKE(KEM::ID::DHKEM_X25519_SHA256, + KDF::ID::HKDF_SHA256, + AEAD::ID::CHACHA20_POLY1305), + Digest::get(), + Signature::get(), + }; + + static const auto ciphers_P521_AES256GCM_SHA512_P521 = CipherSuite::Ciphers{ + HPKE( + KEM::ID::DHKEM_P521_SHA512, KDF::ID::HKDF_SHA512, AEAD::ID::AES_256_GCM), + Digest::get(), + Signature::get(), + }; + + static const auto ciphers_P384_AES256GCM_SHA384_P384 = CipherSuite::Ciphers{ + HPKE( + KEM::ID::DHKEM_P384_SHA384, KDF::ID::HKDF_SHA384, AEAD::ID::AES_256_GCM), + Digest::get(), + Signature::get(), + }; + +#if !defined(WITH_BORINGSSL) + static const auto ciphers_X448_AES256GCM_SHA512_Ed448 = CipherSuite::Ciphers{ + HPKE( + KEM::ID::DHKEM_X448_SHA512, KDF::ID::HKDF_SHA512, AEAD::ID::AES_256_GCM), + Digest::get(), + Signature::get(), + }; + + static const auto ciphers_X448_CHACHA20POLY1305_SHA512_Ed448 = + CipherSuite::Ciphers{ + HPKE(KEM::ID::DHKEM_X448_SHA512, + KDF::ID::HKDF_SHA512, + AEAD::ID::CHACHA20_POLY1305), + Digest::get(), + Signature::get(), + }; +#endif + + switch (id) { + case ID::unknown: + throw InvalidParameterError("Uninitialized ciphersuite"); + + case ID::X25519_AES128GCM_SHA256_Ed25519: + return ciphers_X25519_AES128GCM_SHA256_Ed25519; + + case ID::P256_AES128GCM_SHA256_P256: + return ciphers_P256_AES128GCM_SHA256_P256; + + case ID::X25519_CHACHA20POLY1305_SHA256_Ed25519: + return ciphers_X25519_CHACHA20POLY1305_SHA256_Ed25519; + + case ID::P521_AES256GCM_SHA512_P521: + return ciphers_P521_AES256GCM_SHA512_P521; + + case ID::P384_AES256GCM_SHA384_P384: + return ciphers_P384_AES256GCM_SHA384_P384; + +#if !defined(WITH_BORINGSSL) + case ID::X448_AES256GCM_SHA512_Ed448: + return ciphers_X448_AES256GCM_SHA512_Ed448; + + case ID::X448_CHACHA20POLY1305_SHA512_Ed448: + return ciphers_X448_CHACHA20POLY1305_SHA512_Ed448; +#endif + + default: + throw InvalidParameterError("Unsupported ciphersuite"); + } +} + +struct HKDFLabel +{ + uint16_t length; + bytes label; + bytes context; + + TLS_SERIALIZABLE(length, label, context) +}; + +bytes +CipherSuite::expand_with_label(const bytes& secret, + const std::string& label, + const bytes& context, + size_t length) const +{ + auto mls_label = from_ascii(std::string("MLS 1.0 ") + label); + auto length16 = static_cast(length); + auto label_bytes = tls::marshal(HKDFLabel{ length16, mls_label, context }); + return get().hpke.kdf.expand(secret, label_bytes, length); +} + +bytes +CipherSuite::derive_secret(const bytes& secret, const std::string& label) const +{ + return expand_with_label(secret, label, {}, secret_size()); +} + +bytes +CipherSuite::derive_tree_secret(const bytes& secret, + const std::string& label, + uint32_t generation, + size_t length) const +{ + return expand_with_label(secret, label, tls::marshal(generation), length); +} + +#if WITH_BORINGSSL +const std::array all_supported_suites = { + CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519, + CipherSuite::ID::P256_AES128GCM_SHA256_P256, + CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519, + CipherSuite::ID::P521_AES256GCM_SHA512_P521, + CipherSuite::ID::P384_AES256GCM_SHA384_P384, +}; +#else +const std::array all_supported_suites = { + CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519, + CipherSuite::ID::P256_AES128GCM_SHA256_P256, + CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519, + CipherSuite::ID::P521_AES256GCM_SHA512_P521, + CipherSuite::ID::P384_AES256GCM_SHA384_P384, + CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448, + CipherSuite::ID::X448_AES256GCM_SHA512_Ed448, +}; +#endif + +// MakeKeyPackageRef(value) = KDF.expand( +// KDF.extract("", value), "MLS 1.0 KeyPackage Reference", 16) +template<> +const bytes& +CipherSuite::reference_label() +{ + static const auto label = from_ascii("MLS 1.0 KeyPackage Reference"); + return label; +} + +// MakeProposalRef(value) = KDF.expand( +// KDF.extract("", value), "MLS 1.0 Proposal Reference", 16) +// +// Even though the label says "Proposal", we actually hash the entire enclosing +// AuthenticatedContent object. +template<> +const bytes& +CipherSuite::reference_label() +{ + static const auto label = from_ascii("MLS 1.0 Proposal Reference"); + return label; +} + +/// +/// HPKEPublicKey and HPKEPrivateKey +/// + +// This function produces a non-literal type, so it can't be constexpr. +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define MLS_1_0_PLUS(label) from_ascii("MLS 1.0 " label) + +static bytes +mls_1_0_plus(const std::string& label) +{ + auto plus = "MLS 1.0 "s + label; + return from_ascii(plus); +} + +namespace encrypt_label { +const std::string update_path_node = "UpdatePathNode"; +const std::string welcome = "Welcome"; +} // namespace encrypt_label + +struct EncryptContext +{ + const bytes& label; + const bytes& content; + TLS_SERIALIZABLE(label, content) +}; + +HPKECiphertext +HPKEPublicKey::encrypt(CipherSuite suite, + const std::string& label, + const bytes& context, + const bytes& pt) const +{ + auto label_plus = mls_1_0_plus(label); + auto encrypt_context = tls::marshal(EncryptContext{ label_plus, context }); + auto pkR = suite.hpke().kem.deserialize(data); + auto [enc, ctx] = suite.hpke().setup_base_s(*pkR, encrypt_context); + auto ct = ctx.seal({}, pt); + return HPKECiphertext{ enc, ct }; +} + +std::tuple +HPKEPublicKey::do_export(CipherSuite suite, + const bytes& info, + const std::string& label, + size_t size) const +{ + auto label_data = from_ascii(label); + auto pkR = suite.hpke().kem.deserialize(data); + auto [enc, ctx] = suite.hpke().setup_base_s(*pkR, info); + auto exported = ctx.do_export(label_data, size); + return std::make_tuple(enc, exported); +} + +HPKEPrivateKey +HPKEPrivateKey::generate(CipherSuite suite) +{ + auto priv = suite.hpke().kem.generate_key_pair(); + auto priv_data = suite.hpke().kem.serialize_private(*priv); + auto pub = priv->public_key(); + auto pub_data = suite.hpke().kem.serialize(*pub); + return { priv_data, pub_data }; +} + +HPKEPrivateKey +HPKEPrivateKey::parse(CipherSuite suite, const bytes& data) +{ + auto priv = suite.hpke().kem.deserialize_private(data); + auto pub = priv->public_key(); + auto pub_data = suite.hpke().kem.serialize(*pub); + return { data, pub_data }; +} + +HPKEPrivateKey +HPKEPrivateKey::derive(CipherSuite suite, const bytes& secret) +{ + auto priv = suite.hpke().kem.derive_key_pair(secret); + auto priv_data = suite.hpke().kem.serialize_private(*priv); + auto pub = priv->public_key(); + auto pub_data = suite.hpke().kem.serialize(*pub); + return { priv_data, pub_data }; +} + +bytes +HPKEPrivateKey::decrypt(CipherSuite suite, + const std::string& label, + const bytes& context, + const HPKECiphertext& ct) const +{ + auto label_plus = mls_1_0_plus(label); + auto encrypt_context = tls::marshal(EncryptContext{ label_plus, context }); + auto skR = suite.hpke().kem.deserialize_private(data); + auto ctx = suite.hpke().setup_base_r(ct.kem_output, *skR, encrypt_context); + auto pt = ctx.open({}, ct.ciphertext); + if (!pt) { + throw InvalidParameterError("HPKE decryption failure"); + } + + return opt::get(pt); +} + +bytes +HPKEPrivateKey::do_export(CipherSuite suite, + const bytes& info, + const bytes& kem_output, + const std::string& label, + size_t size) const +{ + auto label_data = from_ascii(label); + auto skR = suite.hpke().kem.deserialize_private(data); + auto ctx = suite.hpke().setup_base_r(kem_output, *skR, info); + return ctx.do_export(label_data, size); +} + +HPKEPrivateKey::HPKEPrivateKey(bytes priv_data, bytes pub_data) + : data(std::move(priv_data)) + , public_key{ std::move(pub_data) } +{ +} + +void +HPKEPrivateKey::set_public_key(CipherSuite suite) +{ + const auto priv = suite.hpke().kem.deserialize_private(data); + auto pub = priv->public_key(); + public_key.data = suite.hpke().kem.serialize(*pub); +} + +/// +/// SignaturePublicKey and SignaturePrivateKey +/// +namespace sign_label { +const std::string mls_content = "FramedContentTBS"; +const std::string leaf_node = "LeafNodeTBS"; +const std::string key_package = "KeyPackageTBS"; +const std::string group_info = "GroupInfoTBS"; +const std::string multi_credential = "MultiCredential"; +} // namespace sign_label + +struct SignContent +{ + const bytes& label; + const bytes& content; + TLS_SERIALIZABLE(label, content) +}; + +bool +SignaturePublicKey::verify(const CipherSuite& suite, + const std::string& label, + const bytes& message, + const bytes& signature) const +{ + auto label_plus = mls_1_0_plus(label); + const auto content = tls::marshal(SignContent{ label_plus, message }); + auto pub = suite.sig().deserialize(data); + return suite.sig().verify(content, signature, *pub); +} + +SignaturePublicKey +SignaturePublicKey::from_jwk(CipherSuite suite, const std::string& json_str) +{ + auto pub = suite.sig().import_jwk(json_str); + auto pub_data = suite.sig().serialize(*pub); + return SignaturePublicKey{ pub_data }; +} + +std::string +SignaturePublicKey::to_jwk(CipherSuite suite) const +{ + auto pub = suite.sig().deserialize(data); + return suite.sig().export_jwk(*pub); +} + +PublicJWK +PublicJWK::parse(const std::string& jwk_json) +{ + const auto parsed = Signature::parse_jwk(jwk_json); + const auto scheme = tls_signature_scheme(parsed.sig.id); + const auto pub_data = parsed.sig.serialize(*parsed.key); + return { scheme, parsed.key_id, { pub_data } }; +} + +SignaturePrivateKey +SignaturePrivateKey::generate(CipherSuite suite) +{ + auto priv = suite.sig().generate_key_pair(); + auto priv_data = suite.sig().serialize_private(*priv); + auto pub = priv->public_key(); + auto pub_data = suite.sig().serialize(*pub); + return { priv_data, pub_data }; +} + +SignaturePrivateKey +SignaturePrivateKey::parse(CipherSuite suite, const bytes& data) +{ + auto priv = suite.sig().deserialize_private(data); + auto pub = priv->public_key(); + auto pub_data = suite.sig().serialize(*pub); + return { data, pub_data }; +} + +SignaturePrivateKey +SignaturePrivateKey::derive(CipherSuite suite, const bytes& secret) +{ + auto priv = suite.sig().derive_key_pair(secret); + auto priv_data = suite.sig().serialize_private(*priv); + auto pub = priv->public_key(); + auto pub_data = suite.sig().serialize(*pub); + return { priv_data, pub_data }; +} + +bytes +SignaturePrivateKey::sign(const CipherSuite& suite, + const std::string& label, + const bytes& message) const +{ + auto label_plus = mls_1_0_plus(label); + const auto content = tls::marshal(SignContent{ label_plus, message }); + const auto priv = suite.sig().deserialize_private(data); + return suite.sig().sign(content, *priv); +} + +SignaturePrivateKey::SignaturePrivateKey(bytes priv_data, bytes pub_data) + : data(std::move(priv_data)) + , public_key{ std::move(pub_data) } +{ +} + +void +SignaturePrivateKey::set_public_key(CipherSuite suite) +{ + const auto priv = suite.sig().deserialize_private(data); + auto pub = priv->public_key(); + public_key.data = suite.sig().serialize(*pub); +} + +SignaturePrivateKey +SignaturePrivateKey::from_jwk(CipherSuite suite, const std::string& json_str) +{ + auto priv = suite.sig().import_jwk_private(json_str); + auto priv_data = suite.sig().serialize_private(*priv); + auto pub = priv->public_key(); + auto pub_data = suite.sig().serialize(*pub); + return { priv_data, pub_data }; +} + +std::string +SignaturePrivateKey::to_jwk(CipherSuite suite) const +{ + const auto priv = suite.sig().deserialize_private(data); + return suite.sig().export_jwk_private(*priv); +} + +} // namespace mlspp diff --git a/mlspp/src/grease.cpp b/mlspp/src/grease.cpp new file mode 100755 index 0000000000..05aaddbe9e --- /dev/null +++ b/mlspp/src/grease.cpp @@ -0,0 +1,126 @@ +#include "grease.h" + +#include +#include + +namespace mlspp { + +#ifdef DISABLE_GREASE + +void +grease([[maybe_unused]] Capabilities& capabilities, + [[maybe_unused]] const ExtensionList& extensions) +{ +} + +void +grease([[maybe_unused]] ExtensionList& extensions) +{ +} + +#else + +// Randomness parmeters: +// * Given a list of N items, insert max(1, rand(p_grease * N)) GREASE values +// * Each GREASE value added is distinct, unless more than 15 values are needed +// * For extensions, each GREASE extension has rand(n_grease_ext) random bytes +// of data +const size_t log_p_grease = 1; // -log2(p_grease) => p_grease = 1/2 +const size_t max_grease_ext_size = 16; + +const std::array grease_values = { 0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, + 0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A, + 0x8A8A, 0x9A9A, 0xAAAA, 0xBABA, + 0xCACA, 0xDADA, 0xEAEA }; + +static size_t +rand_int(size_t n) +{ + static auto seed = std::random_device()(); + static auto rng = std::mt19937(seed); + return std::uniform_int_distribution(0, n)(rng); +} + +static uint16_t +grease_value() +{ + const auto where = rand_int(grease_values.size() - 1); + return grease_values.at(where); +} + +static bool +grease_value(uint16_t val) +{ + static constexpr auto grease_mask = uint16_t(0x0F0F); + return ((val & grease_mask) == 0x0A0A) && val != 0xFAFA; +} + +static std::set +grease_sample(size_t count) +{ + auto vals = std::set{}; + + while (vals.size() < count) { + uint16_t val = grease_value(); + while (vals.count(val) > 0 && vals.size() < grease_values.size()) { + val = grease_value(); + } + + vals.insert(val); + } + + return vals; +} + +template +static void +grease(std::vector& vec) +{ + const auto count = std::max(size_t(1), rand_int(vec.size() >> log_p_grease)); + for (const auto val : grease_sample(count)) { + const auto where = static_cast(rand_int(vec.size())); + vec.insert(std::begin(vec) + where, static_cast(val)); + } +} + +void +grease(Capabilities& capabilities, const ExtensionList& extensions) +{ + // Add GREASE to the appropriate portions of the capabilities + grease(capabilities.cipher_suites); + grease(capabilities.extensions); + grease(capabilities.proposals); + grease(capabilities.credentials); + + // Ensure that the GREASE extensions are reflected in Capabilities.extensions + for (const auto& ext : extensions.extensions) { + if (!grease_value(ext.type)) { + continue; + } + + if (stdx::contains(capabilities.extensions, ext.type)) { + continue; + } + + const auto where = + static_cast(rand_int(capabilities.extensions.size())); + const auto where_ptr = std::begin(capabilities.extensions) + where; + capabilities.extensions.insert(where_ptr, ext.type); + } +} + +void +grease(ExtensionList& extensions) +{ + auto& ext = extensions.extensions; + const auto count = std::max(size_t(1), rand_int(ext.size() >> log_p_grease)); + for (const auto ext_type : grease_sample(count)) { + const auto where = static_cast(rand_int(ext.size())); + auto ext_data = random_bytes(rand_int(max_grease_ext_size)); + ext.insert(std::begin(ext) + where, { ext_type, std::move(ext_data) }); + } +} + +#endif // DISABLE_GREASE + +} // namespace mlspp diff --git a/mlspp/src/grease.h b/mlspp/src/grease.h new file mode 100755 index 0000000000..94ed157eed --- /dev/null +++ b/mlspp/src/grease.h @@ -0,0 +1,13 @@ +#pragma once + +#include "mls/core_types.h" + +namespace mlspp { + +void +grease(Capabilities& capabilities, const ExtensionList& extensions); + +void +grease(ExtensionList& extensions); + +} // namespace mlspp diff --git a/mlspp/src/key_schedule.cpp b/mlspp/src/key_schedule.cpp new file mode 100755 index 0000000000..7a5788103e --- /dev/null +++ b/mlspp/src/key_schedule.cpp @@ -0,0 +1,579 @@ +#include + +namespace mlspp { + +/// +/// Key Derivation Functions +/// + +struct TreeContext +{ + NodeIndex node; + uint32_t generation = 0; + + TLS_SERIALIZABLE(node, generation) +}; + +/// +/// HashRatchet +/// + +HashRatchet::HashRatchet(CipherSuite suite_in, bytes base_secret_in) + : suite(suite_in) + , next_secret(std::move(base_secret_in)) + , next_generation(0) + , key_size(suite.hpke().aead.key_size) + , nonce_size(suite.hpke().aead.nonce_size) + , secret_size(suite.secret_size()) +{ +} + +std::tuple +HashRatchet::next() +{ + auto generation = next_generation; + auto key = suite.derive_tree_secret(next_secret, "key", generation, key_size); + auto nonce = + suite.derive_tree_secret(next_secret, "nonce", generation, nonce_size); + auto secret = + suite.derive_tree_secret(next_secret, "secret", generation, secret_size); + + next_generation += 1; + next_secret = secret; + + cache[generation] = { key, nonce }; + return { generation, cache.at(generation) }; +} + +// Note: This construction deliberately does not preserve the forward-secrecy +// invariant, in that keys/nonces are not deleted after they are used. +// Otherwise, it would not be possible for a node to send to itself. Keys can +// be deleted once they are not needed by calling HashRatchet::erase(). +KeyAndNonce +HashRatchet::get(uint32_t generation) +{ + if (cache.count(generation) > 0) { + auto out = cache.at(generation); + return out; + } + + if (next_generation > generation) { + throw ProtocolError("Request for expired key"); + } + + while (next_generation <= generation) { + next(); + } + + return cache.at(generation); +} + +void +HashRatchet::erase(uint32_t generation) +{ + if (cache.count(generation) == 0) { + return; + } + + cache.erase(generation); +} + +/// +/// SecretTree +/// + +SecretTree::SecretTree(CipherSuite suite_in, + LeafCount group_size_in, + bytes encryption_secret_in) + : suite(suite_in) + , group_size(LeafCount::full(group_size_in)) + , root(NodeIndex::root(group_size)) + , secret_size(suite_in.secret_size()) +{ + secrets.emplace(root, std::move(encryption_secret_in)); +} + +bytes +SecretTree::get(LeafIndex sender) +{ + static const auto context_left = from_ascii("left"); + static const auto context_right = from_ascii("right"); + auto node = NodeIndex(sender); + + // Find an ancestor that is populated + auto dirpath = node.dirpath(group_size); + dirpath.insert(dirpath.begin(), node); + dirpath.push_back(root); + uint32_t curr = 0; + for (; curr < dirpath.size(); ++curr) { + auto i = dirpath.at(curr); + if (secrets.count(i) > 0) { + break; + } + } + + if (curr > dirpath.size()) { + throw InvalidParameterError("No secret found to derive base key"); + } + + // Derive down + for (; curr > 0; --curr) { + auto curr_node = dirpath.at(curr); + auto left = curr_node.left(); + auto right = curr_node.right(); + + auto& secret = secrets.at(curr_node); + + const auto left_secret = + suite.expand_with_label(secret, "tree", context_left, secret_size); + const auto right_secret = + suite.expand_with_label(secret, "tree", context_right, secret_size); + + secrets.insert_or_assign(left, left_secret); + secrets.insert_or_assign(right, right_secret); + } + + // Copy the leaf + auto out = secrets.at(node); + + // Zeroize along the direct path + for (auto i : dirpath) { + secrets.erase(i); + } + + return out; +} + +/// +/// ReuseGuard +/// + +static ReuseGuard +new_reuse_guard() +{ + auto random = random_bytes(4); + auto guard = ReuseGuard(); + std::copy(random.begin(), random.end(), guard.begin()); + return guard; +} + +static void +apply_reuse_guard(const ReuseGuard& guard, bytes& nonce) +{ + for (size_t i = 0; i < guard.size(); i++) { + nonce.at(i) ^= guard.at(i); + } +} + +/// +/// GroupKeySource +/// + +GroupKeySource::GroupKeySource(CipherSuite suite_in, + LeafCount group_size, + bytes encryption_secret) + : suite(suite_in) + , secret_tree(suite, group_size, std::move(encryption_secret)) +{ +} + +HashRatchet& +GroupKeySource::chain(ContentType type, LeafIndex sender) +{ + switch (type) { + case ContentType::proposal: + case ContentType::commit: + return chain(RatchetType::handshake, sender); + + case ContentType::application: + return chain(RatchetType::application, sender); + + default: + throw InvalidParameterError("Invalid content type"); + } +} + +HashRatchet& +GroupKeySource::chain(RatchetType type, LeafIndex sender) +{ + auto key = Key{ type, sender }; + if (chains.count(key) > 0) { + return chains[key]; + } + + auto secret_size = suite.secret_size(); + auto leaf_secret = secret_tree.get(sender); + + auto handshake_secret = + suite.expand_with_label(leaf_secret, "handshake", {}, secret_size); + auto application_secret = + suite.expand_with_label(leaf_secret, "application", {}, secret_size); + + chains.emplace(Key{ RatchetType::handshake, sender }, + HashRatchet{ suite, handshake_secret }); + chains.emplace(Key{ RatchetType::application, sender }, + HashRatchet{ suite, application_secret }); + + return chains[key]; +} + +std::tuple +GroupKeySource::next(ContentType type, LeafIndex sender) +{ + auto [generation, keys] = chain(type, sender).next(); + + auto reuse_guard = new_reuse_guard(); + apply_reuse_guard(reuse_guard, keys.nonce); + + return { generation, reuse_guard, keys }; +} + +KeyAndNonce +GroupKeySource::get(ContentType type, + LeafIndex sender, + uint32_t generation, + ReuseGuard reuse_guard) +{ + auto keys = chain(type, sender).get(generation); + apply_reuse_guard(reuse_guard, keys.nonce); + return keys; +} + +void +GroupKeySource::erase(ContentType type, LeafIndex sender, uint32_t generation) +{ + return chain(type, sender).erase(generation); +} + +// struct { +// opaque group_id<0..255>; +// uint64 epoch; +// ContentType content_type; +// opaque authenticated_data<0..2^32-1>; +// } ContentAAD; +struct ContentAAD +{ + const bytes& group_id; + const epoch_t epoch; + const ContentType content_type; + const bytes& authenticated_data; + + TLS_SERIALIZABLE(group_id, epoch, content_type, authenticated_data) +}; + +/// +/// KeyScheduleEpoch +/// + +struct PSKLabel +{ + const PreSharedKeyID& id; + uint16_t index; + uint16_t count; + + TLS_SERIALIZABLE(id, index, count); +}; + +static bytes +make_joiner_secret(CipherSuite suite, + const bytes& context, + const bytes& init_secret, + const bytes& commit_secret) +{ + auto pre_joiner_secret = suite.hpke().kdf.extract(init_secret, commit_secret); + return suite.expand_with_label( + pre_joiner_secret, "joiner", context, suite.secret_size()); +} + +static bytes +make_epoch_secret(CipherSuite suite, + const bytes& joiner_secret, + const bytes& psk_secret, + const bytes& context) +{ + auto member_secret = suite.hpke().kdf.extract(joiner_secret, psk_secret); + return suite.expand_with_label( + member_secret, "epoch", context, suite.secret_size()); +} + +KeyScheduleEpoch +KeyScheduleEpoch::joiner(CipherSuite suite_in, + const bytes& joiner_secret, + const std::vector& psks, + const bytes& context) +{ + return { suite_in, joiner_secret, make_psk_secret(suite_in, psks), context }; +} + +KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in, + const bytes& joiner_secret, + const bytes& psk_secret, + const bytes& context) + : suite(suite_in) + , joiner_secret(joiner_secret) + , epoch_secret( + make_epoch_secret(suite_in, joiner_secret, psk_secret, context)) + , sender_data_secret(suite.derive_secret(epoch_secret, "sender data")) + , encryption_secret(suite.derive_secret(epoch_secret, "encryption")) + , exporter_secret(suite.derive_secret(epoch_secret, "exporter")) + , epoch_authenticator(suite.derive_secret(epoch_secret, "authentication")) + , external_secret(suite.derive_secret(epoch_secret, "external")) + , confirmation_key(suite.derive_secret(epoch_secret, "confirm")) + , membership_key(suite.derive_secret(epoch_secret, "membership")) + , resumption_psk(suite.derive_secret(epoch_secret, "resumption")) + , init_secret(suite.derive_secret(epoch_secret, "init")) + , external_priv(HPKEPrivateKey::derive(suite, external_secret)) +{ +} + +KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in) + : suite(suite_in) +{ +} + +KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in, + const bytes& init_secret, + const bytes& context) + : KeyScheduleEpoch( + suite_in, + make_joiner_secret(suite_in, context, init_secret, suite_in.zero()), + { /* no PSKs */ }, + context) +{ +} + +KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in, + const bytes& init_secret, + const bytes& commit_secret, + const bytes& psk_secret, + const bytes& context) + : KeyScheduleEpoch( + suite_in, + make_joiner_secret(suite_in, context, init_secret, commit_secret), + psk_secret, + context) +{ +} + +std::tuple +KeyScheduleEpoch::external_init(CipherSuite suite, + const HPKEPublicKey& external_pub) +{ + auto size = suite.secret_size(); + return external_pub.do_export( + suite, {}, "MLS 1.0 external init secret", size); +} + +bytes +KeyScheduleEpoch::receive_external_init(const bytes& kem_output) const +{ + auto size = suite.secret_size(); + return external_priv.do_export( + suite, {}, kem_output, "MLS 1.0 external init secret", size); +} + +KeyScheduleEpoch +KeyScheduleEpoch::next(const bytes& commit_secret, + const std::vector& psks, + const std::optional& force_init_secret, + const bytes& context) const +{ + return next_raw( + commit_secret, make_psk_secret(suite, psks), force_init_secret, context); +} + +KeyScheduleEpoch +KeyScheduleEpoch::next_raw(const bytes& commit_secret, + const bytes& psk_secret, + const std::optional& force_init_secret, + const bytes& context) const +{ + auto actual_init_secret = init_secret; + if (force_init_secret) { + actual_init_secret = opt::get(force_init_secret); + } + + return { suite, actual_init_secret, commit_secret, psk_secret, context }; +} + +GroupKeySource +KeyScheduleEpoch::encryption_keys(LeafCount size) const +{ + return { suite, size, encryption_secret }; +} + +bytes +KeyScheduleEpoch::confirmation_tag(const bytes& confirmed_transcript_hash) const +{ + return suite.digest().hmac(confirmation_key, confirmed_transcript_hash); +} + +bytes +KeyScheduleEpoch::do_export(const std::string& label, + const bytes& context, + size_t size) const +{ + auto secret = suite.derive_secret(exporter_secret, label); + auto context_hash = suite.digest().hash(context); + return suite.expand_with_label(secret, "exported", context_hash, size); +} + +PSKWithSecret +KeyScheduleEpoch::resumption_psk_w_secret(ResumptionPSKUsage usage, + const bytes& group_id, + epoch_t epoch) +{ + auto nonce = random_bytes(suite.secret_size()); + auto psk = ResumptionPSK{ usage, group_id, epoch }; + return { { psk, nonce }, resumption_psk }; +} + +bytes +KeyScheduleEpoch::make_psk_secret(CipherSuite suite, + const std::vector& psks) +{ + auto psk_secret = suite.zero(); + auto count = uint16_t(psks.size()); + auto index = uint16_t(0); + for (const auto& psk : psks) { + auto psk_extracted = suite.hpke().kdf.extract(suite.zero(), psk.secret); + auto psk_label = tls::marshal(PSKLabel{ psk.id, index, count }); + auto psk_input = suite.expand_with_label( + psk_extracted, "derived psk", psk_label, suite.secret_size()); + psk_secret = suite.hpke().kdf.extract(psk_input, psk_secret); + index += 1; + } + return psk_secret; +} + +bytes +KeyScheduleEpoch::welcome_secret(CipherSuite suite, + const bytes& joiner_secret, + const std::vector& psks) +{ + auto psk_secret = make_psk_secret(suite, psks); + return welcome_secret_raw(suite, joiner_secret, psk_secret); +} + +bytes +KeyScheduleEpoch::welcome_secret_raw(CipherSuite suite, + const bytes& joiner_secret, + const bytes& psk_secret) +{ + auto extract = suite.hpke().kdf.extract(joiner_secret, psk_secret); + return suite.derive_secret(extract, "welcome"); +} + +KeyAndNonce +KeyScheduleEpoch::sender_data_keys(CipherSuite suite, + const bytes& sender_data_secret, + const bytes& ciphertext) +{ + auto sample_size = suite.secret_size(); + auto sample = bytes(sample_size); + if (ciphertext.size() <= sample_size) { + sample = ciphertext; + } else { + sample = ciphertext.slice(0, sample_size); + } + + auto key_size = suite.hpke().aead.key_size; + auto nonce_size = suite.hpke().aead.nonce_size; + return { + suite.expand_with_label(sender_data_secret, "key", sample, key_size), + suite.expand_with_label(sender_data_secret, "nonce", sample, nonce_size), + }; +} + +bool +operator==(const KeyScheduleEpoch& lhs, const KeyScheduleEpoch& rhs) +{ + auto epoch_secret = (lhs.epoch_secret == rhs.epoch_secret); + auto sender_data_secret = (lhs.sender_data_secret == rhs.sender_data_secret); + auto encryption_secret = (lhs.encryption_secret == rhs.encryption_secret); + auto exporter_secret = (lhs.exporter_secret == rhs.exporter_secret); + auto confirmation_key = (lhs.confirmation_key == rhs.confirmation_key); + auto init_secret = (lhs.init_secret == rhs.init_secret); + auto external_priv = (lhs.external_priv == rhs.external_priv); + + return epoch_secret && sender_data_secret && encryption_secret && + exporter_secret && confirmation_key && init_secret && external_priv; +} + +// struct { +// WireFormat wire_format; +// GroupContent content; // with content.content_type == commit +// opaque signature; +// } ConfirmedTranscriptHashInput; +struct ConfirmedTranscriptHashInput +{ + WireFormat wire_format; + const GroupContent& content; + const bytes& signature; + + TLS_SERIALIZABLE(wire_format, content, signature) +}; + +// struct { +// MAC confirmation_tag; +// } InterimTranscriptHashInput; +struct InterimTranscriptHashInput +{ + bytes confirmation_tag; + + TLS_SERIALIZABLE(confirmation_tag) +}; + +TranscriptHash::TranscriptHash(CipherSuite suite_in) + : suite(suite_in) +{ +} + +TranscriptHash::TranscriptHash(CipherSuite suite_in, + bytes confirmed_in, + const bytes& confirmation_tag) + : suite(suite_in) + , confirmed(std::move(confirmed_in)) +{ + update_interim(confirmation_tag); +} + +void +TranscriptHash::update(const AuthenticatedContent& content_auth) +{ + update_confirmed(content_auth); + update_interim(content_auth); +} + +void +TranscriptHash::update_confirmed(const AuthenticatedContent& content_auth) +{ + const auto transcript = + interim + content_auth.confirmed_transcript_hash_input(); + confirmed = suite.digest().hash(transcript); +} + +void +TranscriptHash::update_interim(const bytes& confirmation_tag) +{ + const auto transcript = confirmed + tls::marshal(confirmation_tag); + interim = suite.digest().hash(transcript); +} + +void +TranscriptHash::update_interim(const AuthenticatedContent& content_auth) +{ + const auto transcript = + confirmed + content_auth.interim_transcript_hash_input(); + interim = suite.digest().hash(transcript); +} + +bool +operator==(const TranscriptHash& lhs, const TranscriptHash& rhs) +{ + auto confirmed = (lhs.confirmed == rhs.confirmed); + auto interim = (lhs.interim == rhs.interim); + return confirmed && interim; +} + +} // namespace mlspp diff --git a/mlspp/src/messages.cpp b/mlspp/src/messages.cpp new file mode 100755 index 0000000000..5372ea5236 --- /dev/null +++ b/mlspp/src/messages.cpp @@ -0,0 +1,947 @@ +#include +#include +#include +#include + +#include "grease.h" + +namespace mlspp { + +// Extensions + +const Extension::Type ExternalPubExtension::type = ExtensionType::external_pub; +const Extension::Type RatchetTreeExtension::type = ExtensionType::ratchet_tree; +const Extension::Type ExternalSendersExtension::type = + ExtensionType::external_senders; +const Extension::Type SFrameParameters::type = ExtensionType::sframe_parameters; +const Extension::Type SFrameCapabilities::type = + ExtensionType::sframe_parameters; + +bool +SFrameCapabilities::compatible(const SFrameParameters& params) const +{ + return stdx::contains(cipher_suites, params.cipher_suite); +} + +// GroupContext + +GroupContext::GroupContext(CipherSuite cipher_suite_in, + bytes group_id_in, + epoch_t epoch_in, + bytes tree_hash_in, + bytes confirmed_transcript_hash_in, + ExtensionList extensions_in) + : cipher_suite(cipher_suite_in) + , group_id(std::move(group_id_in)) + , epoch(epoch_in) + , tree_hash(std::move(tree_hash_in)) + , confirmed_transcript_hash(std::move(confirmed_transcript_hash_in)) + , extensions(std::move(extensions_in)) +{ +} + +// GroupInfo + +GroupInfo::GroupInfo(GroupContext group_context_in, + ExtensionList extensions_in, + bytes confirmation_tag_in) + : group_context(std::move(group_context_in)) + , extensions(std::move(extensions_in)) + , confirmation_tag(std::move(confirmation_tag_in)) + , signer(0) +{ + grease(extensions); +} + +struct GroupInfoTBS +{ + GroupContext group_context; + ExtensionList extensions; + bytes confirmation_tag; + LeafIndex signer; + + TLS_SERIALIZABLE(group_context, extensions, confirmation_tag, signer) +}; + +bytes +GroupInfo::to_be_signed() const +{ + return tls::marshal( + GroupInfoTBS{ group_context, extensions, confirmation_tag, signer }); +} + +void +GroupInfo::sign(const TreeKEMPublicKey& tree, + LeafIndex signer_index, + const SignaturePrivateKey& priv) +{ + auto maybe_leaf = tree.leaf_node(signer_index); + if (!maybe_leaf) { + throw InvalidParameterError("Cannot sign from a blank leaf"); + } + + if (priv.public_key != opt::get(maybe_leaf).signature_key) { + throw InvalidParameterError("Bad key for index"); + } + + signer = signer_index; + signature = priv.sign(tree.suite, sign_label::group_info, to_be_signed()); +} + +bool +GroupInfo::verify(const TreeKEMPublicKey& tree) const +{ + auto maybe_leaf = tree.leaf_node(signer); + if (!maybe_leaf) { + throw InvalidParameterError("Signer not found"); + } + + const auto& leaf = opt::get(maybe_leaf); + return verify(leaf.signature_key); +} + +void +GroupInfo::sign(LeafIndex signer_index, const SignaturePrivateKey& priv) +{ + signer = signer_index; + signature = priv.sign( + group_context.cipher_suite, sign_label::group_info, to_be_signed()); +} + +bool +GroupInfo::verify(const SignaturePublicKey& pub) const +{ + return pub.verify(group_context.cipher_suite, + sign_label::group_info, + to_be_signed(), + signature); +} + +// Welcome + +Welcome::Welcome() + : cipher_suite(CipherSuite::ID::unknown) +{ +} + +Welcome::Welcome(CipherSuite suite, + const bytes& joiner_secret, + const std::vector& psks, + const GroupInfo& group_info) + : cipher_suite(suite) + , _joiner_secret(joiner_secret) +{ + // Cache the list of PSK IDs + for (const auto& psk : psks) { + _psks.psks.push_back(psk.id); + } + + // Pre-encrypt the GroupInfo + auto [key, nonce] = group_info_key_nonce(suite, joiner_secret, psks); + auto group_info_data = tls::marshal(group_info); + encrypted_group_info = + cipher_suite.hpke().aead.seal(key, nonce, {}, group_info_data); +} + +std::optional +Welcome::find(const KeyPackage& kp) const +{ + auto ref = kp.ref(); + for (size_t i = 0; i < secrets.size(); i++) { + if (ref == secrets[i].new_member) { + return static_cast(i); + } + } + return std::nullopt; +} + +void +Welcome::encrypt(const KeyPackage& kp, const std::optional& path_secret) +{ + auto gs = GroupSecrets{ _joiner_secret, std::nullopt, _psks }; + if (path_secret) { + gs.path_secret = GroupSecrets::PathSecret{ opt::get(path_secret) }; + } + + auto gs_data = tls::marshal(gs); + auto enc_gs = kp.init_key.encrypt( + kp.cipher_suite, encrypt_label::welcome, encrypted_group_info, gs_data); + secrets.push_back({ kp.ref(), enc_gs }); +} + +GroupSecrets +Welcome::decrypt_secrets(int kp_index, const HPKEPrivateKey& init_priv) const +{ + auto secrets_data = + init_priv.decrypt(cipher_suite, + encrypt_label::welcome, + encrypted_group_info, + secrets.at(kp_index).encrypted_group_secrets); + return tls::get(secrets_data); +} + +GroupInfo +Welcome::decrypt(const bytes& joiner_secret, + const std::vector& psks) const +{ + auto [key, nonce] = group_info_key_nonce(cipher_suite, joiner_secret, psks); + auto group_info_data = + cipher_suite.hpke().aead.open(key, nonce, {}, encrypted_group_info); + if (!group_info_data) { + throw ProtocolError("Welcome decryption failed"); + } + + return tls::get(opt::get(group_info_data)); +} + +KeyAndNonce +Welcome::group_info_key_nonce(CipherSuite suite, + const bytes& joiner_secret, + const std::vector& psks) +{ + auto welcome_secret = + KeyScheduleEpoch::welcome_secret(suite, joiner_secret, psks); + + // XXX(RLB): These used to be done with ExpandWithLabel. Should we do that + // instead, for better domain separation? (In particular, including "mls10") + // That is what we do for the sender data key/nonce. + auto key = + suite.expand_with_label(welcome_secret, "key", {}, suite.key_size()); + auto nonce = + suite.expand_with_label(welcome_secret, "nonce", {}, suite.nonce_size()); + return { std::move(key), std::move(nonce) }; +} + +// Commit +std::optional +Commit::valid_external() const +{ + // External Commits MUST contain a path field (and is therefore a "full" + // Commit). The joiner is added at the leftmost free leaf node (just as if + // they were added with an Add proposal), and the path is calculated relative + // to that leaf node. + // + // The Commit MUST NOT include any proposals by reference, since an external + // joiner cannot determine the validity of proposals sent within the group + const auto all_by_value = stdx::all_of(proposals, [](const auto& p) { + return var::holds_alternative(p.content); + }); + if (!path || !all_by_value) { + return std::nullopt; + } + + const auto ext_init_ptr = stdx::find_if(proposals, [](const auto& p) { + const auto proposal = var::get(p.content); + return proposal.proposal_type() == ProposalType::external_init; + }); + if (ext_init_ptr == proposals.end()) { + return std::nullopt; + } + + const auto& ext_init_proposal = var::get(ext_init_ptr->content); + const auto& ext_init = var::get(ext_init_proposal.content); + return ext_init.kem_output; +} + +// PublicMessage +Proposal::Type +Proposal::proposal_type() const +{ + return tls::variant::type(content).val; +} + +SenderType +Sender::sender_type() const +{ + return tls::variant::type(sender); +} + +tls::ostream& +operator<<(tls::ostream& str, const GroupContentAuthData& obj) +{ + switch (obj.content_type) { + case ContentType::proposal: + case ContentType::application: + return str << obj.signature; + + case ContentType::commit: + return str << obj.signature << opt::get(obj.confirmation_tag); + + default: + throw InvalidParameterError("Invalid content type"); + } +} + +tls::istream& +operator>>(tls::istream& str, GroupContentAuthData& obj) +{ + switch (obj.content_type) { + case ContentType::proposal: + case ContentType::application: + return str >> obj.signature; + + case ContentType::commit: + obj.confirmation_tag.emplace(); + return str >> obj.signature >> opt::get(obj.confirmation_tag); + + default: + throw InvalidParameterError("Invalid content type"); + } +} + +bool +operator==(const GroupContentAuthData& lhs, const GroupContentAuthData& rhs) +{ + return lhs.content_type == rhs.content_type && + lhs.signature == rhs.signature && + lhs.confirmation_tag == rhs.confirmation_tag; +} + +GroupContent::GroupContent(bytes group_id_in, + epoch_t epoch_in, + Sender sender_in, + bytes authenticated_data_in, + RawContent content_in) + : group_id(std::move(group_id_in)) + , epoch(epoch_in) + , sender(sender_in) + , authenticated_data(std::move(authenticated_data_in)) + , content(std::move(content_in)) +{ +} + +GroupContent::GroupContent(bytes group_id_in, + epoch_t epoch_in, + Sender sender_in, + bytes authenticated_data_in, + ContentType content_type) + : group_id(std::move(group_id_in)) + , epoch(epoch_in) + , sender(sender_in) + , authenticated_data(std::move(authenticated_data_in)) +{ + switch (content_type) { + case ContentType::commit: + content.emplace(); + break; + + case ContentType::proposal: + content.emplace(); + break; + + case ContentType::application: + content.emplace(); + break; + + default: + throw InvalidParameterError("Invalid content type"); + } +} + +ContentType +GroupContent::content_type() const +{ + return tls::variant::type(content); +} + +AuthenticatedContent +AuthenticatedContent::sign(WireFormat wire_format, + GroupContent content, + CipherSuite suite, + const SignaturePrivateKey& sig_priv, + const std::optional& context) +{ + if (wire_format == WireFormat::mls_public_message && + content.content_type() == ContentType::application) { + throw InvalidParameterError( + "Application data cannot be sent as PublicMessage"); + } + + auto content_auth = AuthenticatedContent{ wire_format, std::move(content) }; + auto tbs = content_auth.to_be_signed(context); + content_auth.auth.signature = + sig_priv.sign(suite, sign_label::mls_content, tbs); + return content_auth; +} + +bool +AuthenticatedContent::verify(CipherSuite suite, + const SignaturePublicKey& sig_pub, + const std::optional& context) const +{ + if (wire_format == WireFormat::mls_public_message && + content.content_type() == ContentType::application) { + return false; + } + + auto tbs = to_be_signed(context); + return sig_pub.verify(suite, sign_label::mls_content, tbs, auth.signature); +} + +struct ConfirmedTranscriptHashInput +{ + WireFormat wire_format; + const GroupContent& content; + const bytes& signature; + + TLS_SERIALIZABLE(wire_format, content, signature); +}; + +struct InterimTranscriptHashInput +{ + const bytes& confirmation_tag; + + TLS_SERIALIZABLE(confirmation_tag); +}; + +bytes +AuthenticatedContent::confirmed_transcript_hash_input() const +{ + return tls::marshal(ConfirmedTranscriptHashInput{ + wire_format, + content, + auth.signature, + }); +} + +bytes +AuthenticatedContent::interim_transcript_hash_input() const +{ + return tls::marshal( + InterimTranscriptHashInput{ opt::get(auth.confirmation_tag) }); +} + +void +AuthenticatedContent::set_confirmation_tag(const bytes& confirmation_tag) +{ + auth.confirmation_tag = confirmation_tag; +} + +bool +AuthenticatedContent::check_confirmation_tag( + const bytes& confirmation_tag) const +{ + return confirmation_tag == opt::get(auth.confirmation_tag); +} + +tls::ostream& +operator<<(tls::ostream& str, const AuthenticatedContent& obj) +{ + return str << obj.wire_format << obj.content << obj.auth; +} + +tls::istream& +operator>>(tls::istream& str, AuthenticatedContent& obj) +{ + str >> obj.wire_format >> obj.content; + + obj.auth.content_type = obj.content.content_type(); + return str >> obj.auth; +} + +bool +operator==(const AuthenticatedContent& lhs, const AuthenticatedContent& rhs) +{ + return lhs.wire_format == rhs.wire_format && lhs.content == rhs.content && + lhs.auth == rhs.auth; +} + +AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in, + GroupContent content_in) + : wire_format(wire_format_in) + , content(std::move(content_in)) +{ + auth.content_type = content.content_type(); +} + +AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in, + GroupContent content_in, + GroupContentAuthData auth_in) + : wire_format(wire_format_in) + , content(std::move(content_in)) + , auth(std::move(auth_in)) +{ +} + +const AuthenticatedContent& +ValidatedContent::authenticated_content() const +{ + return content_auth; +} + +ValidatedContent::ValidatedContent(AuthenticatedContent content_auth_in) + : content_auth(std::move(content_auth_in)) +{ +} + +bool +operator==(const ValidatedContent& lhs, const ValidatedContent& rhs) +{ + return lhs.content_auth == rhs.content_auth; +} + +struct GroupContentTBS +{ + WireFormat wire_format = WireFormat::reserved; + const GroupContent& content; + const std::optional& context; +}; + +static tls::ostream& +operator<<(tls::ostream& str, const GroupContentTBS& obj) +{ + str << ProtocolVersion::mls10 << obj.wire_format << obj.content; + + switch (obj.content.sender.sender_type()) { + case SenderType::member: + case SenderType::new_member_commit: + str << opt::get(obj.context); + break; + + case SenderType::external: + case SenderType::new_member_proposal: + break; + + default: + throw InvalidParameterError("Invalid sender type"); + } + + return str; +} + +bytes +AuthenticatedContent::to_be_signed( + const std::optional& context) const +{ + return tls::marshal(GroupContentTBS{ + wire_format, + content, + context, + }); +} + +PublicMessage +PublicMessage::protect(AuthenticatedContent content_auth, + CipherSuite suite, + const std::optional& membership_key, + const std::optional& context) +{ + auto pt = PublicMessage(std::move(content_auth)); + + // Add the membership_mac if required + switch (pt.content.sender.sender_type()) { + case SenderType::member: + pt.membership_tag = + pt.membership_mac(suite, opt::get(membership_key), context); + break; + + default: + break; + } + + return pt; +} + +std::optional +PublicMessage::unprotect(CipherSuite suite, + const std::optional& membership_key, + const std::optional& context) const +{ + // Verify the membership_tag if the message was sent within the group + switch (content.sender.sender_type()) { + case SenderType::member: { + auto candidate = membership_mac(suite, opt::get(membership_key), context); + if (candidate != opt::get(membership_tag)) { + return std::nullopt; + } + break; + } + + default: + break; + } + + return { { AuthenticatedContent{ + WireFormat::mls_public_message, + content, + auth, + } } }; +} + +bool +PublicMessage::contains(const AuthenticatedContent& content_auth) const +{ + return content == content_auth.content && auth == content_auth.auth; +} + +AuthenticatedContent +PublicMessage::authenticated_content() const +{ + auto auth_content = AuthenticatedContent{}; + auth_content.wire_format = WireFormat::mls_public_message; + auth_content.content = content; + auth_content.auth = auth; + return auth_content; +} + +PublicMessage::PublicMessage(AuthenticatedContent content_auth) + : content(std::move(content_auth.content)) + , auth(std::move(content_auth.auth)) +{ + if (content_auth.wire_format != WireFormat::mls_public_message) { + throw InvalidParameterError("Wire format mismatch (not mls_plaintext)"); + } +} + +struct GroupContentTBM +{ + GroupContentTBS content_tbs; + GroupContentAuthData auth; + + TLS_SERIALIZABLE(content_tbs, auth); +}; + +bytes +PublicMessage::membership_mac(CipherSuite suite, + const bytes& membership_key, + const std::optional& context) const +{ + auto tbm = tls::marshal(GroupContentTBM{ + { WireFormat::mls_public_message, content, context }, + auth, + }); + + return suite.digest().hmac(membership_key, tbm); +} + +tls::ostream& +operator<<(tls::ostream& str, const PublicMessage& obj) +{ + switch (obj.content.sender.sender_type()) { + case SenderType::member: + return str << obj.content << obj.auth << opt::get(obj.membership_tag); + + case SenderType::external: + case SenderType::new_member_proposal: + case SenderType::new_member_commit: + return str << obj.content << obj.auth; + + default: + throw InvalidParameterError("Invalid sender type"); + } +} + +tls::istream& +operator>>(tls::istream& str, PublicMessage& obj) +{ + str >> obj.content; + + obj.auth.content_type = obj.content.content_type(); + str >> obj.auth; + + if (obj.content.sender.sender_type() == SenderType::member) { + obj.membership_tag.emplace(); + str >> opt::get(obj.membership_tag); + } + + return str; +} + +bool +operator==(const PublicMessage& lhs, const PublicMessage& rhs) +{ + return lhs.content == rhs.content && lhs.auth == rhs.auth && + lhs.membership_tag == rhs.membership_tag; +} + +bool +operator!=(const PublicMessage& lhs, const PublicMessage& rhs) +{ + return !(lhs == rhs); +} + +static bytes +marshal_ciphertext_content(const GroupContent& content, + const GroupContentAuthData& auth, + size_t padding_size) +{ + auto w = tls::ostream{}; + var::visit([&w](const auto& val) { w << val; }, content.content); + w << auth; + w.write_raw(bytes(padding_size, 0)); + return w.bytes(); +} + +static void +unmarshal_ciphertext_content(const bytes& content_pt, + GroupContent& content, + GroupContentAuthData& auth) +{ + auto r = tls::istream(content_pt); + + var::visit([&r](auto& val) { r >> val; }, content.content); + r >> auth; + + const auto padding = r.bytes(); + const auto nonzero = [](const auto& x) { return x != 0; }; + if (stdx::any_of(padding, nonzero)) { + throw ProtocolError("Malformed AuthenticatedContentTBE padding"); + } +} + +struct ContentAAD +{ + const bytes& group_id; + const epoch_t epoch; + const ContentType content_type; + const bytes& authenticated_data; + + TLS_SERIALIZABLE(group_id, epoch, content_type, authenticated_data) +}; + +struct SenderData +{ + LeafIndex sender{ 0 }; + uint32_t generation{ 0 }; + ReuseGuard reuse_guard{ 0, 0, 0, 0 }; + + TLS_SERIALIZABLE(sender, generation, reuse_guard) +}; + +struct SenderDataAAD +{ + const bytes& group_id; + const epoch_t epoch; + const ContentType content_type; + + TLS_SERIALIZABLE(group_id, epoch, content_type) +}; + +PrivateMessage +PrivateMessage::protect(AuthenticatedContent content_auth, + CipherSuite suite, + GroupKeySource& keys, + const bytes& sender_data_secret, + size_t padding_size) +{ + // Pull keys from the secret tree + auto index = + var::get(content_auth.content.sender.sender).sender; + auto content_type = content_auth.content.content_type(); + auto [generation, reuse_guard, content_keys] = keys.next(content_type, index); + + // Encrypt the content + auto content_pt = marshal_ciphertext_content( + content_auth.content, content_auth.auth, padding_size); + auto content_aad = tls::marshal(ContentAAD{ + content_auth.content.group_id, + content_auth.content.epoch, + content_auth.content.content_type(), + content_auth.content.authenticated_data, + }); + + auto content_ct = suite.hpke().aead.seal( + content_keys.key, content_keys.nonce, content_aad, content_pt); + + // Encrypt the sender data + auto sender_index = + var::get(content_auth.content.sender.sender).sender; + auto sender_data_pt = tls::marshal(SenderData{ + sender_index, + generation, + reuse_guard, + }); + auto sender_data_aad = tls::marshal(SenderDataAAD{ + content_auth.content.group_id, + content_auth.content.epoch, + content_auth.content.content_type(), + }); + + auto sender_data_keys = + KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, content_ct); + + auto sender_data_ct = suite.hpke().aead.seal(sender_data_keys.key, + sender_data_keys.nonce, + sender_data_aad, + sender_data_pt); + + return PrivateMessage{ + std::move(content_auth.content), + std::move(sender_data_ct), + std::move(content_ct), + }; +} + +std::optional +PrivateMessage::unprotect(CipherSuite suite, + GroupKeySource& keys, + const bytes& sender_data_secret) const +{ + // Decrypt and parse the sender data + auto sender_data_keys = + KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, ciphertext); + auto sender_data_aad = tls::marshal(SenderDataAAD{ + group_id, + epoch, + content_type, + }); + + auto sender_data_pt = suite.hpke().aead.open(sender_data_keys.key, + sender_data_keys.nonce, + sender_data_aad, + encrypted_sender_data); + if (!sender_data_pt) { + return std::nullopt; + } + + auto sender_data = tls::get(opt::get(sender_data_pt)); + if (!keys.has_leaf(sender_data.sender)) { + return std::nullopt; + } + + // Decrypt the content + auto content_keys = keys.get(content_type, + sender_data.sender, + sender_data.generation, + sender_data.reuse_guard); + keys.erase(content_type, sender_data.sender, sender_data.generation); + + auto content_aad = tls::marshal(ContentAAD{ + group_id, + epoch, + content_type, + authenticated_data, + }); + + auto content_pt = suite.hpke().aead.open( + content_keys.key, content_keys.nonce, content_aad, ciphertext); + if (!content_pt) { + return std::nullopt; + } + + // Parse the content + auto content = GroupContent{ group_id, + epoch, + { MemberSender{ sender_data.sender } }, + authenticated_data, + content_type }; + auto auth = GroupContentAuthData{ content_type, {}, {} }; + + unmarshal_ciphertext_content(opt::get(content_pt), content, auth); + + return { { AuthenticatedContent{ + WireFormat::mls_private_message, + std::move(content), + std::move(auth), + } } }; +} + +PrivateMessage::PrivateMessage(GroupContent content, + bytes encrypted_sender_data_in, + bytes ciphertext_in) + : group_id(std::move(content.group_id)) + , epoch(content.epoch) + , content_type(content.content_type()) + , authenticated_data(std::move(content.authenticated_data)) + , encrypted_sender_data(std::move(encrypted_sender_data_in)) + , ciphertext(std::move(ciphertext_in)) +{ +} + +bytes +MLSMessage::group_id() const +{ + return var::visit( + overloaded{ + [](const PublicMessage& pt) -> bytes { return pt.get_group_id(); }, + [](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); }, + [](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; }, + [](const auto& /* unused */) -> bytes { + throw InvalidParameterError("MLSMessage has no group_id"); + }, + }, + message); +} + +epoch_t +MLSMessage::epoch() const +{ + return var::visit( + overloaded{ + [](const PublicMessage& pt) -> epoch_t { return pt.get_epoch(); }, + [](const PrivateMessage& pt) -> epoch_t { return pt.get_epoch(); }, + [](const auto& /* unused */) -> epoch_t { + throw InvalidParameterError("MLSMessage has no epoch"); + }, + }, + message); +} + +WireFormat +MLSMessage::wire_format() const +{ + return tls::variant::type(message); +} + +MLSMessage::MLSMessage(PublicMessage public_message) + : message(std::move(public_message)) +{ +} + +MLSMessage::MLSMessage(PrivateMessage private_message) + : message(std::move(private_message)) +{ +} + +MLSMessage::MLSMessage(Welcome welcome) + : message(std::move(welcome)) +{ +} + +MLSMessage::MLSMessage(GroupInfo group_info) + : message(std::move(group_info)) +{ +} + +MLSMessage::MLSMessage(KeyPackage key_package) + : message(std::move(key_package)) +{ +} + +MLSMessage +external_proposal(CipherSuite suite, + const bytes& group_id, + epoch_t epoch, + const Proposal& proposal, + uint32_t signer_index, + const SignaturePrivateKey& sig_priv) +{ + switch (proposal.proposal_type()) { + // These proposal types are OK + case ProposalType::add: + case ProposalType::remove: + case ProposalType::psk: + case ProposalType::reinit: + case ProposalType::group_context_extensions: + break; + + // These proposal types are forbidden + case ProposalType::invalid: + case ProposalType::update: + case ProposalType::external_init: + default: + throw ProtocolError("External proposal has invalid type"); + } + + auto content = GroupContent{ group_id, + epoch, + { ExternalSenderIndex{ signer_index } }, + { /* no authenticated data */ }, + { proposal } }; + auto content_auth = AuthenticatedContent::sign( + WireFormat::mls_public_message, std::move(content), suite, sig_priv, {}); + + return PublicMessage::protect(std::move(content_auth), suite, {}, {}); +} + +} // namespace mlspp diff --git a/mlspp/src/session.cpp b/mlspp/src/session.cpp new file mode 100755 index 0000000000..f96b6bba63 --- /dev/null +++ b/mlspp/src/session.cpp @@ -0,0 +1,437 @@ +#include +#include + +#include + +namespace mlspp { + +/// +/// Inner struct declarations for PendingJoin and Session +/// + +struct PendingJoin::Inner +{ + const CipherSuite suite; + const HPKEPrivateKey init_priv; + const HPKEPrivateKey leaf_priv; + const SignaturePrivateKey sig_priv; + const KeyPackage key_package; + + Inner(CipherSuite suite_in, + SignaturePrivateKey sig_priv_in, + Credential cred_in); + + static PendingJoin create(CipherSuite suite, + SignaturePrivateKey sig_priv, + Credential cred); +}; + +struct Session::Inner +{ + std::deque history; + std::map outbound_cache; + bool encrypt_handshake{ false }; + + explicit Inner(State state); + + static Session begin(CipherSuite suite, + const bytes& group_id, + const HPKEPrivateKey& leaf_priv, + const SignaturePrivateKey& sig_priv, + const LeafNode& leaf_node); + static Session join(const HPKEPrivateKey& init_priv, + const HPKEPrivateKey& leaf_priv, + const SignaturePrivateKey& sig_priv, + const KeyPackage& key_package, + const bytes& welcome_data); + + bytes fresh_secret() const; + MLSMessage import_handshake(const bytes& encoded) const; + State& for_epoch(epoch_t epoch); +}; + +/// +/// Client +/// + +Client::Client(CipherSuite suite_in, + SignaturePrivateKey sig_priv_in, + Credential cred_in) + : suite(suite_in) + , sig_priv(std::move(sig_priv_in)) + , cred(std::move(cred_in)) +{ +} + +Session +Client::begin_session(const bytes& group_id) const +{ + auto leaf_priv = HPKEPrivateKey::generate(suite); + auto leaf_node = LeafNode(suite, + leaf_priv.public_key, + sig_priv.public_key, + cred, + Capabilities::create_default(), + Lifetime::create_default(), + {}, + sig_priv); + + return Session::Inner::begin(suite, group_id, leaf_priv, sig_priv, leaf_node); +} + +PendingJoin +Client::start_join() const +{ + return PendingJoin::Inner::create(suite, sig_priv, cred); +} + +/// +/// PendingJoin +/// + +PendingJoin::Inner::Inner(CipherSuite suite_in, + SignaturePrivateKey sig_priv_in, + Credential cred_in) + : suite(suite_in) + , init_priv(HPKEPrivateKey::generate(suite)) + , leaf_priv(HPKEPrivateKey::generate(suite)) + , sig_priv(std::move(sig_priv_in)) + , key_package(suite, + init_priv.public_key, + LeafNode(suite, + leaf_priv.public_key, + sig_priv.public_key, + std::move(cred_in), + Capabilities::create_default(), + Lifetime::create_default(), + {}, + sig_priv), + {}, + sig_priv) +{ +} + +PendingJoin +PendingJoin::Inner::create(CipherSuite suite, + SignaturePrivateKey sig_priv, + Credential cred) +{ + auto inner = + std::make_unique(suite, std::move(sig_priv), std::move(cred)); + return { inner.release() }; +} + +PendingJoin::PendingJoin(PendingJoin&& other) noexcept = default; + +PendingJoin& +PendingJoin::operator=(PendingJoin&& other) noexcept = default; + +PendingJoin::~PendingJoin() = default; + +PendingJoin::PendingJoin(Inner* inner_in) + : inner(inner_in) +{ +} + +bytes +PendingJoin::key_package() const +{ + return tls::marshal(inner->key_package); +} + +Session +PendingJoin::complete(const bytes& welcome) const +{ + return Session::Inner::join(inner->init_priv, + inner->leaf_priv, + inner->sig_priv, + inner->key_package, + welcome); +} + +/// +/// Session +/// + +Session::Inner::Inner(State state) + : history{ std::move(state) } + , encrypt_handshake(true) +{ +} + +Session +Session::Inner::begin(CipherSuite suite, + const bytes& group_id, + const HPKEPrivateKey& leaf_priv, + const SignaturePrivateKey& sig_priv, + const LeafNode& leaf_node) +{ + auto state = State(group_id, suite, leaf_priv, sig_priv, leaf_node, {}); + auto inner = std::make_unique(state); + return { inner.release() }; +} + +Session +Session::Inner::join(const HPKEPrivateKey& init_priv, + const HPKEPrivateKey& leaf_priv, + const SignaturePrivateKey& sig_priv, + const KeyPackage& key_package, + const bytes& welcome_data) +{ + auto welcome = tls::get(welcome_data); + + auto state = State( + init_priv, leaf_priv, sig_priv, key_package, welcome, std::nullopt, {}); + auto inner = std::make_unique(state); + return { inner.release() }; +} + +bytes +Session::Inner::fresh_secret() const +{ + const auto suite = history.front().cipher_suite(); + return random_bytes(suite.secret_size()); +} + +MLSMessage +Session::Inner::import_handshake(const bytes& encoded) const +{ + auto msg = tls::get(encoded); + + switch (msg.wire_format()) { + case WireFormat::mls_public_message: + if (encrypt_handshake) { + throw ProtocolError("Handshake not encrypted as required"); + } + + return msg; + + case WireFormat::mls_private_message: { + if (!encrypt_handshake) { + throw ProtocolError("Unexpected handshake encryption"); + } + + return msg; + } + + default: + throw InvalidParameterError("Illegal wire format"); + } +} + +State& +Session::Inner::for_epoch(epoch_t epoch) +{ + for (auto& state : history) { + if (state.epoch() == epoch) { + return state; + } + } + + throw MissingStateError("No state for epoch"); +} + +Session::Session(Session&& other) noexcept = default; + +Session& +Session::operator=(Session&& other) noexcept = default; + +Session::~Session() = default; + +Session::Session(Inner* inner_in) + : inner(inner_in) +{ +} + +void +Session::encrypt_handshake(bool enabled) +{ + inner->encrypt_handshake = enabled; +} + +bytes +Session::add(const bytes& key_package_data) +{ + auto key_package = tls::get(key_package_data); + auto proposal = inner->history.front().add( + key_package, { inner->encrypt_handshake, {}, 0 }); + return tls::marshal(proposal); +} + +bytes +Session::update() +{ + auto leaf_secret = inner->fresh_secret(); + + auto leaf_priv = HPKEPrivateKey::generate(cipher_suite()); + auto proposal = inner->history.front().update( + std::move(leaf_priv), {}, { inner->encrypt_handshake, {}, 0 }); + return tls::marshal(proposal); +} + +bytes +Session::remove(uint32_t index) +{ + auto proposal = inner->history.front().remove( + RosterIndex{ index }, { inner->encrypt_handshake, {}, 0 }); + return tls::marshal(proposal); +} + +std::tuple +Session::commit(const bytes& proposal) +{ + return commit(std::vector{ proposal }); +} + +std::tuple +Session::commit(const std::vector& proposals) +{ + auto provisional_state = inner->history.front(); + for (const auto& proposal_data : proposals) { + auto msg = inner->import_handshake(proposal_data); + auto maybe_state = provisional_state.handle(msg); + if (maybe_state) { + throw InvalidParameterError("Invalid proposal; actually a commit"); + } + } + + inner->history.front() = std::move(provisional_state); + return commit(); +} + +std::tuple +Session::commit() +{ + auto commit_secret = inner->fresh_secret(); + auto encrypt = inner->encrypt_handshake; + auto [commit, welcome, new_state] = inner->history.front().commit( + commit_secret, CommitOpts{ {}, true, encrypt, {} }, { encrypt, {}, 0 }); + + auto commit_msg = tls::marshal(commit); + auto welcome_msg = tls::marshal(welcome); + + inner->outbound_cache.insert({ commit_msg, new_state }); + return std::make_tuple(welcome_msg, commit_msg); +} + +bool +Session::handle(const bytes& handshake_data) +{ + auto msg = inner->import_handshake(handshake_data); + + auto maybe_cached_state = std::optional{}; + auto node = inner->outbound_cache.extract(handshake_data); + if (!node.empty()) { + maybe_cached_state = node.mapped(); + } + + auto maybe_next_state = + inner->history.front().handle(msg, maybe_cached_state); + if (!maybe_next_state) { + return false; + } + + inner->history.emplace_front(opt::get(maybe_next_state)); + return true; +} + +epoch_t +Session::epoch() const +{ + return inner->history.front().epoch(); +} + +LeafIndex +Session::index() const +{ + return inner->history.front().index(); +} + +CipherSuite +Session::cipher_suite() const +{ + return inner->history.front().cipher_suite(); +} + +const ExtensionList& +Session::extensions() const +{ + return inner->history.front().extensions(); +} + +const TreeKEMPublicKey& +Session::tree() const +{ + return inner->history.front().tree(); +} + +bytes +Session::do_export(const std::string& label, + const bytes& context, + size_t size) const +{ + return inner->history.front().do_export(label, context, size); +} + +GroupInfo +Session::group_info() const +{ + return inner->history.front().group_info(true); +} + +std::vector +Session::roster() const +{ + return inner->history.front().roster(); +} + +bytes +Session::epoch_authenticator() const +{ + return inner->history.front().epoch_authenticator(); +} + +bytes +Session::protect(const bytes& plaintext) +{ + auto msg = inner->history.front().protect({}, plaintext, 0); + return tls::marshal(msg); +} + +// TODO(rlb@ipv.sx): It would be good to expose identity information +// here, since ciphertexts are authenticated per sender. Who sent +// this ciphertext? +bytes +Session::unprotect(const bytes& ciphertext) +{ + auto ciphertext_obj = tls::get(ciphertext); + auto& state = inner->for_epoch(ciphertext_obj.epoch()); + auto [aad, pt] = state.unprotect(ciphertext_obj); + silence_unused(aad); + return pt; +} + +bool +operator==(const Session& lhs, const Session& rhs) +{ + if (lhs.inner->encrypt_handshake != rhs.inner->encrypt_handshake) { + return false; + } + + auto size = std::min(lhs.inner->history.size(), rhs.inner->history.size()); + for (size_t i = 0; i < size; i += 1) { + if (lhs.inner->history.at(i) != rhs.inner->history.at(i)) { + return false; + } + } + + return true; +} + +bool +operator!=(const Session& lhs, const Session& rhs) +{ + return !(lhs == rhs); +} + +} // namespace mlspp diff --git a/mlspp/src/state.cpp b/mlspp/src/state.cpp new file mode 100755 index 0000000000..ed15c1e9eb --- /dev/null +++ b/mlspp/src/state.cpp @@ -0,0 +1,2219 @@ +#include +#include + +namespace mlspp { + +/// +/// Constructors +/// + +State::State(bytes group_id, + CipherSuite suite, + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const LeafNode& leaf_node, + ExtensionList extensions) + : _suite(suite) + , _group_id(std::move(group_id)) + , _epoch(0) + , _tree(suite) + , _transcript_hash(suite) + , _extensions(std::move(extensions)) + , _index(0) + , _identity_priv(std::move(sig_priv)) +{ + // Verify that the client supports the proposed group extensions + if (!leaf_node.verify_extension_support(_extensions)) { + throw InvalidParameterError("Client doesn't support required extensions"); + } + + _index = _tree.add_leaf(leaf_node); + _tree.set_hash_all(); + _tree_priv = TreeKEMPrivateKey::solo(suite, _index, std::move(enc_priv)); + if (!_tree_priv.consistent(_tree)) { + throw InvalidParameterError("LeafNode inconsistent with private key"); + } + + // XXX(RLB): Convert KeyScheduleEpoch to take GroupContext? + auto ctx = tls::marshal(group_context()); + _key_schedule = + KeyScheduleEpoch(_suite, random_bytes(_suite.secret_size()), ctx); + _keys = _key_schedule.encryption_keys(_tree.size); + + // Update the interim transcript hash with a virtual confirmation tag + _transcript_hash.update_interim( + _key_schedule.confirmation_tag(_transcript_hash.confirmed)); +} + +TreeKEMPublicKey +State::import_tree(const bytes& tree_hash, + const std::optional& external, + const ExtensionList& extensions) +{ + auto tree = TreeKEMPublicKey(_suite); + auto maybe_tree_extn = extensions.find(); + if (external) { + tree = opt::get(external); + } else if (maybe_tree_extn) { + tree = opt::get(maybe_tree_extn).tree; + } else { + throw InvalidParameterError("No tree available"); + } + + tree.suite = _suite; + + tree.set_hash_all(); + if (tree.root_hash() != tree_hash) { + throw InvalidParameterError("Tree does not match GroupInfo"); + } + + return tree; +} + +bool +State::validate_tree() const +{ + // The functionality here is somewhat duplicative of State::valid(const + // LeafNode&). Simply calling that method, however, would result in this + // method having quadratic scaling, since each call to valid() does a linear + // scan through the tree to check uniqueness of keys and compatibility of + // credential support. + + // Validate that the tree is parent-hash valid + if (!_tree.parent_hash_valid()) { + return false; + } + + // Validate the signatures on all leaves + const auto signature_valid = + _tree.all_leaves([&](auto i, const auto& leaf_node) { + auto binding = std::optional{}; + switch (leaf_node.source()) { + case LeafNodeSource::commit: + case LeafNodeSource::update: + binding = LeafNode::MemberBinding{ _group_id, i }; + break; + + default: + // Nothing to do + break; + } + + return leaf_node.verify(_suite, binding); + }); + if (!signature_valid) { + return false; + } + + // Collect cross-tree properties + auto n_leaves = size_t(0); + auto encryption_keys = std::set{}; + auto signature_keys = std::set{}; + auto credential_types = std::set{}; + _tree.all_leaves([&](auto /* i */, const auto& leaf_node) { + n_leaves += 1; + encryption_keys.insert(leaf_node.encryption_key.data); + signature_keys.insert(leaf_node.signature_key.data); + credential_types.insert(leaf_node.credential.type()); + return true; + }); + + // Verify uniqueness of keys + if (encryption_keys.size() != n_leaves) { + return false; + } + + if (signature_keys.size() != n_leaves) { + return false; + } + + // Verify that each leaf indicates support for all required parameters + return _tree.all_leaves([&](auto /* i */, const auto& leaf_node) { + const auto supports_group_extensions = + leaf_node.verify_extension_support(_extensions); + const auto supports_own_extensions = + leaf_node.verify_extension_support(leaf_node.extensions); + const auto supports_group_credentials = + leaf_node.capabilities.credentials_supported(credential_types); + return supports_group_extensions && supports_own_extensions && + supports_group_credentials; + }); +} + +State::State(SignaturePrivateKey sig_priv, + const GroupInfo& group_info, + const std::optional& tree) + : _suite(group_info.group_context.cipher_suite) + , _group_id(group_info.group_context.group_id) + , _epoch(group_info.group_context.epoch) + , _tree(import_tree(group_info.group_context.tree_hash, + tree, + group_info.extensions)) + , _transcript_hash(_suite, + group_info.group_context.confirmed_transcript_hash, + group_info.confirmation_tag) + , _extensions(group_info.group_context.extensions) + , _key_schedule(_suite) + , _index(0) + , _identity_priv(std::move(sig_priv)) +{ + if (!validate_tree()) { + throw InvalidParameterError("Invalid tree"); + } + + // The following are not set: + // _index + // _tree_priv + // + // This ctor should only be used within external_commit, in which case these + // fields are populated by the subsequent commit() +} + +// Initialize a group from a Welcome +State::State(const HPKEPrivateKey& init_priv, + HPKEPrivateKey leaf_priv, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const Welcome& welcome, + const std::optional& tree, + std::map external_psks) + : State(init_priv, + std::move(leaf_priv), + std::move(sig_priv), + key_package, + welcome, + tree, + std::move(external_psks), + {}) +{ +} + +State::State(const HPKEPrivateKey& init_priv, + HPKEPrivateKey leaf_priv, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const Welcome& welcome, + const std::optional& tree, + std::map external_psks, + std::map resumption_psks) + : _suite(welcome.cipher_suite) + , _epoch(0) + , _tree(welcome.cipher_suite) + , _transcript_hash(welcome.cipher_suite) + , _identity_priv(std::move(sig_priv)) + , _external_psks(std::move(external_psks)) + , _resumption_psks(std::move(resumption_psks)) +{ + auto maybe_kpi = welcome.find(key_package); + if (!maybe_kpi) { + throw InvalidParameterError("Welcome not intended for key package"); + } + auto kpi = opt::get(maybe_kpi); + + if (key_package.cipher_suite != welcome.cipher_suite) { + throw InvalidParameterError("Ciphersuite mismatch"); + } + + // Decrypt the GroupSecrets and look up required PSKs + auto secrets = welcome.decrypt_secrets(kpi, init_priv); + auto psks = resolve(secrets.psks.psks); + + // Decrypt the GroupInfo + auto group_info = welcome.decrypt(secrets.joiner_secret, psks); + if (group_info.group_context.cipher_suite != _suite) { + throw InvalidParameterError("GroupInfo and Welcome ciphersuites disagree"); + } + + // Import the tree from the argument or from the extension + _tree = import_tree( + group_info.group_context.tree_hash, tree, group_info.extensions); + + // Verify the signature on the GroupInfo + if (!group_info.verify(_tree)) { + throw InvalidParameterError("Invalid GroupInfo"); + } + + // Ingest the GroupSecrets and GroupInfo + _epoch = group_info.group_context.epoch; + _group_id = group_info.group_context.group_id; + + _transcript_hash.confirmed = + group_info.group_context.confirmed_transcript_hash; + _transcript_hash.update_interim(group_info.confirmation_tag); + + _extensions = group_info.group_context.extensions; + + // Validate that the tree is in fact consistent with the group's parameters + if (!validate_tree()) { + throw InvalidParameterError("Invalid tree"); + } + + // Construct TreeKEM private key from parts provided + auto maybe_index = _tree.find(key_package.leaf_node); + if (!maybe_index) { + throw InvalidParameterError("New joiner not in tree"); + } + + _index = opt::get(maybe_index); + + auto ancestor = _index.ancestor(group_info.signer); + auto path_secret = std::optional{}; + if (secrets.path_secret) { + path_secret = opt::get(secrets.path_secret).secret; + } + + _tree_priv = TreeKEMPrivateKey::joiner( + _tree, _index, std::move(leaf_priv), ancestor, path_secret); + + // Ratchet forward into the current epoch + auto group_ctx = tls::marshal(group_context()); + _key_schedule = + KeyScheduleEpoch::joiner(_suite, secrets.joiner_secret, psks, group_ctx); + _keys = _key_schedule.encryption_keys(_tree.size); + + // Verify the confirmation + const auto confirmation_tag = + _key_schedule.confirmation_tag(_transcript_hash.confirmed); + if (confirmation_tag != group_info.confirmation_tag) { + throw ProtocolError("Confirmation failed to verify"); + } +} + +std::tuple +State::external_join(const bytes& leaf_secret, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const GroupInfo& group_info, + const std::optional& tree, + const MessageOpts& msg_opts, + std::optional remove_prior, + const std::map& psks) +{ + // Create a preliminary state + auto initial_state = State(std::move(sig_priv), group_info, tree); + + // Look up the external public key for the group + const auto maybe_external_pub = + group_info.extensions.find(); + if (!maybe_external_pub) { + throw InvalidParameterError("No external pub in GroupInfo"); + } + + const auto& external_pub = opt::get(maybe_external_pub).external_pub; + + // Insert an ExternalInit proposal + auto opts = CommitOpts{}; + auto [enc, force_init_secret] = + KeyScheduleEpoch::external_init(key_package.cipher_suite, external_pub); + auto ext_init = Proposal{ ExternalInit{ enc } }; + opts.extra_proposals.push_back(ext_init); + + // Evict a prior appearance if required + if (remove_prior) { + auto remove = initial_state.remove_proposal(opt::get(remove_prior)); + opts.extra_proposals.push_back(remove); + } + + // Inject PSKs + for (const auto& [id, secret] : psks) { + initial_state.add_external_psk(id, secret); + auto psk = initial_state.pre_shared_key_proposal(id); + opts.extra_proposals.push_back(psk); + } + + // Use the preliminary state to create a commit and advance to a real state + auto params = ExternalCommitParams{ key_package, force_init_secret }; + auto [commit_msg, welcome, state] = + initial_state.commit(leaf_secret, opts, msg_opts, params); + silence_unused(welcome); + return { commit_msg, state }; +} + +MLSMessage +State::new_member_add(const bytes& group_id, + epoch_t epoch, + const KeyPackage& new_member, + const SignaturePrivateKey& sig_priv) +{ + const auto suite = new_member.cipher_suite; + auto proposal = Proposal{ Add{ new_member } }; + auto content = GroupContent{ group_id, + epoch, + { NewMemberProposalSender{} }, + { /* no authenticated data */ }, + { std::move(proposal) } }; + auto content_auth = AuthenticatedContent::sign( + WireFormat::mls_public_message, std::move(content), suite, sig_priv, {}); + + return PublicMessage::protect(std::move(content_auth), suite, {}, {}); +} + +/// +/// Proposal and commit factories +/// +template +MLSMessage +State::protect_full(Inner&& inner_content, const MessageOpts& msg_opts) +{ + auto content_auth = sign({ MemberSender{ _index } }, + std::forward(inner_content), + msg_opts.authenticated_data, + msg_opts.encrypt); + return protect(std::move(content_auth), msg_opts.padding_size); +} + +template +AuthenticatedContent +State::sign(const Sender& sender, + Inner&& inner_content, + const bytes& authenticated_data, + bool encrypt) const +{ + auto content = GroupContent{ _group_id, + _epoch, + sender, + authenticated_data, + { std::forward(inner_content) } }; + + auto wire_format = (encrypt) ? WireFormat::mls_private_message + : WireFormat::mls_public_message; + + auto content_auth = AuthenticatedContent::sign( + wire_format, std::move(content), _suite, _identity_priv, group_context()); + + return content_auth; +} + +MLSMessage +State::protect(AuthenticatedContent&& content_auth, size_t padding_size) +{ + switch (content_auth.wire_format) { + case WireFormat::mls_public_message: + return PublicMessage::protect(std::move(content_auth), + _suite, + _key_schedule.membership_key, + group_context()); + + case WireFormat::mls_private_message: + return PrivateMessage::protect(std::move(content_auth), + _suite, + _keys, + _key_schedule.sender_data_secret, + padding_size); + + default: + throw InvalidParameterError("Malformed AuthenticatedContent"); + } +} + +ValidatedContent +State::unwrap(const MLSMessage& msg) +{ + if (msg.version != ProtocolVersion::mls10) { + throw InvalidParameterError("Unsupported version"); + } + + const auto unprotect = overloaded{ + [&](const PublicMessage& pt) -> ValidatedContent { + if (pt.get_group_id() != _group_id) { + throw ProtocolError("PublicMessage not for this group"); + } + + if (pt.get_epoch() != _epoch) { + throw ProtocolError("PublicMessage not for this epoch"); + } + + auto maybe_content_auth = + pt.unprotect(_suite, _key_schedule.membership_key, group_context()); + if (!maybe_content_auth) { + throw ProtocolError("Membership tag failed to verify"); + } + return opt::get(maybe_content_auth); + }, + + [&](const PrivateMessage& ct) -> ValidatedContent { + if (ct.get_group_id() != _group_id) { + throw ProtocolError("PrivateMessage not for this group"); + } + + if (ct.get_epoch() != _epoch) { + throw ProtocolError("PrivateMessage not for this epoch"); + } + + auto maybe_content_auth = + ct.unprotect(_suite, _keys, _key_schedule.sender_data_secret); + if (!maybe_content_auth) { + throw ProtocolError("PrivateMessage decryption failure"); + } + return opt::get(maybe_content_auth); + }, + + [](const auto& /* unused */) -> ValidatedContent { + throw ProtocolError("Invalid wire format"); + }, + }; + + auto val_content = var::visit(unprotect, msg.message); + if (!verify(val_content.content_auth)) { + throw InvalidParameterError("Message signature failed to verify"); + } + + return val_content; +} + +Proposal +State::add_proposal(const KeyPackage& key_package) const +{ + // Check that the key package is validly signed + if (!key_package.verify()) { + throw InvalidParameterError("Invalid signature on key package"); + } + + // Check that the group's basic properties are supported + auto now = seconds_since_epoch(); + if (!key_package.leaf_node.verify_expiry(now)) { + throw InvalidParameterError("Expired key package"); + } + + // Check that the group's extensions are supported + if (!key_package.leaf_node.verify_extension_support(_extensions)) { + throw InvalidParameterError( + "Key package does not support group's extensions"); + } + + return { Add{ key_package } }; +} + +Proposal +State::update_proposal(HPKEPrivateKey leaf_priv, const LeafNodeOptions& opts) +{ + if (_cached_update) { + throw ProtocolError("Only one update may be generated per epoch"); + } + + auto leaf = opt::get(_tree.leaf_node(_index)); + + auto new_leaf = leaf.for_update( + _suite, _group_id, _index, leaf_priv.public_key, opts, _identity_priv); + + auto update = Update{ new_leaf }; + _cached_update = CachedUpdate{ std::move(leaf_priv), update }; + return { update }; +} + +Proposal +State::remove_proposal(RosterIndex index) const +{ + return remove_proposal(leaf_for_roster_entry(index)); +} + +Proposal +State::remove_proposal(LeafIndex removed) const +{ + if (!_tree.has_leaf(removed)) { + throw InvalidParameterError("Remove on blank leaf"); + } + + return { Remove{ removed } }; +} + +Proposal +State::group_context_extensions_proposal(ExtensionList exts) const +{ + if (!extensions_supported(exts)) { + throw InvalidParameterError("Unsupported extensions"); + } + + return { GroupContextExtensions{ std::move(exts) } }; +} + +Proposal +State::pre_shared_key_proposal(const bytes& external_psk_id) const +{ + if (_external_psks.count(external_psk_id) == 0) { + throw InvalidParameterError("Unknown PSK"); + } + + auto psk_id = PreSharedKeyID{ + { ExternalPSK{ external_psk_id } }, + random_bytes(_suite.secret_size()), + }; + return { PreSharedKey{ psk_id } }; +} + +Proposal +State::pre_shared_key_proposal(const bytes& group_id, epoch_t epoch) const +{ + if (epoch != _epoch && _resumption_psks.count({ group_id, epoch }) == 0) { + throw InvalidParameterError("Unknown PSK"); + } + + auto psk_id = PreSharedKeyID{ + { ResumptionPSK{ ResumptionPSKUsage::application, group_id, epoch } }, + random_bytes(_suite.secret_size()), + }; + return { PreSharedKey{ psk_id } }; +} + +Proposal +State::reinit_proposal(bytes group_id, + ProtocolVersion version, + CipherSuite cipher_suite, + ExtensionList extensions) +{ + return { ReInit{ + std::move(group_id), version, cipher_suite, std::move(extensions) } }; +} + +MLSMessage +State::add(const KeyPackage& key_package, const MessageOpts& msg_opts) +{ + return protect_full(add_proposal(key_package), msg_opts); +} + +MLSMessage +State::update(HPKEPrivateKey leaf_priv, + const LeafNodeOptions& opts, + const MessageOpts& msg_opts) +{ + return protect_full(update_proposal(std::move(leaf_priv), opts), msg_opts); +} + +MLSMessage +State::remove(RosterIndex index, const MessageOpts& msg_opts) +{ + return protect_full(remove_proposal(index), msg_opts); +} + +MLSMessage +State::remove(LeafIndex removed, const MessageOpts& msg_opts) +{ + return protect_full(remove_proposal(removed), msg_opts); +} + +MLSMessage +State::group_context_extensions(ExtensionList exts, const MessageOpts& msg_opts) +{ + return protect_full(group_context_extensions_proposal(std::move(exts)), + msg_opts); +} + +MLSMessage +State::pre_shared_key(const bytes& external_psk_id, const MessageOpts& msg_opts) +{ + return protect_full(pre_shared_key_proposal(external_psk_id), msg_opts); +} + +MLSMessage +State::pre_shared_key(const bytes& group_id, + epoch_t epoch, + const MessageOpts& msg_opts) +{ + return protect_full(pre_shared_key_proposal(group_id, epoch), msg_opts); +} + +MLSMessage +State::reinit(bytes group_id, + ProtocolVersion version, + CipherSuite cipher_suite, + ExtensionList extensions, + const MessageOpts& msg_opts) +{ + return protect_full( + reinit_proposal( + std::move(group_id), version, cipher_suite, std::move(extensions)), + msg_opts); +} + +std::tuple +State::commit(const bytes& leaf_secret, + const std::optional& opts, + const MessageOpts& msg_opts) +{ + return commit(leaf_secret, opts, msg_opts, NormalCommitParams{}); +} + +std::tuple +State::commit(const bytes& leaf_secret, + const std::optional& opts, + const MessageOpts& msg_opts, + CommitParams params) +{ + // Construct a commit from cached proposals + // TODO(rlb) ignore some proposals: + // * Update after Update + // * Update after Remove + // * Remove after Remove + Commit commit; + auto joiners = std::vector{}; + for (const auto& cached : _pending_proposals) { + if (var::holds_alternative(cached.proposal.content)) { + const auto& add = var::get(cached.proposal.content); + joiners.push_back(add.key_package); + } + + commit.proposals.push_back({ cached.ref }); + } + + // Add the extra proposals to those we had cached + if (opts) { + const auto& extra_proposals = opt::get(opts).extra_proposals; + for (const auto& proposal : extra_proposals) { + if (var::holds_alternative(proposal.content)) { + const auto& add = var::get(proposal.content); + joiners.push_back(add.key_package); + } + + commit.proposals.push_back({ proposal }); + } + } + + // If this is an external commit, insert an ExternalInit proposal + auto external_commit = std::optional{}; + if (var::holds_alternative(params)) { + external_commit = var::get(params); + } + + auto force_init_secret = std::optional{}; + if (external_commit) { + force_init_secret = opt::get(external_commit).force_init_secret; + } + + // Apply proposals + State next = successor(); + + const auto proposals = must_resolve(commit.proposals, _index); + if (!valid(proposals, _index, params)) { + throw ProtocolError("Invalid proposal list"); + } + + const auto [joiner_locations, psks] = next.apply(proposals); + + if (external_commit) { + const auto& leaf_node = + opt::get(external_commit).joiner_key_package.leaf_node; + next._index = next._tree.add_leaf(leaf_node); + } + + // If this is an external commit, indicate it in the sender field + auto sender = Sender{ MemberSender{ _index } }; + if (external_commit) { + sender = Sender{ NewMemberCommitSender{} }; + } + + // KEM new entropy to the group and the new joiners + auto commit_secret = _suite.zero(); + auto path_secrets = + std::vector>(joiner_locations.size()); + auto force_path = opts && opt::get(opts).force_path; + if (force_path || path_required(proposals)) { + auto leaf_node_opts = LeafNodeOptions{}; + if (opts) { + leaf_node_opts = opt::get(opts).leaf_node_opts; + } + + auto new_priv = next._tree.update( + next._index, leaf_secret, next._group_id, _identity_priv, leaf_node_opts); + + auto ctx = tls::marshal(GroupContext{ + next._suite, + next._group_id, + next._epoch + 1, + next._tree.root_hash(), + next._transcript_hash.confirmed, + next._extensions, + }); + auto path = next._tree.encap(new_priv, ctx, joiner_locations); + + next._tree_priv = new_priv; + commit.path = path; + commit_secret = new_priv.update_secret; + + for (size_t i = 0; i < joiner_locations.size(); i++) { + auto [overlap, shared_path_secret, ok] = + new_priv.shared_path_secret(joiner_locations[i]); + silence_unused(overlap); + silence_unused(ok); + + path_secrets[i] = shared_path_secret; + } + } + + // Create the Commit message and advance the transcripts / key schedule + auto commit_content_auth = + sign(sender, commit, msg_opts.authenticated_data, msg_opts.encrypt); + + next._transcript_hash.update_confirmed(commit_content_auth); + next._epoch += 1; + next.update_epoch_secrets(commit_secret, psks, force_init_secret); + + const auto confirmation_tag = + next._key_schedule.confirmation_tag(next._transcript_hash.confirmed); + commit_content_auth.set_confirmation_tag(confirmation_tag); + + next._transcript_hash.update_interim(commit_content_auth); + + auto commit_message = + protect(std::move(commit_content_auth), msg_opts.padding_size); + + // Complete the GroupInfo and form the Welcome + auto group_info = GroupInfo{ + { + next._suite, + next._group_id, + next._epoch, + next._tree.root_hash(), + next._transcript_hash.confirmed, + next._extensions, + }, + { /* No other extensions */ }, + { confirmation_tag }, + }; + if (opts && opt::get(opts).inline_tree) { + group_info.extensions.add(RatchetTreeExtension{ next._tree }); + } + group_info.sign(next._tree, next._index, next._identity_priv); + + auto welcome = + Welcome{ _suite, next._key_schedule.joiner_secret, psks, group_info }; + for (size_t i = 0; i < joiners.size(); i++) { + welcome.encrypt(joiners[i], path_secrets[i]); + } + + return std::make_tuple(commit_message, welcome, next); +} + +/// +/// Message handlers +/// + +GroupContext +State::group_context() const +{ + return GroupContext{ + _suite, + _group_id, + _epoch, + _tree.root_hash(), + _transcript_hash.confirmed, + _extensions, + }; +} + +std::optional +State::handle(const MLSMessage& msg) +{ + return handle(unwrap(msg), std::nullopt, std::nullopt); +} + +std::optional +State::handle(const MLSMessage& msg, std::optional cached_state) +{ + return handle(unwrap(msg), std::move(cached_state), std::nullopt); +} + +std::optional +State::handle(const ValidatedContent& content_auth) +{ + return handle(content_auth, std::nullopt, std::nullopt); +} + +std::optional +State::handle(const ValidatedContent& content_auth, + std::optional cached_state) +{ + return handle(content_auth, std::move(cached_state), std::nullopt); +} + +std::optional +State::handle(const MLSMessage& msg, + std::optional cached_state, + const std::optional& expected_params) +{ + return handle(unwrap(msg), std::move(cached_state), expected_params); +} + +std::optional +State::handle(const ValidatedContent& val_content, + std::optional cached_state, + const std::optional& expected_params) +{ + // Dispatch on content type + const auto& content_auth = val_content.authenticated_content(); + const auto& content = content_auth.content; + switch (content.content_type()) { + // Proposals get queued, do not result in a state transition + case ContentType::proposal: + cache_proposal(content_auth); + return std::nullopt; + + // Commits are handled in the remainder of this method + case ContentType::commit: + break; + + // Any other content type in this method is an error + default: + throw InvalidParameterError("Invalid content type"); + } + + switch (content.sender.sender_type()) { + case SenderType::member: + case SenderType::new_member_commit: + break; + + default: + throw ProtocolError("Invalid commit sender type"); + } + + auto sender = std::optional(); + if (content.sender.sender_type() == SenderType::member) { + sender = var::get(content.sender.sender).sender; + } + + if (sender == _index) { + if (cached_state) { + // Verify that the cached state is a plausible successor to this state + const auto& next = opt::get(cached_state); + if (next._group_id != _group_id || next._epoch != _epoch + 1 || + next._index != _index) { + throw InvalidParameterError("Invalid successor state"); + } + + return next; + } + + throw InvalidParameterError("Handle own commits with caching"); + } + + // Apply the commit + const auto& commit = var::get(content.content); + const auto proposals = must_resolve(commit.proposals, sender); + + const auto params = infer_commit_type(sender, proposals, expected_params); + auto external_commit = var::holds_alternative(params); + + // Check that a path is present when required + if (path_required(proposals) && !commit.path) { + throw ProtocolError("Path required but not present"); + } + + // Apply the proposals + auto next = successor(); + auto [joiner_locations, psks] = next.apply(proposals); + + // If this is an external commit, add the joiner to the tree and note the + // location where they were added. Also, compute the "externally forced" + // value that we will use for the init_secret (as opposed to the init_secret + // from the key schedule). + auto force_init_secret = std::optional{}; + auto sender_location = LeafIndex{ 0 }; + if (!external_commit) { + sender_location = opt::get(sender); + } else { + // Find where the joiner will be added + sender_location = next._tree.allocate_leaf(); + + // Extract the forced init secret + auto kem_output = commit.valid_external(); + if (!kem_output) { + throw ProtocolError("Invalid external commit"); + } + + force_init_secret = + _key_schedule.receive_external_init(opt::get(kem_output)); + } + + // Decapsulate and apply the UpdatePath, if provided + auto commit_secret = _suite.zero(); + if (commit.path) { + const auto& path = opt::get(commit.path); + + if (!valid(path.leaf_node, LeafNodeSource::commit, sender_location)) { + throw ProtocolError("Commit path has invalid leaf node"); + } + + if (!next._tree.parent_hash_valid(sender_location, path)) { + throw ProtocolError("Commit path has invalid parent hash"); + } + + next._tree.merge(sender_location, path); + + auto ctx = tls::marshal(GroupContext{ + next._suite, + next._group_id, + next._epoch + 1, + next._tree.root_hash(), + next._transcript_hash.confirmed, + next._extensions, + }); + next._tree_priv.decap( + sender_location, next._tree, ctx, path, joiner_locations); + + commit_secret = next._tree_priv.update_secret; + } + + // Update the transcripts and advance the key schedule + next._transcript_hash.update(content_auth); + next._epoch += 1; + next.update_epoch_secrets(commit_secret, { psks }, force_init_secret); + + // Verify the confirmation MAC + const auto confirmation_tag = + next._key_schedule.confirmation_tag(next._transcript_hash.confirmed); + if (!content_auth.check_confirmation_tag(confirmation_tag)) { + throw ProtocolError("Confirmation failed to verify"); + } + + return next; +} + +/// +/// Subgroup branching +/// + +// Parameters: +// * ctor inputs +// * leaf_secret +// * commit_opts +std::tuple +State::create_branch(bytes group_id, + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const LeafNode& leaf_node, + ExtensionList extensions, + const std::vector& key_packages, + const bytes& leaf_secret, + const CommitOpts& commit_opts) const +{ + // Create new empty group with the appropriate PSK + auto new_group = + State{ std::move(group_id), _suite, std::move(enc_priv), + std::move(sig_priv), leaf_node, std::move(extensions) }; + + new_group.add_resumption_psk(_group_id, _epoch, _key_schedule.resumption_psk); + + // Create Add proposals + auto proposals = stdx::transform( + key_packages, [&](const auto& kp) { return new_group.add_proposal(kp); }); + + // Create PSK proposal + proposals.push_back({ PreSharedKey{ + { ResumptionPSK{ ResumptionPSKUsage::branch, _group_id, _epoch }, + random_bytes(_suite.secret_size()) } } }); + + // Commit the Add and PSK proposals + auto opts = CommitOpts{ + proposals, + commit_opts.inline_tree, + commit_opts.force_path, + commit_opts.leaf_node_opts, + }; + auto [_commit, welcome, state] = new_group.commit( + leaf_secret, opts, {}, RestartCommitParams{ ResumptionPSKUsage::branch }); + return { state, welcome }; +} + +State +State::handle_branch(const HPKEPrivateKey& init_priv, + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const Welcome& welcome, + const std::optional& tree) const +{ + auto resumption_psks = + std::map{ { { _group_id, _epoch }, + _key_schedule.resumption_psk } }; + auto branch_state = State{ + init_priv, + std::move(enc_priv), + std::move(sig_priv), + key_package, + welcome, + tree, + {}, + resumption_psks, + }; + + if (branch_state._suite != _suite) { + throw ProtocolError("Attempt to branch with a different ciphersuite"); + } + + if (branch_state._epoch != 1) { + throw ProtocolError("Branch not done at the beginning of the group"); + } + + return branch_state; +} + +State::Tombstone::Tombstone(const State& state_in, ReInit reinit_in) + : epoch_authenticator(state_in.epoch_authenticator()) + , reinit(std::move(reinit_in)) + , prior_group_id(state_in._group_id) + , prior_epoch(state_in._epoch) + , resumption_psk(state_in._key_schedule.resumption_psk) +{ +} + +std::tuple +State::Tombstone::create_welcome(HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const LeafNode& leaf_node, + const std::vector& key_packages, + const bytes& leaf_secret, + const CommitOpts& commit_opts) const +{ + // Create new empty group with the appropriate PSK + auto new_group = + State{ reinit.group_id, reinit.cipher_suite, std::move(enc_priv), + std::move(sig_priv), leaf_node, reinit.extensions }; + + new_group.add_resumption_psk(prior_group_id, prior_epoch, resumption_psk); + + // Create Add proposals + auto proposals = stdx::transform( + key_packages, [&](const auto& kp) { return new_group.add_proposal(kp); }); + + // Create PSK proposal + proposals.push_back({ PreSharedKey{ + { ResumptionPSK{ ResumptionPSKUsage::reinit, prior_group_id, prior_epoch }, + random_bytes(reinit.cipher_suite.secret_size()) } } }); + + // Commit the Add and PSK proposals + auto opts = CommitOpts{ + proposals, + commit_opts.inline_tree, + commit_opts.force_path, + commit_opts.leaf_node_opts, + }; + auto [_commit, welcome, state] = new_group.commit( + leaf_secret, opts, {}, RestartCommitParams{ ResumptionPSKUsage::reinit }); + return { state, welcome }; +} + +State +State::Tombstone::handle_welcome( + const HPKEPrivateKey& init_priv, + HPKEPrivateKey enc_priv, + SignaturePrivateKey sig_priv, + const KeyPackage& key_package, + const Welcome& welcome, + const std::optional& tree) const +{ + auto resumption_psks = + std::map{ { { prior_group_id, prior_epoch }, + resumption_psk } }; + auto new_state = State{ + init_priv, + std::move(enc_priv), + std::move(sig_priv), + key_package, + welcome, + tree, + {}, + resumption_psks, + }; + + if (new_state._suite != reinit.cipher_suite) { + throw ProtocolError("Attempt to reinit with the wrong ciphersuite"); + } + + if (new_state._epoch != 1) { + throw ProtocolError("ReInit not done at the beginning of the group"); + } + + return new_state; +} + +std::tuple +State::reinit_commit(const bytes& leaf_secret, + const std::optional& opts, + const MessageOpts& msg_opts) +{ + // Ensure that either the proposal cache or the inline proposals have a ReInit + // proposal, and no others. + auto reinit_proposal = Proposal{}; + if (_pending_proposals.size() == 1) { + reinit_proposal = _pending_proposals.front().proposal; + } else if (opts && opt::get(opts).extra_proposals.size() == 1) { + reinit_proposal = opt::get(opts).extra_proposals.front(); + } else { + throw ProtocolError("Illegal proposals for reinitialization"); + } + + auto reinit = var::get(reinit_proposal.content); + + // Create the commit + const auto [commit_msg, welcome, new_state] = + commit(leaf_secret, opts, msg_opts, ReInitCommitParams{}); + silence_unused(welcome); + + // Create the Tombstone from the terminal state + return { { new_state, reinit }, commit_msg }; +} + +State::Tombstone +State::handle_reinit_commit(const MLSMessage& commit_msg) +{ + // Verify the signature and process the commit + const auto val_content = unwrap(commit_msg); + const auto& content_auth = val_content.authenticated_content(); + if (!verify(content_auth)) { + throw InvalidParameterError("Message signature failed to verify"); + } + + auto new_state = + opt::get(handle(content_auth, std::nullopt, ReInitCommitParams{})); + + // Extract the ReInit and create the Tombstone + const auto& commit = var::get(content_auth.content.content); + const auto proposals = must_resolve(commit.proposals, std::nullopt); + if (!valid_reinit(proposals)) { + throw ProtocolError("Invalid proposals for reinit"); + } + + const auto& reinit_proposal = proposals.front(); + const auto& reinit = var::get(reinit_proposal.proposal.content); + return Tombstone{ new_state, reinit }; +} + +/// +/// Internals +/// + +LeafIndex +State::apply(const Add& add) +{ + return _tree.add_leaf(add.key_package.leaf_node); +} + +void +State::apply(LeafIndex target, const Update& update) +{ + _tree.update_leaf(target, update.leaf_node); +} + +void +State::apply(LeafIndex target, + const Update& update, + const HPKEPrivateKey& leaf_priv) +{ + _tree.update_leaf(target, update.leaf_node); + _tree_priv.set_leaf_priv(leaf_priv); +} + +LeafIndex +State::apply(const Remove& remove) +{ + if (!_tree.has_leaf(remove.removed)) { + throw ProtocolError("Attempt to remove non-member"); + } + + _tree.blank_path(remove.removed); + return remove.removed; +} + +void +State::apply(const GroupContextExtensions& gce) +{ + // TODO(RLB): Update spec to clarify that you MUST verify that the new + // extensions are compatible with all members. + if (!extensions_supported(gce.group_context_extensions)) { + throw ProtocolError("Unsupported extensions in GroupContextExtensions"); + } + + _extensions = gce.group_context_extensions; +} + +bool +State::extensions_supported(const ExtensionList& exts) const +{ + return _tree.all_leaves([&](auto /* i */, const auto& leaf_node) { + return leaf_node.verify_extension_support(exts); + }); +} + +void +State::cache_proposal(AuthenticatedContent content_auth) +{ + auto ref = _suite.ref(content_auth); + if (stdx::any_of(_pending_proposals, + [&](const auto& cached) { return cached.ref == ref; })) { + return; + } + + auto sender_location = std::optional(); + if (content_auth.content.sender.sender_type() == SenderType::member) { + const auto& sender = content_auth.content.sender.sender; + sender_location = var::get(sender).sender; + } + + const auto& proposal = var::get(content_auth.content.content); + + if (content_auth.content.sender.sender_type() == SenderType::external && + !valid_external_proposal_type(proposal.proposal_type())) { + throw ProtocolError("Invalid external proposal"); + } + + if (!valid(sender_location, proposal)) { + throw ProtocolError("Invalid proposal"); + } + + _pending_proposals.push_back({ + _suite.ref(content_auth), + proposal, + sender_location, + }); +} + +std::optional +State::resolve(const ProposalOrRef& id, + std::optional sender_index) const +{ + if (var::holds_alternative(id.content)) { + return CachedProposal{ + {}, + var::get(id.content), + sender_index, + }; + } + + const auto& ref = var::get(id.content); + for (const auto& cached : _pending_proposals) { + if (cached.ref == ref) { + return cached; + } + } + + return std::nullopt; +} + +std::vector +// NOLINTNEXTLINE(readability-convert-member-functions-to-static) +State::must_resolve(const std::vector& ids, + std::optional sender_index) const +{ + auto must_resolve = [&](const auto& id) { + return opt::get(resolve(id, sender_index)); + }; + return stdx::transform(ids, must_resolve); +} + +std::vector +State::resolve(const std::vector& psks) const +{ + return stdx::transform(psks, [&](const auto& psk_id) { + auto get_secret = overloaded{ + [&](const ExternalPSK& ext_psk) { + if (_external_psks.count(ext_psk.psk_id) == 0) { + throw ProtocolError("Unknown external PSK"); + } + + return _external_psks.at(ext_psk.psk_id); + }, + + [&](const ResumptionPSK& res_psk) { + if (res_psk.psk_epoch == _epoch) { + return _key_schedule.resumption_psk; + } + + auto key = std::make_tuple(res_psk.psk_group_id, res_psk.psk_epoch); + if (_resumption_psks.count(key) == 0) { + throw ProtocolError("Unknown Resumption PSK"); + } + + return _resumption_psks.at(key); + }, + }; + + auto secret = var::visit(get_secret, psk_id.content); + return PSKWithSecret{ psk_id, secret }; + }); +} + +std::vector +State::apply(const std::vector& proposals, + Proposal::Type required_type) +{ + auto locations = std::vector{}; + for (const auto& cached : proposals) { + auto proposal_type = cached.proposal.proposal_type(); + if (proposal_type != required_type) { + continue; + } + + switch (proposal_type) { + case ProposalType::add: { + locations.push_back(apply(var::get(cached.proposal.content))); + break; + } + + case ProposalType::update: { + const auto& update = var::get(cached.proposal.content); + + if (!cached.sender) { + throw ProtocolError("Update without target leaf"); + } + + auto target = opt::get(cached.sender); + if (target != _index) { + apply(target, update); + break; + } + + if (!_cached_update) { + throw ProtocolError("Self-update with no cached secret"); + } + + const auto& cached_update = opt::get(_cached_update); + if (update != cached_update.proposal) { + throw ProtocolError("Self-update does not match cached data"); + } + + apply(target, update, cached_update.update_priv); + locations.push_back(target); + break; + } + + case ProposalType::remove: { + const auto& remove = var::get(cached.proposal.content); + locations.push_back(apply(remove)); + break; + } + + case ProposalType::group_context_extensions: { + const auto& gce = + var::get(cached.proposal.content); + apply(gce); + break; + } + + default: + throw ProtocolError("Unsupported proposal type"); + } + } + + // The cached update needs to be reset after applying proposals, so that it is + // in a clean state for the next epoch. + _cached_update.reset(); + + return locations; +} + +std::tuple, std::vector> +State::apply(const std::vector& proposals) +{ + apply(proposals, ProposalType::update); + apply(proposals, ProposalType::remove); + auto joiner_locations = apply(proposals, ProposalType::add); + apply(proposals, ProposalType::group_context_extensions); + + // Extract the PSK proposals and look up the secrets + // TODO(RLB): Factor this out, and also factor the above methods into + // apply_update, apply_remove, etc. + auto psk_ids = std::vector{}; + for (const auto& cached : proposals) { + if (cached.proposal.proposal_type() != ProposalType::psk) { + continue; + } + + const auto& proposal = var::get(cached.proposal.content); + psk_ids.push_back(proposal.psk); + } + auto psks = resolve(psk_ids); + + _tree.truncate(); + _tree_priv.truncate(_tree.size); + _tree.set_hash_all(); + return { joiner_locations, psks }; +} + +/// +/// Message protection +/// + +MLSMessage +State::protect(const bytes& authenticated_data, + const bytes& pt, + size_t padding_size) +{ + auto msg_opts = MessageOpts{ true, authenticated_data, padding_size }; + return protect_full(ApplicationData{ pt }, msg_opts); +} + +std::tuple +State::unprotect(const MLSMessage& ct) +{ + const auto val_content = unwrap(ct); + const auto& content_auth = val_content.authenticated_content(); + + if (!verify(content_auth)) { + throw InvalidParameterError("Message signature failed to verify"); + } + + if (content_auth.content.content_type() != ContentType::application) { + throw ProtocolError("Unprotect of handshake message"); + } + + if (content_auth.wire_format != WireFormat::mls_private_message) { + throw ProtocolError("Application data not sent as PrivateMessage"); + } + + return { + content_auth.content.authenticated_data, + var::get(content_auth.content.content).data, + }; +} + +/// +/// Properties of a proposal list +/// + +bool +State::valid(const LeafNode& leaf_node, + LeafNodeSource required_source, + std::optional index) const +{ + // Verify that the credential in the LeafNode is valid as described in Section + // 5.3.1. + // XXX(RLB) N/A, no credential validation in the library right now + + // Verify the leaf_node_source field: + const auto correct_source = (leaf_node.source() == required_source); + + // Verify that the signature on the LeafNode is valid using signature_key. + auto binding = std::optional{}; + switch (required_source) { + case LeafNodeSource::commit: + case LeafNodeSource::update: + binding = LeafNode::MemberBinding{ _group_id, opt::get(index) }; + break; + + default: + // Nothing to do + break; + } + + const auto signature_valid = leaf_node.verify(_suite, binding); + + // Verify that the LeafNode is compatible with the group's parameters. If the + // GroupContext has a required_capabilities extension, then the required + // extensions, proposals, and credential types MUST be listed in the + // LeafNode's capabilities field. + const auto supports_group_extensions = + leaf_node.verify_extension_support(_extensions); + + // TODO(RLB) Verify the lifetime field + + // Verify that the credential type is supported by all members of the group, + // as specified by the capabilities field of each member's LeafNode, and that + // the capabilities field of this LeafNode indicates support for all the + // credential types currently in use by other members. + // + // Verify that the following fields are unique among the members of the group: + // signature_key + // encryption_key + // + // Note: Uniqueness of signature and encryption keys is assured by the + // tree operations (add/update), so we do not need to verify those here. + const auto mutual_credential_support = + _tree.all_leaves([&](auto /* i */, const auto& leaf) { + return leaf.capabilities.credential_supported(leaf_node.credential) && + leaf_node.capabilities.credential_supported(leaf.credential); + }); + + // Verify that the extensions in the LeafNode are supported by checking that + // the ID for each extension in the extensions field is listed in the + // capabilities.extensions field of the LeafNode. + auto supports_own_extensions = + leaf_node.verify_extension_support(leaf_node.extensions); + + return (signature_valid && supports_group_extensions && correct_source && + mutual_credential_support && supports_own_extensions); +} + +bool +State::valid(const KeyPackage& key_package) const +{ + // Verify that the ciphersuite and protocol version of the KeyPackage match + // those in the GroupContext. + const auto correct_ciphersuite = (key_package.cipher_suite == _suite); + + // Verify that the signature on the KeyPackage is valid using the public key + // in leaf_node.credential. + const auto valid_signature = key_package.verify(); + + // Verify that the leaf_node of the KeyPackage is valid for a KeyPackage + // according to Section 7.3. + const auto leaf_node_valid = + valid(key_package.leaf_node, LeafNodeSource::key_package, std::nullopt); + + // Verify that the value of leaf_node.encryption_key is different from the + // value of the init_key field. + const auto distinct_keys = + (key_package.init_key != key_package.leaf_node.encryption_key); + + return (correct_ciphersuite && valid_signature && leaf_node_valid && + distinct_keys); +} + +bool +State::valid(const Add& add) const +{ + return valid(add.key_package); +} + +bool +State::valid(LeafIndex sender, const Update& update) const +{ + const auto maybe_leaf = _tree.leaf_node(sender); + if (!maybe_leaf) { + return false; + } + + return valid(update.leaf_node, LeafNodeSource::update, sender); +} + +bool +State::valid(const Remove& remove) const +{ + // We mark self-removes invalid here even though a resync Commit will + // sometimes cause them. This is OK because this method is only called from + // the normal proposal list validation method, not the external commit one. + auto in_tree = remove.removed < _tree.size && _tree.has_leaf(remove.removed); + auto not_me = remove.removed != _index; + return in_tree && not_me; +} + +bool +State::valid(const PreSharedKey& psk) const +{ + // External PSKs are allowed if we have the corresponding secret + if (var::holds_alternative(psk.psk.content)) { + const auto& ext_psk = var::get(psk.psk.content); + return _external_psks.count(ext_psk.psk_id) > 0; + } + + // Resumption PSKs are allowed only with usage 'application', and only if we + // have the corresponding secret. + if (var::holds_alternative(psk.psk.content)) { + const auto& res_psk = var::get(psk.psk.content); + if (res_psk.usage != ResumptionPSKUsage::application) { + return false; + } + + const auto key = std::make_tuple(res_psk.psk_group_id, res_psk.psk_epoch); + return res_psk.psk_epoch == _epoch || _resumption_psks.count(key) > 0; + } + + return false; +} + +bool +State::valid(const ReInit& reinit) +{ + // Check that the version and CipherSuite are ones we support + auto supported_version = (reinit.version == ProtocolVersion::mls10); + auto supported_suite = + stdx::contains(all_supported_suites, reinit.cipher_suite.cipher_suite()); + + return supported_version && supported_suite; +} + +bool +State::valid(const ExternalInit& external_init) const +{ + return external_init.kem_output.size() == _suite.hpke().kem.enc_size; +} + +bool +State::valid(const GroupContextExtensions& gce) const +{ + return extensions_supported(gce.group_context_extensions); +} + +bool +State::valid(std::optional sender, const Proposal& proposal) const +{ + const auto specifically_valid = overloaded{ + [&](const Update& update) { return valid(opt::get(sender), update); }, + [&](const auto& proposal) { return valid(proposal); }, + }; + return var::visit(specifically_valid, proposal.content); +} + +bool +State::valid(const std::vector& proposals, + LeafIndex commit_sender, + const CommitParams& params) const +{ + auto specifically = overloaded{ + [&](const NormalCommitParams& /* unused */) { + return valid_normal(proposals, commit_sender); + }, + [&](const ExternalCommitParams& /* unused */) { + return valid_external(proposals); + }, + [&](const RestartCommitParams& params) { + return valid_restart(proposals, params.allowed_usage); + }, + [&](const ReInitCommitParams& /* unused */) { + return valid_reinit(proposals); + }, + }; + + return var::visit(specifically, params); +} + +bool +// NB(RLB): clang-tidy thinks this can be static, but it can't. +// NOLINTNEXTLINE(readability-convert-member-functions-to-static) +State::valid_normal(const std::vector& proposals, + LeafIndex commit_sender) const +{ + // It contains an individual proposal that is invalid as specified in Section + // 12.1. + const auto has_invalid_proposal = + stdx::any_of(proposals, [&](const auto& cached) { + return !valid(cached.sender, cached.proposal); + }); + + // It contains an Update proposal generated by the committer. + const auto has_self_update = stdx::any_of(proposals, [&](const auto& cached) { + return cached.proposal.proposal_type() == ProposalType::update && + cached.sender == commit_sender; + }); + + // It contains a Remove proposal that removes the committer. + const auto has_self_remove = stdx::any_of(proposals, [&](const auto& cached) { + return cached.proposal.proposal_type() == ProposalType::remove && + var::get(cached.proposal.content).removed == commit_sender; + }); + + // It contains multiple Update and/or Remove proposals that apply to the same + // leaf. If the committer has received multiple such proposals they SHOULD + // prefer any Remove received, or the most recent Update if there are no + // Removes. + auto updated_or_removed = std::set{}; + const auto has_dup_update_remove = + stdx::any_of(proposals, [&](const auto& cached) { + auto index = LeafIndex{ 0 }; + switch (cached.proposal.proposal_type()) { + case ProposalType::update: + index = opt::get(cached.sender); + break; + + case ProposalType::remove: + index = var::get(cached.proposal.content).removed; + break; + + default: + return false; + } + + if (stdx::contains(updated_or_removed, index)) { + return true; + } + + updated_or_removed.insert(index); + return false; + }); + + // It contains multiple Add proposals that contain KeyPackages that represent + // the same client according to the application (for example, identical + // signature keys). + auto signature_keys = std::vector{}; + const auto has_dup_signature_key = + stdx::any_of(proposals, [&](const auto& cached) { + if (cached.proposal.proposal_type() != ProposalType::add) { + return false; + } + + auto key_package = var::get(cached.proposal.content).key_package; + auto signature_key = key_package.leaf_node.signature_key; + if (stdx::contains(signature_keys, signature_key)) { + return true; + } + + signature_keys.push_back(signature_key); + return false; + }); + + // It contains an Add proposal with a KeyPackage that represents a client + // already in the group according to the application, unless there is a Remove + // proposal in the list removing the matching client from the group. + // TODO(RLB) + + // It contains multiple PreSharedKey proposals that reference the same + // PreSharedKeyID. + auto psk_ids = std::vector{}; + const auto has_dup_psk_id = stdx::any_of(proposals, [&](const auto& cached) { + if (cached.proposal.proposal_type() != ProposalType::psk) { + return false; + } + + auto psk_id = var::get(cached.proposal.content).psk; + if (stdx::contains(psk_ids, psk_id)) { + return true; + } + + psk_ids.push_back(psk_id); + return false; + }); + + // It contains multiple GroupContextExtensions proposals. + const auto gce_count = stdx::count_if(proposals, [](const auto& cached) { + return cached.proposal.proposal_type() == + ProposalType::group_context_extensions; + }); + const auto has_multiple_gce = (gce_count > 1); + + // It contains a ReInit proposal together with any other proposal. If the + // committer has received other proposals during the epoch, they SHOULD prefer + // them over the ReInit proposal, allowing the ReInit to be resent and applied + // in a subsequent epoch. + const auto has_reinit = stdx::any_of(proposals, [](const auto& cached) { + return cached.proposal.proposal_type() == ProposalType::reinit; + }); + + // It contains an ExternalInit proposal. + const auto has_external_init = + stdx::any_of(proposals, [](const auto& cached) { + return cached.proposal.proposal_type() == ProposalType::external_init; + }); + + // It contains a proposal with a non-default proposal type that is not + // supported by some members of the group that will process the Commit (i.e., + // members being added or removed by the Commit do not need to support the + // proposal type). + // XXX(RLB): N/A, no non-default proposal types + + // After processing the commit the ratchet tree is invalid, in particular, if + // it contains any leaf node that is invalid according to Section 7.3. + // + // NB(RLB): Leaf nodes are already checked in the individual proposal check at + // the top. So the focus here is key uniqueness. We check this by checking + // uniqueness of encryption keys across the Adds and Updates in this list of + // proposals. The keys have already been checked to be distinct from any keys + // already in the tree. + auto enc_keys = std::vector{}; + const auto has_dup_enc_key = stdx::any_of(proposals, [&](const auto& cached) { + const auto get_enc_key = + overloaded{ [](const Add& add) -> std::optional { + return add.key_package.leaf_node.encryption_key; + }, + [](const Update& update) -> std::optional { + return update.leaf_node.encryption_key; + }, + + [](const auto& /* default */) + -> std::optional { return std::nullopt; } }; + auto maybe_enc_key = var::visit(get_enc_key, cached.proposal.content); + if (!maybe_enc_key) { + return false; + } + + const auto& enc_key = opt::get(maybe_enc_key); + if (stdx::contains(enc_keys, enc_key)) { + return true; + } + + enc_keys.push_back(enc_key); + return false; + }); + + return !(has_invalid_proposal || has_self_update || has_self_remove || + has_dup_update_remove || has_dup_signature_key || has_dup_psk_id || + has_multiple_gce || has_reinit || has_external_init || + has_dup_enc_key); +} + +bool +State::valid_reinit(const std::vector& proposals) +{ + // Check that the list contains a ReInit proposal + const auto has_reinit = stdx::any_of(proposals, [](const auto& cached) { + return cached.proposal.proposal_type() == ProposalType::reinit; + }); + + // Check whether the list contains any disallowed proposals + const auto has_disallowed = stdx::any_of(proposals, [](const auto& cached) { + return cached.proposal.proposal_type() != ProposalType::reinit; + }); + + return has_reinit && !has_disallowed; +} + +bool +State::valid_restart(const std::vector& proposals, + ResumptionPSKUsage allowed_usage) +{ + // Check that the list has exactly one resumption PSK proposal with the + // allowed usage and any other PSKs are external + auto found_allowed = false; + const auto acceptable_psks = stdx::all_of(proposals, [&](const auto& cached) { + if (cached.proposal.proposal_type() != ProposalType::psk) { + return true; + } + + const auto& psk = var::get(cached.proposal.content); + if (var::holds_alternative(psk.psk.content)) { + return true; + } + + const auto& res_psk = var::get(psk.psk.content); + const auto allowed = res_psk.usage == allowed_usage; + if (found_allowed && allowed) { + return false; + } + + found_allowed = found_allowed || allowed; + return true; + }); + + return acceptable_psks && found_allowed; +} + +bool +State::valid_external_proposal_type(const Proposal::Type proposal_type) +{ + switch (proposal_type) { + case ProposalType::add: + case ProposalType::remove: + case ProposalType::psk: + case ProposalType::reinit: + case ProposalType::group_context_extensions: + return true; + + default: + return false; + } +} + +bool +// NOLINTNEXTLINE(readability-convert-member-functions-to-static) +State::valid_external(const std::vector& proposals) const +{ + // Exactly one ExternalInit + auto ext_init_count = stdx::count_if(proposals, [](const auto& cached) { + return cached.proposal.proposal_type() == ProposalType::external_init; + }); + auto one_ext_init = (ext_init_count == 1); + + // At most one Remove proposal, with which the joiner removes an old version + // of themselves. If a Remove proposal is present, then the LeafNode in the + // path field of the external commit MUST meet the same criteria as would the + // LeafNode in an Update for the removed leaf (see Section 12.1.2). In + // particular, the credential in the LeafNode MUST present a set of + // identifiers that is acceptable to the application for the removed + // participant. + // TODO(RLB) Verify that Remove is properly formed + auto remove_count = stdx::count_if(proposals, [](const auto& cached) { + return cached.proposal.proposal_type() == ProposalType::remove; + }); + auto no_dup_remove = (remove_count <= 1); + + // Zero or more PreSharedKey proposals. + // No other proposals. + auto no_disallowed = stdx::all_of(proposals, [&](const auto& cached) { + switch (cached.proposal.proposal_type()) { + case ProposalType::external_init: + case ProposalType::remove: + return true; + + case ProposalType::psk: + return valid(var::get(cached.proposal.content)); + + default: + return false; + } + }); + + return one_ext_init && no_dup_remove && no_disallowed; +} + +State::CommitParams +State::infer_commit_type( + const std::optional& sender, + const std::vector& proposals, + const std::optional& expected_params) const +{ + // If an expected type was provided, validate against it + if (expected_params) { + const auto& expected = opt::get(expected_params); + + auto specifically = overloaded{ + [&](const NormalCommitParams& /* unused */) { + return sender && valid_normal(proposals, opt::get(sender)); + }, + [&](const ExternalCommitParams& /* unused */) { + return !sender && valid_external(proposals); + }, + [&](const RestartCommitParams& params) { + return sender && valid_restart(proposals, params.allowed_usage); + }, + [&](const ReInitCommitParams& /* unused */) { + return sender && valid_reinit(proposals); + }, + }; + + if (!var::visit(specifically, expected)) { + throw ProtocolError("Invalid proposal list"); + } + + return expected; + } + + // Otherwise, check to see if this is a valid external or normal commit + if (!sender && valid_external(proposals)) { + return ExternalCommitParams{}; + } + + if (sender && valid_normal(proposals, opt::get(sender))) { + return NormalCommitParams{}; + } + + throw ProtocolError("Invalid proposal list"); +} + +bool +State::path_required(const std::vector& proposals) +{ + static const auto path_required_types = std::set{ + ProposalType::update, + ProposalType::remove, + ProposalType::external_init, + ProposalType::group_context_extensions, + }; + + if (proposals.empty()) { + return true; + } + + return stdx::any_of(proposals, [](const auto& cp) { + return path_required_types.count(cp.proposal.proposal_type()) != 0; + }); +} + +/// +/// Inner logic and convenience functions +/// + +bool +operator==(const State& lhs, const State& rhs) +{ + auto suite = (lhs._suite == rhs._suite); + auto group_id = (lhs._group_id == rhs._group_id); + auto epoch = (lhs._epoch == rhs._epoch); + auto tree = (lhs._tree == rhs._tree); + auto transcript_hash = (lhs._transcript_hash == rhs._transcript_hash); + auto key_schedule = (lhs._key_schedule == rhs._key_schedule); + auto extensions = (lhs._extensions == rhs._extensions); + + return suite && group_id && epoch && tree && transcript_hash && + key_schedule && extensions; +} + +bool +operator!=(const State& lhs, const State& rhs) +{ + return !(lhs == rhs); +} + +void +State::update_epoch_secrets(const bytes& commit_secret, + const std::vector& psks, + const std::optional& force_init_secret) +{ + auto ctx = tls::marshal(GroupContext{ + _suite, + _group_id, + _epoch, + _tree.root_hash(), + _transcript_hash.confirmed, + _extensions, + }); + _key_schedule = + _key_schedule.next(commit_secret, psks, force_init_secret, ctx); + _keys = _key_schedule.encryption_keys(_tree.size); +} + +/// +/// Message encryption and decryption +/// +bool +State::verify_internal(const AuthenticatedContent& content_auth) const +{ + const auto& sender = + var::get(content_auth.content.sender.sender).sender; + auto maybe_leaf = _tree.leaf_node(sender); + if (!maybe_leaf) { + throw InvalidParameterError("Signature from blank node"); + } + + const auto& pub = opt::get(maybe_leaf).signature_key; + return content_auth.verify(_suite, pub, group_context()); +} + +bool +State::verify_external(const AuthenticatedContent& content_auth) const +{ + const auto& ext_sender = + var::get(content_auth.content.sender.sender); + const auto senders_ext = _extensions.find(); + const auto& senders = opt::get(senders_ext).senders; + const auto& pub = senders.at(ext_sender.sender_index).signature_key; + return content_auth.verify(_suite, pub, group_context()); +} + +bool +State::verify_new_member_proposal( + const AuthenticatedContent& content_auth) const +{ + const auto& proposal = var::get(content_auth.content.content); + const auto& add = var::get(proposal.content); + const auto& pub = add.key_package.leaf_node.signature_key; + return content_auth.verify(_suite, pub, group_context()); +} + +bool +State::verify_new_member_commit(const AuthenticatedContent& content_auth) const +{ + const auto& commit = var::get(content_auth.content.content); + const auto& path = opt::get(commit.path); + const auto& pub = path.leaf_node.signature_key; + return content_auth.verify(_suite, pub, group_context()); +} + +bool +State::verify(const AuthenticatedContent& content_auth) const +{ + switch (content_auth.content.sender.sender_type()) { + case SenderType::member: + return verify_internal(content_auth); + + case SenderType::external: + return verify_external(content_auth); + + case SenderType::new_member_proposal: + return verify_new_member_proposal(content_auth); + + case SenderType::new_member_commit: + return verify_new_member_commit(content_auth); + + default: + throw ProtocolError("Invalid sender type"); + } +} + +void +State::add_resumption_psk(const bytes& group_id, epoch_t epoch, bytes secret) +{ + _resumption_psks.insert_or_assign({ group_id, epoch }, std::move(secret)); +} + +void +State::remove_resumption_psk(const bytes& group_id, epoch_t epoch) +{ + _resumption_psks.erase({ group_id, epoch }); +} + +void +State::add_external_psk(const bytes& id, const bytes& secret) +{ + _external_psks.insert_or_assign(id, secret); +} + +void +State::remove_external_psk(const bytes& id) +{ + _external_psks.erase(id); +} + +bytes +State::do_export(const std::string& label, + const bytes& context, + size_t size) const +{ + return _key_schedule.do_export(label, context, size); +} + +GroupInfo +State::group_info(bool inline_tree) const +{ + auto group_info = GroupInfo{ + { + _suite, + _group_id, + _epoch, + _tree.root_hash(), + _transcript_hash.confirmed, + _extensions, + }, + { /* No other extensions */ }, + _key_schedule.confirmation_tag(_transcript_hash.confirmed), + }; + + group_info.extensions.add( + ExternalPubExtension{ _key_schedule.external_priv.public_key }); + + if (inline_tree) { + group_info.extensions.add(RatchetTreeExtension{ _tree }); + } + + group_info.sign(_tree, _index, _identity_priv); + return group_info; +} + +std::vector +State::roster() const +{ + auto leaves = std::vector{}; + leaves.reserve(_tree.size.val); + + _tree.all_leaves([&](auto /* i */, auto leaf) { + leaves.push_back(leaf); + return true; + }); + + return leaves; +} + +bytes +State::epoch_authenticator() const +{ + return _key_schedule.epoch_authenticator; +} + +LeafIndex +State::leaf_for_roster_entry(RosterIndex index) const +{ + auto visited = RosterIndex{ 0 }; + auto found = std::optional{}; + _tree.all_leaves([&](auto i, const auto& /* leaf_node */) { + if (visited == index) { + found = i; + return false; + } + + visited.val += 1; + return true; + }); + + return opt::get(found); +} + +State +State::successor() const +{ + // Copy everything, then clear things that shouldn't be copied + auto next = *this; + next._pending_proposals.clear(); + + // Copy forward a resumption PSK + next.add_resumption_psk(_group_id, _epoch, _key_schedule.resumption_psk); + + return next; +} + +} // namespace mlspp diff --git a/mlspp/src/tree_math.cpp b/mlspp/src/tree_math.cpp new file mode 100755 index 0000000000..eef9423889 --- /dev/null +++ b/mlspp/src/tree_math.cpp @@ -0,0 +1,223 @@ +#include "mls/tree_math.h" +#include "mls/common.h" + +#include + +static const uint32_t one = 0x01; + +static uint32_t +log2(uint32_t x) +{ + if (x == 0) { + return 0; + } + + uint32_t k = 0; + while ((x >> k) > 0) { + k += 1; + } + return k - 1; +} + +namespace mlspp { + +LeafCount::LeafCount(const NodeCount w) +{ + if (w.val == 0) { + val = 0; + return; + } + + if ((w.val & one) == 0) { + throw InvalidParameterError("Only odd node counts describe trees"); + } + + val = (w.val >> one) + 1; +} + +LeafCount +LeafCount::full(const LeafCount n) +{ + auto w = uint32_t(1); + while (w < n.val) { + w <<= 1U; + } + return LeafCount{ w }; +} + +NodeCount::NodeCount(const LeafCount n) + : UInt32(2 * (n.val - 1) + 1) +{ +} + +LeafIndex::LeafIndex(NodeIndex x) + : UInt32(0) +{ + if (x.val % 2 == 1) { + throw InvalidParameterError("Only even node indices describe leaves"); + } + + val = x.val >> 1; // NOLINT(hicpp-signed-bitwise) +} + +NodeIndex +LeafIndex::ancestor(LeafIndex other) const +{ + auto ln = NodeIndex(*this); + auto rn = NodeIndex(other); + if (ln == rn) { + return ln; + } + + uint8_t k = 0; + while (ln != rn) { + ln.val = ln.val >> 1U; + rn.val = rn.val >> 1U; + k += 1; + } + + const uint32_t prefix = ln.val << k; + const uint32_t stop = (1U << uint8_t(k - 1)); + return NodeIndex{ prefix + (stop - 1) }; +} + +NodeIndex::NodeIndex(LeafIndex x) + : UInt32(2 * x.val) +{ +} + +NodeIndex +NodeIndex::root(LeafCount n) +{ + if (n.val == 0) { + throw std::runtime_error("Root for zero-size tree is undefined"); + } + + auto w = NodeCount(n); + return NodeIndex{ (one << log2(w.val)) - 1 }; +} + +bool +NodeIndex::is_leaf() const +{ + return val % 2 == 0; +} + +bool +NodeIndex::is_below(NodeIndex other) const +{ + auto lx = level(); + auto ly = other.level(); + return lx <= ly && (val >> (ly + 1) == other.val >> (ly + 1)); +} + +NodeIndex +NodeIndex::left() const +{ + if (is_leaf()) { + return *this; + } + + // The clang analyzer doesn't realize that is_leaf() assures that level >= 1 + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + return NodeIndex{ val ^ (one << (level() - 1)) }; +} + +NodeIndex +NodeIndex::right() const +{ + if (is_leaf()) { + return *this; + } + + return NodeIndex{ val ^ (uint32_t(0x03) << (level() - 1)) }; +} + +NodeIndex +NodeIndex::parent() const +{ + auto k = level(); + return NodeIndex{ (val | (one << k)) & ~(one << (k + 1)) }; +} + +NodeIndex +NodeIndex::sibling() const +{ + return sibling(parent()); +} + +NodeIndex +NodeIndex::sibling(NodeIndex ancestor) const +{ + if (!is_below(ancestor)) { + throw InvalidParameterError("Node is not below claimed ancestor"); + } + + auto l = ancestor.left(); + auto r = ancestor.right(); + + if (is_below(l)) { + return r; + } + + return l; +} + +std::vector +NodeIndex::dirpath(LeafCount n) +{ + if (val >= NodeCount(n).val) { + throw InvalidParameterError("Request for dirpath outside of tree"); + } + + auto d = std::vector{}; + + auto r = root(n); + if (*this == r) { + return d; + } + + auto p = parent(); + while (p.val != r.val) { + d.push_back(p); + p = p.parent(); + } + + // Include the root except in a one-member tree + if (val != r.val) { + d.push_back(p); + } + + return d; +} + +std::vector +NodeIndex::copath(LeafCount n) +{ + auto d = dirpath(n); + if (d.empty()) { + return {}; + } + + // Prepend leaf; omit root + d.insert(d.begin(), *this); + d.pop_back(); + + return stdx::transform(d, [](auto x) { return x.sibling(); }); +} + +uint32_t +NodeIndex::level() const +{ + if ((val & one) == 0) { + return 0; + } + + uint32_t k = 0; + while (((val >> k) & one) == 1) { + k += 1; + } + return k; +} + +} // namespace mlspp diff --git a/mlspp/src/treekem.cpp b/mlspp/src/treekem.cpp new file mode 100755 index 0000000000..fe3aebc2a2 --- /dev/null +++ b/mlspp/src/treekem.cpp @@ -0,0 +1,1127 @@ +#include + +#if ENABLE_TREE_DUMP +#include +#endif + +namespace mlspp { + +// Utility method used for removing leaves from a resolution +static void +remove_leaves(std::vector& res, const std::vector& except) +{ + for (const auto& leaf : except) { + auto it = std::find(res.begin(), res.end(), NodeIndex(leaf)); + if (it == res.end()) { + continue; + } + + res.erase(it); + } +} + +/// +/// Node +/// + +const HPKEPublicKey& +Node::public_key() const +{ + const auto get_key = overloaded{ + [](const LeafNode& n) -> const HPKEPublicKey& { return n.encryption_key; }, + [](const ParentNode& n) -> const HPKEPublicKey& { return n.public_key; }, + }; + return var::visit(get_key, node); +} + +std::optional +Node::parent_hash() const +{ + const auto get_leaf_ph = overloaded{ + [](const ParentHash& ph) -> std::optional { return ph.parent_hash; }, + [](const auto& /* other */) -> std::optional { + return std::nullopt; + }, + }; + + const auto get_ph = overloaded{ + [&](const LeafNode& node) -> std::optional { + return var::visit(get_leaf_ph, node.content); + }, + [](const ParentNode& node) -> std::optional { + return node.parent_hash; + }, + }; + + return var::visit(get_ph, node); +} + +/// +/// TreeKEMPrivateKey +/// + +TreeKEMPrivateKey +TreeKEMPrivateKey::solo(CipherSuite suite, + LeafIndex index, + HPKEPrivateKey leaf_priv) +{ + auto priv = TreeKEMPrivateKey{ suite, index, {}, {}, {} }; + priv.private_key_cache.insert({ NodeIndex(index), std::move(leaf_priv) }); + return priv; +} + +TreeKEMPrivateKey +TreeKEMPrivateKey::create(const TreeKEMPublicKey& pub, + LeafIndex from, + const bytes& leaf_secret) +{ + auto priv = TreeKEMPrivateKey{ pub.suite, from, {}, {}, {} }; + priv.implant(pub, NodeIndex(from), leaf_secret); + return priv; +} + +TreeKEMPrivateKey +TreeKEMPrivateKey::joiner(const TreeKEMPublicKey& pub, + LeafIndex index, + HPKEPrivateKey leaf_priv, + NodeIndex intersect, + const std::optional& path_secret) +{ + auto priv = TreeKEMPrivateKey{ pub.suite, index, {}, {}, {} }; + priv.private_key_cache.insert({ NodeIndex(index), std::move(leaf_priv) }); + if (path_secret) { + priv.implant(pub, intersect, opt::get(path_secret)); + } + return priv; +} + +void +TreeKEMPrivateKey::implant(const TreeKEMPublicKey& pub, + NodeIndex start, + const bytes& path_secret) +{ + const auto fdp = pub.filtered_direct_path(start); + auto secret = path_secret; + + path_secrets.insert_or_assign(start, secret); + private_key_cache.erase(start); + + for (const auto& [n, _res] : fdp) { + secret = pub.suite.derive_secret(secret, "path"); + path_secrets.insert_or_assign(n, secret); + private_key_cache.erase(n); + } + + update_secret = pub.suite.derive_secret(secret, "path"); +} + +std::optional +TreeKEMPrivateKey::private_key(NodeIndex n) const +{ + auto pki = private_key_cache.find(n); + if (pki != private_key_cache.end()) { + return pki->second; + } + + auto i = path_secrets.find(n); + if (i == path_secrets.end()) { + return std::nullopt; + } + + auto node_secret = suite.derive_secret(i->second, "node"); + return HPKEPrivateKey::derive(suite, node_secret); +} + +bool +TreeKEMPrivateKey::have_private_key(NodeIndex n) const +{ + auto path_secret = path_secrets.find(n) != path_secrets.end(); + auto cached_priv = private_key_cache.find(n) != private_key_cache.end(); + return path_secret || cached_priv; +} + +std::optional +TreeKEMPrivateKey::private_key(NodeIndex n) +{ + auto priv = static_cast(*this).private_key(n); + if (priv) { + private_key_cache.insert_or_assign(n, opt::get(priv)); + } + return priv; +} + +void +TreeKEMPrivateKey::set_leaf_priv(HPKEPrivateKey priv) +{ + auto n = NodeIndex(index); + path_secrets.erase(n); + private_key_cache.insert_or_assign(n, std::move(priv)); +} + +std::tuple +TreeKEMPrivateKey::shared_path_secret(LeafIndex to) const +{ + auto n = index.ancestor(to); + auto i = path_secrets.find(n); + if (i == path_secrets.end()) { + return std::make_tuple(n, bytes{}, false); + } + + return std::make_tuple(n, i->second, true); +} + +#if ENABLE_TREE_DUMP +// XXX(RLB) This should ultimately be deleted, but it is handy for interop +// debugging, so I'm keeping it around for now. If re-enabled, you'll also need +// to add the appropriate declarations to treekem.h and include + +void +TreeKEMPrivateKey::dump() const +{ + for (const auto& [node, _] : path_secrets) { + private_key(node); + } + + std::cout << "Tree (priv):" << std::endl; + std::cout << " Index: " << NodeIndex(index).val << std::endl; + + std::cout << " Secrets: " << std::endl; + for (const auto& [n, path_secret] : path_secrets) { + auto node_secret = suite.derive_secret(path_secret, "node"); + auto sk = HPKEPrivateKey::derive(suite, node_secret); + + auto psm = to_hex(path_secret).substr(0, 8); + auto pkm = to_hex(sk.public_key.data).substr(0, 8); + std::cout << " " << n.val << " => " << psm << " => " << pkm << std::endl; + } + + std::cout << " Cached key pairs: " << std::endl; + for (const auto& [n, sk] : private_key_cache) { + auto pkm = to_hex(sk.public_key.data).substr(0, 8); + std::cout << " " << n.val << " => " << pkm << std::endl; + } +} + +void +TreeKEMPublicKey::dump() const +{ + std::cout << "Tree:" << std::endl; + auto width = NodeCount(size); + for (auto i = NodeIndex{ 0 }; i.val < width.val; i.val++) { + printf(" %03d : ", i.val); // NOLINT + if (!node_at(i).blank()) { + auto pkRm = to_hex(opt::get(node_at(i).node).public_key().data); + std::cout << pkRm.substr(0, 8); + } else { + std::cout << " "; + } + + std::cout << " | "; + for (uint32_t j = 0; j < i.level(); j++) { + std::cout << " "; + } + + if (!node_at(i).blank()) { + std::cout << "X"; + + if (!i.is_leaf()) { + auto parent = node_at(i).parent_node(); + std::cout << " ["; + for (const auto u : parent.unmerged_leaves) { + std::cout << u.val << ", "; + } + std::cout << "]"; + } + + } else { + std::cout << "_"; + } + + std::cout << std::endl; + } +} +#endif + +void +TreeKEMPrivateKey::decap(LeafIndex from, + const TreeKEMPublicKey& pub, + const bytes& context, + const UpdatePath& path, + const std::vector& except) +{ + // Identify which node in the path secret we will be decrypting + auto ni = NodeIndex(index); + auto dp = pub.filtered_direct_path(NodeIndex(from)); + if (dp.size() != path.nodes.size()) { + throw ProtocolError("Malformed direct path"); + } + + size_t dpi = 0; + auto overlap_node = NodeIndex{}; + auto res = std::vector{}; + for (dpi = 0; dpi < dp.size(); dpi++) { + const auto [dpn, dpres] = dp[dpi]; + if (ni.is_below(dpn)) { + overlap_node = dpn; + res = dpres; + break; + } + } + + if (dpi == dp.size()) { + throw ProtocolError("No overlap in path"); + } + + // Identify which node in the resolution of the copath we will use to decrypt + remove_leaves(res, except); + if (res.size() != path.nodes[dpi].encrypted_path_secret.size()) { + throw ProtocolError("Malformed direct path node"); + } + + size_t resi = 0; + const NodeIndex res_overlap_node; + for (resi = 0; resi < res.size(); resi++) { + if (have_private_key(res[resi])) { + break; + } + } + + if (resi == res.size()) { + throw ProtocolError("No private key to decrypt path secret"); + } + + // Decrypt and implant + auto priv = opt::get(private_key(res[resi])); + auto path_secret = priv.decrypt(suite, + encrypt_label::update_path_node, + context, + path.nodes[dpi].encrypted_path_secret[resi]); + implant(pub, overlap_node, path_secret); + + // Check that the resulting state is consistent with the public key + if (!consistent(pub)) { + throw ProtocolError("TreeKEMPublicKey inconsistent with TreeKEMPrivateKey"); + } +} + +void +TreeKEMPrivateKey::truncate(LeafCount size) +{ + auto ni = NodeIndex(LeafIndex{ size.val - 1 }); + auto to_remove = std::vector{}; + for (const auto& entry : path_secrets) { + if (entry.first.val > ni.val) { + to_remove.push_back(entry.first); + } + } + + for (auto n : to_remove) { + path_secrets.erase(n); + private_key_cache.erase(n); + } +} + +bool +TreeKEMPrivateKey::consistent(const TreeKEMPrivateKey& other) const +{ + if (suite != other.suite) { + return false; + } + + if (update_secret != other.update_secret) { + return false; + } + + const auto match_if_present = [&](const auto& entry) { + auto other_entry = other.path_secrets.find(entry.first); + if (other_entry == other.path_secrets.end()) { + return true; + } + + return entry.second == other_entry->second; + }; + return stdx::all_of(path_secrets, match_if_present); +} + +bool +TreeKEMPrivateKey::consistent(const TreeKEMPublicKey& other) const +{ + if (suite != other.suite) { + return false; + } + + for (const auto& [node, _] : path_secrets) { + private_key(node); + } + + return stdx::all_of(private_key_cache, [other](const auto& entry) { + const auto& [node, priv] = entry; + const auto& opt_node = other.node_at(node).node; + if (!opt_node) { + // It's OK for a TreeKEMPrivateKey to have private keys + // for nodes that are blank in the TreeKEMPublicKey. + // This will happen traniently during Commit + // processing, since proposals will be applied in the + // public tree and not in the private tree. + return true; + } + + const auto& pub = opt::get(opt_node).public_key(); + return priv.public_key == pub; + }); +} + +/// +/// TreeKEMPublicKey +/// + +TreeKEMPublicKey::TreeKEMPublicKey(CipherSuite suite_in) + : suite(suite_in) +{ +} + +LeafIndex +TreeKEMPublicKey::allocate_leaf() +{ + // Find the leftmost blank leaf node + auto index = LeafIndex(0); + while (index.val < size.val && !node_at(index).blank()) { + index.val++; + } + + // Extend the tree if necessary + if (index.val >= size.val) { + if (size.val == 0) { + size.val = 1; + nodes.resize(1); + } else { + size.val *= 2; + nodes.resize(2 * nodes.size() + 1); + } + } + + return index; +} + +LeafIndex +TreeKEMPublicKey::add_leaf(const LeafNode& leaf) +{ + // Check that the leaf node's keys are not already present in the tree + if (exists_in_tree(leaf.encryption_key, std::nullopt)) { + throw InvalidParameterError("Duplicate encryption key"); + } + + if (exists_in_tree(leaf.signature_key, std::nullopt)) { + throw InvalidParameterError("Duplicate signature key"); + } + + // Allocate a blank leaf for this node + const auto index = allocate_leaf(); + + // Set the leaf + node_at(index).node = Node{ leaf }; + + // Update the unmerged list + for (auto& n : NodeIndex(index).dirpath(size)) { + if (!node_at(n).node) { + continue; + } + + auto& parent = var::get(opt::get(node_at(n).node).node); + + // Insert into unmerged leaves while maintaining order + const auto insert_point = stdx::upper_bound(parent.unmerged_leaves, index); + parent.unmerged_leaves.insert(insert_point, index); + } + + clear_hash_path(index); + return index; +} + +void +TreeKEMPublicKey::update_leaf(LeafIndex index, const LeafNode& leaf) +{ + // Check that the leaf node's keys are not already present in the tree, except + // for the signature key, which is allowed to repeat. + if (exists_in_tree(leaf.encryption_key, std::nullopt)) { + throw InvalidParameterError("Duplicate encryption key"); + } + + if (exists_in_tree(leaf.signature_key, index)) { + throw InvalidParameterError("Duplicate signature key"); + } + + blank_path(index); + node_at(NodeIndex(index)).node = Node{ leaf }; + clear_hash_path(index); +} + +void +TreeKEMPublicKey::blank_path(LeafIndex index) +{ + if (nodes.empty()) { + return; + } + + auto ni = NodeIndex(index); + node_at(ni).node.reset(); + for (auto n : ni.dirpath(size)) { + node_at(n).node.reset(); + } + + clear_hash_path(index); +} + +void +TreeKEMPublicKey::merge(LeafIndex from, const UpdatePath& path) +{ + update_leaf(from, path.leaf_node); + + auto dp = filtered_direct_path(NodeIndex(from)); + if (dp.size() != path.nodes.size()) { + throw ProtocolError("Malformed direct path"); + } + + auto ph = parent_hashes(from, dp, path.nodes); + for (size_t i = 0; i < dp.size(); i++) { + auto [n, _res] = dp[i]; + + auto parent_hash = bytes{}; + if (i < dp.size() - 1) { + parent_hash = ph[i + 1]; + } + + node_at(n).node = + Node{ ParentNode{ path.nodes[i].public_key, parent_hash, {} } }; + } + + set_hash_all(); +} + +void +TreeKEMPublicKey::set_hash_all() +{ + auto r = NodeIndex::root(size); + get_hash(r); +} + +bytes +TreeKEMPublicKey::root_hash() const +{ + auto r = NodeIndex::root(size); + if (hashes.count(r) == 0) { + throw InvalidParameterError("Root hash not set"); + } + + return hashes.at(r); +} + +bool +TreeKEMPublicKey::has_parent_hash(NodeIndex child, const bytes& target_ph) const +{ + const auto res = resolve(child); + return stdx::any_of(res, [&](auto nr) { + return opt::get(node_at(nr).node).parent_hash() == target_ph; + }); +} + +bool +TreeKEMPublicKey::parent_hash_valid() const +{ + auto cache = TreeHashCache{}; + + auto width = NodeCount(size); + auto height = NodeIndex::root(size).level(); + for (auto level = uint32_t(1); level <= height; level++) { + auto stride = uint32_t(2) << level; + auto start = NodeIndex{ (stride >> 1U) - 1 }; + + for (auto p = start; p.val < width.val; p.val += stride) { + if (node_at(p).blank()) { + continue; + } + + auto l = p.left(); + auto r = p.right(); + + auto lh = original_parent_hash(cache, p, r); + auto rh = original_parent_hash(cache, p, l); + + if (!has_parent_hash(l, lh) && !has_parent_hash(r, rh)) { + dump(); + return false; + } + } + } + return true; +} + +std::vector +TreeKEMPublicKey::resolve(NodeIndex index) const +{ + auto at_leaf = (index.level() == 0); + if (!node_at(index).blank()) { + auto out = std::vector{ index }; + if (index.is_leaf()) { + return out; + } + + const auto& node = node_at(index); + auto unmerged = + stdx::transform(node.parent_node().unmerged_leaves, + [](LeafIndex x) { return NodeIndex(x); }); + + out.insert(out.end(), unmerged.begin(), unmerged.end()); + return out; + } + + if (at_leaf) { + return {}; + } + + auto l = resolve(index.left()); + auto r = resolve(index.right()); + l.insert(l.end(), r.begin(), r.end()); + return l; +} + +TreeKEMPublicKey::FilteredDirectPath +TreeKEMPublicKey::filtered_direct_path(NodeIndex index) const +{ + auto fdp = FilteredDirectPath{}; + + const auto cp = index.copath(size); + auto last = index; + for (auto n : cp) { + const auto p = n.parent(); + const auto res = resolve(n); + last = p; + if (res.empty()) { + continue; + } + + fdp.emplace_back(p, res); + } + + return fdp; +} + +bool +TreeKEMPublicKey::has_leaf(LeafIndex index) const +{ + return !node_at(index).blank(); +} + +std::optional +TreeKEMPublicKey::find(const LeafNode& leaf) const +{ + for (LeafIndex i{ 0 }; i < size; i.val++) { + const auto& node = node_at(i); + if (!node.blank() && node.leaf_node() == leaf) { + return i; + } + } + + return std::nullopt; +} + +std::optional +TreeKEMPublicKey::leaf_node(LeafIndex index) const +{ + const auto& node = node_at(index); + if (node.blank()) { + return std::nullopt; + } + + return node.leaf_node(); +} + +TreeKEMPrivateKey +TreeKEMPublicKey::update(LeafIndex from, + const bytes& leaf_secret, + const bytes& group_id, + const SignaturePrivateKey& sig_priv, + const LeafNodeOptions& opts) +{ + // Grab information about the sender + const auto& leaf_node = node_at(from); + if (leaf_node.blank()) { + throw InvalidParameterError("Cannot update from blank node"); + } + + // Generate path secrets + auto priv = TreeKEMPrivateKey::create(*this, from, leaf_secret); + auto dp = filtered_direct_path(NodeIndex(from)); + + // Encrypt path secrets to the copath, forming a stub UpdatePath with no + // encryptions + auto path_nodes = stdx::transform(dp, [&](const auto& dpn) { + auto [n, _res] = dpn; + + auto path_secret = priv.path_secrets.at(n); + auto node_priv = opt::get(priv.private_key(n)); + + return UpdatePathNode{ node_priv.public_key, {} }; + }); + + // Update and re-sign the leaf_node + auto ph = parent_hashes(from, dp, path_nodes); + auto ph0 = bytes{}; + if (!ph.empty()) { + ph0 = ph[0]; + } + + auto leaf_pub = opt::get(priv.private_key(NodeIndex(from))).public_key; + auto new_leaf = leaf_node.leaf_node().for_commit( + suite, group_id, from, leaf_pub, ph0, opts, sig_priv); + + // Merge the changes into the tree + merge(from, UpdatePath{ std::move(new_leaf), std::move(path_nodes) }); + + return priv; +} + +UpdatePath +TreeKEMPublicKey::encap(const TreeKEMPrivateKey& priv, + const bytes& context, + const std::vector& except) const +{ + auto dp = filtered_direct_path(NodeIndex(priv.index)); + + // Encrypt path secrets to the copath + auto path_nodes = stdx::transform(dp, [&](const auto& dpn) { + // We need the copy here so that we can modify the resolution. + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto [n, res] = dpn; + remove_leaves(res, except); + + auto path_secret = priv.path_secrets.at(n); + auto node_priv = opt::get(priv.private_key(n)); + + auto ct = stdx::transform(res, [&](auto nr) { + const auto& node_pub = opt::get(node_at(nr).node).public_key(); + return node_pub.encrypt( + suite, encrypt_label::update_path_node, context, path_secret); + }); + + return UpdatePathNode{ node_priv.public_key, std::move(ct) }; + }); + + // Package everything into an UpdatePath + auto new_leaf = opt::get(leaf_node(priv.index)); + auto path = UpdatePath{ new_leaf, std::move(path_nodes) }; + + return path; +} + +void +TreeKEMPublicKey::truncate() +{ + if (size.val == 0) { + return; + } + + // Clear the parent hashes across blank leaves before truncating + auto index = LeafIndex{ size.val - 1 }; + for (; index.val > 0; index.val--) { + if (!node_at(index).blank()) { + break; + } + clear_hash_path(index); + } + + if (node_at(index).blank()) { + nodes.clear(); + return; + } + + // Remove the right subtree until the tree is of minimal size + while (size.val / 2 > index.val) { + nodes.resize(nodes.size() / 2); + size.val /= 2; + } +} + +OptionalNode& +TreeKEMPublicKey::node_at(NodeIndex n) +{ + auto width = NodeCount(size); + if (n.val >= width.val) { + throw InvalidParameterError("Node index not in tree"); + } + + if (n.val >= nodes.size()) { + return blank_node; + } + + return nodes.at(n.val); +} + +const OptionalNode& +TreeKEMPublicKey::node_at(NodeIndex n) const +{ + auto width = NodeCount(size); + if (n.val >= width.val) { + throw InvalidParameterError("Node index not in tree"); + } + + if (n.val >= nodes.size()) { + return blank_node; + } + + return nodes.at(n.val); +} + +OptionalNode& +TreeKEMPublicKey::node_at(LeafIndex n) +{ + return node_at(NodeIndex(n)); +} + +const OptionalNode& +TreeKEMPublicKey::node_at(LeafIndex n) const +{ + return node_at(NodeIndex(n)); +} + +void +TreeKEMPublicKey::clear_hash_all() +{ + hashes.clear(); +} + +void +TreeKEMPublicKey::clear_hash_path(LeafIndex index) +{ + auto dp = NodeIndex(index).dirpath(size); + hashes.erase(NodeIndex(index)); + for (auto n : dp) { + hashes.erase(n); + } +} + +struct LeafNodeHashInput +{ + LeafIndex leaf_index; + std::optional leaf_node; + TLS_SERIALIZABLE(leaf_index, leaf_node) +}; + +struct ParentNodeHashInput +{ + std::optional parent_node; + const bytes& left_hash; + const bytes& right_hash; + TLS_SERIALIZABLE(parent_node, left_hash, right_hash) +}; + +struct TreeHashInput +{ + var::variant node; + TLS_SERIALIZABLE(node); + TLS_TRAITS(tls::variant) +}; + +const bytes& +TreeKEMPublicKey::get_hash(NodeIndex index) +{ + if (hashes.count(index) > 0) { + return hashes.at(index); + } + + auto hash_input = bytes{}; + const auto& node = node_at(index); + if (index.level() == 0) { + auto input = LeafNodeHashInput{ LeafIndex(index), {} }; + if (!node.blank()) { + input.leaf_node = node.leaf_node(); + } + + hash_input = tls::marshal(TreeHashInput{ input }); + } else { + auto input = ParentNodeHashInput{ + {}, + get_hash(index.left()), + get_hash(index.right()), + }; + + if (!node.blank()) { + input.parent_node = node.parent_node(); + } + + hash_input = tls::marshal(TreeHashInput{ input }); + } + + auto hash = suite.digest().hash(hash_input); + hashes.insert_or_assign(index, hash); + return hashes.at(index); +} + +// struct { +// HPKEPublicKey encryption_key; +// opaque parent_hash; +// opaque original_sibling_tree_hash; +// } ParentHashInput; +struct ParentHashInput +{ + const HPKEPublicKey& public_key; + const bytes& parent_hash; + const bytes& original_child_resolution; + + TLS_SERIALIZABLE(public_key, parent_hash, original_child_resolution) +}; + +bytes +TreeKEMPublicKey::parent_hash(const ParentNode& parent, + NodeIndex copath_child) const +{ + if (hashes.count(copath_child) == 0) { + throw InvalidParameterError("Child hash not set"); + } + + auto hash_input = ParentHashInput{ + parent.public_key, + parent.parent_hash, + hashes.at(copath_child), + }; + + return suite.digest().hash(tls::marshal(hash_input)); +} + +std::vector +TreeKEMPublicKey::parent_hashes( + LeafIndex from, + const FilteredDirectPath& fdp, + const std::vector& path_nodes) const +{ + // An empty filtered direct path indicates a one-member tree, since there's + // nobody else there to encrypt with. In this special case, there's no + // parent hashing to be done. + if (fdp.empty()) { + return {}; + } + + // The list of nodes for whom parent hashes are computed, namely: Direct path + // excluding the last entry, including leaf + auto from_node = NodeIndex(from); + auto dp = fdp; + auto [last, _res_last] = dp.back(); + dp.pop_back(); + dp.insert(dp.begin(), { from_node, {} }); + + if (dp.size() != path_nodes.size()) { + throw ProtocolError("Malformed UpdatePath"); + } + + // Parent hash for all the parents, starting from the last entry of the + // filtered direct path + auto last_hash = bytes{}; + auto ph = std::vector(dp.size()); + for (int i = static_cast(dp.size()) - 1; i >= 0; i--) { + auto [n, _res] = dp[i]; + auto s = n.sibling(last); + + auto parent_node = ParentNode{ path_nodes[i].public_key, last_hash, {} }; + last_hash = parent_hash(parent_node, s); + ph[i] = last_hash; + + last = n; + } + + return ph; +} + +const bytes& +TreeKEMPublicKey::original_tree_hash(TreeHashCache& cache, + NodeIndex index, + std::vector parent_except) const +{ + // Scope the unmerged leaves list down to this subtree + auto except = std::vector{}; + std::copy_if(parent_except.begin(), + parent_except.end(), + std::back_inserter(except), + [&](auto i) { return NodeIndex(i).is_below(index); }); + + auto have_local_changes = !except.empty(); + + // If there are no local changes, then we can use the cached tree hash + if (!have_local_changes) { + return hashes.at(index); + } + + // If this method has been called before with the same number of excluded + // leaves (which implies the same set), then use the cached value. + if (auto it = cache.find(index); it != cache.end()) { + const auto& [key, value] = *it; + const auto& [except_size, hash] = value; + if (except_size == except.size()) { + return hash; + } + } + + // If there is no entry in either cache, recompute the value + auto hash = bytes{}; + if (index.is_leaf()) { + // A leaf node with local changes is by definition excluded from the parent + // hash. So we return the hash of an empty leaf. + auto leaf_hash_input = LeafNodeHashInput{ LeafIndex(index), std::nullopt }; + hash = suite.digest().hash(tls::marshal(TreeHashInput{ leaf_hash_input })); + } else { + // If there is no cached value, recalculate the child hashes with the + // specified `except` list, removing the `except` list from + // `unmerged_leaves`. + auto parent_hash_input = ParentNodeHashInput{ + std::nullopt, + original_tree_hash(cache, index.left(), except), + original_tree_hash(cache, index.right(), except), + }; + + if (!node_at(index).blank()) { + parent_hash_input.parent_node = node_at(index).parent_node(); + auto& unmerged_leaves = + opt::get(parent_hash_input.parent_node).unmerged_leaves; + auto end = std::remove_if( + unmerged_leaves.begin(), unmerged_leaves.end(), [&](auto leaf) { + return std::count(except.begin(), except.end(), leaf) != 0; + }); + unmerged_leaves.erase(end, unmerged_leaves.end()); + } + + hash = + suite.digest().hash(tls::marshal(TreeHashInput{ parent_hash_input })); + } + + cache.insert_or_assign(index, std::make_pair(except.size(), hash)); + return cache.at(index).second; +} + +bytes +TreeKEMPublicKey::original_parent_hash(TreeHashCache& cache, + NodeIndex parent, + NodeIndex sibling) const +{ + const auto& parent_node = node_at(parent).parent_node(); + const auto& unmerged = parent_node.unmerged_leaves; + const auto& sibling_hash = original_tree_hash(cache, sibling, unmerged); + + return suite.digest().hash(tls::marshal(ParentHashInput{ + parent_node.public_key, + parent_node.parent_hash, + sibling_hash, + })); +} + +bool +TreeKEMPublicKey::parent_hash_valid(LeafIndex from, + const UpdatePath& path) const +{ + auto fdp = filtered_direct_path(NodeIndex(from)); + auto hash_chain = parent_hashes(from, fdp, path.nodes); + auto leaf_ph = + var::visit(overloaded{ + [](const ParentHash& ph) -> std::optional { + return ph.parent_hash; + }, + [](const auto& /* other */) -> std::optional { + return std::nullopt; + }, + }, + path.leaf_node.content); + + // If there are no nodes to hash, then ParentHash MUST be omitted + if (hash_chain.empty()) { + return !leaf_ph; + } + + return leaf_ph && opt::get(leaf_ph) == hash_chain[0]; +} + +bool +TreeKEMPublicKey::exists_in_tree(const HPKEPublicKey& key, + std::optional except) const +{ + return any_leaf([&](auto i, const auto& node) { + return i != except && node.encryption_key == key; + }); +} + +bool +TreeKEMPublicKey::exists_in_tree(const SignaturePublicKey& key, + std::optional except) const +{ + return any_leaf([&](auto i, const auto& node) { + return i != except && node.signature_key == key; + }); +} + +tls::ostream& +operator<<(tls::ostream& str, const TreeKEMPublicKey& obj) +{ + // Empty tree + if (obj.size.val == 0) { + return str << std::vector{}; + } + + LeafIndex cut = LeafIndex{ obj.size.val - 1 }; + while (cut.val > 0 && obj.node_at(cut).blank()) { + cut.val -= 1; + } + + const auto begin = obj.nodes.begin(); + const auto end = begin + NodeIndex(cut).val + 1; + const auto view = std::vector(begin, end); + return str << view; +} + +tls::istream& +operator>>(tls::istream& str, TreeKEMPublicKey& obj) +{ + // Read the node list + str >> obj.nodes; + if (obj.nodes.empty()) { + return str; + } + + // Verify that the tree is well-formed and minimal + if (obj.nodes.size() % 2 == 0) { + throw ProtocolError("Malformed ratchet tree: even number of nodes"); + } + + if (obj.nodes.back().blank()) { + throw ProtocolError("Ratchet tree does not use minimal encoding"); + } + + // Adjust the size value to fit the non-blank nodes + obj.size.val = 1; + while (NodeCount(obj.size).val < obj.nodes.size()) { + obj.size.val *= 2; + } + + // Add blank nodes to the end + obj.nodes.resize(NodeCount(obj.size).val); + + // Verify the basic structure of the tree is sane + for (size_t i = 0; i < obj.nodes.size(); i++) { + if (obj.nodes[i].blank()) { + continue; + } + + const auto& node = opt::get(obj.nodes[i].node).node; + auto at_leaf = (i % 2 == 0); + auto holds_leaf = var::holds_alternative(node); + auto holds_parent = var::holds_alternative(node); + + if (at_leaf && !holds_leaf) { + throw InvalidParameterError("Parent node in leaf node position"); + } + + if (!at_leaf && !holds_parent) { + throw InvalidParameterError("Leaf node in parent node position"); + } + } + + return str; +} + +} // namespace mlspp diff --git a/mlspp/third_party/CMakeLists.txt b/mlspp/third_party/CMakeLists.txt new file mode 100755 index 0000000000..c1612f59be --- /dev/null +++ b/mlspp/third_party/CMakeLists.txt @@ -0,0 +1,17 @@ +set(CURRENT_LIB_NAME third_party) + +add_library(${CURRENT_LIB_NAME} INTERFACE) +target_include_directories(${CURRENT_LIB_NAME} + INTERFACE + $ + $ +) + +### +### Install +### + +install(TARGETS ${CURRENT_LIB_NAME} EXPORT mlspp-targets) +install(FILES variant.hpp + DESTINATION + ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}) diff --git a/mlspp/third_party/variant.hpp b/mlspp/third_party/variant.hpp new file mode 100755 index 0000000000..990a44cbab --- /dev/null +++ b/mlspp/third_party/variant.hpp @@ -0,0 +1,2813 @@ +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_VARIANT_HPP +#define MPARK_VARIANT_HPP + +/* + variant synopsis + +namespace std { + + // 20.7.2, class template variant + template + class variant { + public: + + // 20.7.2.1, constructors + constexpr variant() noexcept(see below); + variant(const variant&); + variant(variant&&) noexcept(see below); + + template constexpr variant(T&&) noexcept(see below); + + template + constexpr explicit variant(in_place_type_t, Args&&...); + + template + constexpr explicit variant( + in_place_type_t, initializer_list, Args&&...); + + template + constexpr explicit variant(in_place_index_t, Args&&...); + + template + constexpr explicit variant( + in_place_index_t, initializer_list, Args&&...); + + // 20.7.2.2, destructor + ~variant(); + + // 20.7.2.3, assignment + variant& operator=(const variant&); + variant& operator=(variant&&) noexcept(see below); + + template variant& operator=(T&&) noexcept(see below); + + // 20.7.2.4, modifiers + template + T& emplace(Args&&...); + + template + T& emplace(initializer_list, Args&&...); + + template + variant_alternative& emplace(Args&&...); + + template + variant_alternative& emplace(initializer_list, Args&&...); + + // 20.7.2.5, value status + constexpr bool valueless_by_exception() const noexcept; + constexpr size_t index() const noexcept; + + // 20.7.2.6, swap + void swap(variant&) noexcept(see below); + }; + + // 20.7.3, variant helper classes + template struct variant_size; // undefined + + template + constexpr size_t variant_size_v = variant_size::value; + + template struct variant_size; + template struct variant_size; + template struct variant_size; + + template + struct variant_size>; + + template struct variant_alternative; // undefined + + template + using variant_alternative_t = typename variant_alternative::type; + + template struct variant_alternative; + template struct variant_alternative; + template struct variant_alternative; + + template + struct variant_alternative>; + + constexpr size_t variant_npos = -1; + + // 20.7.4, value access + template + constexpr bool holds_alternative(const variant&) noexcept; + + template + constexpr variant_alternative_t>& + get(variant&); + + template + constexpr variant_alternative_t>&& + get(variant&&); + + template + constexpr variant_alternative_t> const& + get(const variant&); + + template + constexpr variant_alternative_t> const&& + get(const variant&&); + + template + constexpr T& get(variant&); + + template + constexpr T&& get(variant&&); + + template + constexpr const T& get(const variant&); + + template + constexpr const T&& get(const variant&&); + + template + constexpr add_pointer_t>> + get_if(variant*) noexcept; + + template + constexpr add_pointer_t>> + get_if(const variant*) noexcept; + + template + constexpr add_pointer_t + get_if(variant*) noexcept; + + template + constexpr add_pointer_t + get_if(const variant*) noexcept; + + // 20.7.5, relational operators + template + constexpr bool operator==(const variant&, const variant&); + + template + constexpr bool operator!=(const variant&, const variant&); + + template + constexpr bool operator<(const variant&, const variant&); + + template + constexpr bool operator>(const variant&, const variant&); + + template + constexpr bool operator<=(const variant&, const variant&); + + template + constexpr bool operator>=(const variant&, const variant&); + + // 20.7.6, visitation + template + constexpr see below visit(Visitor&&, Variants&&...); + + // 20.7.7, class monostate + struct monostate; + + // 20.7.8, monostate relational operators + constexpr bool operator<(monostate, monostate) noexcept; + constexpr bool operator>(monostate, monostate) noexcept; + constexpr bool operator<=(monostate, monostate) noexcept; + constexpr bool operator>=(monostate, monostate) noexcept; + constexpr bool operator==(monostate, monostate) noexcept; + constexpr bool operator!=(monostate, monostate) noexcept; + + // 20.7.9, specialized algorithms + template + void swap(variant&, variant&) noexcept(see below); + + // 20.7.10, class bad_variant_access + class bad_variant_access; + + // 20.7.11, hash support + template struct hash; + template struct hash>; + template <> struct hash; + +} // namespace std + +*/ + +#include +#include +#include +#include +#include +#include +#include + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_CONFIG_HPP +#define MPARK_CONFIG_HPP + +// MSVC 2015 Update 3. +#if __cplusplus < 201103L && (!defined(_MSC_VER) || _MSC_FULL_VER < 190024210) +#error "MPark.Variant requires C++11 support." +#endif + +#ifndef __has_attribute +#define __has_attribute(x) 0 +#endif + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif + +#ifndef __has_include +#define __has_include(x) 0 +#endif + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if __has_attribute(always_inline) || defined(__GNUC__) +#define MPARK_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#elif defined(_MSC_VER) +#define MPARK_ALWAYS_INLINE __forceinline +#else +#define MPARK_ALWAYS_INLINE inline +#endif + +#if __has_builtin(__builtin_addressof) || \ + (defined(__GNUC__) && __GNUC__ >= 7) || defined(_MSC_VER) +#define MPARK_BUILTIN_ADDRESSOF +#endif + +#if __has_builtin(__builtin_unreachable) || defined(__GNUC__) +#define MPARK_BUILTIN_UNREACHABLE __builtin_unreachable() +#elif defined(_MSC_VER) +#define MPARK_BUILTIN_UNREACHABLE __assume(false) +#else +#define MPARK_BUILTIN_UNREACHABLE +#endif + +#if __has_builtin(__type_pack_element) +#define MPARK_TYPE_PACK_ELEMENT +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 200704 && \ + !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 9) +#define MPARK_CPP11_CONSTEXPR +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304 +#define MPARK_CPP14_CONSTEXPR +#endif + +#if __has_feature(cxx_exceptions) || defined(__cpp_exceptions) || \ + (defined(_MSC_VER) && defined(_CPPUNWIND)) +#define MPARK_EXCEPTIONS +#endif + +#if defined(__cpp_generic_lambdas) || defined(_MSC_VER) +#define MPARK_GENERIC_LAMBDAS +#endif + +#if defined(__cpp_lib_integer_sequence) +#define MPARK_INTEGER_SEQUENCE +#endif + +#if defined(__cpp_return_type_deduction) || defined(_MSC_VER) +#define MPARK_RETURN_TYPE_DEDUCTION +#endif + +#if defined(__cpp_lib_transparent_operators) || defined(_MSC_VER) +#define MPARK_TRANSPARENT_OPERATORS +#endif + +#if defined(__cpp_variable_templates) || defined(_MSC_VER) +#define MPARK_VARIABLE_TEMPLATES +#endif + +#if !defined(__GLIBCXX__) || __has_include() // >= libstdc++-5 +#define MPARK_TRIVIALITY_TYPE_TRAITS +#define MPARK_INCOMPLETE_TYPE_TRAITS +#endif + +#endif // MPARK_CONFIG_HPP + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_IN_PLACE_HPP +#define MPARK_IN_PLACE_HPP + +#include + + +namespace mpark { + + struct in_place_t { explicit in_place_t() = default; }; + + template + struct in_place_index_t { explicit in_place_index_t() = default; }; + + template + struct in_place_type_t { explicit in_place_type_t() = default; }; + +#ifdef MPARK_VARIABLE_TEMPLATES + constexpr in_place_t in_place{}; + + template constexpr in_place_index_t in_place_index{}; + + template constexpr in_place_type_t in_place_type{}; +#endif + +} // namespace mpark + +#endif // MPARK_IN_PLACE_HPP + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_LIB_HPP +#define MPARK_LIB_HPP + +#include +#include +#include +#include + + +#define MPARK_RETURN(...) \ + noexcept(noexcept(__VA_ARGS__)) -> decltype(__VA_ARGS__) { return __VA_ARGS__; } + +namespace mpark { + namespace lib { + template + struct identity { using type = T; }; + + inline namespace cpp14 { + template + struct array { + constexpr const T &operator[](std::size_t index) const { + return data[index]; + } + + T data[N == 0 ? 1 : N]; + }; + + template + using add_pointer_t = typename std::add_pointer::type; + + template + using common_type_t = typename std::common_type::type; + + template + using decay_t = typename std::decay::type; + + template + using enable_if_t = typename std::enable_if::type; + + template + using remove_const_t = typename std::remove_const::type; + + template + using remove_reference_t = typename std::remove_reference::type; + + template + inline constexpr T &&forward(remove_reference_t &t) noexcept { + return static_cast(t); + } + + template + inline constexpr T &&forward(remove_reference_t &&t) noexcept { + static_assert(!std::is_lvalue_reference::value, + "can not forward an rvalue as an lvalue"); + return static_cast(t); + } + + template + inline constexpr remove_reference_t &&move(T &&t) noexcept { + return static_cast &&>(t); + } + +#ifdef MPARK_INTEGER_SEQUENCE + using std::integer_sequence; + using std::index_sequence; + using std::make_index_sequence; + using std::index_sequence_for; +#else + template + struct integer_sequence { + using value_type = T; + static constexpr std::size_t size() noexcept { return sizeof...(Is); } + }; + + template + using index_sequence = integer_sequence; + + template + struct make_index_sequence_concat; + + template + struct make_index_sequence_concat, + index_sequence> + : identity> {}; + + template + struct make_index_sequence_impl; + + template + using make_index_sequence = typename make_index_sequence_impl::type; + + template + struct make_index_sequence_impl + : make_index_sequence_concat, + make_index_sequence> {}; + + template <> + struct make_index_sequence_impl<0> : identity> {}; + + template <> + struct make_index_sequence_impl<1> : identity> {}; + + template + using index_sequence_for = make_index_sequence; +#endif + + // +#ifdef MPARK_TRANSPARENT_OPERATORS + using equal_to = std::equal_to<>; +#else + struct equal_to { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) == lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using not_equal_to = std::not_equal_to<>; +#else + struct not_equal_to { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) != lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using less = std::less<>; +#else + struct less { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) < lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using greater = std::greater<>; +#else + struct greater { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) > lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using less_equal = std::less_equal<>; +#else + struct less_equal { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) <= lib::forward(rhs)) + }; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS + using greater_equal = std::greater_equal<>; +#else + struct greater_equal { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) >= lib::forward(rhs)) + }; +#endif + } // namespace cpp14 + + inline namespace cpp17 { + + // + template + using bool_constant = std::integral_constant; + + template + struct voider : identity {}; + + template + using void_t = typename voider::type; + + namespace detail { + namespace swappable { + + using std::swap; + + template + struct is_swappable { + private: + template (), + std::declval()))> + inline static std::true_type test(int); + + template + inline static std::false_type test(...); + + public: + static constexpr bool value = decltype(test(0))::value; + }; + + template + struct is_nothrow_swappable { + static constexpr bool value = + noexcept(swap(std::declval(), std::declval())); + }; + + template + struct is_nothrow_swappable : std::false_type {}; + + } // namespace swappable + } // namespace detail + + using detail::swappable::is_swappable; + + template + using is_nothrow_swappable = + detail::swappable::is_nothrow_swappable::value, T>; + + // + namespace detail { + + template + struct is_reference_wrapper : std::false_type {}; + + template + struct is_reference_wrapper> + : std::true_type {}; + + template + struct Invoke; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN((lib::forward(arg).*pmf)(lib::forward(args)...)) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN((lib::forward(arg).get().*pmf)(lib::forward(args)...)) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN(((*lib::forward(arg)).*pmf)(lib::forward(args)...)) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN(lib::forward(arg).*pmo) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN(lib::forward(arg).get().*pmo) + }; + + template <> + struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN((*lib::forward(arg)).*pmo) + }; + + template + inline constexpr auto invoke(R T::*f, Arg &&arg, Args &&... args) + MPARK_RETURN( + Invoke::value, + (std::is_base_of>::value + ? 0 + : is_reference_wrapper>::value + ? 1 + : 2)>::invoke(f, + lib::forward(arg), + lib::forward(args)...)) + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + template + inline constexpr auto invoke(F &&f, Args &&... args) + MPARK_RETURN(lib::forward(f)(lib::forward(args)...)) +#ifdef _MSC_VER +#pragma warning(pop) +#endif + } // namespace detail + + template + inline constexpr auto invoke(F &&f, Args &&... args) + MPARK_RETURN(detail::invoke(lib::forward(f), + lib::forward(args)...)) + + namespace detail { + + template + struct invoke_result {}; + + template + struct invoke_result(), std::declval()...))>, + F, + Args...> + : identity(), std::declval()...))> {}; + + } // namespace detail + + template + using invoke_result = detail::invoke_result; + + template + using invoke_result_t = typename invoke_result::type; + + namespace detail { + + template + struct is_invocable : std::false_type {}; + + template + struct is_invocable>, F, Args...> + : std::true_type {}; + + template + struct is_invocable_r : std::false_type {}; + + template + struct is_invocable_r>, + R, + F, + Args...> + : std::is_convertible, R> {}; + + } // namespace detail + + template + using is_invocable = detail::is_invocable; + + template + using is_invocable_r = detail::is_invocable_r; + + namespace detail { + + template + struct is_nothrow_invocable { + static constexpr bool value = + noexcept(lib::invoke(std::declval(), std::declval()...)); + }; + + template + struct is_nothrow_invocable : std::false_type {}; + + template + struct is_nothrow_invocable_r { + private: + inline static R impl() { + return lib::invoke(std::declval(), std::declval()...); + } + + public: + static constexpr bool value = noexcept(impl()); + }; + + template + struct is_nothrow_invocable_r : std::false_type {}; + + } // namespace detail + + template + using is_nothrow_invocable = detail:: + is_nothrow_invocable::value, F, Args...>; + + template + using is_nothrow_invocable_r = + detail::is_nothrow_invocable_r::value, + R, + F, + Args...>; + + // +#ifdef MPARK_BUILTIN_ADDRESSOF + template + inline constexpr T *addressof(T &arg) noexcept { + return __builtin_addressof(arg); + } +#else + namespace detail { + + namespace has_addressof_impl { + + struct fail; + + template + inline fail operator&(T &&); + + template + inline static constexpr bool impl() { + return (std::is_class::value || std::is_union::value) && + !std::is_same()), fail>::value; + } + + } // namespace has_addressof_impl + + template + using has_addressof = bool_constant()>; + + template + inline constexpr T *addressof(T &arg, std::true_type) noexcept { + return std::addressof(arg); + } + + template + inline constexpr T *addressof(T &arg, std::false_type) noexcept { + return &arg; + } + + } // namespace detail + + template + inline constexpr T *addressof(T &arg) noexcept { + return detail::addressof(arg, detail::has_addressof{}); + } +#endif + + template + inline constexpr T *addressof(const T &&) = delete; + + } // namespace cpp17 + + template + struct remove_all_extents : identity {}; + + template + struct remove_all_extents> : remove_all_extents {}; + + template + using remove_all_extents_t = typename remove_all_extents::type; + + template + using size_constant = std::integral_constant; + + template + struct indexed_type : size_constant { using type = T; }; + + template + using all = std::is_same, + integer_sequence>; + +#ifdef MPARK_TYPE_PACK_ELEMENT + template + using type_pack_element_t = __type_pack_element; +#else + template + struct type_pack_element_impl { + private: + template + struct set; + + template + struct set> : indexed_type... {}; + + template + inline static std::enable_if impl(indexed_type); + + inline static std::enable_if impl(...); + + public: + using type = decltype(impl(set>{})); + }; + + template + using type_pack_element = typename type_pack_element_impl::type; + + template + using type_pack_element_t = typename type_pack_element::type; +#endif + +#ifdef MPARK_TRIVIALITY_TYPE_TRAITS + using std::is_trivially_copy_constructible; + using std::is_trivially_move_constructible; + using std::is_trivially_copy_assignable; + using std::is_trivially_move_assignable; +#else + template + struct is_trivially_copy_constructible + : bool_constant< + std::is_copy_constructible::value && __has_trivial_copy(T)> {}; + + template + struct is_trivially_move_constructible : bool_constant<__is_trivial(T)> {}; + + template + struct is_trivially_copy_assignable + : bool_constant< + std::is_copy_assignable::value && __has_trivial_assign(T)> {}; + + template + struct is_trivially_move_assignable : bool_constant<__is_trivial(T)> {}; +#endif + + template + struct dependent_type : T {}; + + template + struct push_back; + + template + using push_back_t = typename push_back::type; + + template + struct push_back, J> { + using type = index_sequence; + }; + + } // namespace lib +} // namespace mpark + +#undef MPARK_RETURN + +#endif // MPARK_LIB_HPP + + +namespace mpark { + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + +#define AUTO auto +#define AUTO_RETURN(...) { return __VA_ARGS__; } + +#define AUTO_REFREF auto && +#define AUTO_REFREF_RETURN(...) { return __VA_ARGS__; } + +#define DECLTYPE_AUTO decltype(auto) +#define DECLTYPE_AUTO_RETURN(...) { return __VA_ARGS__; } + +#else + +#define AUTO auto +#define AUTO_RETURN(...) \ + -> lib::decay_t { return __VA_ARGS__; } + +#define AUTO_REFREF auto +#define AUTO_REFREF_RETURN(...) \ + -> decltype((__VA_ARGS__)) { \ + static_assert(std::is_reference::value, ""); \ + return __VA_ARGS__; \ + } + +#define DECLTYPE_AUTO auto +#define DECLTYPE_AUTO_RETURN(...) \ + -> decltype(__VA_ARGS__) { return __VA_ARGS__; } + +#endif + + class bad_variant_access : public std::exception { + public: + virtual const char *what() const noexcept override { return "bad_variant_access"; } + }; + + [[noreturn]] inline void throw_bad_variant_access() { +#ifdef MPARK_EXCEPTIONS + throw bad_variant_access{}; +#else + std::terminate(); + MPARK_BUILTIN_UNREACHABLE; +#endif + } + + template + class variant; + + template + struct variant_size; + +#ifdef MPARK_VARIABLE_TEMPLATES + template + constexpr std::size_t variant_size_v = variant_size::value; +#endif + + template + struct variant_size : variant_size {}; + + template + struct variant_size : variant_size {}; + + template + struct variant_size : variant_size {}; + + template + struct variant_size> : lib::size_constant {}; + + template + struct variant_alternative; + + template + using variant_alternative_t = typename variant_alternative::type; + + template + struct variant_alternative + : std::add_const> {}; + + template + struct variant_alternative + : std::add_volatile> {}; + + template + struct variant_alternative + : std::add_cv> {}; + + template + struct variant_alternative> { + static_assert(I < sizeof...(Ts), + "index out of bounds in `std::variant_alternative<>`"); + using type = lib::type_pack_element_t; + }; + + constexpr std::size_t variant_npos = static_cast(-1); + + namespace detail { + + constexpr std::size_t not_found = static_cast(-1); + constexpr std::size_t ambiguous = static_cast(-2); + +#ifdef MPARK_CPP14_CONSTEXPR + template + inline constexpr std::size_t find_index() { + constexpr lib::array matches = { + {std::is_same::value...} + }; + std::size_t result = not_found; + for (std::size_t i = 0; i < sizeof...(Ts); ++i) { + if (matches[i]) { + if (result != not_found) { + return ambiguous; + } + result = i; + } + } + return result; + } +#else + inline constexpr std::size_t find_index_impl(std::size_t result, + std::size_t) { + return result; + } + + template + inline constexpr std::size_t find_index_impl(std::size_t result, + std::size_t idx, + bool b, + Bs... bs) { + return b ? (result != not_found ? ambiguous + : find_index_impl(idx, idx + 1, bs...)) + : find_index_impl(result, idx + 1, bs...); + } + + template + inline constexpr std::size_t find_index() { + return find_index_impl(not_found, 0, std::is_same::value...); + } +#endif + + template + using find_index_sfinae_impl = + lib::enable_if_t>; + + template + using find_index_sfinae = find_index_sfinae_impl()>; + + template + struct find_index_checked_impl : lib::size_constant { + static_assert(I != not_found, "the specified type is not found."); + static_assert(I != ambiguous, "the specified type is ambiguous."); + }; + + template + using find_index_checked = find_index_checked_impl()>; + + struct valueless_t {}; + + enum class Trait { TriviallyAvailable, Available, Unavailable }; + + template class IsTriviallyAvailable, + template class IsAvailable> + inline constexpr Trait trait() { + return IsTriviallyAvailable::value + ? Trait::TriviallyAvailable + : IsAvailable::value ? Trait::Available + : Trait::Unavailable; + } + +#ifdef MPARK_CPP14_CONSTEXPR + template + inline constexpr Trait common_trait(Traits... traits_) { + Trait result = Trait::TriviallyAvailable; + lib::array traits = {{traits_...}}; + for (std::size_t i = 0; i < sizeof...(Traits); ++i) { + Trait t = traits[i]; + if (static_cast(t) > static_cast(result)) { + result = t; + } + } + return result; + } +#else + inline constexpr Trait common_trait_impl(Trait result) { return result; } + + template + inline constexpr Trait common_trait_impl(Trait result, + Trait t, + Traits... ts) { + return static_cast(t) > static_cast(result) + ? common_trait_impl(t, ts...) + : common_trait_impl(result, ts...); + } + + template + inline constexpr Trait common_trait(Traits... ts) { + return common_trait_impl(Trait::TriviallyAvailable, ts...); + } +#endif + + template + struct traits { + static constexpr Trait copy_constructible_trait = + common_trait(trait()...); + + static constexpr Trait move_constructible_trait = + common_trait(trait()...); + + static constexpr Trait copy_assignable_trait = + common_trait(copy_constructible_trait, + trait()...); + + static constexpr Trait move_assignable_trait = + common_trait(move_constructible_trait, + trait()...); + + static constexpr Trait destructible_trait = + common_trait(trait()...); + }; + + namespace access { + + struct recursive_union { +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto &&get_alt(V &&v, in_place_index_t<0>) { + return lib::forward(v).head_; + } + + template + inline static constexpr auto &&get_alt(V &&v, in_place_index_t) { + return get_alt(lib::forward(v).tail_, in_place_index_t{}); + } +#else + template + struct get_alt_impl { + template + inline constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN(get_alt_impl{}(lib::forward(v).tail_)) + }; + + template + struct get_alt_impl<0, Dummy> { + template + inline constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN(lib::forward(v).head_) + }; + + template + inline static constexpr AUTO_REFREF get_alt(V &&v, in_place_index_t) + AUTO_REFREF_RETURN(get_alt_impl{}(lib::forward(v))) +#endif + }; + + struct base { + template + inline static constexpr AUTO_REFREF get_alt(V &&v) +#ifdef _MSC_VER + AUTO_REFREF_RETURN(recursive_union::get_alt( + lib::forward(v).data_, in_place_index_t{})) +#else + AUTO_REFREF_RETURN(recursive_union::get_alt( + data(lib::forward(v)), in_place_index_t{})) +#endif + }; + + struct variant { + template + inline static constexpr AUTO_REFREF get_alt(V &&v) + AUTO_REFREF_RETURN(base::get_alt(lib::forward(v).impl_)) + }; + + } // namespace access + + namespace visitation { + +#if defined(MPARK_CPP14_CONSTEXPR) && !defined(_MSC_VER) +#define MPARK_VARIANT_SWITCH_VISIT +#endif + + struct base { + template + using dispatch_result_t = decltype( + lib::invoke(std::declval(), + access::base::get_alt<0>(std::declval())...)); + + template + struct expected { + template + inline static constexpr bool but_got() { + return std::is_same::value; + } + }; + + template + struct visit_return_type_check { + static_assert( + expected::template but_got(), + "`visit` requires the visitor to have a single return type"); + + template + inline static constexpr DECLTYPE_AUTO invoke(Visitor &&visitor, + Alts &&... alts) + DECLTYPE_AUTO_RETURN(lib::invoke(lib::forward(visitor), + lib::forward(alts)...)) + }; + +#ifdef MPARK_VARIANT_SWITCH_VISIT + template + struct dispatcher; + + template + struct dispatcher { + template + MPARK_ALWAYS_INLINE static constexpr R dispatch( + F &&, typename ITs::type &&..., Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_case(F &&, Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_at(std::size_t, + F &&, + Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + }; + + template + struct dispatcher { + template + MPARK_ALWAYS_INLINE static constexpr R dispatch( + F &&f, typename ITs::type &&... visited_vs) { + using Expected = R; + using Actual = decltype(lib::invoke( + lib::forward(f), + access::base::get_alt( + lib::forward(visited_vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt( + lib::forward(visited_vs))...); + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch( + F &&f, typename ITs::type &&... visited_vs, V &&v, Vs &&... vs) { +#define MPARK_DISPATCH(I) \ + dispatcher<(I < lib::decay_t::size()), \ + R, \ + ITs..., \ + lib::indexed_type>:: \ + template dispatch<0>(lib::forward(f), \ + lib::forward(visited_vs)..., \ + lib::forward(v), \ + lib::forward(vs)...) + +#define MPARK_DEFAULT(I) \ + dispatcher<(I < lib::decay_t::size()), R, ITs...>::template dispatch( \ + lib::forward(f), \ + lib::forward(visited_vs)..., \ + lib::forward(v), \ + lib::forward(vs)...) + + switch (v.index()) { + case B + 0: return MPARK_DISPATCH(B + 0); + case B + 1: return MPARK_DISPATCH(B + 1); + case B + 2: return MPARK_DISPATCH(B + 2); + case B + 3: return MPARK_DISPATCH(B + 3); + case B + 4: return MPARK_DISPATCH(B + 4); + case B + 5: return MPARK_DISPATCH(B + 5); + case B + 6: return MPARK_DISPATCH(B + 6); + case B + 7: return MPARK_DISPATCH(B + 7); + case B + 8: return MPARK_DISPATCH(B + 8); + case B + 9: return MPARK_DISPATCH(B + 9); + case B + 10: return MPARK_DISPATCH(B + 10); + case B + 11: return MPARK_DISPATCH(B + 11); + case B + 12: return MPARK_DISPATCH(B + 12); + case B + 13: return MPARK_DISPATCH(B + 13); + case B + 14: return MPARK_DISPATCH(B + 14); + case B + 15: return MPARK_DISPATCH(B + 15); + case B + 16: return MPARK_DISPATCH(B + 16); + case B + 17: return MPARK_DISPATCH(B + 17); + case B + 18: return MPARK_DISPATCH(B + 18); + case B + 19: return MPARK_DISPATCH(B + 19); + case B + 20: return MPARK_DISPATCH(B + 20); + case B + 21: return MPARK_DISPATCH(B + 21); + case B + 22: return MPARK_DISPATCH(B + 22); + case B + 23: return MPARK_DISPATCH(B + 23); + case B + 24: return MPARK_DISPATCH(B + 24); + case B + 25: return MPARK_DISPATCH(B + 25); + case B + 26: return MPARK_DISPATCH(B + 26); + case B + 27: return MPARK_DISPATCH(B + 27); + case B + 28: return MPARK_DISPATCH(B + 28); + case B + 29: return MPARK_DISPATCH(B + 29); + case B + 30: return MPARK_DISPATCH(B + 30); + case B + 31: return MPARK_DISPATCH(B + 31); + default: return MPARK_DEFAULT(B + 32); + } + +#undef MPARK_DEFAULT +#undef MPARK_DISPATCH + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_case(F &&f, + Vs &&... vs) { + using Expected = R; + using Actual = decltype( + lib::invoke(lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_at(std::size_t index, + F &&f, + V &&v, + Vs &&... vs) { + static_assert(lib::all<(lib::decay_t::size() == + lib::decay_t::size())...>::value, + "all of the variants must be the same size."); +#define MPARK_DISPATCH_AT(I) \ + dispatcher<(I < lib::decay_t::size()), R>::template dispatch_case( \ + lib::forward(f), lib::forward(v), lib::forward(vs)...) + +#define MPARK_DEFAULT(I) \ + dispatcher<(I < lib::decay_t::size()), R>::template dispatch_at( \ + index, lib::forward(f), lib::forward(v), lib::forward(vs)...) + + switch (index) { + case B + 0: return MPARK_DISPATCH_AT(B + 0); + case B + 1: return MPARK_DISPATCH_AT(B + 1); + case B + 2: return MPARK_DISPATCH_AT(B + 2); + case B + 3: return MPARK_DISPATCH_AT(B + 3); + case B + 4: return MPARK_DISPATCH_AT(B + 4); + case B + 5: return MPARK_DISPATCH_AT(B + 5); + case B + 6: return MPARK_DISPATCH_AT(B + 6); + case B + 7: return MPARK_DISPATCH_AT(B + 7); + case B + 8: return MPARK_DISPATCH_AT(B + 8); + case B + 9: return MPARK_DISPATCH_AT(B + 9); + case B + 10: return MPARK_DISPATCH_AT(B + 10); + case B + 11: return MPARK_DISPATCH_AT(B + 11); + case B + 12: return MPARK_DISPATCH_AT(B + 12); + case B + 13: return MPARK_DISPATCH_AT(B + 13); + case B + 14: return MPARK_DISPATCH_AT(B + 14); + case B + 15: return MPARK_DISPATCH_AT(B + 15); + case B + 16: return MPARK_DISPATCH_AT(B + 16); + case B + 17: return MPARK_DISPATCH_AT(B + 17); + case B + 18: return MPARK_DISPATCH_AT(B + 18); + case B + 19: return MPARK_DISPATCH_AT(B + 19); + case B + 20: return MPARK_DISPATCH_AT(B + 20); + case B + 21: return MPARK_DISPATCH_AT(B + 21); + case B + 22: return MPARK_DISPATCH_AT(B + 22); + case B + 23: return MPARK_DISPATCH_AT(B + 23); + case B + 24: return MPARK_DISPATCH_AT(B + 24); + case B + 25: return MPARK_DISPATCH_AT(B + 25); + case B + 26: return MPARK_DISPATCH_AT(B + 26); + case B + 27: return MPARK_DISPATCH_AT(B + 27); + case B + 28: return MPARK_DISPATCH_AT(B + 28); + case B + 29: return MPARK_DISPATCH_AT(B + 29); + case B + 30: return MPARK_DISPATCH_AT(B + 30); + case B + 31: return MPARK_DISPATCH_AT(B + 31); + default: return MPARK_DEFAULT(B + 32); + } + +#undef MPARK_DEFAULT +#undef MPARK_DISPATCH_AT + } + }; +#else + template + inline static constexpr const T &at(const T &elem) noexcept { + return elem; + } + + template + inline static constexpr const lib::remove_all_extents_t &at( + const lib::array &elems, std::size_t i, Is... is) noexcept { + return at(elems[i], is...); + } + + template + inline static constexpr lib::array, sizeof...(Fs) + 1> + make_farray(F &&f, Fs &&... fs) { + return {{lib::forward(f), lib::forward(fs)...}}; + } + + template + struct make_fmatrix_impl { + + template + inline static constexpr dispatch_result_t dispatch( + F &&f, Vs &&... vs) { + using Expected = dispatch_result_t; + using Actual = decltype(lib::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto impl(lib::index_sequence) { + return &dispatch; + } + + template + inline static constexpr auto impl(Is, + lib::index_sequence, + Ls... ls) { + return make_farray(impl(lib::push_back_t{}, ls...)...); + } +#else + template + struct impl; + + template + struct impl> { + inline constexpr AUTO operator()() const + AUTO_RETURN(&dispatch) + }; + + template + struct impl, Ls...> { + inline constexpr AUTO operator()() const + AUTO_RETURN( + make_farray(impl, Ls...>{}()...)) + }; +#endif + }; + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto make_fmatrix() { + return make_fmatrix_impl::impl( + lib::index_sequence<>{}, + lib::make_index_sequence::size()>{}...); + } +#else + template + inline static constexpr AUTO make_fmatrix() + AUTO_RETURN( + typename make_fmatrix_impl::template impl< + lib::index_sequence<>, + lib::make_index_sequence::size()>...>{}()) +#endif + + template + struct make_fdiagonal_impl { + template + inline static constexpr dispatch_result_t dispatch( + F &&f, Vs &&... vs) { + using Expected = dispatch_result_t; + using Actual = decltype( + lib::invoke(lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + + template + inline static constexpr AUTO impl(lib::index_sequence) + AUTO_RETURN(make_farray(&dispatch...)) + }; + + template + inline static constexpr auto make_fdiagonal() + -> decltype(make_fdiagonal_impl::impl( + lib::make_index_sequence::size()>{})) { + static_assert(lib::all<(lib::decay_t::size() == + lib::decay_t::size())...>::value, + "all of the variants must be the same size."); + return make_fdiagonal_impl::impl( + lib::make_index_sequence::size()>{}); + } +#endif + }; + +#if !defined(MPARK_VARIANT_SWITCH_VISIT) && \ + (!defined(_MSC_VER) || _MSC_VER >= 1910) + template + using fmatrix_t = decltype(base::make_fmatrix()); + + template + struct fmatrix { + static constexpr fmatrix_t value = + base::make_fmatrix(); + }; + + template + constexpr fmatrix_t fmatrix::value; + + template + using fdiagonal_t = decltype(base::make_fdiagonal()); + + template + struct fdiagonal { + static constexpr fdiagonal_t value = + base::make_fdiagonal(); + }; + + template + constexpr fdiagonal_t fdiagonal::value; +#endif + + struct alt { + template + inline static constexpr DECLTYPE_AUTO visit_alt(Visitor &&visitor, + Vs &&... vs) +#ifdef MPARK_VARIANT_SWITCH_VISIT + DECLTYPE_AUTO_RETURN( + base::dispatcher< + true, + base::dispatch_result_t(vs)))...>>:: + template dispatch<0>(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#elif !defined(_MSC_VER) || _MSC_VER >= 1910 + DECLTYPE_AUTO_RETURN(base::at( + fmatrix(vs)))...>::value, + vs.index()...)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#else + DECLTYPE_AUTO_RETURN(base::at( + base::make_fmatrix(vs)))...>(), + vs.index()...)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#endif + + template + inline static constexpr DECLTYPE_AUTO visit_alt_at(std::size_t index, + Visitor &&visitor, + Vs &&... vs) +#ifdef MPARK_VARIANT_SWITCH_VISIT + DECLTYPE_AUTO_RETURN( + base::dispatcher< + true, + base::dispatch_result_t(vs)))...>>:: + template dispatch_at<0>(index, + lib::forward(visitor), + as_base(lib::forward(vs))...)) +#elif !defined(_MSC_VER) || _MSC_VER >= 1910 + DECLTYPE_AUTO_RETURN(base::at( + fdiagonal(vs)))...>::value, + index)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#else + DECLTYPE_AUTO_RETURN(base::at( + base::make_fdiagonal(vs)))...>(), + index)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#endif + }; + + struct variant { + private: + template + struct visitor { + template + inline static constexpr bool does_not_handle() { + return lib::is_invocable::value; + } + }; + + template + struct visit_exhaustiveness_check { + static_assert(visitor::template does_not_handle(), + "`visit` requires the visitor to be exhaustive."); + + inline static constexpr DECLTYPE_AUTO invoke(Visitor &&visitor, + Values &&... values) + DECLTYPE_AUTO_RETURN(lib::invoke(lib::forward(visitor), + lib::forward(values)...)) + }; + + template + struct value_visitor { + Visitor &&visitor_; + + template + inline constexpr DECLTYPE_AUTO operator()(Alts &&... alts) const + DECLTYPE_AUTO_RETURN( + visit_exhaustiveness_check< + Visitor, + decltype((lib::forward(alts).value))...>:: + invoke(lib::forward(visitor_), + lib::forward(alts).value...)) + }; + + template + inline static constexpr AUTO make_value_visitor(Visitor &&visitor) + AUTO_RETURN(value_visitor{lib::forward(visitor)}) + + public: + template + inline static constexpr DECLTYPE_AUTO visit_alt(Visitor &&visitor, + Vs &&... vs) + DECLTYPE_AUTO_RETURN(alt::visit_alt(lib::forward(visitor), + lib::forward(vs).impl_...)) + + template + inline static constexpr DECLTYPE_AUTO visit_alt_at(std::size_t index, + Visitor &&visitor, + Vs &&... vs) + DECLTYPE_AUTO_RETURN( + alt::visit_alt_at(index, + lib::forward(visitor), + lib::forward(vs).impl_...)) + + template + inline static constexpr DECLTYPE_AUTO visit_value(Visitor &&visitor, + Vs &&... vs) + DECLTYPE_AUTO_RETURN( + visit_alt(make_value_visitor(lib::forward(visitor)), + lib::forward(vs)...)) + + template + inline static constexpr DECLTYPE_AUTO visit_value_at(std::size_t index, + Visitor &&visitor, + Vs &&... vs) + DECLTYPE_AUTO_RETURN( + visit_alt_at(index, + make_value_visitor(lib::forward(visitor)), + lib::forward(vs)...)) + }; + + } // namespace visitation + + template + struct alt { + using value_type = T; + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + template + inline explicit constexpr alt(in_place_t, Args &&... args) + : value(lib::forward(args)...) {} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + T value; + }; + + template + union recursive_union; + + template + union recursive_union {}; + +#define MPARK_VARIANT_RECURSIVE_UNION(destructible_trait, destructor) \ + template \ + union recursive_union { \ + public: \ + inline explicit constexpr recursive_union(valueless_t) noexcept \ + : dummy_{} {} \ + \ + template \ + inline explicit constexpr recursive_union(in_place_index_t<0>, \ + Args &&... args) \ + : head_(in_place_t{}, lib::forward(args)...) {} \ + \ + template \ + inline explicit constexpr recursive_union(in_place_index_t, \ + Args &&... args) \ + : tail_(in_place_index_t{}, lib::forward(args)...) {} \ + \ + recursive_union(const recursive_union &) = default; \ + recursive_union(recursive_union &&) = default; \ + \ + destructor \ + \ + recursive_union &operator=(const recursive_union &) = default; \ + recursive_union &operator=(recursive_union &&) = default; \ + \ + private: \ + char dummy_; \ + alt head_; \ + recursive_union tail_; \ + \ + friend struct access::recursive_union; \ + } + + MPARK_VARIANT_RECURSIVE_UNION(Trait::TriviallyAvailable, + ~recursive_union() = default;); + MPARK_VARIANT_RECURSIVE_UNION(Trait::Available, + ~recursive_union() {}); + MPARK_VARIANT_RECURSIVE_UNION(Trait::Unavailable, + ~recursive_union() = delete;); + +#undef MPARK_VARIANT_RECURSIVE_UNION + + using index_t = unsigned int; + + template + class base { + public: + inline explicit constexpr base(valueless_t tag) noexcept + : data_(tag), index_(static_cast(-1)) {} + + template + inline explicit constexpr base(in_place_index_t, Args &&... args) + : data_(in_place_index_t{}, lib::forward(args)...), + index_(I) {} + + inline constexpr bool valueless_by_exception() const noexcept { + return index_ == static_cast(-1); + } + + inline constexpr std::size_t index() const noexcept { + return valueless_by_exception() ? variant_npos : index_; + } + + protected: + using data_t = recursive_union; + + friend inline constexpr base &as_base(base &b) { return b; } + friend inline constexpr const base &as_base(const base &b) { return b; } + friend inline constexpr base &&as_base(base &&b) { return lib::move(b); } + friend inline constexpr const base &&as_base(const base &&b) { return lib::move(b); } + + friend inline constexpr data_t &data(base &b) { return b.data_; } + friend inline constexpr const data_t &data(const base &b) { return b.data_; } + friend inline constexpr data_t &&data(base &&b) { return lib::move(b).data_; } + friend inline constexpr const data_t &&data(const base &&b) { return lib::move(b).data_; } + + inline static constexpr std::size_t size() { return sizeof...(Ts); } + + data_t data_; + index_t index_; + + friend struct access::base; + friend struct visitation::base; + }; + + struct dtor { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + template + inline void operator()(Alt &alt) const noexcept { alt.~Alt(); } +#ifdef _MSC_VER +#pragma warning(pop) +#endif + }; + +#if !defined(_MSC_VER) || _MSC_VER >= 1910 +#define MPARK_INHERITING_CTOR(type, base) using base::base; +#else +#define MPARK_INHERITING_CTOR(type, base) \ + template \ + inline explicit constexpr type(Args &&... args) \ + : base(lib::forward(args)...) {} +#endif + + template + class destructor; + +#define MPARK_VARIANT_DESTRUCTOR(destructible_trait, definition, destroy) \ + template \ + class destructor, destructible_trait> \ + : public base { \ + using super = base; \ + \ + public: \ + MPARK_INHERITING_CTOR(destructor, super) \ + using super::operator=; \ + \ + destructor(const destructor &) = default; \ + destructor(destructor &&) = default; \ + definition \ + destructor &operator=(const destructor &) = default; \ + destructor &operator=(destructor &&) = default; \ + \ + protected: \ + destroy \ + } + + MPARK_VARIANT_DESTRUCTOR( + Trait::TriviallyAvailable, + ~destructor() = default;, + inline void destroy() noexcept { + this->index_ = static_cast(-1); + }); + + MPARK_VARIANT_DESTRUCTOR( + Trait::Available, + ~destructor() { destroy(); }, + inline void destroy() noexcept { + if (!this->valueless_by_exception()) { + visitation::alt::visit_alt(dtor{}, *this); + } + this->index_ = static_cast(-1); + }); + + MPARK_VARIANT_DESTRUCTOR( + Trait::Unavailable, + ~destructor() = delete;, + inline void destroy() noexcept = delete;); + +#undef MPARK_VARIANT_DESTRUCTOR + + template + class constructor : public destructor { + using super = destructor; + + public: + MPARK_INHERITING_CTOR(constructor, super) + using super::operator=; + + protected: +#ifndef MPARK_GENERIC_LAMBDAS + struct ctor { + template + inline void operator()(LhsAlt &lhs_alt, RhsAlt &&rhs_alt) const { + constructor::construct_alt(lhs_alt, + lib::forward(rhs_alt).value); + } + }; +#endif + + template + inline static T &construct_alt(alt &a, Args &&... args) { + auto *result = ::new (static_cast(lib::addressof(a))) + alt(in_place_t{}, lib::forward(args)...); + return result->value; + } + + template + inline static void generic_construct(constructor &lhs, Rhs &&rhs) { + lhs.destroy(); + if (!rhs.valueless_by_exception()) { + visitation::alt::visit_alt_at( + rhs.index(), +#ifdef MPARK_GENERIC_LAMBDAS + [](auto &lhs_alt, auto &&rhs_alt) { + constructor::construct_alt( + lhs_alt, lib::forward(rhs_alt).value); + } +#else + ctor{} +#endif + , + lhs, + lib::forward(rhs)); + lhs.index_ = rhs.index_; + } + } + }; + + template + class move_constructor; + +#define MPARK_VARIANT_MOVE_CONSTRUCTOR(move_constructible_trait, definition) \ + template \ + class move_constructor, move_constructible_trait> \ + : public constructor> { \ + using super = constructor>; \ + \ + public: \ + MPARK_INHERITING_CTOR(move_constructor, super) \ + using super::operator=; \ + \ + move_constructor(const move_constructor &) = default; \ + definition \ + ~move_constructor() = default; \ + move_constructor &operator=(const move_constructor &) = default; \ + move_constructor &operator=(move_constructor &&) = default; \ + } + + MPARK_VARIANT_MOVE_CONSTRUCTOR( + Trait::TriviallyAvailable, + move_constructor(move_constructor &&that) = default;); + + MPARK_VARIANT_MOVE_CONSTRUCTOR( + Trait::Available, + move_constructor(move_constructor &&that) noexcept( + lib::all::value...>::value) + : move_constructor(valueless_t{}) { + this->generic_construct(*this, lib::move(that)); + }); + + MPARK_VARIANT_MOVE_CONSTRUCTOR( + Trait::Unavailable, + move_constructor(move_constructor &&) = delete;); + +#undef MPARK_VARIANT_MOVE_CONSTRUCTOR + + template + class copy_constructor; + +#define MPARK_VARIANT_COPY_CONSTRUCTOR(copy_constructible_trait, definition) \ + template \ + class copy_constructor, copy_constructible_trait> \ + : public move_constructor> { \ + using super = move_constructor>; \ + \ + public: \ + MPARK_INHERITING_CTOR(copy_constructor, super) \ + using super::operator=; \ + \ + definition \ + copy_constructor(copy_constructor &&) = default; \ + ~copy_constructor() = default; \ + copy_constructor &operator=(const copy_constructor &) = default; \ + copy_constructor &operator=(copy_constructor &&) = default; \ + } + + MPARK_VARIANT_COPY_CONSTRUCTOR( + Trait::TriviallyAvailable, + copy_constructor(const copy_constructor &that) = default;); + + MPARK_VARIANT_COPY_CONSTRUCTOR( + Trait::Available, + copy_constructor(const copy_constructor &that) + : copy_constructor(valueless_t{}) { + this->generic_construct(*this, that); + }); + + MPARK_VARIANT_COPY_CONSTRUCTOR( + Trait::Unavailable, + copy_constructor(const copy_constructor &) = delete;); + +#undef MPARK_VARIANT_COPY_CONSTRUCTOR + + template + class assignment : public copy_constructor { + using super = copy_constructor; + + public: + MPARK_INHERITING_CTOR(assignment, super) + using super::operator=; + + template + inline /* auto & */ auto emplace(Args &&... args) + -> decltype(this->construct_alt(access::base::get_alt(*this), + lib::forward(args)...)) { + this->destroy(); + auto &result = this->construct_alt(access::base::get_alt(*this), + lib::forward(args)...); + this->index_ = I; + return result; + } + + protected: +#ifndef MPARK_GENERIC_LAMBDAS + template + struct assigner { + template + inline void operator()(ThisAlt &this_alt, ThatAlt &&that_alt) const { + self->assign_alt(this_alt, lib::forward(that_alt).value); + } + assignment *self; + }; +#endif + + template + inline void assign_alt(alt &a, Arg &&arg) { + if (this->index() == I) { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + a.value = lib::forward(arg); +#ifdef _MSC_VER +#pragma warning(pop) +#endif + } else { + struct { + void operator()(std::true_type) const { + this_->emplace(lib::forward(arg_)); + } + void operator()(std::false_type) const { + this_->emplace(T(lib::forward(arg_))); + } + assignment *this_; + Arg &&arg_; + } impl{this, lib::forward(arg)}; + impl(lib::bool_constant< + std::is_nothrow_constructible::value || + !std::is_nothrow_move_constructible::value>{}); + } + } + + template + inline void generic_assign(That &&that) { + if (this->valueless_by_exception() && that.valueless_by_exception()) { + // do nothing. + } else if (that.valueless_by_exception()) { + this->destroy(); + } else { + visitation::alt::visit_alt_at( + that.index(), +#ifdef MPARK_GENERIC_LAMBDAS + [this](auto &this_alt, auto &&that_alt) { + this->assign_alt( + this_alt, lib::forward(that_alt).value); + } +#else + assigner{this} +#endif + , + *this, + lib::forward(that)); + } + } + }; + + template + class move_assignment; + +#define MPARK_VARIANT_MOVE_ASSIGNMENT(move_assignable_trait, definition) \ + template \ + class move_assignment, move_assignable_trait> \ + : public assignment> { \ + using super = assignment>; \ + \ + public: \ + MPARK_INHERITING_CTOR(move_assignment, super) \ + using super::operator=; \ + \ + move_assignment(const move_assignment &) = default; \ + move_assignment(move_assignment &&) = default; \ + ~move_assignment() = default; \ + move_assignment &operator=(const move_assignment &) = default; \ + definition \ + } + + MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::TriviallyAvailable, + move_assignment &operator=(move_assignment &&that) = default;); + + MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::Available, + move_assignment & + operator=(move_assignment &&that) noexcept( + lib::all<(std::is_nothrow_move_constructible::value && + std::is_nothrow_move_assignable::value)...>::value) { + this->generic_assign(lib::move(that)); + return *this; + }); + + MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::Unavailable, + move_assignment &operator=(move_assignment &&) = delete;); + +#undef MPARK_VARIANT_MOVE_ASSIGNMENT + + template + class copy_assignment; + +#define MPARK_VARIANT_COPY_ASSIGNMENT(copy_assignable_trait, definition) \ + template \ + class copy_assignment, copy_assignable_trait> \ + : public move_assignment> { \ + using super = move_assignment>; \ + \ + public: \ + MPARK_INHERITING_CTOR(copy_assignment, super) \ + using super::operator=; \ + \ + copy_assignment(const copy_assignment &) = default; \ + copy_assignment(copy_assignment &&) = default; \ + ~copy_assignment() = default; \ + definition \ + copy_assignment &operator=(copy_assignment &&) = default; \ + } + + MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::TriviallyAvailable, + copy_assignment &operator=(const copy_assignment &that) = default;); + + MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::Available, + copy_assignment &operator=(const copy_assignment &that) { + this->generic_assign(that); + return *this; + }); + + MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::Unavailable, + copy_assignment &operator=(const copy_assignment &) = delete;); + +#undef MPARK_VARIANT_COPY_ASSIGNMENT + + template + class impl : public copy_assignment> { + using super = copy_assignment>; + + public: + MPARK_INHERITING_CTOR(impl, super) + using super::operator=; + + template + inline void assign(Arg &&arg) { + this->assign_alt(access::base::get_alt(*this), + lib::forward(arg)); + } + + inline void swap(impl &that) { + if (this->valueless_by_exception() && that.valueless_by_exception()) { + // do nothing. + } else if (this->index() == that.index()) { + visitation::alt::visit_alt_at(this->index(), +#ifdef MPARK_GENERIC_LAMBDAS + [](auto &this_alt, auto &that_alt) { + using std::swap; + swap(this_alt.value, + that_alt.value); + } +#else + swapper{} +#endif + , + *this, + that); + } else { + impl *lhs = this; + impl *rhs = lib::addressof(that); + if (lhs->move_nothrow() && !rhs->move_nothrow()) { + std::swap(lhs, rhs); + } + impl tmp(lib::move(*rhs)); +#ifdef MPARK_EXCEPTIONS + // EXTENSION: When the move construction of `lhs` into `rhs` throws + // and `tmp` is nothrow move constructible then we move `tmp` back + // into `rhs` and provide the strong exception safety guarantee. + try { + this->generic_construct(*rhs, lib::move(*lhs)); + } catch (...) { + if (tmp.move_nothrow()) { + this->generic_construct(*rhs, lib::move(tmp)); + } + throw; + } +#else + this->generic_construct(*rhs, lib::move(*lhs)); +#endif + this->generic_construct(*lhs, lib::move(tmp)); + } + } + + private: +#ifndef MPARK_GENERIC_LAMBDAS + struct swapper { + template + inline void operator()(ThisAlt &this_alt, ThatAlt &that_alt) const { + using std::swap; + swap(this_alt.value, that_alt.value); + } + }; +#endif + + inline constexpr bool move_nothrow() const { + return this->valueless_by_exception() || + lib::array{ + {std::is_nothrow_move_constructible::value...} + }[this->index()]; + } + }; + +#undef MPARK_INHERITING_CTOR + + template + struct overload_leaf { + using F = lib::size_constant (*)(T); + operator F() const { return nullptr; } + }; + + template + struct overload_impl { + private: + template + struct impl; + + template + struct impl> : overload_leaf... {}; + + public: + using type = impl>; + }; + + template + using overload = typename overload_impl::type; + + template + using best_match = lib::invoke_result_t, T &&>; + + template + struct is_in_place_index : std::false_type {}; + + template + struct is_in_place_index> : std::true_type {}; + + template + struct is_in_place_type : std::false_type {}; + + template + struct is_in_place_type> : std::true_type {}; + + } // detail + + template + class variant { + static_assert(0 < sizeof...(Ts), + "variant must consist of at least one alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have an array type as an alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have a reference type as an alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have a void type as an alternative."); + + public: + template < + typename Front = lib::type_pack_element_t<0, Ts...>, + lib::enable_if_t::value, int> = 0> + inline constexpr variant() noexcept( + std::is_nothrow_default_constructible::value) + : impl_(in_place_index_t<0>{}) {} + + variant(const variant &) = default; + variant(variant &&) = default; + + template < + typename Arg, + typename Decayed = lib::decay_t, + lib::enable_if_t::value, int> = 0, + lib::enable_if_t::value, int> = 0, + lib::enable_if_t::value, int> = 0, + std::size_t I = detail::best_match::value, + typename T = lib::type_pack_element_t, + lib::enable_if_t::value, int> = 0> + inline constexpr variant(Arg &&arg) noexcept( + std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(arg)) {} + + template < + std::size_t I, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t::value, int> = 0> + inline explicit constexpr variant( + in_place_index_t, + Args &&... args) noexcept(std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(args)...) {} + + template < + std::size_t I, + typename Up, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t &, + Args...>::value, + int> = 0> + inline explicit constexpr variant( + in_place_index_t, + std::initializer_list il, + Args &&... args) noexcept(std:: + is_nothrow_constructible< + T, + std::initializer_list &, + Args...>::value) + : impl_(in_place_index_t{}, il, lib::forward(args)...) {} + + template < + typename T, + typename... Args, + std::size_t I = detail::find_index_sfinae::value, + lib::enable_if_t::value, int> = 0> + inline explicit constexpr variant( + in_place_type_t, + Args &&... args) noexcept(std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(args)...) {} + + template < + typename T, + typename Up, + typename... Args, + std::size_t I = detail::find_index_sfinae::value, + lib::enable_if_t &, + Args...>::value, + int> = 0> + inline explicit constexpr variant( + in_place_type_t, + std::initializer_list il, + Args &&... args) noexcept(std:: + is_nothrow_constructible< + T, + std::initializer_list &, + Args...>::value) + : impl_(in_place_index_t{}, il, lib::forward(args)...) {} + + ~variant() = default; + + variant &operator=(const variant &) = default; + variant &operator=(variant &&) = default; + + template , variant>::value, + int> = 0, + std::size_t I = detail::best_match::value, + typename T = lib::type_pack_element_t, + lib::enable_if_t<(std::is_assignable::value && + std::is_constructible::value), + int> = 0> + inline variant &operator=(Arg &&arg) noexcept( + (std::is_nothrow_assignable::value && + std::is_nothrow_constructible::value)) { + impl_.template assign(lib::forward(arg)); + return *this; + } + + template < + std::size_t I, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t::value, int> = 0> + inline T &emplace(Args &&... args) { + return impl_.template emplace(lib::forward(args)...); + } + + template < + std::size_t I, + typename Up, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t &, + Args...>::value, + int> = 0> + inline T &emplace(std::initializer_list il, Args &&... args) { + return impl_.template emplace(il, lib::forward(args)...); + } + + template < + typename T, + typename... Args, + std::size_t I = detail::find_index_sfinae::value, + lib::enable_if_t::value, int> = 0> + inline T &emplace(Args &&... args) { + return impl_.template emplace(lib::forward(args)...); + } + + template < + typename T, + typename Up, + typename... Args, + std::size_t I = detail::find_index_sfinae::value, + lib::enable_if_t &, + Args...>::value, + int> = 0> + inline T &emplace(std::initializer_list il, Args &&... args) { + return impl_.template emplace(il, lib::forward(args)...); + } + + inline constexpr bool valueless_by_exception() const noexcept { + return impl_.valueless_by_exception(); + } + + inline constexpr std::size_t index() const noexcept { + return impl_.index(); + } + + template , + Dummy>::value && + lib::dependent_type, + Dummy>::value)...>::value, + int> = 0> + inline void swap(variant &that) noexcept( + lib::all<(std::is_nothrow_move_constructible::value && + lib::is_nothrow_swappable::value)...>::value) { + impl_.swap(that.impl_); + } + + private: + detail::impl impl_; + + friend struct detail::access::variant; + friend struct detail::visitation::variant; + }; + + template + inline constexpr bool holds_alternative(const variant &v) noexcept { + return v.index() == I; + } + + template + inline constexpr bool holds_alternative(const variant &v) noexcept { + return holds_alternative::value>(v); + } + + namespace detail { + template + struct generic_get_impl { + constexpr generic_get_impl(int) noexcept {} + + constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN( + access::variant::get_alt(lib::forward(v)).value) + }; + + template + inline constexpr AUTO_REFREF generic_get(V &&v) + AUTO_REFREF_RETURN(generic_get_impl( + holds_alternative(v) ? 0 : (throw_bad_variant_access(), 0))( + lib::forward(v))) + } // namespace detail + + template + inline constexpr variant_alternative_t> &get( + variant &v) { + return detail::generic_get(v); + } + + template + inline constexpr variant_alternative_t> &&get( + variant &&v) { + return detail::generic_get(lib::move(v)); + } + + template + inline constexpr const variant_alternative_t> &get( + const variant &v) { + return detail::generic_get(v); + } + + template + inline constexpr const variant_alternative_t> &&get( + const variant &&v) { + return detail::generic_get(lib::move(v)); + } + + template + inline constexpr T &get(variant &v) { + return get::value>(v); + } + + template + inline constexpr T &&get(variant &&v) { + return get::value>(lib::move(v)); + } + + template + inline constexpr const T &get(const variant &v) { + return get::value>(v); + } + + template + inline constexpr const T &&get(const variant &&v) { + return get::value>(lib::move(v)); + } + + namespace detail { + + template + inline constexpr /* auto * */ AUTO generic_get_if(V *v) noexcept + AUTO_RETURN(v && holds_alternative(*v) + ? lib::addressof(access::variant::get_alt(*v).value) + : nullptr) + + } // namespace detail + + template + inline constexpr lib::add_pointer_t>> + get_if(variant *v) noexcept { + return detail::generic_get_if(v); + } + + template + inline constexpr lib::add_pointer_t< + const variant_alternative_t>> + get_if(const variant *v) noexcept { + return detail::generic_get_if(v); + } + + template + inline constexpr lib::add_pointer_t + get_if(variant *v) noexcept { + return get_if::value>(v); + } + + template + inline constexpr lib::add_pointer_t + get_if(const variant *v) noexcept { + return get_if::value>(v); + } + + namespace detail { + template + struct convert_to_bool { + template + inline constexpr bool operator()(Lhs &&lhs, Rhs &&rhs) const { + static_assert(std::is_convertible, + bool>::value, + "relational operators must return a type" + " implicitly convertible to bool"); + return lib::invoke( + RelOp{}, lib::forward(lhs), lib::forward(rhs)); + } + }; + } // namespace detail + + template + inline constexpr bool operator==(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using equal_to = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.index() != rhs.index()) return false; + if (lhs.valueless_by_exception()) return true; + return variant::visit_value_at(lhs.index(), equal_to{}, lhs, rhs); +#else + return lhs.index() == rhs.index() && + (lhs.valueless_by_exception() || + variant::visit_value_at(lhs.index(), equal_to{}, lhs, rhs)); +#endif + } + + template + inline constexpr bool operator!=(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using not_equal_to = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.index() != rhs.index()) return true; + if (lhs.valueless_by_exception()) return false; + return variant::visit_value_at(lhs.index(), not_equal_to{}, lhs, rhs); +#else + return lhs.index() != rhs.index() || + (!lhs.valueless_by_exception() && + variant::visit_value_at(lhs.index(), not_equal_to{}, lhs, rhs)); +#endif + } + + template + inline constexpr bool operator<(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using less = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (rhs.valueless_by_exception()) return false; + if (lhs.valueless_by_exception()) return true; + if (lhs.index() < rhs.index()) return true; + if (lhs.index() > rhs.index()) return false; + return variant::visit_value_at(lhs.index(), less{}, lhs, rhs); +#else + return !rhs.valueless_by_exception() && + (lhs.valueless_by_exception() || lhs.index() < rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), less{}, lhs, rhs))); +#endif + } + + template + inline constexpr bool operator>(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using greater = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.valueless_by_exception()) return false; + if (rhs.valueless_by_exception()) return true; + if (lhs.index() > rhs.index()) return true; + if (lhs.index() < rhs.index()) return false; + return variant::visit_value_at(lhs.index(), greater{}, lhs, rhs); +#else + return !lhs.valueless_by_exception() && + (rhs.valueless_by_exception() || lhs.index() > rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), greater{}, lhs, rhs))); +#endif + } + + template + inline constexpr bool operator<=(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using less_equal = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.valueless_by_exception()) return true; + if (rhs.valueless_by_exception()) return false; + if (lhs.index() < rhs.index()) return true; + if (lhs.index() > rhs.index()) return false; + return variant::visit_value_at(lhs.index(), less_equal{}, lhs, rhs); +#else + return lhs.valueless_by_exception() || + (!rhs.valueless_by_exception() && + (lhs.index() < rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), less_equal{}, lhs, rhs)))); +#endif + } + + template + inline constexpr bool operator>=(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using greater_equal = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (rhs.valueless_by_exception()) return true; + if (lhs.valueless_by_exception()) return false; + if (lhs.index() > rhs.index()) return true; + if (lhs.index() < rhs.index()) return false; + return variant::visit_value_at(lhs.index(), greater_equal{}, lhs, rhs); +#else + return rhs.valueless_by_exception() || + (!lhs.valueless_by_exception() && + (lhs.index() > rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at( + lhs.index(), greater_equal{}, lhs, rhs)))); +#endif + } + + struct monostate {}; + + inline constexpr bool operator<(monostate, monostate) noexcept { + return false; + } + + inline constexpr bool operator>(monostate, monostate) noexcept { + return false; + } + + inline constexpr bool operator<=(monostate, monostate) noexcept { + return true; + } + + inline constexpr bool operator>=(monostate, monostate) noexcept { + return true; + } + + inline constexpr bool operator==(monostate, monostate) noexcept { + return true; + } + + inline constexpr bool operator!=(monostate, monostate) noexcept { + return false; + } + +#ifdef MPARK_CPP14_CONSTEXPR + namespace detail { + + inline constexpr bool all(std::initializer_list bs) { + for (bool b : bs) { + if (!b) { + return false; + } + } + return true; + } + + } // namespace detail + + template + inline constexpr decltype(auto) visit(Visitor &&visitor, Vs &&... vs) { + return (detail::all({!vs.valueless_by_exception()...}) + ? (void)0 + : throw_bad_variant_access()), + detail::visitation::variant::visit_value( + lib::forward(visitor), lib::forward(vs)...); + } +#else + namespace detail { + + template + inline constexpr bool all_impl(const lib::array &bs, + std::size_t idx) { + return idx >= N || (bs[idx] && all_impl(bs, idx + 1)); + } + + template + inline constexpr bool all(const lib::array &bs) { + return all_impl(bs, 0); + } + + } // namespace detail + + template + inline constexpr DECLTYPE_AUTO visit(Visitor &&visitor, Vs &&... vs) + DECLTYPE_AUTO_RETURN( + (detail::all( + lib::array{{!vs.valueless_by_exception()...}}) + ? (void)0 + : throw_bad_variant_access()), + detail::visitation::variant::visit_value(lib::forward(visitor), + lib::forward(vs)...)) +#endif + + template + inline auto swap(variant &lhs, + variant &rhs) noexcept(noexcept(lhs.swap(rhs))) + -> decltype(lhs.swap(rhs)) { + lhs.swap(rhs); + } + + namespace detail { + + template + using enabled_type = T; + + namespace hash { + + template + constexpr bool meets_requirements() noexcept { + return std::is_copy_constructible::value && + std::is_move_constructible::value && + lib::is_invocable_r::value; + } + + template + constexpr bool is_enabled() noexcept { + using H = std::hash; + return meets_requirements() && + std::is_default_constructible::value && + std::is_copy_assignable::value && + std::is_move_assignable::value; + } + + } // namespace hash + + } // namespace detail + +#undef AUTO +#undef AUTO_RETURN + +#undef AUTO_REFREF +#undef AUTO_REFREF_RETURN + +#undef DECLTYPE_AUTO +#undef DECLTYPE_AUTO_RETURN + +} // namespace mpark + +namespace std { + + template + struct hash, + mpark::lib::enable_if_t>()...>::value>>> { + using argument_type = mpark::variant; + using result_type = std::size_t; + + inline result_type operator()(const argument_type &v) const { + using mpark::detail::visitation::variant; + std::size_t result = + v.valueless_by_exception() + ? 299792458 // Random value chosen by the universe upon creation + : variant::visit_alt( +#ifdef MPARK_GENERIC_LAMBDAS + [](const auto &alt) { + using alt_type = mpark::lib::decay_t; + using value_type = mpark::lib::remove_const_t< + typename alt_type::value_type>; + return hash{}(alt.value); + } +#else + hasher{} +#endif + , + v); + return hash_combine(result, hash{}(v.index())); + } + + private: +#ifndef MPARK_GENERIC_LAMBDAS + struct hasher { + template + inline std::size_t operator()(const Alt &alt) const { + using alt_type = mpark::lib::decay_t; + using value_type = + mpark::lib::remove_const_t; + return hash{}(alt.value); + } + }; +#endif + + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + } + }; + + template <> + struct hash { + using argument_type = mpark::monostate; + using result_type = std::size_t; + + inline result_type operator()(const argument_type &) const noexcept { + return 66740831; // return a fundamentally attractive random value. + } + }; + +} // namespace std + +#endif // MPARK_VARIANT_HPP diff --git a/src/dpp/dave/array_view.h b/src/dpp/dave/array_view.h new file mode 100755 index 0000000000..f426e44a4c --- /dev/null +++ b/src/dpp/dave/array_view.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include + +namespace discord { +namespace dave { + +template +class ArrayView { +public: + ArrayView() = default; + ArrayView(T* data, size_t size) + : data_(data) + , size_(size) + { + } + + size_t size() const { return size_; } + T* data() const { return data_; } + + T* begin() const { return data_; } + T* end() const { return data_ + size_; } + +private: + T* data_ = nullptr; + size_t size_ = 0; +}; + +template +inline ArrayView MakeArrayView(T* data, size_t size) +{ + return ArrayView(data, size); +} + +template +inline ArrayView MakeArrayView(std::vector& data) +{ + return ArrayView(data.data(), data.size()); +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/clock.h b/src/dpp/dave/clock.h new file mode 100755 index 0000000000..314d6777d9 --- /dev/null +++ b/src/dpp/dave/clock.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace discord { +namespace dave { + +class IClock { +public: + using BaseClock = std::chrono::steady_clock; + using TimePoint = BaseClock::time_point; + using Duration = BaseClock::duration; + + virtual ~IClock() = default; + virtual TimePoint Now() const = 0; +}; + +class Clock : public IClock { +public: + TimePoint Now() const override { return BaseClock::now(); } +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/codec_utils.cpp b/src/dpp/dave/codec_utils.cpp new file mode 100755 index 0000000000..fb62b32208 --- /dev/null +++ b/src/dpp/dave/codec_utils.cpp @@ -0,0 +1,435 @@ +#include "codec_utils.h" + +#include +#include +#include + +#include "logger.h" +#include "leb128.h" + +namespace discord { +namespace dave { +namespace codec_utils { + +UnencryptedFrameHeaderSize BytesCoveringH264PPS(const uint8_t* payload, + const uint64_t sizeRemaining) +{ + // the payload starts with three exponential golomb encoded values + // (first_mb_in_slice, sps_id, pps_id) + // the depacketizer needs the pps_id unencrypted + // and the payload has RBSP encoding that we need to work around + + constexpr uint8_t kEmulationPreventionByte = 0x03; + + uint64_t payloadBitIndex = 0; + auto zeroBitCount = 0; + auto parsedExpGolombValues = 0; + + while (payloadBitIndex < sizeRemaining * 8 && parsedExpGolombValues < 3) { + auto bitIndex = payloadBitIndex % 8; + auto byteIndex = payloadBitIndex / 8; + auto payloadByte = payload[byteIndex]; + + // if we're starting a new byte + // check if this is an emulation prevention byte + // which we skip over + if (bitIndex == 0) { + if (byteIndex >= 2 && payloadByte == kEmulationPreventionByte && + payload[byteIndex - 1] == 0 && payload[byteIndex - 2] == 0) { + payloadBitIndex += 8; + continue; + } + } + + if ((payloadByte & (1 << (7 - bitIndex))) == 0) { + // still in the run of leading zero bits + ++zeroBitCount; + ++payloadBitIndex; + + if (zeroBitCount >= 32) { + assert(false && "Unexpectedly large exponential golomb encoded value"); + return 0; + } + } + else { + // we hit a one + // skip forward the number of bits dictated by the leading number of zeroes + parsedExpGolombValues += 1; + payloadBitIndex += 1 + zeroBitCount; + zeroBitCount = 0; + } + } + + // return the number of bytes that covers the last exp golomb encoded value + return (payloadBitIndex / 8) + 1; +} + +const uint8_t kH26XNaluLongStartCode[] = {0, 0, 0, 1}; +constexpr uint8_t kH26XNaluShortStartSequenceSize = 3; + +using IndexStartCodeSizePair = std::pair; + +std::optional FindNextH26XNaluIndex(const uint8_t* buffer, + const size_t bufferSize, + const size_t searchStartIndex = 0) +{ + constexpr uint8_t kH26XStartCodeHighestPossibleValue = 1; + constexpr uint8_t kH26XStartCodeEndByteValue = 1; + constexpr uint8_t kH26XStartCodeLeadingBytesValue = 0; + + if (bufferSize < kH26XNaluShortStartSequenceSize) { + return std::nullopt; + } + + // look for NAL unit 3 or 4 byte start code + for (size_t i = searchStartIndex; i < bufferSize - kH26XNaluShortStartSequenceSize;) { + if (buffer[i + 2] > kH26XStartCodeHighestPossibleValue) { + // third byte is not 0 or 1, can't be a start code + i += kH26XNaluShortStartSequenceSize; + } + else if (buffer[i + 2] == kH26XStartCodeEndByteValue) { + // third byte matches the start code end byte, might be a start code sequence + if (buffer[i + 1] == kH26XStartCodeLeadingBytesValue && + buffer[i] == kH26XStartCodeLeadingBytesValue) { + // confirmed start sequence {0, 0, 1} + auto nalUnitStartIndex = i + kH26XNaluShortStartSequenceSize; + + if (i >= 1 && buffer[i - 1] == kH26XStartCodeLeadingBytesValue) { + // 4 byte start code + return std::optional({nalUnitStartIndex, 4}); + } + else { + // 3 byte start code + return std::optional({nalUnitStartIndex, 3}); + } + } + + i += kH26XNaluShortStartSequenceSize; + } + else { + // third byte is 0, might be a four byte start code + ++i; + } + } + + return std::nullopt; +} + +bool ProcessFrameOpus(OutboundFrameProcessor& processor, ArrayView frame) +{ + processor.AddEncryptedBytes(frame.data(), frame.size()); + return true; +} + +bool ProcessFrameVp8(OutboundFrameProcessor& processor, ArrayView frame) +{ + constexpr uint8_t kVP8KeyFrameUnencryptedBytes = 10; + constexpr uint8_t kVP8DeltaFrameUnencryptedBytes = 1; + + // parse the VP8 payload header to determine if it's a key frame + // https://datatracker.ietf.org/doc/html/rfc7741#section-4.3 + + // 0 1 2 3 4 5 6 7 + // +-+-+-+-+-+-+-+-+ + // |Size0|H| VER |P| + // +-+-+-+-+-+-+-+-+ + // P is an inverse key frame flag + + // if this is a key frame the depacketizer will read 10 bytes into the payload header + // if this is a delta frame the depacketizer only needs the first byte of the payload + // header (since that's where the key frame flag is) + + size_t unencryptedHeaderBytes = 0; + if ((frame.data()[0] & 0x01) == 0) { + unencryptedHeaderBytes = kVP8KeyFrameUnencryptedBytes; + } + else { + unencryptedHeaderBytes = kVP8DeltaFrameUnencryptedBytes; + } + + processor.AddUnencryptedBytes(frame.data(), unencryptedHeaderBytes); + processor.AddEncryptedBytes(frame.data() + unencryptedHeaderBytes, + frame.size() - unencryptedHeaderBytes); + return true; +} + +bool ProcessFrameVp9(OutboundFrameProcessor& processor, ArrayView frame) +{ + // payload descriptor is unencrypted in each packet + // and includes all information the depacketizer needs + processor.AddEncryptedBytes(frame.data(), frame.size()); + return true; +} + +bool ProcessFrameH264(OutboundFrameProcessor& processor, ArrayView frame) +{ + // minimize the amount of unencrypted header data for H264 depending on the NAL unit + // type from WebRTC, see: src/modules/rtp_rtcp/source/rtp_format_h264.cc + // src/common_video/h264/h264_common.cc + // src/modules/rtp_rtcp/source/video_rtp_depacketizer_h264.cc + + // constexpr uint8_t kH264SBit = 0x80; + constexpr uint8_t kH264NalHeaderTypeMask = 0x1F; + constexpr uint8_t kH264NalTypeSlice = 1; + constexpr uint8_t kH264NalTypeIdr = 5; + constexpr uint8_t kH264NalUnitHeaderSize = 1; + + // this frame can be packetized as a STAP-A or a FU-A + // so we need to look at the first NAL units to determine how many bytes + // the packetizer/depacketizer will need into the payload + if (frame.size() < kH26XNaluShortStartSequenceSize + kH264NalUnitHeaderSize) { + assert(false && "H264 frame is too small to contain a NAL unit"); + DISCORD_LOG(LS_WARNING) << "H264 frame is too small to contain a NAL unit"; + return false; + } + + auto naluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size()); + while (naluIndexPair && naluIndexPair->first < frame.size() - 1) { + auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; + + auto nalType = frame.data()[nalUnitStartIndex] & kH264NalHeaderTypeMask; + + // copy the start code and then the NAL unit + + // Because WebRTC will convert them all start codes to 4-byte on the receiver side + // always write a long start code and then the NAL unit + processor.AddUnencryptedBytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); + + auto nextNaluIndexPair = + FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); + auto nextNaluStart = nextNaluIndexPair.has_value() + ? nextNaluIndexPair->first - nextNaluIndexPair->second + : frame.size(); + + if (nalType == kH264NalTypeSlice || nalType == kH264NalTypeIdr) { + // once we've hit a slice or an IDR + // we just need to cover getting to the PPS ID + auto nalUnitPayloadStart = nalUnitStartIndex + kH264NalUnitHeaderSize; + auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart, + frame.size() - nalUnitPayloadStart); + + processor.AddUnencryptedBytes(frame.data() + nalUnitStartIndex, + kH264NalUnitHeaderSize + nalUnitPPSBytes); + processor.AddEncryptedBytes( + frame.data() + nalUnitStartIndex + kH264NalUnitHeaderSize + nalUnitPPSBytes, + nextNaluStart - nalUnitStartIndex - kH264NalUnitHeaderSize - nalUnitPPSBytes); + } + else { + // copy the whole NAL unit + processor.AddUnencryptedBytes(frame.data() + nalUnitStartIndex, + nextNaluStart - nalUnitStartIndex); + } + + naluIndexPair = nextNaluIndexPair; + } + + return true; +} + +bool ProcessFrameH265(OutboundFrameProcessor& processor, ArrayView frame) +{ + // minimize the amount of unencrypted header data for H265 depending on the NAL unit + // type from WebRTC, see: src/modules/rtp_rtcp/source/rtp_format_h265.cc + // src/common_video/h265/h265_common.cc + // src/modules/rtp_rtcp/source/video_rtp_depacketizer_h265.cc + + constexpr uint8_t kH265NalHeaderTypeMask = 0x7E; + constexpr uint8_t kH265NalTypeVclCutoff = 32; + constexpr uint8_t kH265NalUnitHeaderSize = 2; + + // this frame can be packetized as a STAP-A or a FU-A + // so we need to look at the first NAL units to determine how many bytes + // the packetizer/depacketizer will need into the payload + if (frame.size() < kH26XNaluShortStartSequenceSize + kH265NalUnitHeaderSize) { + assert(false && "H265 frame is too small to contain a NAL unit"); + DISCORD_LOG(LS_WARNING) << "H265 frame is too small to contain a NAL unit"; + return false; + } + + // look for NAL unit 3 or 4 byte start code + auto naluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size()); + while (naluIndexPair && naluIndexPair->first < frame.size() - 1) { + auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; + + uint8_t nalType = (frame.data()[nalUnitStartIndex] & kH265NalHeaderTypeMask) >> 1; + + // copy the start code and then the NAL unit + + // Because WebRTC will convert them all start codes to 4-byte on the receiver side + // always write a long start code and then the NAL unit + processor.AddUnencryptedBytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); + + auto nextNaluIndexPair = + FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); + auto nextNaluStart = nextNaluIndexPair.has_value() + ? nextNaluIndexPair->first - nextNaluIndexPair->second + : frame.size(); + + if (nalType < kH265NalTypeVclCutoff) { + // found a VCL NAL, encrypt the payload only + processor.AddUnencryptedBytes(frame.data() + nalUnitStartIndex, kH265NalUnitHeaderSize); + processor.AddEncryptedBytes(frame.data() + nalUnitStartIndex + kH265NalUnitHeaderSize, + nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize); + } + else { + // copy the whole NAL unit + processor.AddUnencryptedBytes(frame.data() + nalUnitStartIndex, + nextNaluStart - nalUnitStartIndex); + } + + naluIndexPair = nextNaluIndexPair; + } + + return true; +} + +bool ProcessFrameAv1(OutboundFrameProcessor& processor, ArrayView frame) +{ + constexpr uint8_t kAv1ObuHeaderHasExtensionMask = 0b0'0000'100; + constexpr uint8_t kAv1ObuHeaderHasSizeMask = 0b0'0000'010; + constexpr uint8_t kAv1ObuHeaderTypeMask = 0b0'1111'000; + constexpr uint8_t kObuTypeTemporalDelimiter = 2; + constexpr uint8_t kObuTypeTileList = 8; + constexpr uint8_t kObuTypePadding = 15; + constexpr uint8_t kObuExtensionSizeBytes = 1; + + size_t i = 0; + while (i < frame.size()) { + // Read the OBU header. + size_t obuHeaderIndex = i; + uint8_t obuHeader = frame.data()[obuHeaderIndex]; + i += sizeof(obuHeader); + + bool obuHasExtension = obuHeader & kAv1ObuHeaderHasExtensionMask; + bool obuHasSize = obuHeader & kAv1ObuHeaderHasSizeMask; + int obuType = (obuHeader & kAv1ObuHeaderTypeMask) >> 3; + + if (obuHasExtension) { + // Skip extension byte + i += kObuExtensionSizeBytes; + } + + if (i >= frame.size()) { + // Malformed frame + assert(false && "Malformed AV1 frame: header overflows frame"); + DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: header overflows frame"; + return false; + } + + size_t obuPayloadSize = 0; + if (obuHasSize) { + // Read payload size + const uint8_t* start = frame.data() + i; + const uint8_t* ptr = start; + obuPayloadSize = ReadLeb128(ptr, frame.end()); + if (!ptr) { + // Malformed frame + assert(false && "Malformed AV1 frame: invalid LEB128 size"); + DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: invalid LEB128 size"; + return false; + } + i += ptr - start; + } + else { + // If the size is not present, the OBU extends to the end of the frame. + obuPayloadSize = frame.size() - i; + } + + const auto obuPayloadIndex = i; + + if (i + obuPayloadSize > frame.size()) { + // Malformed frame + assert(false && "Malformed AV1 frame: payload overflows frame"); + DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: payload overflows frame"; + return false; + } + + i += obuPayloadSize; + + // We only copy the OBUs that will not get dropped by the packetizer + if (obuType != kObuTypeTemporalDelimiter && obuType != kObuTypeTileList && + obuType != kObuTypePadding) { + // if this is the last OBU, we may need to flip the "has size" bit + // which allows us to append necessary protocol data to the frame + bool rewrittenWithoutSize = false; + + if (i == frame.size() && obuHasSize) { + // Flip the "has size" bit + obuHeader &= ~kAv1ObuHeaderHasSizeMask; + rewrittenWithoutSize = true; + } + + // write the OBU header unencrypted + processor.AddUnencryptedBytes(&obuHeader, sizeof(obuHeader)); + if (obuHasExtension) { + // write the extension byte unencrypted + processor.AddUnencryptedBytes(frame.data() + obuHeaderIndex + sizeof(obuHeader), + kObuExtensionSizeBytes); + } + + // write the OBU payload size unencrypted if it was present and we didn't rewrite + // without it + if (obuHasSize && !rewrittenWithoutSize) { + // The AMD AV1 encoder may pad LEB128 encoded sizes with a zero byte which the + // webrtc packetizer removes. To prevent the packetizer from changing the frame, + // we sanitize the size by re-writing it ourselves + uint8_t leb128Buffer[Leb128MaxSize]; + size_t additionalBytesToWrite = WriteLeb128(obuPayloadSize, leb128Buffer); + processor.AddUnencryptedBytes(leb128Buffer, additionalBytesToWrite); + } + + // add the OBU payload, encrypted + processor.AddEncryptedBytes(frame.data() + obuPayloadIndex, obuPayloadSize); + } + } + + return true; +} + +bool ValidateEncryptedFrame(OutboundFrameProcessor& processor, ArrayView frame) +{ + auto codec = processor.GetCodec(); + if (codec != Codec::H264 && codec != Codec::H265) { + return true; + } + + static_assert(kH26XNaluShortStartSequenceSize - 1 >= 0, "Padding will overflow!"); + constexpr size_t Padding = kH26XNaluShortStartSequenceSize - 1; + + const auto& unencryptedRanges = processor.GetUnencryptedRanges(); + + // H264 and H265 ciphertexts cannot contain a 3 or 4 byte start code {0, 0, 1} + // otherwise the packetizer gets confused + // and the frame we get on the decryption side will be shifted and fail to decrypt + size_t encryptedSectionStart = 0; + for (auto& range : unencryptedRanges) { + if (encryptedSectionStart == range.offset) { + encryptedSectionStart += range.size; + continue; + } + + auto start = encryptedSectionStart - std::min(encryptedSectionStart, size_t{Padding}); + auto end = std::min(range.offset + Padding, frame.size()); + if (FindNextH26XNaluIndex(frame.data() + start, end - start)) { + return false; + } + + encryptedSectionStart = range.offset + range.size; + } + + if (encryptedSectionStart == frame.size()) { + return true; + } + + auto start = encryptedSectionStart - std::min(encryptedSectionStart, size_t{Padding}); + auto end = frame.size(); + if (FindNextH26XNaluIndex(frame.data() + start, end - start)) { + return false; + } + + return true; +} + +} // namespace codec_utils +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/codec_utils.h b/src/dpp/dave/codec_utils.h new file mode 100755 index 0000000000..56ce14766c --- /dev/null +++ b/src/dpp/dave/codec_utils.h @@ -0,0 +1,22 @@ +#pragma once + +#include "common.h" +#include "frame_processors.h" +#include "array_view.h" + +namespace discord { +namespace dave { +namespace codec_utils { + +bool ProcessFrameOpus(OutboundFrameProcessor& processor, ArrayView frame); +bool ProcessFrameVp8(OutboundFrameProcessor& processor, ArrayView frame); +bool ProcessFrameVp9(OutboundFrameProcessor& processor, ArrayView frame); +bool ProcessFrameH264(OutboundFrameProcessor& processor, ArrayView frame); +bool ProcessFrameH265(OutboundFrameProcessor& processor, ArrayView frame); +bool ProcessFrameAv1(OutboundFrameProcessor& processor, ArrayView frame); + +bool ValidateEncryptedFrame(OutboundFrameProcessor& processor, ArrayView frame); + +} // namespace codec_utils +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/common.h b/src/dpp/dave/common.h new file mode 100755 index 0000000000..9a46ed9b39 --- /dev/null +++ b/src/dpp/dave/common.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "version.h" + +namespace mlspp::bytes_ns { +struct bytes; +}; + +namespace discord { +namespace dave { + +using UnencryptedFrameHeaderSize = uint16_t; +using TruncatedSyncNonce = uint32_t; +using MagicMarker = uint16_t; +using EncryptionKey = ::mlspp::bytes_ns::bytes; +using TransitionId = uint16_t; +using SupplementalBytesSize = uint8_t; + +enum MediaType : uint8_t { Audio, Video }; +enum Codec : uint8_t { Unknown, Opus, VP8, VP9, H264, H265, AV1 }; + +// Returned in std::variant when a message is hard-rejected and should trigger a reset +struct failed_t {}; + +// Returned in std::variant when a message is soft-rejected and should not trigger a reset +struct ignored_t {}; + +// Map of ID-key pairs. +// In ProcessCommit, this lists IDs whose keys have been added, changed, or removed; +// an empty value value means a key was removed. +using RosterMap = std::map>; + +// Return type for functions producing RosterMap or hard or soft failures +using RosterVariant = std::variant; + +constexpr MagicMarker kMarkerBytes = 0xFAFA; + +// Layout constants +constexpr size_t kAesGcm128KeyBytes = 16; +constexpr size_t kAesGcm128NonceBytes = 12; +constexpr size_t kAesGcm128TruncatedSyncNonceBytes = 4; +constexpr size_t kAesGcm128TruncatedSyncNonceOffset = + kAesGcm128NonceBytes - kAesGcm128TruncatedSyncNonceBytes; +constexpr size_t kAesGcm128TruncatedTagBytes = 8; +constexpr size_t kRatchetGenerationBytes = 1; +constexpr size_t kRatchetGenerationShiftBits = + 8 * (kAesGcm128TruncatedSyncNonceBytes - kRatchetGenerationBytes); +constexpr size_t kSupplementalBytes = + kAesGcm128TruncatedTagBytes + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); +constexpr size_t kTransformPaddingBytes = 64; + +// Timing constants +constexpr auto kDefaultTransitionDuration = std::chrono::seconds(10); +constexpr auto kCryptorExpiry = std::chrono::seconds(10); + +// Behavior constants +constexpr auto kInitTransitionId = 0; +constexpr auto kDisabledVersion = 0; +constexpr auto kMaxGenerationGap = 250; +constexpr auto kMaxMissingNonces = 1000; +constexpr auto kGenerationWrap = 1 << (8 * kRatchetGenerationBytes); +constexpr auto kMaxFramesPerSecond = 50 + 2 * 60; // 50 audio frames + 2 * 60fps video streams +constexpr std::array kOpusSilencePacket = {0xF8, 0xFF, 0xFE}; + +// Utility routine for variant return types +template +inline std::optional GetOptional(V&& variant) +{ + if (auto map = std::get_if(&variant)) { + if constexpr (std::is_rvalue_reference_v) { + return std::move(*map); + } + else { + return *map; + } + } + else { + return std::nullopt; + } +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/cryptor.cpp b/src/dpp/dave/cryptor.cpp new file mode 100755 index 0000000000..4adba712f1 --- /dev/null +++ b/src/dpp/dave/cryptor.cpp @@ -0,0 +1,14 @@ +#include "cryptor.h" +#include "openssl_cryptor.h" + +namespace discord { +namespace dave { + +std::unique_ptr CreateCryptor(const EncryptionKey& encryptionKey) +{ + auto cryptor = std::make_unique(encryptionKey); + return cryptor->IsValid() ? std::move(cryptor) : nullptr; +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/cryptor.h b/src/dpp/dave/cryptor.h new file mode 100755 index 0000000000..8e38b0e435 --- /dev/null +++ b/src/dpp/dave/cryptor.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "common.h" +#include "array_view.h" + +namespace discord { +namespace dave { + +class ICryptor { +public: + virtual ~ICryptor() = default; + + virtual bool Encrypt(ArrayView ciphertextBufferOut, + ArrayView plaintextBuffer, + ArrayView nonceBuffer, + ArrayView additionalData, + ArrayView tagBufferOut) = 0; + virtual bool Decrypt(ArrayView plaintextBufferOut, + ArrayView ciphertextBuffer, + ArrayView tagBuffer, + ArrayView nonceBuffer, + ArrayView additionalData) = 0; +}; + +std::unique_ptr CreateCryptor(const EncryptionKey& encryptionKey); + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/cryptor_manager.cpp b/src/dpp/dave/cryptor_manager.cpp new file mode 100755 index 0000000000..057fa93912 --- /dev/null +++ b/src/dpp/dave/cryptor_manager.cpp @@ -0,0 +1,184 @@ +#include "cryptor_manager.h" + +#include + +#include "key_ratchet.h" +#include "logger.h" + +#include + +using namespace std::chrono_literals; + +namespace discord { +namespace dave { + +KeyGeneration ComputeWrappedGeneration(KeyGeneration oldest, KeyGeneration generation) +{ + // Assume generation is greater than or equal to oldest, this may be wrong in a few cases but + // will be caught by the max generation gap check. + auto remainder = oldest % kGenerationWrap; + auto factor = oldest / kGenerationWrap + (generation < remainder ? 1 : 0); + return factor * kGenerationWrap + generation; +} + +BigNonce ComputeWrappedBigNonce(KeyGeneration generation, TruncatedSyncNonce nonce) +{ + // Remove the generation bits from the nonce + auto maskedNonce = nonce & ((1 << kRatchetGenerationShiftBits) - 1); + // Add the wrapped generation bits back in + return static_cast(generation) << kRatchetGenerationShiftBits | maskedNonce; +} + +CryptorManager::CryptorManager(const IClock& clock, std::unique_ptr keyRatchet) + : clock_(clock) + , keyRatchet_(std::move(keyRatchet)) + , ratchetCreation_(clock.Now()) + , ratchetExpiry_(TimePoint::max()) +{ +} + +bool CryptorManager::CanProcessNonce(KeyGeneration generation, TruncatedSyncNonce nonce) const +{ + if (!newestProcessedNonce_) { + return true; + } + + auto bigNonce = ComputeWrappedBigNonce(generation, nonce); + return bigNonce > *newestProcessedNonce_ || + std::find(missingNonces_.rbegin(), missingNonces_.rend(), bigNonce) != missingNonces_.rend(); +} + +ICryptor* CryptorManager::GetCryptor(KeyGeneration generation) +{ + CleanupExpiredCryptors(); + + if (generation < oldestGeneration_) { + DISCORD_LOG(LS_INFO) << "Received frame with old generation: " << generation + << ", oldest generation: " << oldestGeneration_; + return nullptr; + } + + if (generation > newestGeneration_ + kMaxGenerationGap) { + DISCORD_LOG(LS_INFO) << "Received frame with future generation: " << generation + << ", newest generation: " << newestGeneration_; + return nullptr; + } + + auto ratchetLifetimeSec = + std::chrono::duration_cast(clock_.Now() - ratchetCreation_).count(); + auto maxLifetimeFrames = kMaxFramesPerSecond * ratchetLifetimeSec; + auto maxLifetimeGenerations = maxLifetimeFrames >> kRatchetGenerationShiftBits; + if (generation > maxLifetimeGenerations) { + DISCORD_LOG(LS_INFO) << "Received frame with generation " << generation + << " beyond ratchet max lifetime generations: " + << maxLifetimeGenerations + << ", ratchet lifetime: " << ratchetLifetimeSec << "s"; + return nullptr; + } + + auto it = cryptors_.find(generation); + if (it == cryptors_.end()) { + // We don't have a cryptor for this generation, create one + std::tie(it, std::ignore) = cryptors_.emplace(generation, MakeExpiringCryptor(generation)); + } + + // Return a non-owning pointer to the cryptor + auto& [cryptor, expiry] = it->second; + return cryptor.get(); +} + +void CryptorManager::ReportCryptorSuccess(KeyGeneration generation, TruncatedSyncNonce nonce) +{ + auto bigNonce = ComputeWrappedBigNonce(generation, nonce); + + // Add any missing nonces to the queue + if (!newestProcessedNonce_) { + newestProcessedNonce_ = bigNonce; + } + else if (bigNonce > *newestProcessedNonce_) { + auto oldestMissingNonce = bigNonce > kMaxMissingNonces ? bigNonce - kMaxMissingNonces : 0; + + while (!missingNonces_.empty() && missingNonces_.front() < oldestMissingNonce) { + missingNonces_.pop_front(); + } + + // If we're missing a lot, we don't want to add everything since newestProcessedNonce_ + auto missingRangeStart = std::max(oldestMissingNonce, *newestProcessedNonce_ + 1); + for (auto i = missingRangeStart; i < bigNonce; ++i) { + missingNonces_.push_back(i); + } + + // Update the newest processed nonce + newestProcessedNonce_ = bigNonce; + } + else { + auto it = std::find(missingNonces_.begin(), missingNonces_.end(), bigNonce); + if (it != missingNonces_.end()) { + missingNonces_.erase(it); + } + } + + if (generation <= newestGeneration_ || cryptors_.find(generation) == cryptors_.end()) { + return; + } + DISCORD_LOG(LS_INFO) << "Reporting cryptor success, generation: " << generation; + newestGeneration_ = generation; + + // Update the expiry time for all old cryptors + const auto expiryTime = clock_.Now() + kCryptorExpiry; + for (auto& [gen, cryptor] : cryptors_) { + if (gen < newestGeneration_) { + DISCORD_LOG(LS_INFO) << "Updating expiry for cryptor, generation: " << gen; + cryptor.expiry = std::min(cryptor.expiry, expiryTime); + } + } +} + +KeyGeneration CryptorManager::ComputeWrappedGeneration(KeyGeneration generation) +{ + return ::discord::dave::ComputeWrappedGeneration(oldestGeneration_, generation); +} + +CryptorManager::ExpiringCryptor CryptorManager::MakeExpiringCryptor(KeyGeneration generation) +{ + // Get the new key from the ratchet + auto encryptionKey = keyRatchet_->GetKey(generation); + auto expiryTime = TimePoint::max(); + + // If we got frames out of order, we might have to create a cryptor for an old generation + // In that case, create it with a non-infinite expiry time as we have already transitioned + // to a newer generation + if (generation < newestGeneration_) { + DISCORD_LOG(LS_INFO) << "Creating cryptor for old generation: " << generation; + expiryTime = clock_.Now() + kCryptorExpiry; + } + else { + DISCORD_LOG(LS_INFO) << "Creating cryptor for new generation: " << generation; + } + + return {CreateCryptor(encryptionKey), expiryTime}; +} + +void CryptorManager::CleanupExpiredCryptors() +{ + for (auto it = cryptors_.begin(); it != cryptors_.end();) { + auto& [generation, cryptor] = *it; + + bool expired = cryptor.expiry < clock_.Now(); + if (expired) { + DISCORD_LOG(LS_INFO) << "Removing expired cryptor, generation: " << generation; + } + + it = expired ? cryptors_.erase(it) : ++it; + } + + while (oldestGeneration_ < newestGeneration_ && + cryptors_.find(oldestGeneration_) == cryptors_.end()) { + DISCORD_LOG(LS_INFO) << "Deleting key for old generation: " << oldestGeneration_; + keyRatchet_->DeleteKey(oldestGeneration_); + ++oldestGeneration_; + } +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/cryptor_manager.h b/src/dpp/dave/cryptor_manager.h new file mode 100755 index 0000000000..02e9d3daa6 --- /dev/null +++ b/src/dpp/dave/cryptor_manager.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include + +#include "cryptor.h" +#include "key_ratchet.h" +#include "common.h" +#include "clock.h" + +namespace discord { +namespace dave { + +KeyGeneration ComputeWrappedGeneration(KeyGeneration oldest, KeyGeneration generation); + +using BigNonce = uint64_t; +BigNonce ComputeWrappedBigNonce(KeyGeneration generation, TruncatedSyncNonce nonce); + +class CryptorManager { +public: + using TimePoint = typename IClock::TimePoint; + + CryptorManager(const IClock& clock, std::unique_ptr keyRatchet); + + void UpdateExpiry(TimePoint expiry) { ratchetExpiry_ = expiry; } + bool IsExpired() const { return clock_.Now() > ratchetExpiry_; } + + bool CanProcessNonce(KeyGeneration generation, TruncatedSyncNonce nonce) const; + KeyGeneration ComputeWrappedGeneration(KeyGeneration generation); + + ICryptor* GetCryptor(KeyGeneration generation); + void ReportCryptorSuccess(KeyGeneration generation, TruncatedSyncNonce nonce); + +private: + struct ExpiringCryptor { + std::unique_ptr cryptor; + TimePoint expiry; + }; + + ExpiringCryptor MakeExpiringCryptor(KeyGeneration generation); + void CleanupExpiredCryptors(); + + const IClock& clock_; + std::unique_ptr keyRatchet_; + std::unordered_map cryptors_; + + TimePoint ratchetCreation_; + TimePoint ratchetExpiry_; + KeyGeneration oldestGeneration_{0}; + KeyGeneration newestGeneration_{0}; + + std::optional newestProcessedNonce_; + std::deque missingNonces_; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/decryptor.cpp b/src/dpp/dave/decryptor.cpp new file mode 100755 index 0000000000..ab87d8240e --- /dev/null +++ b/src/dpp/dave/decryptor.cpp @@ -0,0 +1,221 @@ +#include "decryptor.h" + +#include + +#include +#include "common.h" +#include "logger.h" +#include "leb128.h" +#include "scope_exit.h" + +using namespace std::chrono_literals; + +namespace discord { +namespace dave { + +constexpr auto kStatsInterval = 10s; + +void Decryptor::TransitionToKeyRatchet(std::unique_ptr keyRatchet, + Duration transitionExpiry) +{ + DISCORD_LOG(LS_INFO) << "Transitioning to new key ratchet: " << keyRatchet.get() + << ", expiry: " << transitionExpiry.count(); + + // Update the expiry time for all existing cryptor managers + UpdateCryptorManagerExpiry(transitionExpiry); + + if (keyRatchet) { + cryptorManagers_.emplace_back(clock_, std::move(keyRatchet)); + } +} + +void Decryptor::TransitionToPassthroughMode(bool passthroughMode, Duration transitionExpiry) +{ + if (passthroughMode) { + allowPassThroughUntil_ = TimePoint::max(); + } + else { + // Update the pass through mode expiry + auto maxExpiry = clock_.Now() + transitionExpiry; + allowPassThroughUntil_ = std::min(allowPassThroughUntil_, maxExpiry); + } +} + +size_t Decryptor::Decrypt(MediaType mediaType, + ArrayView encryptedFrame, + ArrayView frame) +{ + if (mediaType != Audio && mediaType != Video) { + DISCORD_LOG(LS_WARNING) << "Decrypt failed, invalid media type: " + << static_cast(mediaType); + return 0; + } + + auto start = clock_.Now(); + + auto localFrame = GetOrCreateFrameProcessor(); + ScopeExit cleanup([&] { ReturnFrameProcessor(std::move(localFrame)); }); + + // Skip decrypting for silence frames + if (mediaType == Audio && encryptedFrame.size() == kOpusSilencePacket.size() && + std::memcmp(encryptedFrame.data(), kOpusSilencePacket.data(), kOpusSilencePacket.size()) == 0) { + DISCORD_LOG(LS_VERBOSE) << "Decrypt skipping silence of size: " << encryptedFrame.size(); + if (encryptedFrame.data() != frame.data()) { + std::memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); + } + return encryptedFrame.size(); + } + + // Remove any expired cryptor manager + CleanupExpiredCryptorManagers(); + + // Process the incoming frame + // This will check whether it looks like a valid encrypted frame + // and if so it will parse it into its different components + localFrame->ParseFrame(encryptedFrame); + + // If the frame is not encrypted and we can pass it through, do it + bool canUsePassThrough = allowPassThroughUntil_ > start; + if (!localFrame->IsEncrypted() && canUsePassThrough) { + if (encryptedFrame.data() != frame.data()) { + std::memcpy(frame.data(), encryptedFrame.data(), encryptedFrame.size()); + } + stats_[mediaType].passthroughCount++; + return encryptedFrame.size(); + } + + // If the frame is not encrypted and we can't pass it through, fail + if (!localFrame->IsEncrypted()) { + DISCORD_LOG(LS_INFO) + << "Decrypt failed, frame is not encrypted and pass through is disabled"; + stats_[mediaType].decryptFailureCount++; + return 0; + } + + // Try and decrypt with each valid cryptor + // reverse iterate to try the newest cryptors first + bool success = false; + for (auto it = cryptorManagers_.rbegin(); it != cryptorManagers_.rend(); ++it) { + auto& cryptorManager = *it; + success = DecryptImpl(cryptorManager, mediaType, *localFrame, frame); + if (success) { + break; + } + } + + size_t bytesWritten = 0; + if (success) { + stats_[mediaType].decryptSuccessCount++; + bytesWritten = localFrame->ReconstructFrame(frame); + } + else { + stats_[mediaType].decryptFailureCount++; + DISCORD_LOG(LS_WARNING) << "Decrypt failed, no valid cryptor found, type: " + << (mediaType ? "video" : "audio") + << ", encrypted frame size: " << encryptedFrame.size() + << ", plaintext frame size: " << frame.size() + << ", number of cryptor managers: " << cryptorManagers_.size() + << ", pass through enabled: " << (canUsePassThrough ? "yes" : "no"); + } + + auto end = clock_.Now(); + if (end > lastStatsTime_ + kStatsInterval) { + lastStatsTime_ = end; + DISCORD_LOG(LS_INFO) << "Decrypted audio: " << stats_[Audio].decryptSuccessCount + << ", video: " << stats_[Video].decryptSuccessCount + << ". Failed audio: " << stats_[Audio].decryptFailureCount + << ", video: " << stats_[Video].decryptFailureCount; + } + stats_[mediaType].decryptDuration += + std::chrono::duration_cast(end - start).count(); + + return bytesWritten; +} + +bool Decryptor::DecryptImpl(CryptorManager& cryptorManager, + MediaType mediaType, + InboundFrameProcessor& encryptedFrame, + ArrayView frame) +{ + auto tag = encryptedFrame.GetTag(); + auto truncatedNonce = encryptedFrame.GetTruncatedNonce(); + + auto authenticatedData = encryptedFrame.GetAuthenticatedData(); + auto ciphertext = encryptedFrame.GetCiphertext(); + auto plaintext = encryptedFrame.GetPlaintext(); + + // expand the truncated nonce to the full sized one needed for decryption + auto nonceBuffer = std::array(); + memcpy(nonceBuffer.data() + kAesGcm128TruncatedSyncNonceOffset, + &truncatedNonce, + kAesGcm128TruncatedSyncNonceBytes); + + auto nonceBufferView = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); + + auto generation = + cryptorManager.ComputeWrappedGeneration(truncatedNonce >> kRatchetGenerationShiftBits); + + if (!cryptorManager.CanProcessNonce(generation, truncatedNonce)) { + DISCORD_LOG(LS_INFO) << "Decrypt failed, cannot process nonce: " << truncatedNonce; + return false; + } + + // Get the cryptor for this generation + ICryptor* cryptor = cryptorManager.GetCryptor(generation); + + if (!cryptor) { + DISCORD_LOG(LS_INFO) << "Decrypt failed, no cryptor found for generation: " << generation; + return false; + } + + // perform the decryption + bool success = cryptor->Decrypt(plaintext, ciphertext, tag, nonceBufferView, authenticatedData); + stats_[mediaType].decryptAttempts++; + + if (success) { + cryptorManager.ReportCryptorSuccess(generation, truncatedNonce); + } + + return success; +} + +size_t Decryptor::GetMaxPlaintextByteSize(MediaType mediaType, size_t encryptedFrameSize) +{ + return encryptedFrameSize; +} + +void Decryptor::UpdateCryptorManagerExpiry(Duration expiry) +{ + auto maxExpiryTime = clock_.Now() + expiry; + for (auto& cryptorManager : cryptorManagers_) { + cryptorManager.UpdateExpiry(maxExpiryTime); + } +} + +void Decryptor::CleanupExpiredCryptorManagers() +{ + while (!cryptorManagers_.empty() && cryptorManagers_.front().IsExpired()) { + DISCORD_LOG(LS_INFO) << "Removing expired cryptor manager."; + cryptorManagers_.pop_front(); + } +} + +std::unique_ptr Decryptor::GetOrCreateFrameProcessor() +{ + std::lock_guard lock(frameProcessorsMutex_); + if (frameProcessors_.empty()) { + return std::make_unique(); + } + auto frameProcessor = std::move(frameProcessors_.back()); + frameProcessors_.pop_back(); + return frameProcessor; +} + +void Decryptor::ReturnFrameProcessor(std::unique_ptr frameProcessor) +{ + std::lock_guard lock(frameProcessorsMutex_); + frameProcessors_.push_back(std::move(frameProcessor)); +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/decryptor.h b/src/dpp/dave/decryptor.h new file mode 100755 index 0000000000..c79ff975f2 --- /dev/null +++ b/src/dpp/dave/decryptor.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "codec_utils.h" +#include "common.h" +#include "cryptor.h" +#include "cryptor_manager.h" +#include "frame_processors.h" +#include "version.h" +#include "clock.h" + +namespace discord { +namespace dave { + +class IKeyRatchet; + +struct DecryptorStats { + uint64_t passthroughCount = 0; + uint64_t decryptSuccessCount = 0; + uint64_t decryptFailureCount = 0; + uint64_t decryptDuration = 0; + uint64_t decryptAttempts = 0; +}; + +class Decryptor { +public: + using Duration = std::chrono::seconds; + + void TransitionToKeyRatchet(std::unique_ptr keyRatchet, + Duration transitionExpiry = kDefaultTransitionDuration); + void TransitionToPassthroughMode(bool passthroughMode, + Duration transitionExpiry = kDefaultTransitionDuration); + + size_t Decrypt(MediaType mediaType, + ArrayView encryptedFrame, + ArrayView frame); + + size_t GetMaxPlaintextByteSize(MediaType mediaType, size_t encryptedFrameSize); + DecryptorStats GetStats(MediaType mediaType) const { return stats_[mediaType]; } + +private: + using TimePoint = IClock::TimePoint; + + bool DecryptImpl(CryptorManager& cryptor, + MediaType mediaType, + InboundFrameProcessor& encryptedFrame, + ArrayView frame); + + void UpdateCryptorManagerExpiry(Duration expiry); + void CleanupExpiredCryptorManagers(); + + std::unique_ptr GetOrCreateFrameProcessor(); + void ReturnFrameProcessor(std::unique_ptr frameProcessor); + + Clock clock_; + std::deque cryptorManagers_; + + std::mutex frameProcessorsMutex_; + std::vector> frameProcessors_; + + TimePoint allowPassThroughUntil_{TimePoint::min()}; + + TimePoint lastStatsTime_{TimePoint::min()}; + std::array stats_; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/encryptor.cpp b/src/dpp/dave/encryptor.cpp new file mode 100755 index 0000000000..fa5ef66762 --- /dev/null +++ b/src/dpp/dave/encryptor.cpp @@ -0,0 +1,292 @@ +#include "encryptor.h" +#include +#include +#include +#include "common.h" +#include "cryptor_manager.h" +#include "logger.h" +#include "codec_utils.h" +#include "array_view.h" +#include "leb128.h" +#include "scope_exit.h" + +using namespace std::chrono_literals; + +namespace discord { +namespace dave { + +constexpr auto kStatsInterval = 10s; + +void Encryptor::SetKeyRatchet(std::unique_ptr keyRatchet) +{ + std::lock_guard lock(keyGenMutex_); + keyRatchet_ = std::move(keyRatchet); + cryptor_ = nullptr; + currentKeyGeneration_ = 0; + truncatedNonce_ = 0; +} + +void Encryptor::SetPassthroughMode(bool passthroughMode) +{ + passthroughMode_ = passthroughMode; + UpdateCurrentProtocolVersion(passthroughMode ? 0 : MaxSupportedProtocolVersion()); +} + +int Encryptor::Encrypt(MediaType mediaType, + uint32_t ssrc, + ArrayView frame, + ArrayView encryptedFrame, + size_t* bytesWritten) +{ + if (mediaType != Audio && mediaType != Video) { + DISCORD_LOG(LS_WARNING) << "Encrypt failed, invalid media type: " + << static_cast(mediaType); + return 0; + } + + if (passthroughMode_) { + // Pass frame through without encrypting + std::memcpy(encryptedFrame.data(), frame.data(), frame.size()); + *bytesWritten = frame.size(); + stats_[mediaType].passthroughCount++; + return ResultCode::Success; + } + + { + std::lock_guard lock(keyGenMutex_); + if (!keyRatchet_) { + stats_[mediaType].encryptFailureCount++; + return ResultCode::EncryptionFailure; + } + } + + auto start = std::chrono::steady_clock::now(); + auto result = ResultCode::Success; + + // write the codec identifier + auto codec = CodecForSsrc(ssrc); + + auto frameProcessor = GetOrCreateFrameProcessor(); + ScopeExit cleanup([&] { ReturnFrameProcessor(std::move(frameProcessor)); }); + + frameProcessor->ProcessFrame(frame, codec); + + const auto& unencryptedBytes = frameProcessor->GetUnencryptedBytes(); + const auto& encryptedBytes = frameProcessor->GetEncryptedBytes(); + auto& ciphertextBytes = frameProcessor->GetCiphertextBytes(); + + const auto& unencryptedRanges = frameProcessor->GetUnencryptedRanges(); + auto unencryptedRangesSize = UnencryptedRangesSize(unencryptedRanges); + + auto additionalData = MakeArrayView(unencryptedBytes.data(), unencryptedBytes.size()); + auto plaintextBuffer = MakeArrayView(encryptedBytes.data(), encryptedBytes.size()); + auto ciphertextBuffer = MakeArrayView(ciphertextBytes.data(), ciphertextBytes.size()); + + auto frameSize = encryptedBytes.size() + unencryptedBytes.size(); + auto tagBuffer = MakeArrayView(encryptedFrame.data() + frameSize, kAesGcm128TruncatedTagBytes); + + auto nonceBuffer = std::array(); + auto nonceBufferView = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); + + constexpr auto MAX_CIPHERTEXT_VALIDATION_RETRIES = 10; + + // some codecs (e.g. H26X) have packetizers that cannot handle specific byte sequences + // so we attempt up to MAX_CIPHERTEXT_VALIDATION_RETRIES to encrypt the frame + // calling into codec utils to validate the ciphertext + supplemental section + // and re-rolling the truncated nonce if it fails + + // the nonce increment will definitely change the ciphertext and the tag + // incrementing the nonce will also change the appropriate bytes + // in the tail end of the nonce + // which can remove start codes from the last 1 or 2 bytes of the nonce + // and the two bytes of the unencrypted header bytes + for (auto attempt = 1; attempt <= MAX_CIPHERTEXT_VALIDATION_RETRIES; ++attempt) { + auto [cryptor, truncatedNonce] = GetNextCryptorAndNonce(); + + if (!cryptor) { + result = ResultCode::EncryptionFailure; + break; + } + + // write the truncated nonce to our temporary full nonce array + // (since the encryption call expects a full size nonce) + std::memcpy(nonceBuffer.data() + kAesGcm128TruncatedSyncNonceOffset, + &truncatedNonce, + kAesGcm128TruncatedSyncNonceBytes); + + // encrypt the plaintext, adding the unencrypted header to the tag + bool success = cryptor->Encrypt( + ciphertextBuffer, plaintextBuffer, nonceBufferView, additionalData, tagBuffer); + + stats_[mediaType].encryptAttempts++; + stats_[mediaType].encryptMaxAttempts = + std::max(stats_[mediaType].encryptMaxAttempts, (uint64_t)attempt); + + if (!success) { + assert(false && "Failed to encrypt frame"); + result = ResultCode::EncryptionFailure; + break; + } + + auto reconstructedFrameSize = frameProcessor->ReconstructFrame(encryptedFrame); + assert(reconstructedFrameSize == frameSize && "Failed to reconstruct frame"); + + auto nonceSize = Leb128Size(truncatedNonce); + + auto truncatedNonceBuffer = MakeArrayView(tagBuffer.end(), nonceSize); + auto unencryptedRangesBuffer = + MakeArrayView(truncatedNonceBuffer.end(), unencryptedRangesSize); + auto supplementalBytesBuffer = + MakeArrayView(unencryptedRangesBuffer.end(), sizeof(SupplementalBytesSize)); + auto markerBytesBuffer = MakeArrayView(supplementalBytesBuffer.end(), sizeof(MagicMarker)); + + // write the nonce + auto res = WriteLeb128(truncatedNonce, truncatedNonceBuffer.begin()); + if (res != nonceSize) { + assert(false && "Failed to write truncated nonce"); + result = ResultCode::EncryptionFailure; + break; + } + + // write the unencrypted ranges + res = SerializeUnencryptedRanges( + unencryptedRanges, unencryptedRangesBuffer.begin(), unencryptedRangesBuffer.size()); + if (res != unencryptedRangesSize) { + assert(false && "Failed to write unencrypted ranges"); + result = ResultCode::EncryptionFailure; + break; + } + + // write the supplemental bytes size + SupplementalBytesSize supplementalBytes = + kSupplementalBytes + nonceSize + unencryptedRangesSize; + std::memcpy(supplementalBytesBuffer.data(), &supplementalBytes, sizeof(SupplementalBytesSize)); + + // write the marker bytes, ends the frame + std::memcpy(markerBytesBuffer.data(), &kMarkerBytes, sizeof(MagicMarker)); + + auto encryptedFrameBytes = reconstructedFrameSize + kAesGcm128TruncatedTagBytes + + nonceSize + unencryptedRangesSize + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); + + if (codec_utils::ValidateEncryptedFrame( + *frameProcessor, MakeArrayView(encryptedFrame.data(), encryptedFrameBytes))) { + *bytesWritten = encryptedFrameBytes; + break; + } + else if (attempt >= MAX_CIPHERTEXT_VALIDATION_RETRIES) { + assert(false && "Failed to validate encrypted section for codec"); + result = ResultCode::EncryptionFailure; + break; + } + } + + auto now = std::chrono::steady_clock::now(); + stats_[mediaType].encryptDuration += + std::chrono::duration_cast(now - start).count(); + if (result == ResultCode::Success) { + stats_[mediaType].encryptSuccessCount++; + } + else { + stats_[mediaType].encryptFailureCount++; + } + + if (now > lastStatsTime_ + kStatsInterval) { + lastStatsTime_ = now; + DISCORD_LOG(LS_INFO) << "Encrypted audio: " << stats_[Audio].encryptSuccessCount + << ", video: " << stats_[Video].encryptSuccessCount + << ". Failed audio: " << stats_[Audio].encryptFailureCount + << ", video: " << stats_[Video].encryptFailureCount; + DISCORD_LOG(LS_INFO) << "Last encrypted frame, type: " + << (mediaType == Audio ? "audio" : "video") << ", ssrc: " << ssrc + << ", size: " << frame.size(); + } + + return result; +} + +size_t Encryptor::GetMaxCiphertextByteSize(MediaType mediaType, size_t frameSize) +{ + return frameSize + kSupplementalBytes + kTransformPaddingBytes; +} + +void Encryptor::AssignSsrcToCodec(uint32_t ssrc, Codec codecType) +{ + auto existingCodecIt = std::find_if( + ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { + return pair.first == ssrc; + }); + + if (existingCodecIt == ssrcCodecPairs_.end()) { + ssrcCodecPairs_.emplace_back(ssrc, codecType); + } + else { + existingCodecIt->second = codecType; + } +} + +Codec Encryptor::CodecForSsrc(uint32_t ssrc) +{ + auto existingCodecIt = std::find_if( + ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { + return pair.first == ssrc; + }); + + if (existingCodecIt != ssrcCodecPairs_.end()) { + return existingCodecIt->second; + } + else { + return Codec::Unknown; + } +} + +std::unique_ptr Encryptor::GetOrCreateFrameProcessor() +{ + std::lock_guard lock(frameProcessorsMutex_); + if (frameProcessors_.empty()) { + return std::make_unique(); + } + auto frameProcessor = std::move(frameProcessors_.back()); + frameProcessors_.pop_back(); + return frameProcessor; +} + +void Encryptor::ReturnFrameProcessor(std::unique_ptr frameProcessor) +{ + std::lock_guard lock(frameProcessorsMutex_); + frameProcessors_.push_back(std::move(frameProcessor)); +} + +Encryptor::CryptorAndNonce Encryptor::GetNextCryptorAndNonce() +{ + std::lock_guard lock(keyGenMutex_); + if (!keyRatchet_) { + return {nullptr, 0}; + } + + auto generation = ComputeWrappedGeneration(currentKeyGeneration_, + ++truncatedNonce_ >> kRatchetGenerationShiftBits); + + if (generation != currentKeyGeneration_ || !cryptor_) { + currentKeyGeneration_ = generation; + + auto encryptionKey = keyRatchet_->GetKey(currentKeyGeneration_); + cryptor_ = CreateCryptor(encryptionKey); + } + + return {cryptor_, truncatedNonce_}; +} + +void Encryptor::UpdateCurrentProtocolVersion(ProtocolVersion version) +{ + if (version == currentProtocolVersion_) { + return; + } + + currentProtocolVersion_ = version; + if (protocolVersionChangedCallback_) { + protocolVersionChangedCallback_(); + } +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/encryptor.h b/src/dpp/dave/encryptor.h new file mode 100755 index 0000000000..f885eb8ed5 --- /dev/null +++ b/src/dpp/dave/encryptor.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "codec_utils.h" +#include "common.h" +#include "cryptor.h" +#include "key_ratchet.h" +#include "frame_processors.h" +#include "version.h" + +namespace discord { +namespace dave { + +struct EncryptorStats { + uint64_t passthroughCount = 0; + uint64_t encryptSuccessCount = 0; + uint64_t encryptFailureCount = 0; + uint64_t encryptDuration = 0; + uint64_t encryptAttempts = 0; + uint64_t encryptMaxAttempts = 0; +}; + +class Encryptor { +public: + void SetKeyRatchet(std::unique_ptr keyRatchet); + void SetPassthroughMode(bool passthroughMode); + + bool HasKeyRatchet() const { return keyRatchet_ != nullptr; } + bool IsPassthroughMode() const { return passthroughMode_; } + + void AssignSsrcToCodec(uint32_t ssrc, Codec codecType); + Codec CodecForSsrc(uint32_t ssrc); + + int Encrypt(MediaType mediaType, + uint32_t ssrc, + ArrayView frame, + ArrayView encryptedFrame, + size_t* bytesWritten); + + size_t GetMaxCiphertextByteSize(MediaType mediaType, size_t frameSize); + EncryptorStats GetStats(MediaType mediaType) const { return stats_[mediaType]; } + + using ProtocolVersionChangedCallback = std::function; + void SetProtocolVersionChangedCallback(ProtocolVersionChangedCallback callback) + { + protocolVersionChangedCallback_ = std::move(callback); + } + ProtocolVersion GetProtocolVersion() const { return currentProtocolVersion_; } + +private: + std::unique_ptr GetOrCreateFrameProcessor(); + void ReturnFrameProcessor(std::unique_ptr frameProcessor); + + using CryptorAndNonce = std::pair, TruncatedSyncNonce>; + CryptorAndNonce GetNextCryptorAndNonce(); + + void UpdateCurrentProtocolVersion(ProtocolVersion version); + + enum ResultCode { + Success, + UninitializedContext, + InitializationFailure, + UnsupportedCodec, + EncryptionFailure, + FinalizationFailure, + TagAppendFailure + }; + + std::atomic_bool passthroughMode_{false}; + + std::mutex keyGenMutex_; + std::unique_ptr keyRatchet_; + std::shared_ptr cryptor_; + KeyGeneration currentKeyGeneration_{0}; + TruncatedSyncNonce truncatedNonce_{0}; + + std::mutex frameProcessorsMutex_; + std::vector> frameProcessors_; + + using SsrcCodecPair = std::pair; + std::vector ssrcCodecPairs_; + + using TimePoint = std::chrono::time_point; + TimePoint lastStatsTime_{TimePoint::min()}; + std::array stats_; + + ProtocolVersionChangedCallback protocolVersionChangedCallback_; + ProtocolVersion currentProtocolVersion_{MaxSupportedProtocolVersion()}; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/frame_processors.cpp b/src/dpp/dave/frame_processors.cpp new file mode 100755 index 0000000000..997df3cd20 --- /dev/null +++ b/src/dpp/dave/frame_processors.cpp @@ -0,0 +1,390 @@ +#include "frame_processors.h" +#include +#include +#include +#include +#include +#include "codec_utils.h" +#include "logger.h" +#include "array_view.h" +#include "leb128.h" + +#if defined(_MSC_VER) +#include +#endif + +namespace discord { +namespace dave { + +std::pair OverflowAdd(size_t a, size_t b) +{ + size_t res; +#if defined(_MSC_VER) && defined(_M_X64) + bool didOverflow = _addcarry_u64(0, a, b, &res); +#elif defined(_MSC_VER) && defined(_M_IX86) + bool didOverflow = _addcarry_u32(0, a, b, &res); +#else + bool didOverflow = __builtin_add_overflow(a, b, &res); +#endif + return {didOverflow, res}; +} + +uint8_t UnencryptedRangesSize(const Ranges& unencryptedRanges) +{ + size_t size = 0; + for (const auto& range : unencryptedRanges) { + size += Leb128Size(range.offset); + size += Leb128Size(range.size); + } + assert(size <= std::numeric_limits::max() && + "Unencrypted ranges size exceeds 255 bytes"); + return static_cast(size); +} + +uint8_t SerializeUnencryptedRanges(const Ranges& unencryptedRanges, + uint8_t* buffer, + size_t bufferSize) +{ + auto writeAt = buffer; + auto end = buffer + bufferSize; + for (const auto& range : unencryptedRanges) { + auto rangeSize = Leb128Size(range.offset) + Leb128Size(range.size); + if (rangeSize > static_cast(end - writeAt)) { + assert(false && "Buffer is too small to serialize unencrypted ranges"); + break; + } + + writeAt += WriteLeb128(range.offset, writeAt); + writeAt += WriteLeb128(range.size, writeAt); + } + return writeAt - buffer; +} + +uint8_t DeserializeUnencryptedRanges(const uint8_t*& readAt, + const size_t bufferSize, + Ranges& unencryptedRanges) +{ + auto start = readAt; + auto end = readAt + bufferSize; + while (readAt < end) { + size_t offset = ReadLeb128(readAt, end); + if (readAt == nullptr) { + break; + } + + size_t size = ReadLeb128(readAt, end); + if (readAt == nullptr) { + break; + } + unencryptedRanges.push_back({offset, size}); + } + + if (readAt != end) { + DISCORD_LOG(LS_WARNING) << "Failed to deserialize unencrypted ranges"; + unencryptedRanges.clear(); + readAt = nullptr; + return 0; + } + + return readAt - start; +} + +bool ValidateUnencryptedRanges(const Ranges& unencryptedRanges, size_t frameSize) +{ + if (unencryptedRanges.empty()) { + return true; + } + + // validate that the ranges are in order and don't overlap + for (auto i = 0u; i < unencryptedRanges.size(); ++i) { + auto current = unencryptedRanges[i]; + // The current range should not overflow into the next range + // or if it is the last range, the end of the frame + auto maxEnd = + i + 1 < unencryptedRanges.size() ? unencryptedRanges[i + 1].offset : frameSize; + + auto [didOverflow, currentEnd] = OverflowAdd(current.offset, current.size); + if (didOverflow || currentEnd > maxEnd) { + DISCORD_LOG(LS_WARNING) + << "Unencrypted range may overlap or be out of order: current offset: " + << current.offset << ", current size: " << current.size << ", maximum end: " << maxEnd + << ", frame size: " << frameSize; + return false; + } + } + + return true; +} + +size_t Reconstruct(Ranges ranges, + const std::vector& rangeBytes, + const std::vector& otherBytes, + const ArrayView& output) +{ + size_t frameIndex = 0; + size_t rangeBytesIndex = 0; + size_t otherBytesIndex = 0; + + const auto CopyRangeBytes = [&](size_t size) { + assert(rangeBytesIndex + size <= rangeBytes.size()); + assert(frameIndex + size <= output.size()); + std::memcpy(output.data() + frameIndex, rangeBytes.data() + rangeBytesIndex, size); + rangeBytesIndex += size; + frameIndex += size; + }; + + const auto CopyOtherBytes = [&](size_t size) { + assert(otherBytesIndex + size <= otherBytes.size()); + assert(frameIndex + size <= output.size()); + std::memcpy(output.data() + frameIndex, otherBytes.data() + otherBytesIndex, size); + otherBytesIndex += size; + frameIndex += size; + }; + + for (const auto& range : ranges) { + if (range.offset > frameIndex) { + CopyOtherBytes(range.offset - frameIndex); + } + + CopyRangeBytes(range.size); + } + + if (otherBytesIndex < otherBytes.size()) { + CopyOtherBytes(otherBytes.size() - otherBytesIndex); + } + + assert(rangeBytesIndex == rangeBytes.size()); + assert(otherBytesIndex == otherBytes.size()); + assert(frameIndex <= output.size()); + + return frameIndex; +} + +void InboundFrameProcessor::Clear() +{ + isEncrypted_ = false; + originalSize_ = 0; + truncatedNonce_ = std::numeric_limits::max(); + unencryptedRanges_.clear(); + authenticated_.clear(); + ciphertext_.clear(); + plaintext_.clear(); +} + +void InboundFrameProcessor::ParseFrame(ArrayView frame) +{ + Clear(); + + constexpr auto MinSupplementalBytesSize = + kAesGcm128TruncatedTagBytes + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); + if (frame.size() < MinSupplementalBytesSize) { + DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain min supplemental bytes"; + return; + } + + // Check the frame ends with the magic marker + auto magicMarkerBuffer = frame.end() - sizeof(MagicMarker); + if (memcmp(magicMarkerBuffer, &kMarkerBytes, sizeof(MagicMarker)) != 0) { + return; + } + + // Read the supplemental bytes size + SupplementalBytesSize supplementalBytesSize; + auto supplementalBytesSizeBuffer = magicMarkerBuffer - sizeof(SupplementalBytesSize); + assert(frame.begin() <= supplementalBytesSizeBuffer && + supplementalBytesSizeBuffer <= frame.end()); + memcpy(&supplementalBytesSize, supplementalBytesSizeBuffer, sizeof(SupplementalBytesSize)); + + // Check the frame is large enough to contain the supplemental bytes + if (frame.size() < supplementalBytesSize) { + DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain supplemental bytes"; + return; + } + + // Check that supplemental bytes size is large enough to contain the supplemental bytes + if (supplementalBytesSize < MinSupplementalBytesSize) { + DISCORD_LOG(LS_WARNING) + << "Supplemental bytes size is too small to contain supplemental bytes"; + return; + } + + auto supplementalBytesBuffer = frame.end() - supplementalBytesSize; + assert(frame.begin() <= supplementalBytesBuffer && supplementalBytesBuffer <= frame.end()); + + // Read the tag + tag_ = MakeArrayView(supplementalBytesBuffer, kAesGcm128TruncatedTagBytes); + + // Read the nonce + auto nonceBuffer = supplementalBytesBuffer + kAesGcm128TruncatedTagBytes; + assert(frame.begin() <= nonceBuffer && nonceBuffer <= frame.end()); + auto readAt = nonceBuffer; + auto end = supplementalBytesSizeBuffer; + truncatedNonce_ = ReadLeb128(readAt, end); + if (readAt == nullptr) { + DISCORD_LOG(LS_WARNING) << "Failed to read truncated nonce"; + return; + } + + // Read the unencrypted ranges + assert(nonceBuffer <= readAt && readAt <= end); + auto unencryptedRangesSize = end - readAt; + DeserializeUnencryptedRanges(readAt, unencryptedRangesSize, unencryptedRanges_); + if (readAt == nullptr) { + DISCORD_LOG(LS_WARNING) << "Failed to read unencrypted ranges"; + return; + } + + if (!ValidateUnencryptedRanges(unencryptedRanges_, frame.size())) { + DISCORD_LOG(LS_WARNING) << "Invalid unencrypted ranges"; + return; + } + + // This is overly aggressive but will keep reallocations to a minimum + authenticated_.reserve(frame.size()); + ciphertext_.reserve(frame.size()); + plaintext_.reserve(frame.size()); + + originalSize_ = frame.size(); + + // Split the frame into authenticated and ciphertext bytes + size_t frameIndex = 0; + for (const auto& range : unencryptedRanges_) { + auto encryptedBytes = range.offset - frameIndex; + if (encryptedBytes > 0) { + assert(frameIndex + encryptedBytes <= frame.size()); + AddCiphertextBytes(frame.data() + frameIndex, encryptedBytes); + } + + assert(range.offset + range.size <= frame.size()); + AddAuthenticatedBytes(frame.data() + range.offset, range.size); + frameIndex = range.offset + range.size; + } + auto actualFrameSize = frame.size() - supplementalBytesSize; + if (frameIndex < actualFrameSize) { + AddCiphertextBytes(frame.data() + frameIndex, actualFrameSize - frameIndex); + } + + // Make sure the plaintext buffer is the same size as the ciphertext buffer + plaintext_.resize(ciphertext_.size()); + + // We've successfully parsed the frame + // Mark the frame as encrypted + isEncrypted_ = true; +} + +size_t InboundFrameProcessor::ReconstructFrame(ArrayView frame) const +{ + if (!isEncrypted_) { + DISCORD_LOG(LS_WARNING) << "Cannot reconstruct an invalid encrypted frame"; + return 0; + } + + if (authenticated_.size() + plaintext_.size() > frame.size()) { + DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the decrypted frame"; + return 0; + } + + return Reconstruct(unencryptedRanges_, authenticated_, plaintext_, frame); +} + +void InboundFrameProcessor::AddAuthenticatedBytes(const uint8_t* data, size_t size) +{ + authenticated_.resize(authenticated_.size() + size); + memcpy(authenticated_.data() + authenticated_.size() - size, data, size); +} + +void InboundFrameProcessor::AddCiphertextBytes(const uint8_t* data, size_t size) +{ + ciphertext_.resize(ciphertext_.size() + size); + memcpy(ciphertext_.data() + ciphertext_.size() - size, data, size); +} + +void OutboundFrameProcessor::Reset() +{ + codec_ = Codec::Unknown; + frameIndex_ = 0; + unencryptedBytes_.clear(); + encryptedBytes_.clear(); + unencryptedRanges_.clear(); +} + +void OutboundFrameProcessor::ProcessFrame(ArrayView frame, Codec codec) +{ + Reset(); + + codec_ = codec; + unencryptedBytes_.reserve(frame.size()); + encryptedBytes_.reserve(frame.size()); + + bool success = false; + switch (codec) { + case Codec::Opus: + success = codec_utils::ProcessFrameOpus(*this, frame); + break; + case Codec::VP8: + success = codec_utils::ProcessFrameVp8(*this, frame); + break; + case Codec::VP9: + success = codec_utils::ProcessFrameVp9(*this, frame); + break; + case Codec::H264: + success = codec_utils::ProcessFrameH264(*this, frame); + break; + case Codec::H265: + success = codec_utils::ProcessFrameH265(*this, frame); + break; + case Codec::AV1: + success = codec_utils::ProcessFrameAv1(*this, frame); + break; + default: + assert(false && "Unsupported codec for frame encryption"); + break; + } + + if (!success) { + frameIndex_ = 0; + unencryptedBytes_.clear(); + encryptedBytes_.clear(); + unencryptedRanges_.clear(); + AddEncryptedBytes(frame.data(), frame.size()); + } + + ciphertextBytes_.resize(encryptedBytes_.size()); +} + +size_t OutboundFrameProcessor::ReconstructFrame(ArrayView frame) +{ + if (unencryptedBytes_.size() + ciphertextBytes_.size() > frame.size()) { + DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the encrypted frame"; + return 0; + } + + return Reconstruct(unencryptedRanges_, unencryptedBytes_, ciphertextBytes_, frame); +} + +void OutboundFrameProcessor::AddUnencryptedBytes(const uint8_t* bytes, size_t size) +{ + if (!unencryptedRanges_.empty() && + unencryptedRanges_.back().offset + unencryptedRanges_.back().size == frameIndex_) { + // extend the last range + unencryptedRanges_.back().size += size; + } + else { + // add a new range (offset, size) + unencryptedRanges_.push_back({frameIndex_, size}); + } + + unencryptedBytes_.resize(unencryptedBytes_.size() + size); + memcpy(unencryptedBytes_.data() + unencryptedBytes_.size() - size, bytes, size); + frameIndex_ += size; +} + +void OutboundFrameProcessor::AddEncryptedBytes(const uint8_t* bytes, size_t size) +{ + encryptedBytes_.resize(encryptedBytes_.size() + size); + memcpy(encryptedBytes_.data() + encryptedBytes_.size() - size, bytes, size); + frameIndex_ += size; +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/frame_processors.h b/src/dpp/dave/frame_processors.h new file mode 100755 index 0000000000..1ea0bab87a --- /dev/null +++ b/src/dpp/dave/frame_processors.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include +#include +#include "common.h" +#include "array_view.h" + +namespace discord { +namespace dave { + +struct Range { + size_t offset; + size_t size; +}; +using Ranges = std::vector; + +uint8_t UnencryptedRangesSize(const Ranges& unencryptedRanges); +uint8_t SerializeUnencryptedRanges(const Ranges& unencryptedRanges, + uint8_t* buffer, + size_t bufferSize); +uint8_t DeserializeUnencryptedRanges(const uint8_t*& buffer, + const size_t bufferSize, + Ranges& unencryptedRanges); +bool ValidateUnencryptedRanges(const Ranges& unencryptedRanges, size_t frameSize); + +class InboundFrameProcessor { +public: + void ParseFrame(ArrayView frame); + size_t ReconstructFrame(ArrayView frame) const; + + bool IsEncrypted() const { return isEncrypted_; } + size_t Size() const { return originalSize_; } + void Clear(); + + ArrayView GetTag() const { return tag_; } + TruncatedSyncNonce GetTruncatedNonce() const { return truncatedNonce_; } + ArrayView GetAuthenticatedData() const + { + return MakeArrayView(authenticated_.data(), authenticated_.size()); + } + ArrayView GetCiphertext() const + { + return MakeArrayView(ciphertext_.data(), ciphertext_.size()); + } + ArrayView GetPlaintext() { return MakeArrayView(plaintext_); } + +private: + void AddAuthenticatedBytes(const uint8_t* data, size_t size); + void AddCiphertextBytes(const uint8_t* data, size_t size); + + bool isEncrypted_{false}; + size_t originalSize_{0}; + ArrayView tag_; + TruncatedSyncNonce truncatedNonce_; + Ranges unencryptedRanges_; + std::vector authenticated_; + std::vector ciphertext_; + std::vector plaintext_; +}; + +class OutboundFrameProcessor { +public: + void ProcessFrame(ArrayView frame, Codec codec); + size_t ReconstructFrame(ArrayView frame); + + Codec GetCodec() const { return codec_; } + const std::vector& GetUnencryptedBytes() const { return unencryptedBytes_; } + const std::vector& GetEncryptedBytes() const { return encryptedBytes_; } + std::vector& GetCiphertextBytes() { return ciphertextBytes_; } + const Ranges& GetUnencryptedRanges() const { return unencryptedRanges_; } + + void Reset(); + void AddUnencryptedBytes(const uint8_t* bytes, size_t size); + void AddEncryptedBytes(const uint8_t* bytes, size_t size); + +private: + Codec codec_{Codec::Unknown}; + size_t frameIndex_{0}; + std::vector unencryptedBytes_; + std::vector encryptedBytes_; + std::vector ciphertextBytes_; + Ranges unencryptedRanges_; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/key_ratchet.h b/src/dpp/dave/key_ratchet.h new file mode 100755 index 0000000000..2b5e9242bb --- /dev/null +++ b/src/dpp/dave/key_ratchet.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include "common.h" + +namespace discord { +namespace dave { + +using KeyGeneration = uint32_t; + +class IKeyRatchet { +public: + virtual ~IKeyRatchet() noexcept = default; + virtual EncryptionKey GetKey(KeyGeneration generation) noexcept = 0; + virtual void DeleteKey(KeyGeneration generation) noexcept = 0; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/leb128.cpp b/src/dpp/dave/leb128.cpp new file mode 100755 index 0000000000..585c8ceb25 --- /dev/null +++ b/src/dpp/dave/leb128.cpp @@ -0,0 +1,60 @@ + +#include "leb128.h" + +// The following code was copied from the webrtc source code: +// https://webrtc.googlesource.com/src/+/refs/heads/main/modules/rtp_rtcp/source/leb128.cc + +namespace discord { +namespace dave { + +size_t Leb128Size(uint64_t value) +{ + int size = 0; + while (value >= 0x80) { + ++size; + value >>= 7; + } + return size + 1; +} + +uint64_t ReadLeb128(const uint8_t*& readAt, const uint8_t* end) +{ + uint64_t value = 0; + int fillBits = 0; + while (readAt != end && fillBits < 64 - 7) { + uint8_t leb128Byte = *readAt; + value |= uint64_t{leb128Byte & 0x7Fu} << fillBits; + ++readAt; + fillBits += 7; + if ((leb128Byte & 0x80) == 0) { + return value; + } + } + // Read 9 bytes and didn't find the terminator byte. Check if 10th byte + // is that terminator, however to fit result into uint64_t it may carry only + // single bit. + if (readAt != end && *readAt <= 1) { + value |= uint64_t{*readAt} << fillBits; + ++readAt; + return value; + } + // Failed to find terminator leb128 byte. + readAt = nullptr; + return 0; +} + +size_t WriteLeb128(uint64_t value, uint8_t* buffer) +{ + int size = 0; + while (value >= 0x80) { + buffer[size] = 0x80 | (value & 0x7F); + ++size; + value >>= 7; + } + buffer[size] = value; + ++size; + return size; +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/leb128.h b/src/dpp/dave/leb128.h new file mode 100755 index 0000000000..97a40e1a39 --- /dev/null +++ b/src/dpp/dave/leb128.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace discord { +namespace dave { + +constexpr size_t Leb128MaxSize = 10; + +// Returns number of bytes needed to store `value` in leb128 format. +size_t Leb128Size(uint64_t value); + +// Reads leb128 encoded value and advance read_at by number of bytes consumed. +// Sets read_at to nullptr on error. +uint64_t ReadLeb128(const uint8_t*& readAt, const uint8_t* end); + +// Encodes `value` in leb128 format. Assumes buffer has size of at least +// Leb128Size(value). Returns number of bytes consumed. +size_t WriteLeb128(uint64_t value, uint8_t* buffer); + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/logger.cpp b/src/dpp/dave/logger.cpp new file mode 100755 index 0000000000..fec4f3a5f1 --- /dev/null +++ b/src/dpp/dave/logger.cpp @@ -0,0 +1,55 @@ +#include "logger.h" + +#include +#include +#include + +namespace discord { +namespace dave { + +std::atomic gLogSink = nullptr; + +void SetLogSink(LogSink sink) +{ + gLogSink = sink; +} + +LogStreamer::LogStreamer(LoggingSeverity severity, const char* file, int line) + : severity_(severity) + , file_(file) + , line_(line) +{ +} + +LogStreamer::~LogStreamer() +{ + std::string logLine = stream_.str(); + if (logLine.empty()) { + return; + } + + auto sink = gLogSink.load(); + if (sink) { + sink(severity_, file_, line_, logLine); + return; + } + + switch (severity_) { + case LS_VERBOSE: + case LS_INFO: + case LS_WARNING: + case LS_ERROR: { + const char* file = file_; + if (auto separator = strrchr(file, '/')) { + file = separator + 1; + } + std::cout << "(" << file << ":" << line_ << ") " << logLine << std::endl; + break; + } + case LS_NONE: + break; + } +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/logger.h b/src/dpp/dave/logger.h new file mode 100755 index 0000000000..dda0fcb4b5 --- /dev/null +++ b/src/dpp/dave/logger.h @@ -0,0 +1,46 @@ +#pragma once + +#include + +#if !defined(DISCORD_LOG) +#define DISCORD_LOG_FILE_LINE(sev, file, line) ::discord::dave::LogStreamer(sev, file, line) +#define DISCORD_LOG(sev) DISCORD_LOG_FILE_LINE(::discord::dave::sev, __FILE__, __LINE__) +#endif +namespace discord { +namespace dave { + +enum LoggingSeverity { + LS_VERBOSE, + LS_INFO, + LS_WARNING, + LS_ERROR, + LS_NONE, +}; + +using LogSink = void (*)(LoggingSeverity severity, + const char* file, + int line, + const std::string& message); +void SetLogSink(LogSink sink); + +class LogStreamer { +public: + LogStreamer(LoggingSeverity severity, const char* file, int line); + ~LogStreamer(); + + template + LogStreamer& operator<<(const T& value) + { + stream_ << value; + return *this; + } + +private: + LoggingSeverity severity_; + const char* file_; + int line_; + std::ostringstream stream_; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/mls_key_ratchet.cpp b/src/dpp/dave/mls_key_ratchet.cpp new file mode 100755 index 0000000000..060bc44473 --- /dev/null +++ b/src/dpp/dave/mls_key_ratchet.cpp @@ -0,0 +1,39 @@ +#include "mls_key_ratchet.h" + +#include + +#include "logger.h" + +namespace discord { +namespace dave { + +MlsKeyRatchet::MlsKeyRatchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept + : hashRatchet_(suite, std::move(baseSecret)) +{ +} + +MlsKeyRatchet::~MlsKeyRatchet() noexcept = default; + +EncryptionKey MlsKeyRatchet::GetKey(KeyGeneration generation) noexcept +{ + DISCORD_LOG(LS_INFO) << "Retrieving key for generation " << generation << " from HashRatchet"; + + try { + auto keyAndNonce = hashRatchet_.get(generation); + assert(keyAndNonce.key.size() >= kAesGcm128KeyBytes); + return std::move(keyAndNonce.key.as_vec()); + } + catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to retrieve key for generation " << generation << ": " + << e.what(); + return {}; + } +} + +void MlsKeyRatchet::DeleteKey(KeyGeneration generation) noexcept +{ + hashRatchet_.erase(generation); +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/mls_key_ratchet.h b/src/dpp/dave/mls_key_ratchet.h new file mode 100755 index 0000000000..ebe3173b11 --- /dev/null +++ b/src/dpp/dave/mls_key_ratchet.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include "key_ratchet.h" + +namespace discord { +namespace dave { + +class MlsKeyRatchet : public IKeyRatchet { +public: + MlsKeyRatchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept; + ~MlsKeyRatchet() noexcept override; + + EncryptionKey GetKey(KeyGeneration generation) noexcept override; + void DeleteKey(KeyGeneration generation) noexcept override; + +private: + ::mlspp::HashRatchet hashRatchet_; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/openssl_cryptor.cpp b/src/dpp/dave/openssl_cryptor.cpp new file mode 100755 index 0000000000..38a32e79a4 --- /dev/null +++ b/src/dpp/dave/openssl_cryptor.cpp @@ -0,0 +1,108 @@ +#include "openssl_cryptor.h" + +#include + +#include + +#include "common.h" +#include "logger.h" + +namespace discord { +namespace dave { + +void PrintSSLErrors() +{ + ERR_print_errors_cb( + [](const char* str, size_t len, void* ctx) { + DISCORD_LOG(LS_ERROR) << std::string(str, len); + return 1; + }, + nullptr); +} + +BoringSSLCryptor::BoringSSLCryptor(const EncryptionKey& encryptionKey) +{ + /*EVP_AEAD_CTX_zero(&cipherCtx_); + + auto initResult = EVP_AEAD_CTX_init(&cipherCtx_, + EVP_aead_aes_128_gcm(), + encryptionKey.data(), + encryptionKey.size(), + kAesGcm128TruncatedTagBytes, + nullptr); + + if (initResult != 1) { + DISCORD_LOG(LS_ERROR) << "Failed to initialize AEAD context"; + PrintSSLErrors(); + }*/ +} + +BoringSSLCryptor::~BoringSSLCryptor() +{ + //EVP_AEAD_CTX_cleanup(&cipherCtx_); +} + +bool BoringSSLCryptor::Encrypt(ArrayView ciphertextBufferOut, + ArrayView plaintextBuffer, + ArrayView nonceBuffer, + ArrayView additionalData, + ArrayView tagBufferOut) +{ + /*if (cipherCtx_.aead == nullptr) { + DISCORD_LOG(LS_ERROR) << "Encrypt: AEAD context is not initialized"; + return false; + } + + size_t tagSizeOut; + auto encryptResult = EVP_AEAD_CTX_seal_scatter(&cipherCtx_, + ciphertextBufferOut.data(), + tagBufferOut.data(), + &tagSizeOut, + kAesGcm128TruncatedTagBytes, + nonceBuffer.data(), + kAesGcm128NonceBytes, + plaintextBuffer.data(), + plaintextBuffer.size(), + nullptr, + 0, + additionalData.data(), + additionalData.size()); + if (encryptResult != 1) { + DISCORD_LOG(LS_ERROR) << "Failed to encrypt data"; + PrintSSLErrors(); + } + + return encryptResult == 1; + */ + return false; +} + +bool BoringSSLCryptor::Decrypt(ArrayView plaintextBufferOut, + ArrayView ciphertextBuffer, + ArrayView tagBuffer, + ArrayView nonceBuffer, + ArrayView additionalData) +{ + /*if (cipherCtx_.aead == nullptr) { + DISCORD_LOG(LS_ERROR) << "Decrypt: AEAD context is not initialized"; + return false; + } + + auto decryptResult = EVP_AEAD_CTX_open_gather(&cipherCtx_, + plaintextBufferOut.data(), + nonceBuffer.data(), + kAesGcm128NonceBytes, + ciphertextBuffer.data(), + ciphertextBuffer.size(), + tagBuffer.data(), + kAesGcm128TruncatedTagBytes, + additionalData.data(), + additionalData.size()); + + return decryptResult == 1; + */ + return false; +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/openssl_cryptor.h b/src/dpp/dave/openssl_cryptor.h new file mode 100755 index 0000000000..2128c66c34 --- /dev/null +++ b/src/dpp/dave/openssl_cryptor.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include "cryptor.h" + +namespace discord { +namespace dave { + +class BoringSSLCryptor : public ICryptor { +public: + BoringSSLCryptor(const EncryptionKey& encryptionKey); + ~BoringSSLCryptor(); + + bool IsValid() const { /*return cipherCtx_.aead != nullptr;*/ return true; } + + bool Encrypt(ArrayView ciphertextBufferOut, + ArrayView plaintextBuffer, + ArrayView nonceBuffer, + ArrayView additionalData, + ArrayView tagBufferOut) override; + bool Decrypt(ArrayView plaintextBufferOut, + ArrayView ciphertextBuffer, + ArrayView tagBuffer, + ArrayView nonceBuffer, + ArrayView additionalData) override; + +private: + //EVP_AEAD_CTX cipherCtx_; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/parameters.cpp b/src/dpp/dave/parameters.cpp new file mode 100755 index 0000000000..5ee8050dbf --- /dev/null +++ b/src/dpp/dave/parameters.cpp @@ -0,0 +1,57 @@ +#include "parameters.h" + +namespace discord { +namespace dave { +namespace mls { + +::mlspp::CipherSuite::ID CiphersuiteIDForProtocolVersion(ProtocolVersion version) noexcept +{ + return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; +} + +::mlspp::CipherSuite CiphersuiteForProtocolVersion(ProtocolVersion version) noexcept +{ + return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; +} + +::mlspp::CipherSuite::ID CiphersuiteIDForSignatureVersion(SignatureVersion version) noexcept +{ + return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; +} + +::mlspp::CipherSuite CiphersuiteForSignatureVersion(SignatureVersion version) noexcept +{ + return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; +} + +::mlspp::Capabilities LeafNodeCapabilitiesForProtocolVersion(ProtocolVersion version) noexcept +{ + auto capabilities = ::mlspp::Capabilities::create_default(); + + capabilities.cipher_suites = {CiphersuiteIDForProtocolVersion(version)}; + capabilities.credentials = {::mlspp::CredentialType::basic}; + + return capabilities; +} + +::mlspp::ExtensionList LeafNodeExtensionsForProtocolVersion(ProtocolVersion version) noexcept +{ + return ::mlspp::ExtensionList{}; +} + +::mlspp::ExtensionList GroupExtensionsForProtocolVersion( + ProtocolVersion version, + const ::mlspp::ExternalSender& externalSender) noexcept +{ + auto extensionList = ::mlspp::ExtensionList{}; + + extensionList.add(::mlspp::ExternalSendersExtension{{ + {externalSender.signature_key, externalSender.credential}, + }}); + + return extensionList; +} + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/parameters.h b/src/dpp/dave/parameters.h new file mode 100755 index 0000000000..bdda62bf8d --- /dev/null +++ b/src/dpp/dave/parameters.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include +#include + +#include "version.h" + +namespace discord { +namespace dave { +namespace mls { + +::mlspp::CipherSuite::ID CiphersuiteIDForProtocolVersion(ProtocolVersion version) noexcept; +::mlspp::CipherSuite CiphersuiteForProtocolVersion(ProtocolVersion version) noexcept; +::mlspp::CipherSuite::ID CiphersuiteIDForSignatureVersion(SignatureVersion version) noexcept; +::mlspp::CipherSuite CiphersuiteForSignatureVersion(SignatureVersion version) noexcept; +::mlspp::Capabilities LeafNodeCapabilitiesForProtocolVersion(ProtocolVersion version) noexcept; +::mlspp::ExtensionList LeafNodeExtensionsForProtocolVersion(ProtocolVersion version) noexcept; +::mlspp::ExtensionList GroupExtensionsForProtocolVersion( + ProtocolVersion version, + const ::mlspp::ExternalSender& externalSender) noexcept; + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/persisted_key_pair.cpp b/src/dpp/dave/persisted_key_pair.cpp new file mode 100755 index 0000000000..240240b696 --- /dev/null +++ b/src/dpp/dave/persisted_key_pair.cpp @@ -0,0 +1,103 @@ +#include "persisted_key_pair.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "parameters.h" +#include "logger.h" + +static const std::string SelfSignatureLabel = "DiscordSelfSignature"; + +static std::string MakeKeyID(const std::string& sessionID, ::mlspp::CipherSuite suite) +{ + return sessionID + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + + std::to_string(discord::dave::mls::KeyVersion); +} + +static std::mutex mtx; +static std::map> map; + +namespace discord { +namespace dave { +namespace mls { + +static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( + KeyPairContextType ctx, + const std::string& sessionID, + ::mlspp::CipherSuite suite) +{ + std::lock_guard lk(mtx); + + std::string id = MakeKeyID(sessionID, suite); + + if (auto it = map.find(id); it != map.end()) { + return it->second; + } + + std::shared_ptr<::mlspp::SignaturePrivateKey> ret; + + if (!ret) { + ret = ::discord::dave::mls::detail::GetGenericPersistedKeyPair(ctx, id, suite); + } + + if (!ret) { + DISCORD_LOG(LS_ERROR) << "Failed to get key in GetPersistedKeyPair"; + return nullptr; + } + + map.emplace(id, ret); + + return ret; +} + +std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(KeyPairContextType ctx, + const std::string& sessionID, + ProtocolVersion version) +{ + return GetPersistedKeyPair(ctx, sessionID, CiphersuiteForProtocolVersion(version)); +} + +KeyAndSelfSignature GetPersistedPublicKey(KeyPairContextType ctx, + const std::string& sessionID, + SignatureVersion version) +{ + auto suite = CiphersuiteForSignatureVersion(version); + + auto pair = GetPersistedKeyPair(ctx, sessionID, suite); + + if (!pair) { + return {}; + } + + bytes sign_data = from_ascii(sessionID + ":") + pair->public_key.data; + + return { + pair->public_key.data.as_vec(), + std::move(pair->sign(suite, SelfSignatureLabel, sign_data).as_vec()), + }; +} + +bool DeletePersistedKeyPair(KeyPairContextType ctx, + const std::string& sessionID, + SignatureVersion version) +{ + std::string id = MakeKeyID(sessionID, CiphersuiteForSignatureVersion(version)); + + std::lock_guard lk(mtx); + + map.erase(id); + + return ::discord::dave::mls::detail::DeleteGenericPersistedKeyPair(ctx, id); +} + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/persisted_key_pair.h b/src/dpp/dave/persisted_key_pair.h new file mode 100755 index 0000000000..2130a7146b --- /dev/null +++ b/src/dpp/dave/persisted_key_pair.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "parameters.h" +#include "version.h" + +namespace mlspp { +struct SignaturePrivateKey; +}; + +namespace discord { +namespace dave { +namespace mls { + +typedef const char* KeyPairContextType; + + +std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(KeyPairContextType ctx, + const std::string& sessionID, + ProtocolVersion version); + +struct KeyAndSelfSignature { + std::vector key; + std::vector signature; +}; + +KeyAndSelfSignature GetPersistedPublicKey(KeyPairContextType ctx, + const std::string& sessionID, + SignatureVersion version); + +bool DeletePersistedKeyPair(KeyPairContextType ctx, + const std::string& sessionID, + SignatureVersion version); + +constexpr unsigned KeyVersion = 1; + +} // namespace mls +} // namespace dave +} // namespace discord + +namespace discord { + namespace dave { + namespace mls { + namespace detail { + std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPairContextType ctx, const std::string& id, ::mlspp::CipherSuite suite); + bool DeleteGenericPersistedKeyPair(KeyPairContextType ctx, const std::string& id); + } + } + } +} \ No newline at end of file diff --git a/src/dpp/dave/persisted_key_pair_generic.cpp b/src/dpp/dave/persisted_key_pair_generic.cpp new file mode 100755 index 0000000000..e648a5a427 --- /dev/null +++ b/src/dpp/dave/persisted_key_pair_generic.cpp @@ -0,0 +1,176 @@ +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif +#include + +#include +#include + +#include "parameters.h" +#include "logger.h" +#include "persisted_key_pair.h" + +static const std::string_view KeyStorageDir = "Discord Key Storage"; + +static std::filesystem::path GetKeyStorageDirectory() +{ + std::filesystem::path dir; + +#if defined(__ANDROID__) + dir = std::filesystem::path("/data/data"); + + { + std::ifstream idFile("/proc/self/cmdline", std::ios_base::in); + std::string appId; + std::getline(idFile, appId, '\0'); + dir /= appId; + } +#else // __ANDROID__ +#if defined(_WIN32) + if (const wchar_t* appdata = _wgetenv(L"LOCALAPPDATA")) { + dir = std::filesystem::path(appdata); + } +#else // _WIN32 + if (const char* xdg = getenv("XDG_CONFIG_HOME")) { + dir = std::filesystem::path(xdg); + } + else if (const char* home = getenv("HOME")) { + dir = std::filesystem::path(home); + dir /= ".config"; + } +#endif // !_WIN32 + else { + return dir; + } +#endif // !__ANDROID__ + + return dir / KeyStorageDir; +} + +namespace discord { +namespace dave { +namespace mls { +namespace detail { + +std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPairContextType ctx, + const std::string& id, + ::mlspp::CipherSuite suite) +{ + ::mlspp::SignaturePrivateKey ret; + std::string curstr; + std::filesystem::path dir = GetKeyStorageDirectory(); + + if (dir.empty()) { + DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; + return nullptr; + } + + std::error_code errc; + std::filesystem::create_directories(dir, errc); + if (errc) { + DISCORD_LOG(LS_ERROR) << "Failed to create key storage directory in GetPersistedKeyPair: " + << errc; + return nullptr; + } + + std::filesystem::path file = dir / (id + ".key"); + + if (std::filesystem::exists(file)) { + std::ifstream ifs(file, std::ios_base::in | std::ios_base::binary); + if (!ifs) { + DISCORD_LOG(LS_ERROR) << "Failed to open key in GetPersistedKeyPair"; + return nullptr; + } + + curstr = (std::stringstream() << ifs.rdbuf()).str(); + if (!ifs) { + DISCORD_LOG(LS_ERROR) << "Failed to read key in GetPersistedKeyPair"; + return nullptr; + } + + try { + ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr); + } + catch (std::exception& ex) { + DISCORD_LOG(LS_ERROR) << "Failed to parse key in GetPersistedKeyPair: " << ex.what(); + return nullptr; + } + } + else { + ret = ::mlspp::SignaturePrivateKey::generate(suite); + + std::string newstr = ret.to_jwk(suite); + + std::filesystem::path tmpfile = file; + tmpfile += ".tmp"; + +#ifdef _WIN32 + int fd = _wopen(tmpfile.c_str(), _O_WRONLY | _O_CREAT | _O_TRUNC, _S_IREAD | _S_IWRITE); +#else + int fd = open(tmpfile.c_str(), + O_WRONLY | O_CLOEXEC | O_NOFOLLOW | O_CREAT | O_TRUNC, + S_IRUSR | S_IWUSR); +#endif + if (fd < 0) { + DISCORD_LOG(LS_ERROR) << "Failed to open output file in GetPersistedKeyPair: " << errno + << "(" << tmpfile << ")"; + return nullptr; + } + +#ifdef _WIN32 + int wret = _write(fd, newstr.c_str(), newstr.size()); + _close(fd); +#else + ssize_t wret = write(fd, newstr.c_str(), newstr.size()); + close(fd); +#endif + if (wret < 0 || (size_t)wret != newstr.size()) { + DISCORD_LOG(LS_ERROR) << "Failed to write output file in GetPersistedKeyPair: " + << errno; + return nullptr; + } + + std::filesystem::rename(tmpfile, file, errc); + if (errc) { + DISCORD_LOG(LS_ERROR) << "Failed to rename output file in GetPersistedKeyPair: " + << errc; + return nullptr; + } + } + + if (!ret.public_key.data.empty()) { + return std::make_shared<::mlspp::SignaturePrivateKey>(std::move(ret)); + } + else { + return nullptr; + } +} + +bool DeleteGenericPersistedKeyPair(KeyPairContextType ctx, const std::string& id) +{ + std::error_code errc; + std::filesystem::path dir = GetKeyStorageDirectory(); + if (dir.empty()) { + DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; + return false; + } + + std::filesystem::path file = dir / (id + ".key"); + + return std::filesystem::remove(file, errc); +} + +} // namespace detail +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/scope_exit.h b/src/dpp/dave/scope_exit.h new file mode 100755 index 0000000000..f61f65f692 --- /dev/null +++ b/src/dpp/dave/scope_exit.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include + +namespace discord { +namespace dave { + +class [[nodiscard]] ScopeExit final { +public: + template + explicit ScopeExit(Cleanup&& cleanup) + : cleanup_{std::forward(cleanup)} + { + } + + ScopeExit(ScopeExit&& rhs) + : cleanup_{std::move(rhs.cleanup_)} + { + rhs.cleanup_ = nullptr; + } + + ~ScopeExit() + { + if (cleanup_) { + cleanup_(); + } + } + + ScopeExit& operator=(ScopeExit&& rhs) + { + cleanup_ = std::move(rhs.cleanup_); + rhs.cleanup_ = nullptr; + return *this; + } + + void Dismiss() { cleanup_ = nullptr; } + +private: + ScopeExit(ScopeExit const&) = delete; + ScopeExit& operator=(ScopeExit const&) = delete; + + std::function cleanup_; +}; + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/session.cpp b/src/dpp/dave/session.cpp new file mode 100755 index 0000000000..6b8bee0200 --- /dev/null +++ b/src/dpp/dave/session.cpp @@ -0,0 +1,785 @@ +#include "session.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "logger.h" +#include "mls_key_ratchet.h" +#include "user_credential.h" +#include "parameters.h" +#include "persisted_key_pair.h" +#include "util.h" + +#include "openssl/evp.h" + +#define TRACK_MLS_ERROR(reason) \ + if (onMLSFailureCallback_) { \ + onMLSFailureCallback_(__FUNCTION__, reason); \ + } + +namespace discord { +namespace dave { +namespace mls { + +struct QueuedProposal { + ::mlspp::ValidatedContent content; + ::mlspp::bytes_ns::bytes ref; +}; + +Session::Session(KeyPairContextType context, + std::string authSessionId, + MLSFailureCallback callback) noexcept + : signingKeyId_(authSessionId) + , keyPairContext_(context) + , onMLSFailureCallback_(std::move(callback)) +{ + DISCORD_LOG(LS_INFO) << "Creating a new MLS session"; +} + +Session::~Session() noexcept = default; + +void Session::Init(ProtocolVersion protocolVersion, + uint64_t groupId, + std::string const& selfUserId, + std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept +{ + Reset(); + + selfUserId_ = selfUserId; + + DISCORD_LOG(LS_INFO) << "Initializing MLS session with protocol version " << protocolVersion + << " and group ID " << groupId; + protocolVersion_ = protocolVersion; + groupId_ = std::move(BigEndianBytesFrom(groupId).as_vec()); + + InitLeafNode(selfUserId, transientKey); + + CreatePendingGroup(); +} + +void Session::Reset() noexcept +{ + DISCORD_LOG(LS_INFO) << "Resetting MLS session"; + + ClearPendingState(); + + currentState_.reset(); + outboundCachedGroupState_.reset(); + + protocolVersion_ = 0; + groupId_.clear(); +} + +void Session::SetProtocolVersion(ProtocolVersion version) noexcept +{ + if (version != protocolVersion_) { + // when we need to retain backwards compatibility + // there may be some changes to the MLS objects required here + // until then we can just update the stored version + protocolVersion_ = version; + } +} + +std::vector Session::GetLastEpochAuthenticator() const noexcept +{ + if (!currentState_) { + DISCORD_LOG(LS_ERROR) << "Cannot get epoch authenticator without an established MLS group"; + return {}; + } + + return std::move(currentState_->epoch_authenticator().as_vec()); +} + +void Session::SetExternalSender(const std::vector& marshalledExternalSender) noexcept +try { + if (currentState_) { + DISCORD_LOG(LS_ERROR) << "Cannot set external sender after joining/creating an MLS group"; + return; + } + + DISCORD_LOG(LS_INFO) << "Unmarshalling MLS external sender"; + + DISCORD_LOG(LS_INFO) << "Sender: " << ::mlspp::bytes_ns::bytes(marshalledExternalSender); + + externalSender_ = std::make_unique<::mlspp::ExternalSender>( + ::mlspp::tls::get<::mlspp::ExternalSender>(marshalledExternalSender)); + + if (!groupId_.empty()) { + CreatePendingGroup(); + } +} +catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to unmarshal external sender: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return; +} + +std::optional> Session::ProcessProposals( + std::vector proposals, + std::set const& recognizedUserIDs) noexcept +try { + if (!pendingGroupState_ && !currentState_) { + DISCORD_LOG(LS_ERROR) + << "Cannot process proposals without any pending or established MLS group state"; + return std::nullopt; + } + + if (!stateWithProposals_) { + stateWithProposals_ = std::make_unique<::mlspp::State>( + pendingGroupState_ ? *pendingGroupState_ : *currentState_); + } + + DISCORD_LOG(LS_INFO) << "Processing MLS proposals message of " << proposals.size() << " bytes"; + + DISCORD_LOG(LS_INFO) << "Proposals: " << ::mlspp::bytes_ns::bytes(proposals); + + ::mlspp::tls::istream inStream(proposals); + + bool isRevoke = false; + inStream >> isRevoke; + + DISCORD_LOG(LS_INFO) << "Revoking: " << isRevoke; + + const auto suite = stateWithProposals_->cipher_suite(); + + if (isRevoke) { + std::vector<::mlspp::bytes_ns::bytes> refs; + inStream >> refs; + + for (const auto& ref : refs) { + bool found = false; + for (auto it = proposalQueue_.begin(); it != proposalQueue_.end(); it++) { + if (it->ref == ref) { + found = true; + proposalQueue_.erase(it); + break; + } + } + + if (!found) { + DISCORD_LOG(LS_ERROR) << "Cannot revoke unrecognized proposal ref"; + TRACK_MLS_ERROR("Unrecognized proposal revocation"); + return std::nullopt; + } + } + + stateWithProposals_ = std::make_unique<::mlspp::State>( + pendingGroupState_ ? *pendingGroupState_ : *currentState_); + + for (auto& prop : proposalQueue_) { + // success will queue the proposal, failure will throw + stateWithProposals_->handle(prop.content); + } + } + else { + std::vector<::mlspp::MLSMessage> messages; + inStream >> messages; + + for (const auto& proposalMessage : messages) { + auto validatedMessage = stateWithProposals_->unwrap(proposalMessage); + + if (!ValidateProposalMessage(validatedMessage.authenticated_content(), + *stateWithProposals_, + recognizedUserIDs)) { + return std::nullopt; + } + + // success will queue the proposal, failure will throw + stateWithProposals_->handle(validatedMessage); + + auto ref = suite.ref(validatedMessage.authenticated_content()); + + proposalQueue_.push_back({ + std::move(validatedMessage), + std::move(ref), + }); + } + } + + // generate a commit + auto commitSecret = ::mlspp::hpke::random_bytes(suite.secret_size()); + + auto commitOpts = ::mlspp::CommitOpts{ + {}, // no extra proposals + true, // inline tree in welcome + false, // do not force path + {} // default leaf node options + }; + + auto [commitMessage, welcomeMessage, newState] = + stateWithProposals_->commit(commitSecret, commitOpts, {}); + + DISCORD_LOG(LS_INFO) + << "Prepared commit/welcome/next state for MLS group from received proposals"; + + // combine the commit and welcome messages into a single buffer + auto outStream = ::mlspp::tls::ostream(); + outStream << commitMessage; + + // keep a copy of the commit, we can check incoming pending group commit later for a match + pendingGroupCommit_ = std::make_unique<::mlspp::MLSMessage>(std::move(commitMessage)); + + // if there were any add proposals in this commit, then we also include the welcome message + if (welcomeMessage.secrets.size() > 0) { + outStream << welcomeMessage; + } + + // cache the outbound state in case we're the winning sender + outboundCachedGroupState_ = std::make_unique<::mlspp::State>(std::move(newState)); + + DISCORD_LOG(LS_INFO) << "Output: " << ::mlspp::bytes_ns::bytes(outStream.bytes()); + + return outStream.bytes(); +} +catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to parse MLS proposals: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return std::nullopt; +} + +bool Session::IsRecognizedUserID(const ::mlspp::Credential& cred, + std::set const& recognizedUserIDs) const +{ + std::string uid = UserCredentialToString(cred, protocolVersion_); + if (uid.empty()) { + DISCORD_LOG(LS_ERROR) << "Attempted to verify credential of unexpected type"; + return false; + } + + if (recognizedUserIDs.find(uid) == recognizedUserIDs.end()) { + DISCORD_LOG(LS_ERROR) << "Attempted to verify credential for unrecognized user ID: " << uid; + return false; + } + + return true; +} + +bool Session::ValidateProposalMessage(::mlspp::AuthenticatedContent const& message, + ::mlspp::State const& targetState, + std::set const& recognizedUserIDs) const +{ + if (message.wire_format != ::mlspp::WireFormat::mls_public_message) { + DISCORD_LOG(LS_ERROR) << "MLS proposal message must be PublicMessage"; + TRACK_MLS_ERROR("Invalid proposal wire format"); + return false; + } + + if (message.content.epoch != targetState.epoch()) { + DISCORD_LOG(LS_ERROR) << "MLS proposal message must be for current epoch (" + << message.content.epoch << " != " << targetState.epoch() << ")"; + TRACK_MLS_ERROR("Proposal epoch mismatch"); + return false; + } + + if (message.content.content_type() != ::mlspp::ContentType::proposal) { + DISCORD_LOG(LS_ERROR) << "ProcessProposals called with non-proposal message"; + TRACK_MLS_ERROR("Unexpected message type"); + return false; + } + + if (message.content.sender.sender_type() != ::mlspp::SenderType::external) { + DISCORD_LOG(LS_ERROR) << "MLS proposal must be from external sender"; + TRACK_MLS_ERROR("Unexpected proposal sender type"); + return false; + } + + const auto& proposal = ::mlspp::tls::var::get<::mlspp::Proposal>(message.content.content); + switch (proposal.proposal_type()) { + case ::mlspp::ProposalType::add: { + const auto& credential = + ::mlspp::tls::var::get<::mlspp::Add>(proposal.content).key_package.leaf_node.credential; + if (!IsRecognizedUserID(credential, recognizedUserIDs)) { + DISCORD_LOG(LS_ERROR) << "MLS add proposal must be for recognized user"; + TRACK_MLS_ERROR("Unexpected user ID in add proposal"); + return false; + } + break; + } + case ::mlspp::ProposalType::remove: + // Remove proposals are always allowed (mlspp will validate that it's a recognized user) + break; + default: + DISCORD_LOG(LS_ERROR) << "MLS proposal must be add or remove"; + TRACK_MLS_ERROR("Unexpected proposal type"); + return false; + } + + return true; +} + +bool Session::CanProcessCommit(const ::mlspp::MLSMessage& commit) noexcept +{ + if (!stateWithProposals_) { + return false; + } + + if (commit.group_id() != groupId_) { + DISCORD_LOG(LS_ERROR) << "MLS commit message was for unexpected group"; + return false; + } + + return true; +} + +RosterVariant Session::ProcessCommit(std::vector commit) noexcept +try { + DISCORD_LOG(LS_INFO) << "Processing commit"; + DISCORD_LOG(LS_INFO) << "Commit: " << ::mlspp::bytes_ns::bytes(commit); + + auto commitMessage = ::mlspp::tls::get<::mlspp::MLSMessage>(commit); + + if (!CanProcessCommit(commitMessage)) { + DISCORD_LOG(LS_ERROR) << "ProcessCommit called with unprocessable MLS commit"; + return ignored_t{}; + } + + // in case we're the sender of this commit + // we need to pull the cached state from our outbound cache + std::optional<::mlspp::State> optionalCachedState = std::nullopt; + if (outboundCachedGroupState_) { + optionalCachedState = *(outboundCachedGroupState_.get()); + } + + auto newState = stateWithProposals_->handle(commitMessage, optionalCachedState); + + if (!newState) { + DISCORD_LOG(LS_ERROR) << "MLS commit handling did not produce a new state"; + return failed_t{}; + } + + DISCORD_LOG(LS_INFO) << "Successfully processed MLS commit, updating state; our leaf index is " + << newState->index().val << "; current epoch is " << newState->epoch(); + + RosterMap ret = ReplaceState(std::make_unique<::mlspp::State>(std::move(*newState))); + + // reset the outbound cached group since we handled the commit for this epoch + outboundCachedGroupState_.reset(); + + ClearPendingState(); + + return ret; +} +catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to process MLS commit: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return failed_t{}; +} + +std::optional Session::ProcessWelcome( + std::vector welcome, + std::set const& recognizedUserIDs) noexcept +try { + if (!HasCryptographicStateForWelcome()) { + DISCORD_LOG(LS_ERROR) << "Missing local cyrpto state necessary to process MLS welcome"; + return std::nullopt; + } + + if (!externalSender_) { + DISCORD_LOG(LS_ERROR) << "Cannot process MLS welcome without an external sender"; + return std::nullopt; + } + + if (currentState_) { + DISCORD_LOG(LS_ERROR) << "Cannot process MLS welcome after joining/creating an MLS group"; + return std::nullopt; + } + + DISCORD_LOG(LS_INFO) << "Processing welcome: " << ::mlspp::bytes_ns::bytes(welcome); + + // unmarshal the incoming welcome + auto unmarshalledWelcome = ::mlspp::tls::get<::mlspp::Welcome>(welcome); + + // construct the state from the unmarshalled welcome + auto newState = std::make_unique<::mlspp::State>( + *joinInitPrivateKey_, + *selfHPKEPrivateKey_, + *selfSigPrivateKey_, + *joinKeyPackage_, + unmarshalledWelcome, + std::nullopt, + std::map<::mlspp::bytes_ns::bytes, ::mlspp::bytes_ns::bytes>()); + + // perform application-level verification of the new state + if (!VerifyWelcomeState(*newState, recognizedUserIDs)) { + DISCORD_LOG(LS_ERROR) << "Group received in MLS welcome is not valid"; + + return std::nullopt; + } + + DISCORD_LOG(LS_INFO) << "Successfully welcomed to MLS Group, our leaf index is " + << newState->index().val << "; current epoch is " << newState->epoch(); + + // make the verified state our new (and only) state + RosterMap ret = ReplaceState(std::move(newState)); + + // clear out any pending state for creating/joining a group + ClearPendingState(); + + return ret; +} +catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to create group state from MLS welcome: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return std::nullopt; +} + +RosterMap Session::ReplaceState(std::unique_ptr<::mlspp::State>&& state) +{ + RosterMap newRoster; + for (const ::mlspp::LeafNode& node : state->roster()) { + if (node.credential.type() != ::mlspp::CredentialType::basic) { + continue; + } + + const auto& cred = node.credential.template get<::mlspp::BasicCredential>(); + + newRoster[FromBigEndianBytes(cred.identity)] = node.signature_key.data.as_vec(); + } + + RosterMap changeMap; + + std::set_difference(newRoster.begin(), + newRoster.end(), + roster_.begin(), + roster_.end(), + std::inserter(changeMap, changeMap.end())); + + struct MissingItemWrapper { + RosterMap& changeMap_; + + using iterator = RosterMap::iterator; + using const_iterator = RosterMap::const_iterator; + using value_type = RosterMap::value_type; + + iterator insert(const_iterator it, const value_type& value) + { + return changeMap_.try_emplace(it, value.first, std::vector{}); + } + + iterator begin() { return changeMap_.begin(); } + + iterator end() { return changeMap_.end(); } + }; + + MissingItemWrapper wrapper{changeMap}; + + std::set_difference(roster_.begin(), + roster_.end(), + newRoster.begin(), + newRoster.end(), + std::inserter(wrapper, wrapper.end())); + + roster_ = std::move(newRoster); + currentState_ = std::move(state); + + return changeMap; +} + +bool Session::HasCryptographicStateForWelcome() const noexcept +{ + return joinKeyPackage_ && joinInitPrivateKey_ && selfSigPrivateKey_ && selfHPKEPrivateKey_; +} + +bool Session::VerifyWelcomeState(::mlspp::State const& state, + std::set const& recognizedUserIDs) const +{ + if (!externalSender_) { + DISCORD_LOG(LS_ERROR) << "Cannot verify MLS welcome without an external sender"; + TRACK_MLS_ERROR("Missing external sender when processing Welcome"); + return false; + } + + auto ext = state.extensions().template find(); + if (!ext) { + DISCORD_LOG(LS_ERROR) << "MLS welcome missing external senders extension"; + TRACK_MLS_ERROR("Welcome message missing external sender extension"); + return false; + } + + if (ext->senders.size() != 1) { + DISCORD_LOG(LS_ERROR) << "MLS welcome lists unexpected number of external senders: " + << ext->senders.size(); + TRACK_MLS_ERROR("Welcome message lists unexpected external sender count"); + return false; + } + + if (ext->senders.front() != *externalSender_) { + DISCORD_LOG(LS_ERROR) << "MLS welcome lists unexpected external sender"; + TRACK_MLS_ERROR("Welcome message lists unexpected external sender"); + return false; + } + + // TODO: Until we leverage revocation in the protocol + // if we re-enable this change we will refuse welcome messages + // because someone was previously supposed to be added but disconnected + // before all in-flight proposals were handled. + + for (const auto& leaf : state.roster()) { + if (!IsRecognizedUserID(leaf.credential, recognizedUserIDs)) { + DISCORD_LOG(LS_ERROR) << "MLS welcome lists unrecognized user ID"; + // TRACK_MLS_ERROR("Welcome message lists unrecognized user ID"); + // return false; + } + } + + return true; +} + +void Session::InitLeafNode(std::string const& selfUserId, + std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept +try { + auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); + + if (!transientKey) { + if (!signingKeyId_.empty()) { + transientKey = GetPersistedKeyPair(keyPairContext_, signingKeyId_, protocolVersion_); + if (!transientKey) { + DISCORD_LOG(LS_ERROR) << "Did not receive MLS signature private key from " + "GetPersistedKeyPair; aborting"; + return; + } + } + else { + transientKey = std::make_shared<::mlspp::SignaturePrivateKey>( + ::mlspp::SignaturePrivateKey::generate(ciphersuite)); + } + } + + selfSigPrivateKey_ = transientKey; + + auto selfCredential = CreateUserCredential(selfUserId, protocolVersion_); + + selfHPKEPrivateKey_ = + std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); + + selfLeafNode_ = + std::make_unique<::mlspp::LeafNode>(ciphersuite, + selfHPKEPrivateKey_->public_key, + selfSigPrivateKey_->public_key, + std::move(selfCredential), + LeafNodeCapabilitiesForProtocolVersion(protocolVersion_), + ::mlspp::Lifetime::create_default(), + LeafNodeExtensionsForProtocolVersion(protocolVersion_), + *selfSigPrivateKey_); + + DISCORD_LOG(LS_INFO) << "Created MLS leaf node"; +} +catch (const std::exception& e) { + DISCORD_LOG(LS_INFO) << "Failed to initialize MLS leaf node: " << e.what(); + TRACK_MLS_ERROR(e.what()); +} + +void Session::ResetJoinKeyPackage() noexcept +try { + if (!selfLeafNode_) { + DISCORD_LOG(LS_ERROR) << "Cannot initialize join key package without a leaf node"; + return; + } + + auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); + + joinInitPrivateKey_ = + std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); + + joinKeyPackage_ = + std::make_unique<::mlspp::KeyPackage>(ciphersuite, + joinInitPrivateKey_->public_key, + *selfLeafNode_, + LeafNodeExtensionsForProtocolVersion(protocolVersion_), + *selfSigPrivateKey_); + + DISCORD_LOG(LS_INFO) << "Generated key package: " + << ::mlspp::bytes_ns::bytes(::mlspp::tls::marshal(*joinKeyPackage_)); +} +catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to initialize join key package: " << e.what(); + TRACK_MLS_ERROR(e.what()); +} + +void Session::CreatePendingGroup() noexcept +try { + if (groupId_.empty()) { + DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without a group ID"; + return; + } + + if (!externalSender_) { + DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without ExternalSender"; + return; + } + + if (!selfLeafNode_) { + DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without self leaf node"; + return; + } + + DISCORD_LOG(LS_INFO) << "Creating a pending MLS group"; + + auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); + + pendingGroupState_ = std::make_unique<::mlspp::State>( + groupId_, + ciphersuite, + *selfHPKEPrivateKey_, + *selfSigPrivateKey_, + *selfLeafNode_, + GroupExtensionsForProtocolVersion(protocolVersion_, *externalSender_)); + + DISCORD_LOG(LS_INFO) << "Created a pending MLS group"; +} +catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to create MLS group: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return; +} + +std::vector Session::GetMarshalledKeyPackage() noexcept +try { + // key packages are not meant to be re-used + // so every time the client asks for a key package we create a new one + ResetJoinKeyPackage(); + + if (!joinKeyPackage_) { + DISCORD_LOG(LS_ERROR) << "Cannot marshal an uninitialized key package"; + return {}; + } + + return ::mlspp::tls::marshal(*joinKeyPackage_); +} +catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to marshal join key package: " << e.what(); + TRACK_MLS_ERROR(e.what()); + return {}; +} + +std::unique_ptr Session::GetKeyRatchet(std::string const& userId) const noexcept +{ + if (!currentState_) { + DISCORD_LOG(LS_ERROR) << "Cannot get key ratchet without an established MLS group"; + return nullptr; + } + + // change the string user ID to a little endian 64 bit user ID + auto u64userId = strtoull(userId.c_str(), nullptr, 10); + auto userIdBytes = ::mlspp::bytes_ns::bytes(sizeof(u64userId)); + memcpy(userIdBytes.data(), &u64userId, sizeof(u64userId)); + + // generate the base secret for the hash ratchet + auto baseSecret = + currentState_->do_export(Session::USER_MEDIA_KEY_BASE_LABEL, userIdBytes, kAesGcm128KeyBytes); + + // this assumes the MLS ciphersuite produces a kAesGcm128KeyBytes sized key + // would need to be updated to a different ciphersuite if there's a future mismatch + return std::make_unique(currentState_->cipher_suite(), std::move(baseSecret)); +} + +void Session::GetPairwiseFingerprint(uint16_t version, + std::string const& userId, + PairwiseFingerprintCallback callback) const noexcept +try { + if (!currentState_ || !selfSigPrivateKey_) { + throw std::invalid_argument("No established MLS group"); + } + + uint64_t u64RemoteUserId = strtoull(userId.c_str(), nullptr, 10); + uint64_t u64SelfUserId = strtoull(selfUserId_.c_str(), nullptr, 10); + + auto it = roster_.find(u64RemoteUserId); + if (it == roster_.end()) { + throw std::invalid_argument("Unknown user ID: " + userId); + } + + ::mlspp::tls::ostream toHash1; + ::mlspp::tls::ostream toHash2; + + toHash1 << version; + toHash1.write_raw(it->second); + toHash1 << u64RemoteUserId; + + toHash2 << version; + toHash2.write_raw(selfSigPrivateKey_->public_key.data); + toHash2 << u64SelfUserId; + + std::vector> keyData = { + toHash1.bytes(), + toHash2.bytes(), + }; + + std::sort(keyData.begin(), keyData.end()); + + std::thread([callback = std::move(callback), + data = ::mlspp::bytes_ns::bytes(std::move(keyData[0])) + keyData[1]] { + static constexpr uint8_t salt[] = { + 0x24, + 0xca, + 0xb1, + 0x7a, + 0x7a, + 0xf8, + 0xec, + 0x2b, + 0x82, + 0xb4, + 0x12, + 0xb9, + 0x2d, + 0xab, + 0x19, + 0x2e, + }; + + constexpr uint64_t N = 16384, r = 8, p = 2, max_mem = 32 * 1024 * 1024; + constexpr size_t hash_len = 64; + + std::vector out(hash_len); + + int ret = EVP_PBE_scrypt((const char*)data.data(), + data.size(), + salt, + sizeof(salt), + N, + r, + p, + max_mem, + out.data(), + out.size()); + + if (ret == 1) { + callback(out); + } + else { + callback({}); + } + }).detach(); +} +catch (const std::exception& e) { + DISCORD_LOG(LS_ERROR) << "Failed to generate pairwise fingerprint: " << e.what(); + callback({}); +} + +void Session::ClearPendingState() +{ + pendingGroupState_.reset(); + pendingGroupCommit_.reset(); + + joinInitPrivateKey_.reset(); + joinKeyPackage_.reset(); + + selfHPKEPrivateKey_.reset(); + + selfLeafNode_.reset(); + + stateWithProposals_.reset(); + proposalQueue_.clear(); +} + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/session.h b/src/dpp/dave/session.h new file mode 100755 index 0000000000..21059739d6 --- /dev/null +++ b/src/dpp/dave/session.h @@ -0,0 +1,133 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "persisted_key_pair.h" +#include "key_ratchet.h" +#include "version.h" + +namespace mlspp { +struct AuthenticatedContent; +struct Credential; +struct ExternalSender; +struct HPKEPrivateKey; +struct KeyPackage; +struct LeafNode; +struct MLSMessage; +struct SignaturePrivateKey; +class State; +} // namespace mlspp + +namespace discord { +namespace dave { +namespace mls { + +struct QueuedProposal; + +class Session { +public: + using MLSFailureCallback = std::function; + + Session(KeyPairContextType context, + std::string authSessionId, + MLSFailureCallback callback) noexcept; + + ~Session() noexcept; + + void Init(ProtocolVersion version, + uint64_t groupId, + std::string const& selfUserId, + std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; + void Reset() noexcept; + + void SetProtocolVersion(ProtocolVersion version) noexcept; + ProtocolVersion GetProtocolVersion() const noexcept { return protocolVersion_; } + + std::vector GetLastEpochAuthenticator() const noexcept; + + void SetExternalSender(std::vector const& externalSenderPackage) noexcept; + + std::optional> ProcessProposals( + std::vector proposals, + std::set const& recognizedUserIDs) noexcept; + + RosterVariant ProcessCommit(std::vector commit) noexcept; + + std::optional ProcessWelcome( + std::vector welcome, + std::set const& recognizedUserIDs) noexcept; + + std::vector GetMarshalledKeyPackage() noexcept; + + std::unique_ptr GetKeyRatchet(std::string const& userId) const noexcept; + + using PairwiseFingerprintCallback = std::function const&)>; + + void GetPairwiseFingerprint(uint16_t version, + std::string const& userId, + PairwiseFingerprintCallback callback) const noexcept; + +private: + void InitLeafNode(std::string const& selfUserId, + std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; + void ResetJoinKeyPackage() noexcept; + + void CreatePendingGroup() noexcept; + + bool HasCryptographicStateForWelcome() const noexcept; + + bool IsRecognizedUserID(const ::mlspp::Credential& cred, + std::set const& recognizedUserIDs) const; + bool ValidateProposalMessage(::mlspp::AuthenticatedContent const& message, + ::mlspp::State const& targetState, + std::set const& recognizedUserIDs) const; + bool VerifyWelcomeState(::mlspp::State const& state, + std::set const& recognizedUserIDs) const; + + bool CanProcessCommit(const ::mlspp::MLSMessage& commit) noexcept; + + RosterMap ReplaceState(std::unique_ptr<::mlspp::State>&& state); + + void ClearPendingState(); + + inline static const std::string USER_MEDIA_KEY_BASE_LABEL = "Discord Secure Frames v0"; + + ProtocolVersion protocolVersion_; + std::vector groupId_; + std::string signingKeyId_; + std::string selfUserId_; + KeyPairContextType keyPairContext_{nullptr}; + + std::unique_ptr<::mlspp::LeafNode> selfLeafNode_; + std::shared_ptr<::mlspp::SignaturePrivateKey> selfSigPrivateKey_; + std::unique_ptr<::mlspp::HPKEPrivateKey> selfHPKEPrivateKey_; + + std::unique_ptr<::mlspp::HPKEPrivateKey> joinInitPrivateKey_; + std::unique_ptr<::mlspp::KeyPackage> joinKeyPackage_; + + std::unique_ptr<::mlspp::ExternalSender> externalSender_; + + std::unique_ptr<::mlspp::State> pendingGroupState_; + std::unique_ptr<::mlspp::MLSMessage> pendingGroupCommit_; + + std::unique_ptr<::mlspp::State> outboundCachedGroupState_; + + std::unique_ptr<::mlspp::State> currentState_; + RosterMap roster_; + + std::unique_ptr<::mlspp::State> stateWithProposals_; + std::list proposalQueue_; + + MLSFailureCallback onMLSFailureCallback_{}; +}; + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/user_credential.cpp b/src/dpp/dave/user_credential.cpp new file mode 100755 index 0000000000..0254a07d9e --- /dev/null +++ b/src/dpp/dave/user_credential.cpp @@ -0,0 +1,33 @@ +#include "user_credential.h" +#include +#include "util.h" + +namespace discord { +namespace dave { +namespace mls { + +::mlspp::Credential CreateUserCredential(const std::string& userId, ProtocolVersion version) +{ + // convert the string user ID to a big endian uint64_t + auto userID = std::stoull(userId); + auto credentialBytes = BigEndianBytesFrom(userID); + + return ::mlspp::Credential::basic(credentialBytes); +} + +std::string UserCredentialToString(const ::mlspp::Credential& cred, ProtocolVersion version) +{ + if (cred.type() != ::mlspp::CredentialType::basic) { + return ""; + } + + const auto& basic = cred.template get<::mlspp::BasicCredential>(); + + auto uidVal = FromBigEndianBytes(basic.identity); + + return std::to_string(uidVal); +} + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/user_credential.h b/src/dpp/dave/user_credential.h new file mode 100755 index 0000000000..6ad2476eaf --- /dev/null +++ b/src/dpp/dave/user_credential.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include "version.h" + +namespace discord { +namespace dave { +namespace mls { + +::mlspp::Credential CreateUserCredential(const std::string& userId, ProtocolVersion version); +std::string UserCredentialToString(const ::mlspp::Credential& cred, ProtocolVersion version); + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/util.cpp b/src/dpp/dave/util.cpp new file mode 100755 index 0000000000..39447dca8e --- /dev/null +++ b/src/dpp/dave/util.cpp @@ -0,0 +1,34 @@ +#include "util.h" + +namespace discord { +namespace dave { +namespace mls { + +::mlspp::bytes_ns::bytes BigEndianBytesFrom(uint64_t value) noexcept +{ + auto buffer = ::mlspp::bytes_ns::bytes(); + buffer.reserve(sizeof(value)); + + for (int i = sizeof(value) - 1; i >= 0; --i) { + buffer.push_back(static_cast(value >> (i * 8))); + } + + return buffer; +} + +uint64_t FromBigEndianBytes(const ::mlspp::bytes_ns::bytes& buffer) noexcept +{ + uint64_t val = 0; + + if (buffer.size() <= sizeof(val)) { + for (uint8_t byte : buffer) { + val = (val << 8) | byte; + } + } + + return val; +} + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/util.h b/src/dpp/dave/util.h new file mode 100755 index 0000000000..279a48e1b5 --- /dev/null +++ b/src/dpp/dave/util.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#include + +namespace discord { +namespace dave { +namespace mls { + +::mlspp::bytes_ns::bytes BigEndianBytesFrom(uint64_t value) noexcept; +uint64_t FromBigEndianBytes(const ::mlspp::bytes_ns::bytes& value) noexcept; + +} // namespace mls +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/version.cpp b/src/dpp/dave/version.cpp new file mode 100755 index 0000000000..7fbdf286c3 --- /dev/null +++ b/src/dpp/dave/version.cpp @@ -0,0 +1,14 @@ +#include "version.h" + +namespace discord { +namespace dave { + +constexpr ProtocolVersion CurrentDaveProtocolVersion = 1; + +ProtocolVersion MaxSupportedProtocolVersion() +{ + return CurrentDaveProtocolVersion; +} + +} // namespace dave +} // namespace discord diff --git a/src/dpp/dave/version.h b/src/dpp/dave/version.h new file mode 100755 index 0000000000..bca35b975c --- /dev/null +++ b/src/dpp/dave/version.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace discord { +namespace dave { + +using ProtocolVersion = uint16_t; +using SignatureVersion = uint8_t; + +ProtocolVersion MaxSupportedProtocolVersion(); + +} // namespace dave +} // namespace discord