From 5581a669ec6915c533ab7cd5ecda2734b9e9392c Mon Sep 17 00:00:00 2001 From: "seockho.kim" Date: Thu, 26 Sep 2024 16:31:24 +0900 Subject: [PATCH] [runtime] Update circle schema to support RmsNorm and RoPE operation This commit updates circle schema header file to support RmsNorm and RoPE ONE-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com --- .../include/circle_schema_generated.h | 394 +++++++++++++++++- 1 file changed, 385 insertions(+), 9 deletions(-) diff --git a/runtime/libs/circle-schema/include/circle_schema_generated.h b/runtime/libs/circle-schema/include/circle_schema_generated.h index b481875e9cd..bf68cc1c185 100644 --- a/runtime/libs/circle-schema/include/circle_schema_generated.h +++ b/runtime/libs/circle-schema/include/circle_schema_generated.h @@ -667,6 +667,14 @@ struct InstanceNormOptions; struct InstanceNormOptionsBuilder; struct InstanceNormOptionsT; +struct RmsNormOptions; +struct RmsNormOptionsBuilder; +struct RmsNormOptionsT; + +struct RoPEOptions; +struct RoPEOptionsBuilder; +struct RoPEOptionsT; + struct OperatorCode; struct OperatorCodeBuilder; struct OperatorCodeT; @@ -1097,6 +1105,8 @@ inline const char *EnumNameCompressionType(CompressionType e) enum BuiltinOperator : int32_t { + BuiltinOperator_ROPE = -7, + BuiltinOperator_RMS_NORM = -6, BuiltinOperator_GRU = -5, BuiltinOperator_BCQ_GATHER = -4, BuiltinOperator_BCQ_FULLY_CONNECTED = -3, @@ -1307,13 +1317,15 @@ enum BuiltinOperator : int32_t BuiltinOperator_DILATE = 203, BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR = 204, BuiltinOperator_REDUCE_WINDOW = 205, - BuiltinOperator_MIN = BuiltinOperator_GRU, + BuiltinOperator_MIN = BuiltinOperator_ROPE, BuiltinOperator_MAX = BuiltinOperator_REDUCE_WINDOW }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[210] +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[212] { - static const BuiltinOperator values[] = {BuiltinOperator_GRU, + static const BuiltinOperator values[] = {BuiltinOperator_ROPE, + BuiltinOperator_RMS_NORM, + BuiltinOperator_GRU, BuiltinOperator_BCQ_GATHER, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOperator_INSTANCE_NORM, @@ -1528,7 +1540,9 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[210] inline const char *const *EnumNamesBuiltinOperator() { - static const char *const names[212] = {"GRU", + static const char *const names[214] = {"ROPE", + "RMS_NORM", + "GRU", "BCQ_GATHER", "BCQ_FULLY_CONNECTED", "INSTANCE_NORM", @@ -1745,9 +1759,9 @@ inline const char *const *EnumNamesBuiltinOperator() inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (::flatbuffers::IsOutRange(e, BuiltinOperator_GRU, BuiltinOperator_REDUCE_WINDOW)) + if (::flatbuffers::IsOutRange(e, BuiltinOperator_ROPE, BuiltinOperator_REDUCE_WINDOW)) return ""; - const size_t index = static_cast(e) - static_cast(BuiltinOperator_GRU); + const size_t index = static_cast(e) - static_cast(BuiltinOperator_ROPE); return EnumNamesBuiltinOperator()[index]; } @@ -1880,6 +1894,8 @@ enum BuiltinOptions : uint8_t BuiltinOptions_BitcastOptions = 124, BuiltinOptions_BitwiseXorOptions = 125, BuiltinOptions_RightShiftOptions = 126, + BuiltinOptions_RoPEOptions = 249, + BuiltinOptions_RmsNormOptions = 250, BuiltinOptions_GRUOptions = 251, BuiltinOptions_BCQGatherOptions = 252, BuiltinOptions_BCQFullyConnectedOptions = 253, @@ -1888,7 +1904,7 @@ enum BuiltinOptions : uint8_t BuiltinOptions_MAX = BuiltinOptions_InstanceNormOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[131] +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[133] { static const BuiltinOptions values[] = {BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -2017,6 +2033,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[131] BuiltinOptions_BitcastOptions, BuiltinOptions_BitwiseXorOptions, BuiltinOptions_RightShiftOptions, + BuiltinOptions_RoPEOptions, + BuiltinOptions_RmsNormOptions, BuiltinOptions_GRUOptions, BuiltinOptions_BCQGatherOptions, BuiltinOptions_BCQFullyConnectedOptions, @@ -2275,8 +2293,8 @@ inline const char *const *EnumNamesBuiltinOptions() "", "", "", - "", - "", + "RoPEOptions", + "RmsNormOptions", "GRUOptions", "BCQGatherOptions", "BCQFullyConnectedOptions", @@ -2928,6 +2946,16 @@ template <> struct BuiltinOptionsTraits static const BuiltinOptions enum_value = BuiltinOptions_RightShiftOptions; }; +template <> struct BuiltinOptionsTraits +{ + static const BuiltinOptions enum_value = BuiltinOptions_RoPEOptions; +}; + +template <> struct BuiltinOptionsTraits +{ + static const BuiltinOptions enum_value = BuiltinOptions_RmsNormOptions; +}; + template <> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_GRUOptions; @@ -3583,6 +3611,16 @@ template <> struct BuiltinOptionsUnionTraits static const BuiltinOptions enum_value = BuiltinOptions_RightShiftOptions; }; +template <> struct BuiltinOptionsUnionTraits +{ + static const BuiltinOptions enum_value = BuiltinOptions_RoPEOptions; +}; + +template <> struct BuiltinOptionsUnionTraits +{ + static const BuiltinOptions enum_value = BuiltinOptions_RmsNormOptions; +}; + template <> struct BuiltinOptionsUnionTraits { static const BuiltinOptions enum_value = BuiltinOptions_GRUOptions; @@ -5100,6 +5138,29 @@ struct BuiltinOptionsUnion ? reinterpret_cast(value) : nullptr; } + circle::RoPEOptionsT *AsRoPEOptions() + { + return type == BuiltinOptions_RoPEOptions ? reinterpret_cast(value) + : nullptr; + } + const circle::RoPEOptionsT *AsRoPEOptions() const + { + return type == BuiltinOptions_RoPEOptions + ? reinterpret_cast(value) + : nullptr; + } + circle::RmsNormOptionsT *AsRmsNormOptions() + { + return type == BuiltinOptions_RmsNormOptions + ? reinterpret_cast(value) + : nullptr; + } + const circle::RmsNormOptionsT *AsRmsNormOptions() const + { + return type == BuiltinOptions_RmsNormOptions + ? reinterpret_cast(value) + : nullptr; + } circle::GRUOptionsT *AsGRUOptions() { return type == BuiltinOptions_GRUOptions ? reinterpret_cast(value) @@ -6141,6 +6202,34 @@ inline const char *EnumNameReduceWindowFunction(ReduceWindowFunction e) return EnumNamesReduceWindowFunction()[index]; } +enum RoPEMode : int32_t +{ + RoPEMode_GPT_NEOX = 0, + RoPEMode_GPT_J = 1, + RoPEMode_MIN = RoPEMode_GPT_NEOX, + RoPEMode_MAX = RoPEMode_GPT_J +}; + +inline const RoPEMode (&EnumValuesRoPEMode())[2] +{ + static const RoPEMode values[] = {RoPEMode_GPT_NEOX, RoPEMode_GPT_J}; + return values; +} + +inline const char *const *EnumNamesRoPEMode() +{ + static const char *const names[3] = {"GPT_NEOX", "GPT_J", nullptr}; + return names; +} + +inline const char *EnumNameRoPEMode(RoPEMode e) +{ + if (::flatbuffers::IsOutRange(e, RoPEMode_GPT_NEOX, RoPEMode_GPT_J)) + return ""; + const size_t index = static_cast(e); + return EnumNamesRoPEMode()[index]; +} + enum CustomOptionsFormat : int8_t { CustomOptionsFormat_FLEXBUFFERS = 0, @@ -17780,6 +17869,132 @@ ::flatbuffers::Offset CreateInstanceNormOptions(::flatbuffers::FlatBufferBuilder &_fbb, const InstanceNormOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct RmsNormOptionsT : public ::flatbuffers::NativeTable +{ + typedef RmsNormOptions TableType; + float epsilon = 0.0f; +}; + +struct RmsNormOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef RmsNormOptionsT NativeTableType; + typedef RmsNormOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_EPSILON = 4 + }; + float epsilon() const { return GetField(VT_EPSILON, 0.0f); } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField(verifier, VT_EPSILON, 4) && + verifier.EndTable(); + } + RmsNormOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RmsNormOptionsT *_o, + const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset + Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RmsNormOptionsT *_o, + const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RmsNormOptionsBuilder +{ + typedef RmsNormOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_epsilon(float epsilon) + { + fbb_.AddElement(RmsNormOptions::VT_EPSILON, epsilon, 0.0f); + } + explicit RmsNormOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset +CreateRmsNormOptions(::flatbuffers::FlatBufferBuilder &_fbb, float epsilon = 0.0f) +{ + RmsNormOptionsBuilder builder_(_fbb); + builder_.add_epsilon(epsilon); + return builder_.Finish(); +} + +::flatbuffers::Offset +CreateRmsNormOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RmsNormOptionsT *_o, + const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct RoPEOptionsT : public ::flatbuffers::NativeTable +{ + typedef RoPEOptions TableType; + circle::RoPEMode mode = circle::RoPEMode_GPT_NEOX; +}; + +struct RoPEOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table +{ + typedef RoPEOptionsT NativeTableType; + typedef RoPEOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE + { + VT_MODE = 4 + }; + circle::RoPEMode mode() const + { + return static_cast(GetField(VT_MODE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const + { + return VerifyTableStart(verifier) && VerifyField(verifier, VT_MODE, 4) && + verifier.EndTable(); + } + RoPEOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RoPEOptionsT *_o, + const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset + Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RoPEOptionsT *_o, + const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RoPEOptionsBuilder +{ + typedef RoPEOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_mode(circle::RoPEMode mode) + { + fbb_.AddElement(RoPEOptions::VT_MODE, static_cast(mode), 0); + } + explicit RoPEOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) + { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() + { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset +CreateRoPEOptions(::flatbuffers::FlatBufferBuilder &_fbb, + circle::RoPEMode mode = circle::RoPEMode_GPT_NEOX) +{ + RoPEOptionsBuilder builder_(_fbb); + builder_.add_mode(mode); + return builder_.Finish(); +} + +::flatbuffers::Offset +CreateRoPEOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RoPEOptionsT *_o, + const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public ::flatbuffers::NativeTable { typedef OperatorCode TableType; @@ -18700,6 +18915,18 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table ? static_cast(builtin_options()) : nullptr; } + const circle::RoPEOptions *builtin_options_as_RoPEOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_RoPEOptions + ? static_cast(builtin_options()) + : nullptr; + } + const circle::RmsNormOptions *builtin_options_as_RmsNormOptions() const + { + return builtin_options_type() == circle::BuiltinOptions_RmsNormOptions + ? static_cast(builtin_options()) + : nullptr; + } const circle::GRUOptions *builtin_options_as_GRUOptions() const { return builtin_options_type() == circle::BuiltinOptions_GRUOptions @@ -19723,6 +19950,18 @@ Operator::builtin_options_as() const return builtin_options_as_RightShiftOptions(); } +template <> +inline const circle::RoPEOptions *Operator::builtin_options_as() const +{ + return builtin_options_as_RoPEOptions(); +} + +template <> +inline const circle::RmsNormOptions *Operator::builtin_options_as() const +{ + return builtin_options_as_RmsNormOptions(); +} + template <> inline const circle::GRUOptions *Operator::builtin_options_as() const { @@ -29000,6 +29239,91 @@ CreateInstanceNormOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Instance return circle::CreateInstanceNormOptions(_fbb, _epsilon, _fused_activation_function); } +inline RmsNormOptionsT * +RmsNormOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const +{ + auto _o = std::unique_ptr(new RmsNormOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void RmsNormOptions::UnPackTo(RmsNormOptionsT *_o, + const ::flatbuffers::resolver_function_t *_resolver) const +{ + (void)_o; + (void)_resolver; + { + auto _e = epsilon(); + _o->epsilon = _e; + } +} + +inline ::flatbuffers::Offset +RmsNormOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RmsNormOptionsT *_o, + const ::flatbuffers::rehasher_function_t *_rehasher) +{ + return CreateRmsNormOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset +CreateRmsNormOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RmsNormOptionsT *_o, + const ::flatbuffers::rehasher_function_t *_rehasher) +{ + (void)_rehasher; + (void)_o; + struct _VectorArgs + { + ::flatbuffers::FlatBufferBuilder *__fbb; + const RmsNormOptionsT *__o; + const ::flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _epsilon = _o->epsilon; + return circle::CreateRmsNormOptions(_fbb, _epsilon); +} + +inline RoPEOptionsT *RoPEOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const +{ + auto _o = std::unique_ptr(new RoPEOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void RoPEOptions::UnPackTo(RoPEOptionsT *_o, + const ::flatbuffers::resolver_function_t *_resolver) const +{ + (void)_o; + (void)_resolver; + { + auto _e = mode(); + _o->mode = _e; + } +} + +inline ::flatbuffers::Offset +RoPEOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RoPEOptionsT *_o, + const ::flatbuffers::rehasher_function_t *_rehasher) +{ + return CreateRoPEOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset +CreateRoPEOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RoPEOptionsT *_o, + const ::flatbuffers::rehasher_function_t *_rehasher) +{ + (void)_rehasher; + (void)_o; + struct _VectorArgs + { + ::flatbuffers::FlatBufferBuilder *__fbb; + const RoPEOptionsT *__o; + const ::flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _mode = _o->mode; + return circle::CreateRoPEOptions(_fbb, _mode); +} + inline OperatorCodeT * OperatorCode::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { @@ -30834,6 +31158,16 @@ inline bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void * auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_RoPEOptions: + { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RmsNormOptions: + { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } case BuiltinOptions_GRUOptions: { auto ptr = reinterpret_cast(obj); @@ -31514,6 +31848,16 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_RoPEOptions: + { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RmsNormOptions: + { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } case BuiltinOptions_GRUOptions: { auto ptr = reinterpret_cast(obj); @@ -32176,6 +32520,16 @@ BuiltinOptionsUnion::Pack(::flatbuffers::FlatBufferBuilder &_fbb, auto ptr = reinterpret_cast(value); return CreateRightShiftOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_RoPEOptions: + { + auto ptr = reinterpret_cast(value); + return CreateRoPEOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RmsNormOptions: + { + auto ptr = reinterpret_cast(value); + return CreateRmsNormOptions(_fbb, ptr, _rehasher).Union(); + } case BuiltinOptions_GRUOptions: { auto ptr = reinterpret_cast(value); @@ -32897,6 +33251,16 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) new circle::RightShiftOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_RoPEOptions: + { + value = new circle::RoPEOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RmsNormOptions: + { + value = new circle::RmsNormOptionsT(*reinterpret_cast(u.value)); + break; + } case BuiltinOptions_GRUOptions: { value = new circle::GRUOptionsT(*reinterpret_cast(u.value)); @@ -33685,6 +34049,18 @@ inline void BuiltinOptionsUnion::Reset() delete ptr; break; } + case BuiltinOptions_RoPEOptions: + { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RmsNormOptions: + { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } case BuiltinOptions_GRUOptions: { auto ptr = reinterpret_cast(value);