diff --git a/plugins/wasmembed/.gitignore b/plugins/wasmembed/.gitignore new file mode 100644 index 00000000000..2e73a394e92 --- /dev/null +++ b/plugins/wasmembed/.gitignore @@ -0,0 +1 @@ +ref/ \ No newline at end of file diff --git a/plugins/wasmembed/CMakeLists.txt b/plugins/wasmembed/CMakeLists.txt index 7dff7328253..abb86d8779e 100644 --- a/plugins/wasmembed/CMakeLists.txt +++ b/plugins/wasmembed/CMakeLists.txt @@ -3,8 +3,11 @@ project(wasmembed) if(WASMEMBED) ADD_PLUGIN(wasmembed) if(MAKE_WASMEMBED) - set(CMAKE_CXX_STANDARD 20) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + + add_subdirectory(component-model) + find_path(WASMTIME_CPP_API_INCLUDE_DIRS "wasmtime-cpp-api/wasmtime.hh" PATHS ${VCPKG_INSTALLED_DIR}/${VCPKG_TARGET_TRIPLET} ) @@ -21,6 +24,7 @@ if(WASMEMBED) include_directories( ${WASMTIME_CPP_API_INCLUDE_DIRS}/wasmtime-c-api ${WASMTIME_CPP_API_INCLUDE_DIRS}/wasmtime-cpp-api + ./../../system/include ./../../system/jlib ./../../rtl/eclrtl @@ -37,12 +41,12 @@ if(WASMEMBED) add_library(wasmembed SHARED wasmembed.cpp secure-enclave.cpp - abi.cpp util.cpp ) target_link_libraries(wasmembed ${WASMTIME_LIB} + component-model-cpp eclrtl jlib ) diff --git a/plugins/wasmembed/abi.cpp b/plugins/wasmembed/abi.cpp deleted file mode 100644 index 43b91339e4b..00000000000 --- a/plugins/wasmembed/abi.cpp +++ /dev/null @@ -1,269 +0,0 @@ -/* - See: https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md - https://github.com/WebAssembly/component-model/blob/main/design/mvp/canonical-abi/definitions.py -*/ - -#include "abi.hpp" - -#include "jexcept.hpp" - -auto UTF16_TAG = 1U << 31; - -// -/* canonical despecialize (python) ------------------------------------------------------------- - -def despecialize(t): - match t: - case Tuple(ts) : return Record([ Field(str(i), t) for i,t in enumerate(ts) ]) - case Union(ts) : return Variant([ Case(str(i), t) for i,t in enumerate(ts) ]) - case Enum(labels) : return Variant([ Case(l, None) for l in labels ]) - case Option(t) : return Variant([ Case("none", None), Case("some", t) ]) - case Result(ok, error) : return Variant([ Case("ok", ok), Case("error", error) ]) - case _ : return t - -*/ - -// template -// wasmtime::ValType despecialize(const T &t) -// { -// switch (t.kind()) -// { -// case wasmtime::ValKind::I32: -// case wasmtime::ValKind::I64: -// case wasmtime::ValKind::F32: -// case wasmtime::ValKind::F64: -// case wasmtime::ValKind::V128: -// return t.kind(); -// default: -// return wasmtime::ValType::i32(); -// } -// } - -/* canonical alignment (python) ------------------------------------------------------------- - -def alignment(t): - match despecialize(t): - case Bool() : return 1 - case S8() | U8() : return 1 - case S16() | U16() : return 2 - case S32() | U32() : return 4 - case S64() | U64() : return 8 - case Float32() : return 4 - case Float64() : return 8 - case Char() : return 4 - case String() | List(_) : return 4 - case Record(fields) : return alignment_record(fields) - case Variant(cases) : return alignment_variant(cases) - case Flags(labels) : return alignment_flags(labels) - case Own(_) | Borrow(_) : return 4 - -*/ - -// int alignment(const wasmtime::ValType &t) -// { -// switch (t.kind()) -// { -// case wasmtime::ValKind::I32: -// case wasmtime::ValKind::F32: -// return 4; -// case wasmtime::ValKind::I64: -// case wasmtime::ValKind::F64: -// return 8; -// case wasmtime::ValKind::V128: -// return 16; -// default: -// return 1; -// } -// } - -/* canonical align_to (python) ------------------------------------------------------------- - -def align_to(ptr, alignment): - return math.ceil(ptr / alignment) * alignment - -*/ - -uint32_t align_to(uint32_t ptr, uint32_t alignment) -{ - return (ptr + alignment - 1) & ~(alignment - 1); -} - -bool isAligned(uint32_t ptr, uint32_t alignment) -{ - return (ptr & (alignment - 1)) == 0; -} - -// loading --- - -/* canonical load_int (python) ------------------------------------------------------------- - -def load_int(cx, ptr, nbytes, signed = False): - return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed) - -*/ - -template -T load_int(const wasmtime::Span &data, uint32_t ptr) -{ - T retVal = 0; - if constexpr (sizeof(T) == 1) - { - retVal = static_cast(data[ptr]); - } - else if constexpr (sizeof(T) == 2) - { - retVal = static_cast((static_cast(data[ptr + 1]) << 8) | - static_cast(data[ptr])); - } - else if constexpr (sizeof(T) == 4) - { - retVal = static_cast((static_cast(data[ptr + 3]) << 24) | - (static_cast(data[ptr + 2]) << 16) | - (static_cast(data[ptr + 1]) << 8) | - static_cast(data[ptr])); - } - else if constexpr (sizeof(T) == 8) - { - retVal = static_cast((static_cast(data[ptr + 7]) << 56) | - (static_cast(data[ptr + 6]) << 48) | - (static_cast(data[ptr + 5]) << 40) | - (static_cast(data[ptr + 4]) << 32) | - (static_cast(data[ptr + 3]) << 24) | - (static_cast(data[ptr + 2]) << 16) | - (static_cast(data[ptr + 1]) << 8) | - static_cast(data[ptr])); - } - return retVal; -} -/* canonical load_string_from_range (python) ------------------------------------------------------------- - -def load_string_from_range(cx, ptr, tagged_code_units): - match cx.opts.string_encoding: - case 'utf8': - alignment = 1 - byte_length = tagged_code_units - encoding = 'utf-8' - case 'utf16': - alignment = 2 - byte_length = 2 * tagged_code_units - encoding = 'utf-16-le' - case 'latin1+utf16': - alignment = 2 - if bool(tagged_code_units & UTF16_TAG): - byte_length = 2 * (tagged_code_units ^ UTF16_TAG) - encoding = 'utf-16-le' - else: - byte_length = tagged_code_units - encoding = 'latin-1' - - trap_if(ptr != align_to(ptr, alignment)) - trap_if(ptr + byte_length > len(cx.opts.memory)) - try: - s = cx.opts.memory[ptr : ptr+byte_length].decode(encoding) - except UnicodeError: - trap() - - return (s, cx.opts.string_encoding, tagged_code_units) - -*/ - -// More: Not currently available from the wasmtime::context object, see https://github.com/bytecodealliance/wasmtime/issues/6719 -static const std::string global_encoding = "utf8"; - -std::tuple load_string_from_range(const wasmtime::Span &data, uint32_t ptr, uint32_t tagged_code_units) -{ - std::string encoding = "utf-8"; - uint32_t byte_length = tagged_code_units; - uint32_t alignment = 1; - if (global_encoding.compare("utf8") == 0) - { - alignment = 1; - byte_length = tagged_code_units; - encoding = "utf-8"; - } - else if (global_encoding.compare("utf16") == 0) - { - alignment = 2; - byte_length = 2 * tagged_code_units; - encoding = "utf-16-le"; - } - else if (global_encoding.compare("latin1+utf16") == 0) - { - alignment = 2; - if (tagged_code_units & UTF16_TAG) - { - byte_length = 2 * (tagged_code_units ^ UTF16_TAG); - encoding = "utf-16-le"; - } - else - { - byte_length = tagged_code_units; - encoding = "latin-1"; - } - } - - if (!isAligned(ptr, alignment)) - { - throw makeStringException(3, "Invalid alignment"); - } - - if (ptr + byte_length > data.size()) - { - throw makeStringException(1, "Out of bounds"); - } - - return std::make_tuple(ptr, encoding, byte_length); -} - -/* canonical load_string (python) ------------------------------------------------------------- - -def load_string(cx, ptr): - begin = load_int(cx, ptr, 4) - tagged_code_units = load_int(cx, ptr + 4, 4) - return load_string_from_range(cx, begin, tagged_code_units) - -*/ -std::tuple load_string(const wasmtime::Span &data, uint32_t ptr) -{ - uint32_t begin = load_int(data, ptr); - uint32_t tagged_code_units = load_int(data, ptr + 4); - return load_string_from_range(data, begin, tagged_code_units); -} - -/* canonical load_list_from_range (python) ------------------------------------------------------------- - -def load_list_from_range(cx, ptr, length, elem_type): - trap_if(ptr != align_to(ptr, alignment(elem_type))) - trap_if(ptr + length * size(elem_type) > len(cx.opts.memory)) - a = [] - for i in range(length): - a.append(load(cx, ptr + i * size(elem_type), elem_type)) - return a - -*/ - -template -std::vector load_list_from_range(const wasmtime::Span &data, uint32_t ptr, uint32_t length) -{ - if (!isAligned(ptr, alignment(T{}))) - throw makeStringException(2, "Pointer is not aligned"); - if (ptr + length * sizeof(T) > data.size()) - throw makeStringException(1, "Out of bounds access"); - std::vector a; - for (uint32_t i = 0; i < length; i++) - { - a.push_back(load(data, ptr + i * sizeof(T))); - } - return a; -} - -/* canonical load_list (python) ------------------------------------------------------------- - -def load_list(cx, ptr, elem_type): - begin = load_int(cx, ptr, 4) - length = load_int(cx, ptr + 4, 4) - return load_list_from_range(cx, begin, length, elem_type) - -*/ - -// Storing --- diff --git a/plugins/wasmembed/abi.hpp b/plugins/wasmembed/abi.hpp deleted file mode 100644 index 977c7601602..00000000000 --- a/plugins/wasmembed/abi.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#if __has_include() -#include -#else -#include -#include -#endif - -#include - -std::tuple load_string(const wasmtime::Span &data, uint32_t ptr); diff --git a/plugins/wasmembed/component-model/CMakeLists.txt b/plugins/wasmembed/component-model/CMakeLists.txt new file mode 100644 index 00000000000..47f2510a0b1 --- /dev/null +++ b/plugins/wasmembed/component-model/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(src) +add_subdirectory(test) diff --git a/plugins/wasmembed/component-model/src/CMakeLists.txt b/plugins/wasmembed/component-model/src/CMakeLists.txt new file mode 100644 index 00000000000..be1e3627451 --- /dev/null +++ b/plugins/wasmembed/component-model/src/CMakeLists.txt @@ -0,0 +1,42 @@ +project(component-model-cpp) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(SRC + boolean.cpp + boolean.hpp + context.hpp + context.cpp + float.hpp + float.cpp + integer.hpp + integer.cpp + list.hpp + list.cpp + load.hpp + load.cpp + string.hpp + string.cpp + traits.hpp + traits.cpp + util.hpp + util.cpp +) + +add_library(${PROJECT_NAME} STATIC + ${SRC} +) + +target_link_libraries(${PROJECT_NAME} +) + +target_compile_definitions(${PROJECT_NAME} + PRIVATE CMCPP_EXPORTS +) +target_compile_options(${PROJECT_NAME} PUBLIC "$<$:/permissive->") + +target_include_directories( + ${PROJECT_NAME} PUBLIC $ + $ +) diff --git a/plugins/wasmembed/component-model/src/boolean.cpp b/plugins/wasmembed/component-model/src/boolean.cpp new file mode 100644 index 00000000000..f0d08dc8818 --- /dev/null +++ b/plugins/wasmembed/component-model/src/boolean.cpp @@ -0,0 +1,19 @@ +#include "boolean.hpp" +#include "integer.hpp" +#include "util.hpp" + +namespace cmcpp +{ + namespace boolean + { + void store(CallContext &cx, const bool_t &v, offset ptr) + { + integer::store(cx, v, ptr); + } + + bool_t load(const CallContext &cx, uint32_t ptr) + { + return convert_int_to_bool(integer::load(cx, ptr)); + } + } +} diff --git a/plugins/wasmembed/component-model/src/boolean.hpp b/plugins/wasmembed/component-model/src/boolean.hpp new file mode 100644 index 00000000000..1c2a3edee0c --- /dev/null +++ b/plugins/wasmembed/component-model/src/boolean.hpp @@ -0,0 +1,17 @@ +#ifndef CMCPP_BOOLEAN_HPP +#define CMCPP_BOOLEAN_HPP + +#include "context.hpp" + +namespace cmcpp +{ + + namespace boolean + { + void store(CallContext &cx, const bool_t &v, offset ptr); + + bool_t load(const CallContext &cx, uint32_t ptr); + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/context.cpp b/plugins/wasmembed/component-model/src/context.cpp new file mode 100644 index 00000000000..671cd0d266a --- /dev/null +++ b/plugins/wasmembed/component-model/src/context.cpp @@ -0,0 +1,25 @@ +#include "context.hpp" + +namespace cmcpp +{ + std::unique_ptr InstanceContext::createCallContext(const GuestMemory &memory, const Encoding &guest_encoding, const GuestPostReturn &post_return) + { + auto retVal = std::make_unique(); + retVal->trap = trap; + retVal->convert = convert; + retVal->realloc = realloc; + retVal->memory = memory; + retVal->guest_encoding = guest_encoding; + retVal->post_return = post_return; + return retVal; + } + + std::unique_ptr createInstanceContext(const HostTrap &trap, HostUnicodeConversion convert, const GuestRealloc &realloc) + { + auto retVal = std::make_unique(); + retVal->trap = trap; + retVal->convert = convert; + retVal->realloc = realloc; + return retVal; + } +} \ No newline at end of file diff --git a/plugins/wasmembed/component-model/src/context.hpp b/plugins/wasmembed/component-model/src/context.hpp new file mode 100644 index 00000000000..3d7ce41d903 --- /dev/null +++ b/plugins/wasmembed/component-model/src/context.hpp @@ -0,0 +1,48 @@ +#ifndef CMCPP_CONTEXT_HPP +#define CMCPP_CONTEXT_HPP + +#if __has_include() +#include +#else +#include +#include +#endif + +#include "traits.hpp" + +#include +#include +#include + +namespace cmcpp +{ + using HostTrap = std::function; + using GuestRealloc = std::function; + using GuestMemory = std::span; + using GuestPostReturn = std::function; + using HostUnicodeConversion = std::function(char8_t *dest, const char8_t *src, uint32_t byte_len, Encoding from_encoding, Encoding to_encoding)>; + + struct CallContext + { + HostTrap trap; + GuestRealloc realloc; + GuestMemory memory; + Encoding guest_encoding; + HostUnicodeConversion convert; + std::optional post_return; + bool sync = true; + bool always_task_return = false; + }; + + struct InstanceContext + { + HostTrap trap; + HostUnicodeConversion convert; + GuestRealloc realloc; + std::unique_ptr createCallContext(const GuestMemory &memory, const Encoding &encoding = Encoding::Utf8, const GuestPostReturn &post_return = nullptr); + }; + + std::unique_ptr createInstanceContext(const HostTrap &trap, HostUnicodeConversion convert, const GuestRealloc &realloc); +} + +#endif diff --git a/plugins/wasmembed/component-model/src/float.cpp b/plugins/wasmembed/component-model/src/float.cpp new file mode 100644 index 00000000000..8664cbc6850 --- /dev/null +++ b/plugins/wasmembed/component-model/src/float.cpp @@ -0,0 +1,34 @@ +#include "float.hpp" + +namespace cmcpp +{ + namespace float_ + { + int32_t encode_float_as_i32(float32_t f) + { + return *reinterpret_cast(&f); + } + + int64_t encode_float_as_i64(float64_t f) + { + return *reinterpret_cast(&f); + } + + float32_t decode_i32_as_float(int32_t i) + { + return *reinterpret_cast(&i); + } + + float64_t decode_i64_as_float(int64_t i) + { + return *reinterpret_cast(&i); + } + + float32_t core_f32_reinterpret_i32(int32_t i) + { + float f; + std::memcpy(&f, &i, sizeof f); + return f; + } + } +} diff --git a/plugins/wasmembed/component-model/src/float.hpp b/plugins/wasmembed/component-model/src/float.hpp new file mode 100644 index 00000000000..a28ef46c5fe --- /dev/null +++ b/plugins/wasmembed/component-model/src/float.hpp @@ -0,0 +1,80 @@ +#ifndef CMCPP_FLOAT_HPP +#define CMCPP_FLOAT_HPP + +#include "context.hpp" +#include "integer.hpp" + +#include + +namespace cmcpp +{ + + namespace float_ + { + int32_t encode_float_as_i32(float32_t f); + int64_t encode_float_as_i64(float64_t f); + float32_t decode_i32_as_float(int32_t i); + float64_t decode_i64_as_float(int64_t i); + float32_t core_f32_reinterpret_i32(int32_t i); + + template + T canonicalize_nan(T f) + { + if (!std::isfinite(f)) + { + f = std::numeric_limits::quiet_NaN(); + } + return f; + } + + template + T maybe_scramble_nan(T f) + { + if (!std::isfinite(f)) + { + f = std::numeric_limits::quiet_NaN(); + } + return f; + } + + template + inline void store(CallContext &cx, const T &v, offset ptrnbytes) + { + cx.trap("store of unsupported type"); + throw std::runtime_error("trap not terminating execution"); + } + + template <> + inline void store(CallContext &cx, const float32_t &v, offset ptr) + { + integer::store(cx, encode_float_as_i32(v), ptr); + } + + template <> + inline void store(CallContext &cx, const float64_t &v, offset ptr) + { + integer::store(cx, encode_float_as_i64(v), ptr); + } + + template + T load(const CallContext &cx, offset ptr) + { + cx.trap("load of unsupported type"); + throw std::runtime_error("trap not terminating execution"); + } + + template <> + inline float32_t load(const CallContext &cx, offset ptr) + { + return decode_i32_as_float(integer::load(cx, ptr)); + } + + template <> + inline float64_t load(const CallContext &cx, offset ptr) + { + return decode_i64_as_float(integer::load(cx, ptr)); + } + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/integer.cpp b/plugins/wasmembed/component-model/src/integer.cpp new file mode 100644 index 00000000000..b69540fdd11 --- /dev/null +++ b/plugins/wasmembed/component-model/src/integer.cpp @@ -0,0 +1,9 @@ +#include "integer.hpp" + +namespace cmcpp +{ + + namespace integer + { + } +} diff --git a/plugins/wasmembed/component-model/src/integer.hpp b/plugins/wasmembed/component-model/src/integer.hpp new file mode 100644 index 00000000000..cab1394040f --- /dev/null +++ b/plugins/wasmembed/component-model/src/integer.hpp @@ -0,0 +1,58 @@ +#ifndef CMCPP_INTEGER_HPP +#define CMCPP_INTEGER_HPP + +#include "context.hpp" + +#include +#include +#include + +namespace cmcpp +{ + using offset = uint32_t; + + namespace integer + { + template + void store(CallContext &cx, const T &v, offset ptr) + { + std::memcpy(&cx.memory[ptr], &v, sizeof(T)); + } + + template + WasmValVector lower_flat_signed(const T &v, uint32_t core_bits) + { + using WasmValType = ValTrait::flat_type; + WasmValType retVal = v; + return {v}; + } + + template + T load(const CallContext &cx, offset ptr) + { + T retVal; + std::memcpy(&retVal, &cx.memory[ptr], sizeof(T)); + return retVal; + } + + template + T lift_flat_unsigned(const WasmValVectorIterator &vi, uint32_t core_width, uint32_t t_width) + { + using WasmValType = ValTrait::flat_type; + auto retVal = vi.next(); + assert(ValTrait::LOW_VALUE <= retVal && retVal < ValTrait::HIGH_VALUE); + return retVal; + } + + template + T lift_flat_signed(const WasmValVectorIterator &vi, uint32_t core_width, uint32_t t_width) + { + using WasmValType = ValTrait::flat_type; + auto retVal = static_cast(vi.next()); + assert(ValTrait::LOW_VALUE <= retVal && retVal <= ValTrait::HIGH_VALUE); + return retVal; + } + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/lift.hpp b/plugins/wasmembed/component-model/src/lift.hpp new file mode 100644 index 00000000000..83be17e5285 --- /dev/null +++ b/plugins/wasmembed/component-model/src/lift.hpp @@ -0,0 +1,64 @@ +#ifndef CMCPP_LIFT_HPP +#define CMCPP_LIFT_HPP + +#include "context.hpp" +#include "integer.hpp" +#include "float.hpp" +#include "string.hpp" +#include "list.hpp" +#include "record.hpp" +#include "util.hpp" + +namespace cmcpp +{ + template + inline T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + return convert_int_to_bool(vi.next()); + } + + template + inline T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + return integer::lift_flat_unsigned(vi, ValTrait::size * 8, 8); + } + + template + inline T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + return integer::lift_flat_signed(vi, ValTrait::size * 8, 8); + } + + template + inline T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + return float_::canonicalize_nan(vi.next()); + } + + template + inline T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + return string::lift_flat(cx, vi); + } + + template + inline T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi, std::size_t maybe_length) + { + return list::lift_flat::inner_type>(cx, vi, maybe_length); + } + + template + inline T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + return list::lift_flat::inner_type>(cx, vi); + } + + template + T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + auto x = record::lift_flat(cx, vi); + return x; + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/list.cpp b/plugins/wasmembed/component-model/src/list.cpp new file mode 100644 index 00000000000..fd100202fc2 --- /dev/null +++ b/plugins/wasmembed/component-model/src/list.cpp @@ -0,0 +1,9 @@ +#include "list.hpp" + +namespace cmcpp +{ + + namespace list + { + } +} diff --git a/plugins/wasmembed/component-model/src/list.hpp b/plugins/wasmembed/component-model/src/list.hpp new file mode 100644 index 00000000000..cd41ad03b76 --- /dev/null +++ b/plugins/wasmembed/component-model/src/list.hpp @@ -0,0 +1,133 @@ +#ifndef CMCPP_LIST_HPP +#define CMCPP_LIST_HPP + +#include "context.hpp" +#include "integer.hpp" +#include "store.hpp" +#include "load.hpp" +#include "util.hpp" + +#include +#include +#include + +namespace cmcpp +{ + using offset = uint32_t; + using size = uint32_t; + + template + inline void store(CallContext &cx, const list_t::inner_type> &v, uint32_t ptr); + + template + inline list_t::inner_type> load(const CallContext &cx, uint32_t ptr); + + namespace list + { + template + std::tuple store_into_range(CallContext &cx, const list_t &v) + { + auto elem_type = ValTrait::type; + size_t nbytes = ValTrait::size; + auto byte_length = v.size() * nbytes; + if (byte_length >= std::numeric_limits::max()) + { + throw std::runtime_error("byte_length exceeds limit"); + } + uint32_t ptr = cx.realloc(0, 0, ValTrait::alignment, byte_length); + if (ptr != align_to(ptr, ValTrait::alignment)) + { + throw std::runtime_error("ptr not aligned"); + } + if (ptr + byte_length > cx.memory.size()) + { + throw std::runtime_error("memory overflow"); + } + for (size_t i = 0; i < v.size(); ++i) + { + cmcpp::store(cx, v[i], ptr + i * nbytes); + } + return {ptr, v.size()}; + } + + template + void store(CallContext &cx, const list_t &list, offset ptr) + { + auto [begin, length] = store_into_range(cx, list); + integer::store(cx, begin, ptr); + integer::store(cx, length, ptr + 4); + } + + template + WasmValVector lower_flat(CallContext &cx, const list_t &v) + { + std::size_t maybe_length = 0; + if (maybe_length) + { + assert(maybe_length == v.size()); + WasmValVector flat; + for (auto e : v) + { + auto ef = lower_flat(cx, e); + flat.insert(flat.end(), ef.begin(), ef.end()); + } + return flat; + } + auto [ptr, length] = store_into_range(cx, v); + return {static_cast(ptr), static_cast(length)}; + } + + template + list_t load_from_range(const CallContext &cx, offset ptr, size length) + { + assert(ptr == align_to(ptr, ValTrait::alignment)); + assert(ptr + length * ValTrait::size <= cx.memory.size()); + list_t list = {}; + for (uint32_t i = 0; i < length; ++i) + { + list.push_back(cmcpp::load(cx, ptr + i * ValTrait::size)); + } + return list; + } + + template + list_t load(const CallContext &cx, offset ptr) + { + uint32_t begin = integer::load(cx, ptr); + uint32_t length = integer::load(cx, ptr + 4); + return load_from_range(cx, begin, length); + } + + template + list_t lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + std::size_t maybe_length = 0; + if (maybe_length) + { + list_t list = {}; + for (size_t i = 0; i < maybe_length; ++i) + { + list.push_back(lift_flat(cx, vi)); + } + return list; + } + auto ptr = vi.next(); + auto length = vi.next(); + return load_from_range(cx, ptr, length); + } + } + + template + inline void store(CallContext &cx, const list_t::inner_type> &v, uint32_t ptr) + { + list::store(cx, v, ptr); + } + + template + inline list_t::inner_type> load(const CallContext &cx, uint32_t ptr) + { + return list::load::inner_type>(cx, ptr); + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/load.cpp b/plugins/wasmembed/component-model/src/load.cpp new file mode 100644 index 00000000000..b17b8348b4f --- /dev/null +++ b/plugins/wasmembed/component-model/src/load.cpp @@ -0,0 +1,5 @@ +#include "load.hpp" + +namespace cmcpp +{ +} diff --git a/plugins/wasmembed/component-model/src/load.hpp b/plugins/wasmembed/component-model/src/load.hpp new file mode 100644 index 00000000000..b47760b8586 --- /dev/null +++ b/plugins/wasmembed/component-model/src/load.hpp @@ -0,0 +1,37 @@ +#ifndef CMCPP_LOAD_HPP +#define CMCPP_LOAD_HPP + +#include "context.hpp" +#include "integer.hpp" +#include "float.hpp" +#include "string.hpp" +#include "util.hpp" + +namespace cmcpp +{ + template + inline T load(const CallContext &cx, uint32_t ptr) + { + return convert_int_to_bool(integer::load(cx, ptr)); + } + + template + inline uint8_t load(const CallContext &cx, uint32_t ptr) + { + return integer::load(cx, ptr); + } + + template + inline float32_t load(const CallContext &cx, uint32_t ptr) + { + return float_::load(cx, ptr); + } + + template + inline string_t load(const CallContext &cx, uint32_t ptr) + { + return string::load(cx, ptr); + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/lower.hpp b/plugins/wasmembed/component-model/src/lower.hpp new file mode 100644 index 00000000000..a691454a6af --- /dev/null +++ b/plugins/wasmembed/component-model/src/lower.hpp @@ -0,0 +1,63 @@ +#ifndef CMCPP_LOWER_HPP +#define CMCPP_LOWER_HPP + +#include "context.hpp" +#include "integer.hpp" +#include "float.hpp" +#include "string.hpp" +#include "list.hpp" +#include "record.hpp" +#include "util.hpp" + +#include +#include + +namespace cmcpp +{ + template + inline WasmValVector lower_flat(CallContext &cx, const T &v) + { + return {static_cast::flat_type>(v)}; + } + + template + inline WasmValVector lower_flat(CallContext &cx, const T &v) + { + using FT = ValTrait::flat_type; + FT fv = v; + return {fv}; + } + + template + inline WasmValVector lower_flat(CallContext &cx, const T &v) + { + using FT = ValTrait::flat_type; + return integer::lower_flat_signed(v, ValTrait::size * 8); + } + + template + inline WasmValVector lower_flat(CallContext &cx, const T &v) + { + return {float_::maybe_scramble_nan(v)}; + } + + template + inline WasmValVector lower_flat(CallContext &cx, const T &v) + { + return string::lower_flat(cx, v); + } + + template + inline WasmValVector lower_flat(CallContext &cx, const T &v) + { + return list::lower_flat(cx, v); + } + + template + inline WasmValVector lower_flat(CallContext &cx, const T &v) + { + return record::lower_flat(cx, v); + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/record.cpp b/plugins/wasmembed/component-model/src/record.cpp new file mode 100644 index 00000000000..fd100202fc2 --- /dev/null +++ b/plugins/wasmembed/component-model/src/record.cpp @@ -0,0 +1,9 @@ +#include "list.hpp" + +namespace cmcpp +{ + + namespace list + { + } +} diff --git a/plugins/wasmembed/component-model/src/record.hpp b/plugins/wasmembed/component-model/src/record.hpp new file mode 100644 index 00000000000..dc4a12322ab --- /dev/null +++ b/plugins/wasmembed/component-model/src/record.hpp @@ -0,0 +1,80 @@ +#ifndef CMCPP_RECORD_HPP +#define CMCPP_RECORD_HPP + +#include "context.hpp" +#include "integer.hpp" +#include "store.hpp" +#include "load.hpp" +#include "util.hpp" + +#include +#include +#include + +namespace cmcpp +{ + namespace record + { + + template + void store(const CallContext &cx, const T&v, uint32_t ptr) + { + auto process_field = [&](auto &&field) + { + ptr = align_to(ptr, ValTrait>::alignment); + store(cx, field, ptr); + ptr += ValTrait>::size; + }; + + std::apply([&](auto &&...fields) + { (process_field(fields), ...); }, v); + } + + template + WasmValVector lower_flat(CallContext &cx, const T &v) + { + WasmValVector retVal = {}; + auto process_field = [&](auto &&field) + { + auto flat = lower_flat(cx, field); + retVal.insert(retVal.end(), flat.begin(), flat.end()); + }; + + std::apply([&](auto &&...fields) + { (process_field(fields), ...); }, v); + return retVal; + } + + template + T load(const CallContext &cx, uint32_t ptr) + { + T result; + auto process_field = [&](auto &&field) + { + ptr = align_to(ptr, ValTrait>::alignment); + field = load>(cx, ptr); + ptr += ValTrait>::size; + }; + + std::apply([&](auto &&...fields) + { (process_field(fields), ...); }, result); + return result; + } + + template + inline T lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + T result; + auto process_field = [&](auto &&field) + { + field = lift_flat>(cx, vi); + }; + + std::apply([&](auto &&...fields) + { (process_field(fields), ...); }, result); + return result; + } + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/store.hpp b/plugins/wasmembed/component-model/src/store.hpp new file mode 100644 index 00000000000..4ca13e3d79a --- /dev/null +++ b/plugins/wasmembed/component-model/src/store.hpp @@ -0,0 +1,41 @@ +#ifndef CMCPP_STORE_HPP +#define CMCPP_STORE_HPP + +#include "context.hpp" +#include "float.hpp" +#include "integer.hpp" +#include "string.hpp" +#include "list.hpp" +#include "util.hpp" + +#include +#include + +namespace cmcpp +{ + template + inline void store(CallContext &cx, const T &v, uint32_t ptr) + { + integer::store(cx, v, ptr); + } + + template + inline void store(CallContext &cx, const T &v, uint32_t ptr) + { + integer::store(cx, v, ptr); + } + + template + inline void store(CallContext &cx, const T &v, uint32_t ptr) + { + float_::store(cx, v, ptr); + } + + template + inline void store(CallContext &cx, const T &v, uint32_t ptr) + { + string::store(cx, v, ptr); + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/string.cpp b/plugins/wasmembed/component-model/src/string.cpp new file mode 100644 index 00000000000..cd773e6c24f --- /dev/null +++ b/plugins/wasmembed/component-model/src/string.cpp @@ -0,0 +1,287 @@ +#include "string.hpp" +#include "integer.hpp" +#include "util.hpp" + +#include +#include + +namespace cmcpp +{ + namespace string + { + const uint32_t MAX_STRING_BYTE_LENGTH = (1U << 31) - 1; + + std::pair store_string_copy(CallContext &cx, const char8_t *src, uint32_t src_code_units, uint32_t dst_code_unit_size, uint32_t dst_alignment, Encoding dst_encoding) + { + uint32_t dst_byte_length = dst_code_unit_size * src_code_units; + trap_if(cx, dst_byte_length > MAX_STRING_BYTE_LENGTH); + uint32_t ptr = cx.realloc(0, 0, dst_alignment, dst_byte_length); + trap_if(cx, ptr != align_to(ptr, dst_alignment)); + trap_if(cx, ptr + dst_byte_length > cx.memory.size()); + auto encoded = cx.convert((char8_t *)&cx.memory[ptr], src, src_code_units, cx.guest_encoding, dst_encoding); + // Python test case is a utf8 str pretending to be a utf16 --- assert(dst_byte_length == encoded.second); + return std::make_pair(ptr, src_code_units); + } + + std::pair store_string_to_utf8(CallContext &cx, Encoding src_encoding, const char8_t *src, uint32_t src_code_units, uint32_t worst_case_size) + { + assert(worst_case_size <= MAX_STRING_BYTE_LENGTH); + uint32_t ptr = cx.realloc(0, 0, 1, worst_case_size); + trap_if(cx, ptr + src_code_units > cx.memory.size()); + auto encoded = cx.convert((char8_t *)&cx.memory[ptr], src, src_code_units, src_encoding, Encoding::Utf8); + if (worst_case_size > encoded.second) + { + ptr = cx.realloc(ptr, worst_case_size, 1, encoded.second); + assert(ptr + encoded.second <= cx.memory.size()); + } + return std::make_pair(ptr, encoded.second); + } + + std::pair store_utf16_to_utf8(CallContext &cx, const char8_t *src, uint32_t src_code_units) + { + uint32_t worst_case_size = src_code_units * 3; + return store_string_to_utf8(cx, Encoding::Utf16, src, src_code_units, worst_case_size); + } + + std::pair store_latin1_to_utf8(CallContext &cx, const char8_t *src, uint32_t src_code_units) + { + uint32_t worst_case_size = src_code_units * 2; + return store_string_to_utf8(cx, Encoding::Latin1, src, src_code_units, worst_case_size); + } + + std::pair store_utf8_to_utf16(CallContext &cx, const char8_t *src, uint32_t src_code_units) + { + uint32_t worst_case_size = 2 * src_code_units; + trap_if(cx, worst_case_size > MAX_STRING_BYTE_LENGTH); + uint32_t ptr = cx.realloc(0, 0, 2, worst_case_size); + trap_if(cx, ptr != align_to(ptr, 2)); + trap_if(cx, ptr + worst_case_size > cx.memory.size()); + auto encoded = cx.convert((char8_t *)&cx.memory[ptr], src, src_code_units, Encoding::Utf8, Encoding::Utf16); + if (encoded.second < worst_case_size) + { + ptr = cx.realloc(ptr, worst_case_size, 2, encoded.second); + assert(ptr == align_to(ptr, 2)); + assert(ptr + encoded.second <= cx.memory.size()); + } + uint32_t code_units = static_cast(encoded.second / 2); + return std::make_pair(ptr, code_units); + } + + std::pair store_string_to_latin1_or_utf16(CallContext &cx, Encoding src_encoding, const char8_t *src, uint32_t src_code_units) + { + assert(src_code_units <= MAX_STRING_BYTE_LENGTH); + uint32_t ptr = cx.realloc(0, 0, 2, src_code_units); + trap_if(cx, ptr != align_to(ptr, 2)); + trap_if(cx, ptr + src_code_units > cx.memory.size()); + uint32_t dst_byte_length = 0; + for (size_t i = 0; i < src_code_units; ++i) + { + uint32_t usv = *src; + if (static_cast(usv) < (1 << 8)) + { + cx.memory[ptr + dst_byte_length] = static_cast(usv); + dst_byte_length += 1; + } + else + { + uint32_t worst_case_size = 2 * src_code_units; + if (worst_case_size > MAX_STRING_BYTE_LENGTH) + throw std::runtime_error("Worst case size exceeds maximum string byte length"); + ptr = cx.realloc(ptr, src_code_units, 2, worst_case_size); + if (ptr != align_to(ptr, 2)) + throw std::runtime_error("Pointer misaligned"); + if (ptr + worst_case_size > cx.memory.size()) + throw std::runtime_error("Out of bounds access"); + for (int j = dst_byte_length - 1; j >= 0; --j) + { + cx.memory[ptr + 2 * j] = cx.memory[ptr + j]; + cx.memory[ptr + 2 * j + 1] = 0; + } + auto encoded = cx.convert((char8_t *)&cx.memory[ptr + 2 * dst_byte_length], src, src_code_units, cx.guest_encoding, Encoding::Utf16); + if (worst_case_size > encoded.second) + { + ptr = cx.realloc(ptr, worst_case_size, 2, encoded.second); + if (ptr != align_to(ptr, 2)) + throw std::runtime_error("Pointer misaligned"); + if (ptr + encoded.second > cx.memory.size()) + throw std::runtime_error("Out of bounds access"); + } + uint32_t tagged_code_units = static_cast(encoded.second / 2) | UTF16_TAG; + return std::make_pair(ptr, tagged_code_units); + } + } + if (dst_byte_length < src_code_units) + { + ptr = cx.realloc(ptr, src_code_units, 2, dst_byte_length); + if (ptr != align_to(ptr, 2)) + throw std::runtime_error("Pointer misaligned"); + if (ptr + dst_byte_length > cx.memory.size()) + throw std::runtime_error("Out of bounds access"); + } + return std::make_pair(ptr, dst_byte_length); + } + + std::pair store_probably_utf16_to_latin1_or_utf16(CallContext &cx, const char8_t *src, uint32_t src_code_units) + { + uint32_t src_byte_length = 2 * src_code_units; + trap_if(cx, src_byte_length > MAX_STRING_BYTE_LENGTH); + uint32_t ptr = cx.realloc(0, 0, 2, src_byte_length); + trap_if(cx, ptr != align_to(ptr, 2)); + trap_if(cx, ptr + src_byte_length > cx.memory.size()); + auto encoded = cx.convert((char8_t *)&cx.memory[ptr], src, src_code_units, Encoding::Utf16, Encoding::Utf16); + const uint8_t *enc_src_ptr = &cx.memory[ptr]; + if (std::any_of(enc_src_ptr, enc_src_ptr + encoded.second, + [](uint8_t c) + { return static_cast(c) >= (1 << 8); })) + { + uint32_t tagged_code_units = static_cast(encoded.second / 2) | UTF16_TAG; + return std::make_pair(ptr, tagged_code_units); + } + uint32_t latin1_size = static_cast(encoded.second / 2); + for (uint32_t i = 0; i < latin1_size; ++i) + cx.memory[ptr + i] = cx.memory[ptr + 2 * i]; + ptr = cx.realloc(ptr, src_byte_length, 1, latin1_size); + trap_if(cx, ptr + latin1_size > cx.memory.size()); + return std::make_pair(ptr, latin1_size); + } + + std::pair store_into_range(CallContext &cx, const string_t &v) + { + Encoding src_encoding = v.encoding; + const char8_t *src = v.ptr; + const size_t src_tagged_code_units = v.byte_len; + + Encoding src_simple_encoding; + uint32_t src_code_units; + if (src_encoding == Encoding::Latin1_Utf16) + { + if (src_tagged_code_units & UTF16_TAG) + { + src_simple_encoding = Encoding::Utf16; + src_code_units = src_tagged_code_units ^ UTF16_TAG; + } + else + { + src_simple_encoding = Encoding::Latin1; + src_code_units = src_tagged_code_units; + } + } + else + { + src_simple_encoding = src_encoding; + src_code_units = src_tagged_code_units; + } + + switch (cx.guest_encoding) + { + case Encoding::Latin1: + case Encoding::Utf8: + switch (src_simple_encoding) + { + case Encoding::Utf8: + return store_string_copy(cx, src, src_code_units, 1, 1, Encoding::Utf8); + case Encoding::Utf16: + return store_utf16_to_utf8(cx, src, src_code_units); + case Encoding::Latin1: + return store_latin1_to_utf8(cx, src, src_code_units); + } + break; + case Encoding::Utf16: + switch (src_simple_encoding) + { + case Encoding::Utf8: + return store_utf8_to_utf16(cx, src, src_code_units); + case Encoding::Utf16: + return store_string_copy(cx, src, src_code_units, 2, 2, Encoding::Utf16); + case Encoding::Latin1: + return store_string_copy(cx, src, src_code_units, 2, 2, Encoding::Utf16); + } + break; + case Encoding::Latin1_Utf16: + switch (src_encoding) + { + case Encoding::Utf8: + return store_string_to_latin1_or_utf16(cx, src_encoding, src, src_code_units); + case Encoding::Utf16: + return store_string_to_latin1_or_utf16(cx, src_encoding, src, src_code_units); + case Encoding::Latin1_Utf16: + switch (src_simple_encoding) + { + case Encoding::Latin1: + return store_string_copy(cx, src, src_code_units, 1, 2, Encoding::Latin1); + case Encoding::Utf16: + return store_probably_utf16_to_latin1_or_utf16(cx, src, src_code_units); + } + } + } + assert(false); + return std::make_pair(0, 0); + } + + void store(CallContext &cx, const string_t &v, uint32_t ptr) + { + auto [begin, tagged_code_units] = store_into_range(cx, v); + integer::store(cx, begin, ptr); + integer::store(cx, tagged_code_units, ptr + 4); + } + + WasmValVector lower_flat(CallContext &cx, const string_t &v) + { + auto [ptr, packed_length] = store_into_range(cx, v); + return {(int32_t)ptr, (int32_t)packed_length}; + } + + string_t load_from_range(const CallContext &cx, uint32_t ptr, uint32_t tagged_code_units) + { + uint32_t alignment = 0; + uint32_t byte_length = 0; + Encoding encoding = Encoding::Utf8; + switch (cx.guest_encoding) + { + case Encoding::Utf8: + alignment = 1; + byte_length = tagged_code_units; + encoding = Encoding::Utf8; + break; + case Encoding::Utf16: + alignment = 2; + byte_length = 2 * tagged_code_units; + encoding = Encoding::Utf16; + break; + case Encoding::Latin1_Utf16: + alignment = 2; + if (tagged_code_units & UTF16_TAG) + { + byte_length = 2 * (tagged_code_units ^ UTF16_TAG); + encoding = Encoding::Utf16; + } + else + { + byte_length = tagged_code_units; + encoding = Encoding::Latin1; + } + break; + default: + trap_if(cx, false); + } + trap_if(cx, ptr != align_to(ptr, alignment)); + trap_if(cx, ptr + byte_length > cx.memory.size()); + return string_t{encoding, reinterpret_cast(&cx.memory[ptr]), byte_length}; + } + + string_t load(const CallContext &cx, offset offset) + { + auto begin = integer::load(cx, offset); + auto tagged_code_units = integer::load(cx, offset + 4); + return load_from_range(cx, begin, tagged_code_units); + } + + string_t lift_flat(const CallContext &cx, const WasmValVectorIterator &vi) + { + auto ptr = vi.next::flat_type_0>(); + auto packed_length = vi.next::flat_type_1>(); + return load_from_range(cx, ptr, packed_length); + } + }; + +} diff --git a/plugins/wasmembed/component-model/src/string.hpp b/plugins/wasmembed/component-model/src/string.hpp new file mode 100644 index 00000000000..51b7187ef53 --- /dev/null +++ b/plugins/wasmembed/component-model/src/string.hpp @@ -0,0 +1,19 @@ +#ifndef CMCPP_STRING_HPP +#define CMCPP_STRING_HPP + +#include "context.hpp" + +namespace cmcpp +{ + namespace string + { + std::pair store_into_range(CallContext &cx, const string_t &v); + void store(CallContext &cx, const string_t &v, uint32_t ptr); + WasmValVector lower_flat(CallContext &cx, const string_t &v); + + string_t load(const CallContext &cx, offset offset); + string_t lift_flat(const CallContext &cx, const WasmValVectorIterator &vi); + } +} + +#endif diff --git a/plugins/wasmembed/component-model/src/traits.cpp b/plugins/wasmembed/component-model/src/traits.cpp new file mode 100644 index 00000000000..57ef3da169c --- /dev/null +++ b/plugins/wasmembed/component-model/src/traits.cpp @@ -0,0 +1,6 @@ +#include "traits.hpp" + +namespace cmcpp +{ + WasmValVectorIterator::WasmValVectorIterator(const WasmValVector &v) : it(v.begin()), end(v.end()) {} +} diff --git a/plugins/wasmembed/component-model/src/traits.hpp b/plugins/wasmembed/component-model/src/traits.hpp new file mode 100644 index 00000000000..4b7d9196ad3 --- /dev/null +++ b/plugins/wasmembed/component-model/src/traits.hpp @@ -0,0 +1,456 @@ +#ifndef CMCPP_TRAITS_HPP +#define CMCPP_TRAITS_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// See canonical ABI: +// https://github.com/WebAssembly/component-model/blob/main/design/mvp/canonical-abi/definitions.py +// https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md + +namespace cmcpp +{ + using float32_t = float; + using float64_t = double; + + enum class WasmValType : uint8_t + { + UNKNOWN, + i32, + i64, + f32, + f64, + LAST + }; + using WasmVal = std::variant; + using WasmValVector = std::vector; + class WasmValVectorIterator + { + mutable WasmValVector::const_iterator it; + WasmValVector::const_iterator end; + + public: + WasmValVectorIterator(const WasmValVector &v); + + template + T next() const + { + assert(it != end); + return std::get(*it++); + } + }; + + template + struct WasmValTrait + { + static constexpr WasmValType type = WasmValType::UNKNOWN; + }; + + template <> + struct WasmValTrait + { + static constexpr WasmValType type = WasmValType::i32; + }; + + template <> + struct WasmValTrait + { + static constexpr WasmValType type = WasmValType::i64; + }; + + template <> + struct WasmValTrait + { + static constexpr WasmValType type = WasmValType::f32; + }; + + template <> + struct WasmValTrait + { + static constexpr WasmValType type = WasmValType::f64; + }; + + template + concept WasmValue = + WasmValTrait::type == WasmValType::i32 || + WasmValTrait::type == WasmValType::i64 || + WasmValTrait::type == WasmValType::f32 || + WasmValTrait::type == WasmValType::f64; + + // -------------------------------------------------------------------- + + enum class Encoding + { + Latin1, + Utf8, + Utf16, + Latin1_Utf16 + }; + + enum class ValType : uint8_t + { + UNKNOWN, + Bool, + S8, + U8, + S16, + U16, + S32, + U32, + S64, + U64, + F32, + F64, + Char, + String, + List, + Field, + Record, + Tuple, + Case, + Variant, + Enum, + Option, + Result, + Flags, + Own, + Borrow, + LAST + }; + + // ValTrait --------------------------------------------------------------- + template + struct ValTrait + { + static constexpr ValType type = ValType::UNKNOWN; + using inner_type = void; + static constexpr uint32_t size = 0; + static constexpr uint32_t alignment = 0; + using flat_types = void; + }; + + // Boolean -------------------------------------------------------------------- + using bool_t = bool; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::Bool; + static constexpr uint32_t size = 1; + static constexpr uint32_t alignment = 1; + using flat_type = int32_t; + }; + template + concept Boolean = ValTrait::type == ValType::Bool; + + // Numerics -------------------------------------------------------------------- + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::S8; + static constexpr uint32_t size = 1; + static constexpr uint32_t alignment = 1; + using flat_type = int32_t; + + static constexpr int8_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr int8_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::U8; + static constexpr uint32_t size = 1; + static constexpr uint32_t alignment = 1; + using flat_type = int32_t; + + static constexpr int8_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr int8_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::S16; + static constexpr uint32_t size = 2; + static constexpr uint32_t alignment = 2; + using flat_type = int32_t; + + static constexpr int16_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr int16_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::U16; + static constexpr uint32_t size = 2; + static constexpr uint32_t alignment = 2; + using flat_type = int32_t; + + static constexpr uint16_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr uint16_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::S32; + static constexpr uint32_t size = 4; + static constexpr uint32_t alignment = 4; + using flat_type = int32_t; + + static constexpr int32_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr int32_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::U32; + static constexpr uint32_t size = 4; + static constexpr uint32_t alignment = 4; + using flat_type = int32_t; + + static constexpr uint32_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr uint32_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::S64; + static constexpr uint32_t size = 8; + static constexpr uint32_t alignment = 8; + using flat_type = int64_t; + + static constexpr int64_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr int64_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::U64; + static constexpr uint32_t size = 8; + static constexpr uint32_t alignment = 8; + using flat_type = int64_t; + + static constexpr uint64_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr uint64_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template + concept Integer = ValTrait::type == ValType::S8 || ValTrait::type == ValType::S16 || ValTrait::type == ValType::S32 || ValTrait::type == ValType::S64 || + ValTrait::type == ValType::U8 || ValTrait::type == ValType::U16 || ValTrait::type == ValType::U32 || ValTrait::type == ValType::U64; + + template + concept SignedInteger = std::is_signed_v && Integer; + + template + concept UnsignedInteger = !std::is_signed_v && Integer; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::F32; + static constexpr uint32_t size = 4; + static constexpr uint32_t alignment = 4; + using flat_type = float32_t; + + static constexpr float32_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr float32_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::F64; + static constexpr uint32_t size = 8; + static constexpr uint32_t alignment = 8; + using flat_type = float64_t; + + static constexpr float64_t LOW_VALUE = std::numeric_limits::lowest(); + static constexpr float64_t HIGH_VALUE = std::numeric_limits::max(); + }; + + template + concept Float = ValTrait::type == ValType::F32 || ValTrait::type == ValType::F64; + + template + concept Numeric = Integer || Float; + + template + concept Byte = ValTrait::type == ValType::U8 || ValTrait::type == ValType::S8; + + template + concept HalfWord = ValTrait::type == ValType::U16 || ValTrait::type == ValType::S16; + + template + concept Word = ValTrait::type == ValType::U32 || ValTrait::type == ValType::S32 || ValTrait::type == ValType::F32; + + template + concept DoubleWord = ValTrait::type == ValType::U64 || ValTrait::type == ValType::S64 || ValTrait::type == ValType::F64; + + // Strings -------------------------------------------------------------------- + + template <> + struct ValTrait + { + static constexpr ValType type = ValType::Char; + }; + + struct string_t + { + Encoding encoding; + const char8_t *ptr; + size_t byte_len; + }; + template <> + struct ValTrait + { + static constexpr ValType type = ValType::String; + static constexpr uint32_t size = 8; + static constexpr uint32_t alignment = 4; + using flat_type_0 = int32_t; + using flat_type_1 = int32_t; + }; + template + concept String = ValTrait::type == ValType::String; + + // List -------------------------------------------------------------------- + + template + using list_t = std::vector; + template + struct ValTrait> + { + static constexpr ValType type = ValType::List; + using inner_type = T; + static constexpr std::size_t maybe_length = 0; + static constexpr uint32_t size = 8; + static constexpr uint32_t alignment = 4; + // static constexpr WasmValTypeVector() using flat_type_0 = int32_t; + using flat_type_1 = int32_t; + }; + template + concept List = ValTrait::type == ValType::List; + + // Record -------------------------------------------------------------------- + + template + concept Field = ValTrait::type != ValType::UNKNOWN; + + template + using record_t = std::tuple; + template + struct ValTrait> + { + static constexpr ValType type = ValType::Record; + using inner_type = typename std::tuple; + }; + template + concept Record = ValTrait::type == ValType::Record; + + template + T to_struct_impl(const R &t, std::index_sequence) + { + return T{std::get(t)...}; + } + + template + T to_struct(const R &t) + { + return to_struct_impl(t, std::make_index_sequence>{}); + } + + // Variant ------------------------------------------------------------------ + + template + using variant_t = std::variant; + template + struct ValTrait> + { + static constexpr ValType type = ValType::Variant; + using inner_type = typename std::variant; + }; + template + concept Variant = ValTrait::type == ValType::Variant; + + // Other -------------------------------------------------------------------- + + // template + // using tuple_t = std::tuple; + // template + // struct ValTrait> + // { + // static constexpr ValType type = ValType::Tuple; + // }; + + // template + // struct case_t + // { + // std::string label; + // std::optional v; + // }; + // template + // struct ValTrait> + // { + // static constexpr ValType type = ValType::Case; + // using inner_type = T; + // }; + + // template + // using variant_t = std::variant; + // template + // struct ValTrait> + // { + // static constexpr ValType type = ValType::Variant; + // }; + + // using enum_t = std::vector; + // template <> + // struct ValTrait + // { + // static constexpr ValType type = ValType::Enum; + // }; + + // class option_t; + // using option_ptr = std::shared_ptr; + // template <> + // struct ValTrait + // { + // static ValType type() { return ValType::Option; } + // }; + + // class result_t; + // using result_ptr = std::shared_ptr; + // template <> + // struct ValTrait + // { + // static ValType type() { return ValType::Result; } + // }; + + // class flags_t; + // using flags_ptr = std::shared_ptr; + // template <> + // struct ValTrait + // { + // static ValType type() { return ValType::Flags; } + // }; + + // -------------------------------------------------------------------- + + using offset = uint32_t; + using bytes = uint32_t; +} + +#endif diff --git a/plugins/wasmembed/component-model/src/util.cpp b/plugins/wasmembed/component-model/src/util.cpp new file mode 100644 index 00000000000..fcafe19b3d5 --- /dev/null +++ b/plugins/wasmembed/component-model/src/util.cpp @@ -0,0 +1,39 @@ +#include "util.hpp" + +namespace cmcpp +{ + + void trap_if(const CallContext &cx, bool condition, const char *message) + { + if (condition) + { + cx.trap(message); + } + } + + uint32_t align_to(uint32_t ptr, uint8_t alignment) + { + return (ptr + static_cast(alignment) - 1) & ~(static_cast(alignment) - 1); + } + + ValType despecialize(const ValType t) + { + switch (t) + { + case ValType::Tuple: + return ValType::Record; + case ValType::Enum: + return ValType::Variant; + case ValType::Option: + return ValType::Variant; + case ValType::Result: + return ValType::Variant; + } + return t; + } + + bool convert_int_to_bool(uint8_t i) + { + return i > 0; + } +} \ No newline at end of file diff --git a/plugins/wasmembed/component-model/src/util.hpp b/plugins/wasmembed/component-model/src/util.hpp new file mode 100644 index 00000000000..7b7d4e15469 --- /dev/null +++ b/plugins/wasmembed/component-model/src/util.hpp @@ -0,0 +1,18 @@ +#ifndef CMCPP_UTIL_HPP +#define CMCPP_UTIL_HPP + +#include "context.hpp" + +namespace cmcpp +{ + const uint32_t UTF16_TAG = 1U << 31; + const bool DETERMINISTIC_PROFILE = false; + + void trap_if(const CallContext &cx, bool condition, const char *message = nullptr) noexcept(false); + + uint32_t align_to(uint32_t ptr, uint8_t alignment); + + bool convert_int_to_bool(uint8_t i); +} + +#endif diff --git a/plugins/wasmembed/component-model/test/CMakeLists.txt b/plugins/wasmembed/component-model/test/CMakeLists.txt new file mode 100644 index 00000000000..312618d1f69 --- /dev/null +++ b/plugins/wasmembed/component-model/test/CMakeLists.txt @@ -0,0 +1,23 @@ +project(component-model-test) + +find_package(doctest CONFIG REQUIRED) + +include_directories( + ../src +) + +add_executable(${PROJECT_NAME} + main.cpp + host-util.hpp + host-util.cpp +) + +target_link_libraries(${PROJECT_NAME} + PRIVATE doctest::doctest + PRIVATE component-model-cpp +) + +add_test( + NAME ${PROJECT_NAME} + COMMAND $ +) diff --git a/plugins/wasmembed/component-model/test/host-util.cpp b/plugins/wasmembed/component-model/test/host-util.cpp new file mode 100644 index 00000000000..5d59933711e --- /dev/null +++ b/plugins/wasmembed/component-model/test/host-util.cpp @@ -0,0 +1,60 @@ +#include "host-util.hpp" +#include +#include +// #include "utf8.h" + +void trap(const char *msg) +{ + throw new std::runtime_error(msg); +} + +// bool isLatin1(const std::string &str) +// { +// return utf8::is_valid(str); +// } + +std::pair convert(char8_t *dest, const char8_t *src, uint32_t byte_len, Encoding from_encoding, Encoding to_encoding) +{ + switch (from_encoding) + { + case Encoding::Latin1: + case Encoding::Utf8: + switch (to_encoding) + { + case Encoding::Latin1: + case Encoding::Utf8: + std::memcpy(dest, src, byte_len); + return std::make_pair(reinterpret_cast(dest), byte_len); + case Encoding::Utf16: + case Encoding::Latin1_Utf16: + // { + // std::u16string s = utf8::utf8to16(std::string_view((const char *)src, byte_len)); + // std::memcpy(dest, s.data(), s.size() * 2); + // return std::make_pair(reinterpret_cast(dest), s.size() * 2); + // } + default: + throw std::runtime_error("Invalid encoding"); + } + break; + case Encoding::Utf16: + case Encoding::Latin1_Utf16: + switch (to_encoding) + { + case Encoding::Latin1: + case Encoding::Utf8: + // { + // std::string s = utf8::utf16to8(std::u16string_view((const char16_t *)src, byte_len)); + // std::memcpy(dest, s.data(), s.size()); + // return std::make_pair(reinterpret_cast(dest), byte_len); + // } + case Encoding::Utf16: + case Encoding::Latin1_Utf16: + std::memcpy(dest, src, byte_len); + return std::make_pair(reinterpret_cast(dest), byte_len); + default: + throw std::runtime_error("Invalid encoding"); + } + break; + } + throw std::runtime_error("Invalid encoding"); +} diff --git a/plugins/wasmembed/component-model/test/host-util.hpp b/plugins/wasmembed/component-model/test/host-util.hpp new file mode 100644 index 00000000000..1cc2103bb64 --- /dev/null +++ b/plugins/wasmembed/component-model/test/host-util.hpp @@ -0,0 +1,8 @@ +#include +#include "context.hpp" + +using namespace cmcpp; + +void trap(const char *msg = ""); +bool isLatin1(const std::string &std); +std::pair convert(char8_t *dest, const char8_t *src, uint32_t byte_len, Encoding from_encoding, Encoding to_encoding); diff --git a/plugins/wasmembed/component-model/test/main.cpp b/plugins/wasmembed/component-model/test/main.cpp new file mode 100644 index 00000000000..43bb7a31ee0 --- /dev/null +++ b/plugins/wasmembed/component-model/test/main.cpp @@ -0,0 +1,206 @@ +#include "traits.hpp" +#include "lift.hpp" +#include "lower.hpp" +#include "util.hpp" +#include "host-util.hpp" + +using namespace cmcpp; + +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include + +#include +#include +#include +#include +// #include + +class Heap +{ +private: + uint32_t last_alloc = 0; + +public: + std::vector memory; + + Heap(size_t arg) : memory(arg), last_alloc(0) + { + CHECK(true); + } + + Heap() : memory(1024 * 1024), last_alloc(0) + { + CHECK(true); + } + + uint32_t realloc(uint32_t original_ptr, size_t original_size, uint32_t alignment, size_t new_size) + { + if (original_ptr != 0 && new_size < original_size) + { + return align_to(original_ptr, alignment); + } + uint32_t ret = align_to(last_alloc, alignment); + last_alloc = ret + new_size; + if (last_alloc > memory.size()) + { + std::cout << "oom: have " << memory.size() << " need " << last_alloc << std::endl; + trap("oom"); + } + std::memcpy(&memory[ret], &memory[original_ptr], original_size); + return ret; + } +}; + +std::unique_ptr createCallContext(Heap *heap, Encoding encoding) +{ + std::unique_ptr instanceContext = std::make_unique(trap, convert, + [heap](int original_ptr, int original_size, int alignment, int new_size) -> int + { + return heap->realloc(original_ptr, original_size, alignment, new_size); + }); + return instanceContext->createCallContext(heap->memory, encoding); +} + +TEST_CASE("Boolean") +{ + Heap heap; + auto cx = createCallContext(&heap, Encoding::Utf8); + auto v = lower_flat(*cx, true); + auto b = lift_flat(*cx, v); + CHECK(b == true); + v = lower_flat(*cx, false); + b = lift_flat(*cx, v); + CHECK(b == false); +} + +template +void test_numeric(const std::unique_ptr &cx, T v = 42) +{ + auto flat_v = lower_flat(*cx, v); + auto b = lift_flat(*cx, flat_v); + CHECK(b == v); + v = ValTrait::LOW_VALUE; + flat_v = lower_flat(*cx, v); + b = lift_flat(*cx, flat_v); + CHECK(b == v); + v = ValTrait::HIGH_VALUE; + flat_v = lower_flat(*cx, v); + b = lift_flat(*cx, flat_v); + CHECK(b == v); +} + +TEST_CASE("Signed Integer") +{ + Heap heap; + auto cx = createCallContext(&heap, Encoding::Utf8); + test_numeric(cx); + test_numeric(cx); + test_numeric(cx); + test_numeric(cx); + test_numeric(cx, -42); + test_numeric(cx, -42); + test_numeric(cx, -42); + test_numeric(cx, -42); +} + +TEST_CASE("Unigned Integer") +{ + Heap heap; + auto cx = createCallContext(&heap, Encoding::Utf8); + test_numeric(cx); + test_numeric(cx); + test_numeric(cx); + test_numeric(cx); +} + +TEST_CASE("Float") +{ + Heap heap; + auto cx = createCallContext(&heap, Encoding::Utf8); + test_numeric(cx); + test_numeric(cx); + test_numeric(cx, -42); + test_numeric(cx, -42); +} + +const char *const hw = "hello world"; +TEST_CASE("String") +{ + Heap heap; + auto cx = createCallContext(&heap, Encoding::Utf8); + auto v = lower_flat(*cx, string_t{Encoding::Utf8, (const char8_t *)hw, strlen(hw)}); + auto str = lift_flat(*cx, v); + CHECK(str.encoding == Encoding::Utf8); + CHECK(str.byte_len == strlen(hw)); + CHECK(std::string((const char *)str.ptr, str.byte_len) == hw); +} + +TEST_CASE("List") +{ + Heap heap; + auto cx = createCallContext(&heap, Encoding::Utf8); + list_t strings = {string_t{Encoding::Utf8, (const char8_t *)hw, 5}, string_t{Encoding::Utf8, (const char8_t *)hw, 3}}; + auto v = lower_flat(*cx, strings); + auto strs = lift_flat>(*cx, v); + CHECK(strs.size() == 2); + CHECK(strs[0].encoding == Encoding::Utf8); + CHECK(strs[0].byte_len == 5); + CHECK(std::string((const char *)strs[0].ptr, strs[0].byte_len) == std::string(hw, strs[0].byte_len)); + CHECK(strs[1].encoding == Encoding::Utf8); + CHECK(strs[1].byte_len == 3); + CHECK(std::string((const char *)strs[1].ptr, strs[1].byte_len) == std::string(hw, strs[1].byte_len)); + v = lower_flat(*cx, strings); + strs = lift_flat>(*cx, v); + CHECK(strs.size() == 2); + CHECK(strs[0].encoding == Encoding::Utf8); + CHECK(strs[0].byte_len == 5); + CHECK(std::string((const char *)strs[0].ptr, strs[0].byte_len) == std::string(hw, strs[0].byte_len)); + CHECK(strs[1].encoding == Encoding::Utf8); + CHECK(strs[1].byte_len == 3); + CHECK(std::string((const char *)strs[1].ptr, strs[1].byte_len) == std::string(hw, strs[1].byte_len)); +} + +TEST_CASE("List2") +{ + Heap heap; + auto cx = createCallContext(&heap, Encoding::Utf8); + list_t> strings = {{string_t{Encoding::Utf8, (const char8_t *)hw, 5}, string_t{Encoding::Utf8, (const char8_t *)hw, 3}}, {string_t{Encoding::Utf8, (const char8_t *)hw, 5}, string_t{Encoding::Utf8, (const char8_t *)hw, 3}}, {string_t{Encoding::Utf8, (const char8_t *)hw, 5}, string_t{Encoding::Utf8, (const char8_t *)hw, 3}}}; + auto v = lower_flat(*cx, strings); + auto strs = lift_flat>>(*cx, v); + CHECK(strs.size() == 3); + CHECK(strs[0][0].encoding == Encoding::Utf8); + CHECK(strs[0][0].byte_len == 5); + CHECK(std::string((const char *)strs[0][0].ptr, strs[0][0].byte_len) == std::string(hw, strs[0][0].byte_len)); + CHECK(strs[1][0].encoding == Encoding::Utf8); + CHECK(strs[1][0].byte_len == 5); + CHECK(std::string((const char *)strs[0][1].ptr, strs[0][1].byte_len) == std::string(hw, strs[0][1].byte_len)); + v = lower_flat(*cx, strings); + strs = lift_flat>>(*cx, v); + CHECK(strs.size() == 3); + CHECK(strs[0][0].encoding == Encoding::Utf8); + CHECK(strs[0][0].byte_len == 5); + CHECK(std::string((const char *)strs[0][0].ptr, strs[0][0].byte_len) == std::string(hw, strs[0][0].byte_len)); + CHECK(strs[1][0].encoding == Encoding::Utf8); + CHECK(strs[1][0].byte_len == 5); + CHECK(std::string((const char *)strs[0][1].ptr, strs[0][1].byte_len) == std::string(hw, strs[0][1].byte_len)); +} + +struct MyRecord0 { + uint16_t age; + uint32_t weight; +}; + +TEST_CASE("Records") +{ + Heap heap; + auto cx = createCallContext(&heap, Encoding::Utf8); + + using R0 = record_t; + R0 r0 = {42, 43}; + auto v = lower_flat(*cx, r0); + auto rr = lift_flat(*cx, v); + CHECK(r0 == rr); + auto rr0 = to_struct(rr); + // CHECK(r2.age == rr0.age); + // CHECK(r2.weight == rr0.weight); +} diff --git a/plugins/wasmembed/secure-enclave.cpp b/plugins/wasmembed/secure-enclave.cpp index 07a76f7d656..976decf16f6 100644 --- a/plugins/wasmembed/secure-enclave.cpp +++ b/plugins/wasmembed/secure-enclave.cpp @@ -6,11 +6,14 @@ #include "eclhelper.hpp" #include "enginecontext.hpp" -#include "abi.hpp" #include "util.hpp" +#include "component-model/src/context.hpp" +#include "component-model/src/string.hpp" +#include "component-model/src/list.hpp" #include #include +#include // From deftype.hpp in common #define UNKNOWN_LENGTH 0xFFFFFFF1 @@ -24,6 +27,27 @@ { \ } while (0) #endif + +using namespace cmcpp; + +HostTrap trap = [](const char *msg) -> void +{ + throw makeStringException(100, msg); +}; + +HostUnicodeConversion convert = [](char8_t *dest, const char8_t *src, uint32_t byte_len, Encoding from_encoding, Encoding to_encoding) -> std::pair +{ + if (from_encoding == to_encoding) + { + memcpy(dest, src, byte_len); + return std::make_pair(dest, byte_len); + } + else + { + throw makeStringExceptionV(100, "Unsupported encoding conversion %d -> %d", static_cast(from_encoding), static_cast(to_encoding)); + } +}; + class WasmEngine { private: @@ -99,7 +123,7 @@ class WasmEngine wasmtime::Module getModule(const std::string &wasmName) const { - TRACE("WASM SE getModule"); + TRACE("WASM SE getModule %s", wasmName.c_str()); auto found = wasmModules.find(wasmName); if (found == wasmModules.end()) throw makeStringExceptionV(100, "Wasm module not found: %s", wasmName.c_str()); @@ -110,47 +134,19 @@ static std::unique_ptr wasmEngine = std::make_unique(); class WasmStore { -private: +protected: wasmtime::Store store; + std::optional instance; + cmcpp::GuestRealloc realloc; - std::unordered_map wasmInstances; std::unordered_map wasmMems; std::unordered_map wasmFuncs; + std::unique_ptr instanceContext; public: - WasmStore() : store(wasmEngine->engine) - { - TRACE("WASM SE WasmStore"); - } - - ~WasmStore() - { - TRACE("WASM SE ~WasmStore"); - } - - bool hasInstance(const std::string &wasmName) const - { - TRACE("WASM SE hasInstance"); - return wasmInstances.find(wasmName) != wasmInstances.end(); - } - - wasmtime::Instance getInstance(const std::string &wasmName) const + WasmStore(const std::string &wasmName) : store(wasmEngine->engine) { - TRACE("WASM SE getInstance"); - auto found = wasmInstances.find(wasmName); - if (found == wasmInstances.end()) - throw makeStringExceptionV(100, "Wasm instance not found: %s", wasmName.c_str()); - return found->second; - } - - void registerInstance(const std::string &wasmName) - { - TRACE("WASM SE registerInstance %s", wasmName.c_str()); - if (hasInstance(wasmName)) - { - throw makeStringExceptionV(100, "Wasm instance already registered: %s", wasmName.c_str()); - } - TRACE("WASM SE createInstance %s", wasmName.c_str()); + TRACE("WASM SE WasmStore::WasmStore"); auto module = wasmEngine->getModule(wasmName); try { @@ -159,31 +155,41 @@ class WasmStore auto callback = [this, wasmName](wasmtime::Caller caller, uint32_t msg, uint32_t msg_len) { - auto data = this->getData(wasmName); + auto data = getData(); auto msg_ptr = (char *)&data[msg]; std::string str(msg_ptr, msg_len); DBGLOG("from wasm: %s", str.c_str()); }; - auto host_func = linker.func_wrap("$root", "dbglog", callback).unwrap(); - - auto newInstance = linker.instantiate(store, module).unwrap(); - linker.define_instance(store, "linking2", newInstance).unwrap(); + linker.func_wrap("$root", "dbglog", callback).unwrap(); + instance = linker.instantiate(store, module).unwrap(); + linker.define_instance(store, "linking2", *instance).unwrap(); for (auto exportItem : module.exports()) { auto externType = wasmtime::ExternType::from_export(exportItem); std::string name(exportItem.name()); + TRACE("WASM SE exportItem: %s", name.c_str()); if (std::holds_alternative(externType)) { TRACE("WASM SE Exported function: %s", name.c_str()); - auto func = std::get(*newInstance.get(store, name)); - wasmFuncs.insert(std::make_pair(wasmName + "." + name, func)); + auto func = std::get(*instance->get(store, name)); + wasmFuncs.insert(std::make_pair(name, func)); + TRACE("WASM SE Exported function: %s", name.c_str()); + if (name == "cabi_realloc") + { + realloc = [this, func](int ptr, int old_size, int align, int new_size) -> int + { + auto retVal = func.call(store, {ptr, old_size, align, new_size}).unwrap(); + return retVal[0].i32(); + }; + } } else if (std::holds_alternative(externType)) { TRACE("WASM SE Exported memory: %s", name.c_str()); - auto memory = std::get(*newInstance.get(store, name)); - wasmMems.insert(std::make_pair(wasmName + "." + name, memory)); + auto memory = std::get(*instance->get(store, name)); + wasmMems.insert(std::make_pair(name, memory)); + TRACE("WASM SE Exported memory: %s", name.c_str()); } else if (std::holds_alternative(externType)) { @@ -198,7 +204,6 @@ class WasmStore TRACE("WASM SE Unknown export type"); } } - wasmInstances.insert(std::make_pair(wasmName, newInstance)); } catch (const wasmtime::Error &e) { @@ -206,44 +211,62 @@ class WasmStore } } - bool hasFunc(const std::string &qualifiedID) const + wasmtime::Instance getInstance() + { + return instance.value(); + } + + wasmtime::Span getData() + { + TRACE("WASM SE getData"); + auto found = wasmMems.find("memory"); + if (found == wasmMems.end()) + throw makeStringException(100, "Wasm memory not found"); + return found->second.data(store.context()); + } + + bool hasFunc(const std::string &name) const { - TRACE("WASM SE hasFunc"); - return wasmFuncs.find(qualifiedID) != wasmFuncs.end(); + auto found = wasmFuncs.find(name); + if (found == wasmFuncs.end()) + return false; + return true; } - wasmtime::Func getFunc(const std::string &qualifiedID) const + wasmtime::Func getFunc(const std::string &name) const { - TRACE("WASM SE getFunc"); - auto found = wasmFuncs.find(qualifiedID); + TRACE("WASM SE getFunc %s", name.c_str()); + auto found = wasmFuncs.find(name); if (found == wasmFuncs.end()) - throw makeStringExceptionV(100, "Wasm function not found: %s", qualifiedID.c_str()); + throw makeStringExceptionV(100, "Wasm function not found: %s", name.c_str()); return found->second; } - wasmtime::ValType::ListRef getFuncParams(const std::string &qualifiedID) + wasmtime::ValType::ListRef getFuncParams(const std::string &name) { - TRACE("WASM SE getFuncParams"); - auto func = getFunc(qualifiedID); + TRACE("WASM SE getFunc %s", name.c_str()); + auto func = getFunc(name); wasmtime::FuncType funcType = func.type(store.context()); return funcType->params(); } - wasmtime::ValType::ListRef getFuncResults(const std::string &qualifiedID) + wasmtime::ValType::ListRef getFuncResults(const std::string &name) { TRACE("WASM SE getFuncResults"); - auto func = getFunc(qualifiedID); + auto func = getFunc(name); wasmtime::FuncType funcType = func.type(store.context()); return funcType->results(); } - std::vector call(const std::string &qualifiedID, const std::vector ¶ms) + std::vector call(const std::string &name, const std::vector ¶ms) { - TRACE("WASM SE call"); - auto func = getFunc(qualifiedID); + TRACE("WASM SE call %s", name.c_str()); + auto func = getFunc(name); try { + TRACE("WASM SE call 111 %s", name.c_str()); auto retVal = func.call(store, params).unwrap(); + TRACE("WASM SE call 222 %s", name.c_str()); return retVal; } catch (const wasmtime::Trap &e) @@ -252,28 +275,34 @@ class WasmStore } } - std::vector callRealloc(const std::string &wasmName, const std::vector ¶ms) + std::unique_ptr createCallContext(Encoding encoding) { - TRACE("WASM SE callRealloc"); - return call(createQualifiedID(wasmName, "cabi_realloc"), params); + if (instanceContext == nullptr) + { + instanceContext = std::make_unique(trap, convert, realloc); + } + return instanceContext->createCallContext(getData(), encoding); } +}; - wasmtime::Span getData(const std::string &wasmName) +thread_local std::unordered_map> wasmStores; +WasmStore *getWasmStore(const std::string &wasmName) +{ + TRACE("WASM SE getWasmStore %s", wasmName.c_str()); + auto found = wasmStores.find(wasmName); + if (found == wasmStores.end()) { - TRACE("WASM SE getData"); - auto found = wasmMems.find(createQualifiedID(wasmName, "memory")); - if (found == wasmMems.end()) - throw makeStringExceptionV(100, "Wasm memory not found: %s", wasmName.c_str()); - return found->second.data(store.context()); + auto result = wasmStores.insert(std::make_pair(wasmName, std::make_unique(wasmName))); + return result.first->second.get(); } -}; -thread_local std::unique_ptr wasmStore = std::make_unique(); + return found->second.get(); +} class SecureFunction : public CInterfaceOf { + WasmStore *wasmStore = nullptr; std::string wasmName; std::string funcName; - std::string qualifiedID; std::vector args; std::vector wasmResults; @@ -290,12 +319,12 @@ class SecureFunction : public CInterfaceOf TRACE("WASM SE se:destructor"); // Garbage Collection --- - // Function results --- - auto gc_func_name = createQualifiedID(wasmName, "cabi_post_" + funcName); + auto gc_func_name = "cabi_post_" + funcName; if (wasmStore->hasFunc(gc_func_name)) { for (auto &result : wasmResults) { + TRACE("WASM SE se:destructor %s", gc_func_name.c_str()); wasmStore->call(gc_func_name, {result}); } } @@ -380,12 +409,9 @@ class SecureFunction : public CInterfaceOf virtual void bindUTF8Param(const char *name, size32_t chars, const char *val) { TRACE("WASM SE bindUTF8Param %s %d %s", name, chars, val); - auto bytes = rtlUtf8Size(chars, val); - auto memIdxVar = wasmStore->callRealloc(wasmName, {0, 0, 1, (int32_t)bytes}); - auto memIdx = memIdxVar[0].i32(); - auto mem = wasmStore->getData(wasmName); - memcpy(&mem[memIdx], val, bytes); - args.push_back(memIdx); + auto cx = mk_cx(); + auto [offset, bytes] = string::store_into_range(*cx, {Encoding::Utf8, (const char8_t *)val, rtlUtf8Size(chars, val)}); + args.push_back((int32_t)offset); args.push_back((int32_t)bytes); } virtual void bindUnicodeParam(const char *name, size32_t chars, const UChar *val) @@ -396,80 +422,88 @@ class SecureFunction : public CInterfaceOf rtlUnicodeToUtf8X(utfCharCount, utfText.refstr(), chars, val); bindUTF8Param(name, utfCharCount, utfText.getstr()); } - virtual void bindSetParam(const char *name, int elemType, size32_t elemSize, bool isAll, size32_t totalBytes, const void *setData) { TRACE("WASM SE bindSetParam %s %d %d %d %d %p", name, elemType, elemSize, isAll, totalBytes, setData); - throw makeStringException(200, "bindSetParam not implemented"); + if (isAll) + rtlFail(0, "wasmembed: Cannot pass ALL"); - type_vals typecode = (type_vals)elemType; - const byte *inData = (const byte *)setData; - const byte *endData = inData + totalBytes; - int numElems; - if (elemSize == UNKNOWN_LENGTH) + auto cx = mk_cx(); + switch ((type_vals)elemType) + { + case type_boolean: { - numElems = 0; - // Will need 2 passes to work out how many elements there are in the set :( + assert(elemSize == sizeof(bool_t)); + list_t bools; + const byte *inData = (const byte *)setData; + const byte *endData = inData + totalBytes; while (inData < endData) { - int thisSize; - switch (elemType) - { - case type_varstring: - thisSize = strlen((const char *)inData) + 1; - break; - case type_string: - thisSize = *(size32_t *)inData + sizeof(size32_t); - break; - case type_unicode: - thisSize = (*(size32_t *)inData) * sizeof(UChar) + sizeof(size32_t); - break; - case type_utf8: - thisSize = rtlUtf8Size(*(size32_t *)inData, inData + sizeof(size32_t)) + sizeof(size32_t); - break; - default: - rtlFail(0, "wasmembed: Unsupported parameter type"); - break; - } + size32_t thisSize = elemSize; + bools.push_back(*(const bool *)inData ? true : false); inData += thisSize; - numElems++; } - inData = (const byte *)setData; + auto [offset, size] = list::store_into_range(*cx, bools); + args.push_back(static_cast(offset)); + args.push_back(static_cast(size)); + break; } - else - numElems = totalBytes / elemSize; - - std::vector memIdxVar; - int32_t memIdx; - - switch (typecode) + case type_int: { - case type_boolean: - memIdxVar = wasmStore->callRealloc(wasmName, {0, 0, 1, (int32_t)numElems}); - memIdx = memIdxVar[0].i32(); + assert(elemSize == sizeof(int32_t)); + auto [offset, size] = list::store_into_range(*cx, list_t{(const int32_t *)setData, (const int32_t *)setData + (totalBytes / elemSize)}); + args.push_back(static_cast(offset)); + args.push_back(static_cast(size)); break; - default: - rtlFail(0, "wasmembed: Unsupported parameter type"); + } + case type_unsigned: + { + assert(elemSize == sizeof(uint32_t)); + auto [offset, size] = list::store_into_range(*cx, list_t{(const uint32_t *)setData, (const uint32_t *)setData + (totalBytes / elemSize)}); + args.push_back(static_cast(offset)); + args.push_back(static_cast(size)); break; } - - auto mem = wasmStore->getData(wasmName); - size32_t thisSize = elemSize; - for (int idx = 0; idx < numElems; idx++) + case type_real: + { + if (elemSize == sizeof(double)) + { + auto [offset, size] = list::store_into_range(*cx, list_t{(const float64_t *)setData, (const float64_t *)setData + (totalBytes / elemSize)}); + args.push_back(static_cast(offset)); + args.push_back(static_cast(size)); + } + else + { + auto [offset, size] = list::store_into_range(*cx, list_t{(const float32_t *)setData, (const float32_t *)setData + (totalBytes / elemSize)}); + args.push_back(static_cast(offset)); + args.push_back(static_cast(size)); + } + break; + } + case type_string: { - switch (typecode) + list_t strings; + const byte *inData = (const byte *)setData; + const byte *endData = inData + totalBytes; + while (inData < endData) { - case type_boolean: - mem[memIdx + idx] = *(bool *)inData; - break; - default: - rtlFail(0, "v8embed: Unsupported parameter type"); - break; + size32_t thisSize = elemSize; + if (elemSize == UNKNOWN_LENGTH) + { + thisSize = *(size32_t *)inData; + inData += sizeof(size32_t); + } + strings.push_back({Encoding::Utf8, (const char8_t *)inData, thisSize}); + inData += thisSize; } - inData += thisSize; + auto [offset, size] = list::store_into_range(*cx, strings); + args.push_back(static_cast(offset)); + args.push_back(static_cast(size)); + break; + } + default: + throw makeStringException(200, "bindSetParam not implemented"); } - args.push_back(memIdx); - args.push_back(numElems); } virtual void bindRowParam(const char *name, IOutputMetaData &metaVal, const byte *val) override @@ -515,51 +549,106 @@ class SecureFunction : public CInterfaceOf return wasmResults[0].i64(); return static_cast(wasmResults[0].i32()); } + + std::unique_ptr mk_cx() + { + TRACE("WASM SE mk_cx"); + return wasmStore->createCallContext(Encoding::Utf8); + } + virtual void getStringResult(size32_t &chars, char *&result) { TRACE("WASM SE getStringResult %zu", wasmResults.size()); auto ptr = wasmResults[0].i32(); - auto data = wasmStore->getData(wasmName); - uint32_t strPtr; - std::string encoding; - uint32_t bytes; - std::tie(strPtr, encoding, bytes) = load_string(data, ptr); - size32_t codepoints = rtlUtf8Length(bytes, &data[strPtr]); - rtlUtf8ToStrX(chars, result, codepoints, reinterpret_cast(&data[strPtr])); + auto cx = mk_cx(); + auto [encoding, strPtr, bytes] = string::load(*cx, ptr); + size32_t codePoints = rtlUtf8Length(bytes, strPtr); + rtlUtf8ToStrX(chars, result, codePoints, (const char *)strPtr); } virtual void getUTF8Result(size32_t &chars, char *&result) { TRACE("WASM SE getUTF8Result"); auto ptr = wasmResults[0].i32(); - auto data = wasmStore->getData(wasmName); - uint32_t strPtr; - std::string encoding; - uint32_t bytes; - std::tie(strPtr, encoding, bytes) = load_string(data, ptr); - chars = rtlUtf8Length(bytes, &data[strPtr]); - TRACE("WASM SE getUTF8Result %d %d", bytes, chars); + auto cx = mk_cx(); + auto [encoding, strPtr, bytes] = string::load(*cx, ptr); + chars = rtlUtf8Length(bytes, strPtr); result = (char *)rtlMalloc(bytes); - memcpy(result, &data[strPtr], bytes); + memcpy(result, strPtr, bytes); } virtual void getUnicodeResult(size32_t &chars, UChar *&result) { TRACE("WASM SE getUnicodeResult"); auto ptr = wasmResults[0].i32(); - auto data = wasmStore->getData(wasmName); - uint32_t strPtr; - std::string encoding; - uint32_t bytes; - std::tie(strPtr, encoding, bytes) = load_string(data, ptr); - unsigned numchars = rtlUtf8Length(bytes, &data[strPtr]); - rtlUtf8ToUnicodeX(chars, result, numchars, reinterpret_cast(&data[strPtr])); + auto cx = mk_cx(); + auto [encoding, strPtr, bytes] = string::load(*cx, ptr); + size32_t codePoints = rtlUtf8Length(bytes, strPtr); + rtlUtf8ToUnicodeX(chars, result, codePoints, (const char *)strPtr); } - virtual void getSetResult(bool &__isAllResult, size32_t &resultBytes, void *&result, int elemType, size32_t elemSize) + virtual void getSetResult(bool &isAllResult, size32_t &resultBytes, void *&result, int elemType, size32_t elemSize) { TRACE("WASM SE getSetResult %d %d %zu", elemType, elemSize, wasmResults.size()); + isAllResult = false; auto ptr = wasmResults[0].i32(); - auto data = wasmStore->getData(wasmName); - - throw makeStringException(200, "getSetResult not implemented"); + auto cx = mk_cx(); + switch (elemType) + { + case type_boolean: + { + auto list = list::load(*cx, ptr); + resultBytes = list.size(); + result = rtlMalloc(resultBytes); + std::copy(list.begin(), list.end(), reinterpret_cast(result)); + break; + } + case type_real: + { + if (elemSize == sizeof(float64_t)) + { + auto list = list::load(*cx, ptr); + resultBytes = list.size() * sizeof(float64_t); + result = rtlMalloc(resultBytes); + std::copy(list.begin(), list.end(), reinterpret_cast(result)); + } + else + { + assert(elemSize == sizeof(float32_t)); + auto list = list::load(*cx, ptr); + resultBytes = list.size() * sizeof(float32_t); + result = rtlMalloc(resultBytes); + std::copy(list.begin(), list.end(), reinterpret_cast(result)); + } + break; + } + case type_unsigned: + { + auto list = cmcpp::list::load(*cx, ptr); + resultBytes = list.size() * sizeof(uint32_t); + result = rtlMalloc(resultBytes); + memcpy(result, list.data(), resultBytes); + break; + } + case type_string: + { + auto list = cmcpp::list::load(*cx, ptr); + rtlRowBuilder out; + size32_t outBytes = 0; + byte *outData = NULL; + for (auto &item : list) + { + out.ensureAvailable(outBytes + item.byte_len + sizeof(size32_t)); + outData = out.getbytes() + outBytes; + *reinterpret_cast(outData) = item.byte_len; + rtlStrToStr(item.byte_len, outData + sizeof(size32_t), item.byte_len, item.ptr); + outBytes += item.byte_len + sizeof(size32_t); + } + resultBytes = outBytes; + result = out.detachdata(); + break; + } + default: + rtlFail(0, "wasmembed: Unsupported parameter type"); + break; + } } virtual IRowStream *getDatasetResult(IEngineRowAllocator *_resultAllocator) { @@ -604,23 +693,18 @@ class SecureFunction : public CInterfaceOf virtual void importFunction(size32_t lenChars, const char *qualifiedName) override { TRACE("WASM SE importFunction: %s", qualifiedName); - - qualifiedID = std::string(qualifiedName, lenChars); - std::tie(wasmName, funcName) = splitQualifiedID(qualifiedID); - - if (!wasmStore->hasInstance(wasmName)) - { - wasmStore->registerInstance(wasmName); - } + std::tie(wasmName, funcName) = splitQualifiedID({qualifiedName, lenChars}); + wasmStore = getWasmStore(wasmName); } virtual void callFunction() { - TRACE("WASM SE callFunction %s", qualifiedID.c_str()); - wasmResults = wasmStore->call(qualifiedID, args); + TRACE("WASM SE callFunction %s.%s", wasmName.c_str(), funcName.c_str()); + wasmResults = wasmStore->call(funcName, args); } }; IEmbedFunctionContext *createISecureEnclave(ICodeContext *codeCtx) { + TRACE("createISecureEnclave"); return new SecureFunction(codeCtx); } diff --git a/plugins/wasmembed/test/build.sh b/plugins/wasmembed/test/build.sh index 7f91603130c..10746280fe4 100755 --- a/plugins/wasmembed/test/build.sh +++ b/plugins/wasmembed/test/build.sh @@ -6,7 +6,7 @@ ROOT_DIR="${SCRIPT_DIR}/../../.." echo "SCRIPT_DIR: ${SCRIPT_DIR}" echo "ROOT_DIR: $ROOT_DIR" -docker build --progress plain -f "${SCRIPT_DIR}/wasm32-wasi/Dockerfile" \ +docker buildx build --progress plain "${SCRIPT_DIR}/wasm32-wasi/Dockerfile" \ -t wasm32-wasi:latest \ "${SCRIPT_DIR}/." diff --git a/plugins/wasmembed/test/hpcc-scalar-test.wit b/plugins/wasmembed/test/hpcc-scalar-test.wit index 3b632d46ea1..0e2772dbbc1 100644 --- a/plugins/wasmembed/test/hpcc-scalar-test.wit +++ b/plugins/wasmembed/test/hpcc-scalar-test.wit @@ -13,7 +13,8 @@ world hpcc-scalar-test { guests dispose all params as needed hosts call cabi_post_XXX to dispose "results" as needed */ - export bool-test: func(a: bool, b: bool) -> bool + export bool-and-test: func(a: bool, b: bool) -> bool + export bool-or-test: func(a: bool, b: bool) -> bool export float32-test: func(a: float32, b: float32) -> float32 export float64-test: func(a: float64, b: float64) -> float64 export u8-test: func(a: u8, b: u8) -> u8 @@ -24,5 +25,21 @@ world hpcc-scalar-test { export s16-test: func(a: s16, b: s16) -> s16 export s32-test: func(a: s32, b: s32) -> s32 export s64-test: func(a: s64, b: s64) -> s64 + export utf8-string-test: func(a: string, b: string) -> string + + export list-test-bool: func() -> list + export list-bool-test-bool: func(a: list) -> list + + export list-test-u32: func() -> list + export list-u32-test-u32: func(a: list) -> list + + export list-test-float32: func() -> list + export list-float32-test-float32: func(a: list) -> list + + export list-test-string: func() -> list + export list-string-test-string: func(a: list) -> list + + export list-test-list: func() -> list> + // export list-list-test-list: func(a: list>) -> list> } diff --git a/plugins/wasmembed/test/main.cpp b/plugins/wasmembed/test/main.cpp index c0d7420e5f1..e4250cc2e4b 100644 --- a/plugins/wasmembed/test/main.cpp +++ b/plugins/wasmembed/test/main.cpp @@ -1,6 +1,7 @@ #include "hpcc_scalar_test.h" #include +#include void dbglog(const std::string str) { @@ -9,10 +10,14 @@ void dbglog(const std::string str) hpcc_scalar_test_dbglog(&msg); } -bool hpcc_scalar_test_bool_test(bool a, bool b) +bool hpcc_scalar_test_bool_and_test(bool a, bool b) { return a && b; } +bool hpcc_scalar_test_bool_or_test(bool a, bool b) +{ + return a || b; +} float hpcc_scalar_test_float32_test(float a, float b) { return a + b; @@ -68,3 +73,116 @@ void hpcc_scalar_test_utf8_string_test(hpcc_scalar_test_string_t *a, hpcc_scalar dbglog(std::to_string(++tally) + ": " + r); hpcc_scalar_test_string_dup(ret, r.c_str()); } + +// Helper Functions +// void hpcc_scalar_test_list_u32_free(hpcc_scalar_test_list_u32_t *ptr); +// void hpcc_scalar_test_string_set(hpcc_scalar_test_string_t *ret, const char*s); +// void hpcc_scalar_test_string_dup(hpcc_scalar_test_string_t *ret, const char*s); +// void hpcc_scalar_test_string_free(hpcc_scalar_test_string_t *ret); + +void hpcc_scalar_test_list_test_bool(hpcc_scalar_test_list_bool_t *ret) +{ + ret->len = 8; + ret->ptr = (bool *)malloc(ret->len * sizeof(bool)); + for (size_t i = 0; i < ret->len; ++i) + { + ret->ptr[i] = i % 2 == 0; + } +} + +void hpcc_scalar_test_list_bool_test_bool(hpcc_scalar_test_list_bool_t *a, hpcc_scalar_test_list_bool_t *ret) +{ + std::vector v1(a->ptr, a->ptr + a->len); + hpcc_scalar_test_list_bool_free(a); + ret->len = v1.size(); + ret->ptr = (bool *)malloc(ret->len * sizeof(bool)); + for (size_t i = 0; i < ret->len; ++i) + { + ret->ptr[ret->len - i - 1] = v1[i]; + } +} + +void hpcc_scalar_test_list_test_u32(hpcc_scalar_test_list_u32_t *ret) +{ + ret->len = 4; + ret->ptr = (uint32_t *)malloc(ret->len * sizeof(uint32_t)); + for (size_t i = 0; i < ret->len; ++i) + { + ret->ptr[i] = i; + } +} + +void hpcc_scalar_test_list_u32_test_u32(hpcc_scalar_test_list_u32_t *a, hpcc_scalar_test_list_u32_t *ret) +{ + std::vector v1(a->ptr, a->ptr + a->len); + hpcc_scalar_test_list_u32_free(a); + ret->len = v1.size(); + ret->ptr = (uint32_t *)malloc(ret->len * sizeof(uint32_t)); + for (size_t i = 0; i < ret->len; ++i) + { + ret->ptr[ret->len - i - 1] = v1[i]; + } +} + +void hpcc_scalar_test_list_test_float32(hpcc_scalar_test_list_float32_t *ret) +{ + ret->len = 4; + ret->ptr = (float *)malloc(ret->len * sizeof(float)); + for (size_t i = 0; i < ret->len; ++i) + { + ret->ptr[i] = i + 0.33; + } +} + +void hpcc_scalar_test_list_float32_test_float32(hpcc_scalar_test_list_float32_t *a, hpcc_scalar_test_list_float32_t *ret) +{ + std::vector v1(a->ptr, a->ptr + a->len); + hpcc_scalar_test_list_float32_free(a); + ret->len = v1.size(); + ret->ptr = (float *)malloc(ret->len * sizeof(float)); + for (size_t i = 0; i < ret->len; ++i) + { + ret->ptr[ret->len - i - 1] = v1[i]; + } +} + +void hpcc_scalar_test_list_test_string(hpcc_scalar_test_list_string_t *ret) +{ + ret->len = 4; + ret->ptr = (hpcc_scalar_test_string_t *)malloc(ret->len * sizeof(hpcc_scalar_test_string_t)); + for (size_t i = 0; i < ret->len; ++i) + { + std::string str = "test-"; + str += std::to_string(i); + hpcc_scalar_test_string_dup(&ret->ptr[i], str.c_str()); + } +} + +void hpcc_scalar_test_list_string_test_string(hpcc_scalar_test_list_string_t *a, hpcc_scalar_test_list_string_t *ret) +{ + std::vector v1(a->ptr, a->ptr + a->len); + ret->len = v1.size(); + ret->ptr = (hpcc_scalar_test_string_t *)malloc(ret->len * sizeof(hpcc_scalar_test_string_t)); + for (size_t i = 0; i < ret->len; ++i) + { + hpcc_scalar_test_string_dup(&ret->ptr[ret->len - i - 1], std::string(v1[i].ptr, v1[i].len).c_str()); + } + hpcc_scalar_test_list_string_free(a); +} + +void hpcc_scalar_test_list_test_list(hpcc_scalar_test_list_list_string_t *ret) +{ + ret->len = 4; + ret->ptr = (hpcc_scalar_test_list_string_t *)malloc(ret->len * sizeof(hpcc_scalar_test_list_string_t)); + for (size_t i = 0; i < ret->len; ++i) + { + ret->ptr->len = 4; + ret->ptr->ptr = (hpcc_scalar_test_string_t *)malloc(ret->ptr->len * sizeof(hpcc_scalar_test_string_t)); + for (size_t i = 0; i < ret->ptr->len; ++i) + { + std::string str = "test-"; + str += std::to_string(i); + hpcc_scalar_test_string_dup(&ret->ptr->ptr[i], str.c_str()); + } + } +} diff --git a/plugins/wasmembed/test/wasm32-wasi/Dockerfile b/plugins/wasmembed/test/wasm32-wasi/Dockerfile index 5746ac1eb34..d8febb8cf6a 100644 --- a/plugins/wasmembed/test/wasm32-wasi/Dockerfile +++ b/plugins/wasmembed/test/wasm32-wasi/Dockerfile @@ -21,13 +21,14 @@ SHELL ["/bin/bash", "--login", "-c"] WORKDIR /hpcc-dev -ARG WIT_VERSION=0.9.0 +# List of current vertsion can be found in https://github.com/bytecodealliance/wit-bindgen/releases --- +ARG WIT_VERSION=0.36.0 RUN cargo install wasm-tools && \ cargo install --git https://github.com/bytecodealliance/wit-bindgen --tag wit-bindgen-cli-${WIT_VERSION} wit-bindgen-cli && \ curl https://wasmtime.dev/install.sh -sSf | bash -# List of current vertsion can be found in https://github.com/bytecodealliance/wit-bindgen/releases --- -ARG WASI_VERSION=20 +# List of current vertsion can be found in https://github.com/WebAssembly/wasi-sdk/releases --- +ARG WASI_VERSION=25 ARG WASI_MINOR_VERSION=0 ARG WASI_VERSION_FULL=${WASI_VERSION}.${WASI_MINOR_VERSION} RUN wget https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-${WASI_VERSION}/wasi-sdk-${WASI_VERSION_FULL}-linux.tar.gz diff --git a/testing/regress/ecl/key/wasmembed.xml b/testing/regress/ecl/key/wasmembed.xml index 4400c029084..67c8097db10 100644 --- a/testing/regress/ecl/key/wasmembed.xml +++ b/testing/regress/ecl/key/wasmembed.xml @@ -100,3 +100,39 @@ true + + true + + + true + + + true + + + true + + + true + + + true + + + true + + + true + + + true + + + true + + + true + + + true + diff --git a/testing/regress/ecl/wasmembed.ecl b/testing/regress/ecl/wasmembed.ecl index 75d885e777e..818edd92f8a 100644 --- a/testing/regress/ecl/wasmembed.ecl +++ b/testing/regress/ecl/wasmembed.ecl @@ -1,6 +1,7 @@ import wasm; -boolean boolTest (boolean a, boolean b) := IMPORT(wasm, 'wasmembed.bool-test'); +boolean boolAndTest (boolean a, boolean b) := IMPORT(wasm, 'wasmembed.bool-and-test'); +boolean boolOrTest (boolean a, boolean b) := IMPORT(wasm, 'wasmembed.bool-or-test'); real4 float32Test (real4 a, real4 b) := IMPORT(wasm, 'wasmembed.float32-test'); real8 float64Test (real8 a, real8 b) := IMPORT(wasm, 'wasmembed.float64-test'); unsigned1 u8Test (unsigned1 a, unsigned1 b) := IMPORT(wasm, 'wasmembed.u8-test'); @@ -11,19 +12,45 @@ integer1 s8Test (integer1 a, integer1 b) := IMPORT(wasm, 'wasmembed.s8-test'); integer2 s16Test (integer2 a, integer2 b) := IMPORT(wasm, 'wasmembed.s16-test'); integer4 s32Test (integer4 a, integer4 b) := IMPORT(wasm, 'wasmembed.s32-test'); integer8 s64Test (integer8 a, integer8 b) := IMPORT(wasm, 'wasmembed.s64-test'); -string stringTest (string a, string b) := IMPORT(wasm, 'wasmembed.string-test'); -string12 string5Test (string5 a, string5 b) := IMPORT(wasm, 'wasmembed.string-test'); -varstring varstringTest (varstring a, varstring b) := IMPORT(wasm, 'wasmembed.string-test'); -unicode12 unicode5Test (unicode5 a, unicode5 b) := IMPORT(wasm, 'wasmembed.string-test'); -unicode unicodeTest (unicode a, unicode b) := IMPORT(wasm, 'wasmembed.string-test'); -utf8_12 utf8_5Test (utf8_5 a, utf8_5 b) := IMPORT(wasm, 'wasmembed.string-test'); -utf8 utf8Test (utf8 a, utf8 b) := IMPORT(wasm, 'wasmembed.string-test'); + +string stringTest (string a, string b) := IMPORT(wasm, 'wasmembed.utf8-string-test'); +string12 string5Test (string5 a, string5 b) := IMPORT(wasm, 'wasmembed.utf8-string-test'); +varstring varstringTest (varstring a, varstring b) := IMPORT(wasm, 'wasmembed.utf8-string-test'); +unicode12 unicode5Test (unicode5 a, unicode5 b) := IMPORT(wasm, 'wasmembed.utf8-string-test'); +unicode unicodeTest (unicode a, unicode b) := IMPORT(wasm, 'wasmembed.utf8-string-test'); +utf8_12 utf8_5Test (utf8_5 a, utf8_5 b) := IMPORT(wasm, 'wasmembed.utf8-string-test'); +utf8 utf8Test (utf8 a, utf8 b) := IMPORT(wasm, 'wasmembed.utf8-string-test'); + +set of boolean listTestBoolean () := IMPORT(wasm, 'wasmembed.list-test-bool'); +set of boolean listBooleanTestBoolean (set of boolean a) := IMPORT(wasm, 'wasmembed.list-bool-test-bool'); +set of unsigned4 listTestU32 () := IMPORT(wasm, 'wasmembed.list-test-u32'); +set of unsigned4 listU32TestU32 (set of unsigned4 a) := IMPORT(wasm, 'wasmembed.list-u32-test-u32'); +set of real4 listTestFloat32 () := IMPORT(wasm, 'wasmembed.list-test-float32'); +set of real4 listFloat32TestFloat32 (set of real4 a) := IMPORT(wasm, 'wasmembed.list-float32-test-float32'); +set of string listTestString () := IMPORT(wasm, 'wasmembed.list-test-string'); +set of string listStringTestString (set of string a) := IMPORT(wasm, 'wasmembed.list-string-test-string'); + +listTestBoolean() = [true, false, true, false, true, false, true, false]; +listBooleanTestBoolean([false, false, true, true, false, true]) = [true, false, true, true, false, false]; + +listTestFloat32() = [(real4)0.33, (real4)1.33, (real4)2.33, (real4)3.33]; +listFloat32TestFloat32([1.23, 2.34, 3.45, 4.56]) = [(real4)4.56, (real4)3.45, (real4)2.34, (real4)1.23]; + +listTestU32() = [0, 1, 2, 3]; +listU32TestU32([0, 1, 2, 3]) = [3, 2, 1, 0]; + +listTestString()=['test-0', 'test-1', 'test-2', 'test-3']; +listStringTestString(['aaa', 'bbb', 'ccc', 'ddd', 'eee', 'fff', 'ggg', 'hhh', 'iii']) = ['iii', 'hhh', 'ggg', 'fff', 'eee', 'ddd', 'ccc', 'bbb', 'aaa']; // '--- bool ---'; -boolTest(false, false) = (false AND false); -boolTest(false, true) = (false AND true); -boolTest(true, false) = (true AND false); -boolTest(true, true) = (true AND true); +boolAndTest(false, false) = (false AND false); +boolAndTest(false, true) = (false AND true); +boolAndTest(true, false) = (true AND false); +boolAndTest(true, true) = (true AND true); +boolOrTest(false, false) = (false OR false); +boolOrTest(false, true) = (false OR true); +boolOrTest(true, false) = (true OR false); +boolOrTest(true, true) = (true OR true); // '--- float ---'; ROUND(float32Test((real4)1234.1234, (real4)2345.2345), 3) = ROUND((real4)((real4)1234.1234 + (real4)2345.2345), 3); float64Test(123456789.123456789, 23456789.23456789) = (real8)((real8)123456789.123456789 + (real8)23456789.23456789); diff --git a/testing/regress/ecl/wasmembed.wasm b/testing/regress/ecl/wasmembed.wasm old mode 100644 new mode 100755 index f2f32c33c95..c9534660e81 Binary files a/testing/regress/ecl/wasmembed.wasm and b/testing/regress/ecl/wasmembed.wasm differ diff --git a/vcpkg.json.in b/vcpkg.json.in index f2e23d68ff5..0e64ec61702 100644 --- a/vcpkg.json.in +++ b/vcpkg.json.in @@ -53,6 +53,7 @@ "platform": "@VCPKG_ELASTICSTACK_CLIENT@" }, "curl", + "doctest", { "name": "elasticlient", "platform": "@VCPKG_ELASTICSTACK_CLIENT@ & !windows"