diff --git a/zeno/include/zeno/extra/ShaderNode.h b/zeno/include/zeno/extra/ShaderNode.h index 90cee75307..b83a032fa2 100644 --- a/zeno/include/zeno/extra/ShaderNode.h +++ b/zeno/include/zeno/extra/ShaderNode.h @@ -21,6 +21,30 @@ struct ShaderNode : INode { ZENO_API ShaderNode(); ZENO_API ~ShaderNode() override; }; +using ShaderDataTypeList = std::tuple; + +inline const auto ShaderDataTypeNames = std::array { "bool", "int", "uint", "float", "vec2", "vec3", "vec4" }; + +static const inline std::map TypeHint { + + {"bool", 0}, + {"int", 10}, + {"uint", 11}, + + {"float", 1}, + {"vec2", 2}, + {"vec3", 3}, + {"vec4", 4} +}; + +static const inline std::map TypeHintReverse = []() { + + std::map result {}; + for (auto& [k, v] : TypeHint) { + result[v] = k; + } + return result; +} (); template struct ShaderNodeClone : ShaderNode { diff --git a/zeno/src/extra/ShaderNode.cpp b/zeno/src/extra/ShaderNode.cpp index fd9ea7f14b..d3016a2a4e 100644 --- a/zeno/src/extra/ShaderNode.cpp +++ b/zeno/src/extra/ShaderNode.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -38,18 +39,21 @@ ZENO_API int EmissionPass::determineType(IObject *object) { int type = std::visit([&] (auto const &value) -> int { using T = std::decay_t; - if constexpr (std::is_same_v) { - return 1; - } else if constexpr (std::is_same_v) { - return 2; - } else if constexpr (std::is_same_v) { - return 3; - } else if constexpr (std::is_same_v) { - return 4; - } else { - throw zeno::Exception("bad numeric object type: " + (std::string)typeid(T).name()); - } + size_t typeIdx = 0; + + zeno::static_for<0, std::tuple_size_v>([&] (auto i) { + using ThisType = std::tuple_element_t; + + if (std::is_same_v) { + typeIdx = i; + return true; + } + return false; + }); + + return TypeHint.at(ShaderDataTypeNames.at(typeIdx)); }, num->value); + constmap[num] = constants.size(); constants.push_back(ConstInfo{type, num->value}); return type; @@ -126,8 +130,7 @@ ZENO_API std::string EmissionPass::getCommonCode() const { } ZENO_API std::string EmissionPass::typeNameOf(int type) const { - if (type == 1) return "float"; - else return (backend == HLSL ? "float" : "vec") + std::to_string(type); + return TypeHintReverse.at(type); } ZENO_API std::string EmissionPass::collectDefs() const { diff --git a/zeno/src/nodes/mtl/ShaderAttrs.cpp b/zeno/src/nodes/mtl/ShaderAttrs.cpp index 7af675c792..58342a9720 100644 --- a/zeno/src/nodes/mtl/ShaderAttrs.cpp +++ b/zeno/src/nodes/mtl/ShaderAttrs.cpp @@ -55,9 +55,7 @@ static std::string dataTypeListString() { struct ShaderInputAttr : ShaderNodeClone { virtual int determineType(EmissionPass *em) override { auto type = get_input2("type"); - const char *tab[] = {"float", "vec2", "vec3", "vec4"}; - auto idx = std::find(std::begin(tab), std::end(tab), type) - std::begin(tab); - return idx + 1; + return TypeHint.at(type); } virtual void emitCode(EmissionPass *em) override { @@ -75,7 +73,7 @@ struct ShaderInputAttr : ShaderNodeClone { ZENDEFNODE(ShaderInputAttr, { { {"enum" + dataTypeListString(), "attr", dataTypeDefaultString()}, - {"enum float vec2 vec3 vec4 bool", "type", "vec3"}, + {"enum float vec2 vec3 vec4 bool int uint", "type", "vec3"}, }, { {"shader", "out"},