Skip to content

Commit

Permalink
Support more types in shader (#2065)
Browse files Browse the repository at this point in the history
  • Loading branch information
iaomw authored Dec 6, 2024
1 parent 771074c commit 7e06fe5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
24 changes: 24 additions & 0 deletions zeno/include/zeno/extra/ShaderNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,30 @@ struct ShaderNode : INode {
ZENO_API ShaderNode();
ZENO_API ~ShaderNode() override;
};
using ShaderDataTypeList = std::tuple<bool, int, unsigned int, float, vec2f, vec3f, vec4f>;

inline const auto ShaderDataTypeNames = std::array { "bool", "int", "uint", "float", "vec2", "vec3", "vec4" };

static const inline std::map<std::string, int> TypeHint {

{"bool", 0},
{"int", 10},
{"uint", 11},

{"float", 1},
{"vec2", 2},
{"vec3", 3},
{"vec4", 4}
};

static const inline std::map<int, std::string> TypeHintReverse = []() {

std::map<int, std::string> result {};
for (auto& [k, v] : TypeHint) {
result[v] = k;
}
return result;
} ();

template <class Derived>
struct ShaderNodeClone : ShaderNode {
Expand Down
29 changes: 16 additions & 13 deletions zeno/src/extra/ShaderNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <zeno/extra/ShaderNode.h>
#include <zeno/types/ShaderObject.h>
#include <zeno/types/NumericObject.h>
#include <zeno/utils/type_traits.h>
#include <sstream>
#include <cassert>

Expand Down Expand Up @@ -38,18 +39,21 @@ ZENO_API int EmissionPass::determineType(IObject *object) {

int type = std::visit([&] (auto const &value) -> int {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<float, T>) {
return 1;
} else if constexpr (std::is_same_v<vec2f, T>) {
return 2;
} else if constexpr (std::is_same_v<vec3f, T>) {
return 3;
} else if constexpr (std::is_same_v<vec4f, T>) {
return 4;
} else {
throw zeno::Exception("bad numeric object type: " + (std::string)typeid(T).name());
}
size_t typeIdx = 0;

zeno::static_for<0, std::tuple_size_v<ShaderDataTypeList>>([&] (auto i) {
using ThisType = std::tuple_element_t<i, ShaderDataTypeList>;

if (std::is_same_v<ThisType, T>) {
typeIdx = i;
return true;
}
return false;
});

return TypeHint.at(ShaderDataTypeNames.at(typeIdx));
}, num->value);

constmap[num] = constants.size();
constants.push_back(ConstInfo{type, num->value});
return type;
Expand Down Expand Up @@ -126,8 +130,7 @@ ZENO_API std::string EmissionPass::getCommonCode() const {
}

ZENO_API std::string EmissionPass::typeNameOf(int type) const {
if (type == 1) return "float";
else return (backend == HLSL ? "float" : "vec") + std::to_string(type);
return TypeHintReverse.at(type);
}

ZENO_API std::string EmissionPass::collectDefs() const {
Expand Down
6 changes: 2 additions & 4 deletions zeno/src/nodes/mtl/ShaderAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ static std::string dataTypeListString() {
struct ShaderInputAttr : ShaderNodeClone<ShaderInputAttr> {
virtual int determineType(EmissionPass *em) override {
auto type = get_input2<std::string>("type");
const char *tab[] = {"float", "vec2", "vec3", "vec4"};
auto idx = std::find(std::begin(tab), std::end(tab), type) - std::begin(tab);
return idx + 1;
return TypeHint.at(type);
}

virtual void emitCode(EmissionPass *em) override {
Expand All @@ -75,7 +73,7 @@ struct ShaderInputAttr : ShaderNodeClone<ShaderInputAttr> {
ZENDEFNODE(ShaderInputAttr, {
{
{"enum" + dataTypeListString(), "attr", dataTypeDefaultString()},
{"enum float vec2 vec3 vec4 bool", "type", "vec3"},
{"enum float vec2 vec3 vec4 bool int uint", "type", "vec3"},
},
{
{"shader", "out"},
Expand Down

0 comments on commit 7e06fe5

Please sign in to comment.