From 35ab4797debf61d2c78766210300446fb7e567c9 Mon Sep 17 00:00:00 2001 From: iaomw Date: Thu, 5 Dec 2024 20:35:07 +0800 Subject: [PATCH] Support more types in shader --- zeno/include/zeno/extra/ShaderNode.h | 25 +++++++++++++++++++++++++ zeno/src/extra/ShaderNode.cpp | 12 ++++++++++-- zeno/src/nodes/mtl/ShaderAttrs.cpp | 6 ++---- 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/zeno/include/zeno/extra/ShaderNode.h b/zeno/include/zeno/extra/ShaderNode.h index 90cee75307..5710b1aa2c 100644 --- a/zeno/include/zeno/extra/ShaderNode.h +++ b/zeno/include/zeno/extra/ShaderNode.h @@ -22,6 +22,31 @@ struct ShaderNode : INode { ZENO_API ~ShaderNode() override; }; +static const inline std::map TypeHint { + + {"bool", 0}, + {"int", 10}, + {"uint", 11}, + + {"float", 1}, + {"vec2", 2}, + {"vec3", 3}, + {"vec4", 4} +}; + +static const inline std::map TypeHintReverse { + + {0, "bool"}, + {10, "int"}, + {11, "uint"}, + + {1, "float"}, + {2, "vec2"}, + {3, "vec3"}, + {4, "vec4"} +}; + + template struct ShaderNodeClone : ShaderNode { virtual std::shared_ptr clone() const override { diff --git a/zeno/src/extra/ShaderNode.cpp b/zeno/src/extra/ShaderNode.cpp index fd9ea7f14b..a24e724bbf 100644 --- a/zeno/src/extra/ShaderNode.cpp +++ b/zeno/src/extra/ShaderNode.cpp @@ -38,6 +38,15 @@ ZENO_API int EmissionPass::determineType(IObject *object) { int type = std::visit([&] (auto const &value) -> int { using T = std::decay_t; + + if constexpr (std::is_same_v) { + return 0; + } else if constexpr (std::is_same_v) { + return 10; + } else if constexpr (std::is_same_v) { + return 11; + } + if constexpr (std::is_same_v) { return 1; } else if constexpr (std::is_same_v) { @@ -126,8 +135,7 @@ ZENO_API std::string EmissionPass::getCommonCode() const { } ZENO_API std::string EmissionPass::typeNameOf(int type) const { - if (type == 1) return "float"; - else return (backend == HLSL ? "float" : "vec") + std::to_string(type); + return TypeHintReverse.at(type); } ZENO_API std::string EmissionPass::collectDefs() const { diff --git a/zeno/src/nodes/mtl/ShaderAttrs.cpp b/zeno/src/nodes/mtl/ShaderAttrs.cpp index 7af675c792..58342a9720 100644 --- a/zeno/src/nodes/mtl/ShaderAttrs.cpp +++ b/zeno/src/nodes/mtl/ShaderAttrs.cpp @@ -55,9 +55,7 @@ static std::string dataTypeListString() { struct ShaderInputAttr : ShaderNodeClone { virtual int determineType(EmissionPass *em) override { auto type = get_input2("type"); - const char *tab[] = {"float", "vec2", "vec3", "vec4"}; - auto idx = std::find(std::begin(tab), std::end(tab), type) - std::begin(tab); - return idx + 1; + return TypeHint.at(type); } virtual void emitCode(EmissionPass *em) override { @@ -75,7 +73,7 @@ struct ShaderInputAttr : ShaderNodeClone { ZENDEFNODE(ShaderInputAttr, { { {"enum" + dataTypeListString(), "attr", dataTypeDefaultString()}, - {"enum float vec2 vec3 vec4 bool", "type", "vec3"}, + {"enum float vec2 vec3 vec4 bool int uint", "type", "vec3"}, }, { {"shader", "out"},