Skip to content

Commit

Permalink
Support more types in shader
Browse files Browse the repository at this point in the history
  • Loading branch information
iaomw committed Dec 5, 2024
1 parent 3316a44 commit 35ab479
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
25 changes: 25 additions & 0 deletions zeno/include/zeno/extra/ShaderNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,31 @@ struct ShaderNode : INode {
ZENO_API ~ShaderNode() override;
};

static const inline std::map<std::string, int> TypeHint {

{"bool", 0},
{"int", 10},
{"uint", 11},

{"float", 1},
{"vec2", 2},
{"vec3", 3},
{"vec4", 4}
};

static const inline std::map<int, std::string> TypeHintReverse {

{0, "bool"},
{10, "int"},
{11, "uint"},

{1, "float"},
{2, "vec2"},
{3, "vec3"},
{4, "vec4"}
};


template <class Derived>
struct ShaderNodeClone : ShaderNode {
virtual std::shared_ptr<ShaderNode> clone() const override {
Expand Down
12 changes: 10 additions & 2 deletions zeno/src/extra/ShaderNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ ZENO_API int EmissionPass::determineType(IObject *object) {

int type = std::visit([&] (auto const &value) -> int {
using T = std::decay_t<decltype(value)>;

if constexpr (std::is_same_v<bool, T>) {
return 0;
} else if constexpr (std::is_same_v<int, T>) {
return 10;
} else if constexpr (std::is_same_v<unsigned int, T>) {
return 11;
}

if constexpr (std::is_same_v<float, T>) {
return 1;
} else if constexpr (std::is_same_v<vec2f, T>) {
Expand Down Expand Up @@ -126,8 +135,7 @@ ZENO_API std::string EmissionPass::getCommonCode() const {
}

ZENO_API std::string EmissionPass::typeNameOf(int type) const {
if (type == 1) return "float";
else return (backend == HLSL ? "float" : "vec") + std::to_string(type);
return TypeHintReverse.at(type);
}

ZENO_API std::string EmissionPass::collectDefs() const {
Expand Down
6 changes: 2 additions & 4 deletions zeno/src/nodes/mtl/ShaderAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ static std::string dataTypeListString() {
struct ShaderInputAttr : ShaderNodeClone<ShaderInputAttr> {
virtual int determineType(EmissionPass *em) override {
auto type = get_input2<std::string>("type");
const char *tab[] = {"float", "vec2", "vec3", "vec4"};
auto idx = std::find(std::begin(tab), std::end(tab), type) - std::begin(tab);
return idx + 1;
return TypeHint.at(type);
}

virtual void emitCode(EmissionPass *em) override {
Expand All @@ -75,7 +73,7 @@ struct ShaderInputAttr : ShaderNodeClone<ShaderInputAttr> {
ZENDEFNODE(ShaderInputAttr, {
{
{"enum" + dataTypeListString(), "attr", dataTypeDefaultString()},
{"enum float vec2 vec3 vec4 bool", "type", "vec3"},
{"enum float vec2 vec3 vec4 bool int uint", "type", "vec3"},
},
{
{"shader", "out"},
Expand Down

0 comments on commit 35ab479

Please sign in to comment.