From f3d86fa7fdbcf4510ed762ba75a866634ecba3d3 Mon Sep 17 00:00:00 2001 From: littlemine Date: Fri, 27 Sep 2024 21:00:01 +0800 Subject: [PATCH] smallvec to numeric (zpcjit) --- projects/CUDA/zpc | 2 +- projects/PyZpc/interop/vec_nodes.cpp | 72 ++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/projects/CUDA/zpc b/projects/CUDA/zpc index de7227086f..a1b42adcae 160000 --- a/projects/CUDA/zpc +++ b/projects/CUDA/zpc @@ -1 +1 @@ -Subproject commit de7227086f301191b96143533a80df1690f1a470 +Subproject commit a1b42adcaeedd7085ceebf5d0c6ad6315305f397 diff --git a/projects/PyZpc/interop/vec_nodes.cpp b/projects/PyZpc/interop/vec_nodes.cpp index d8dab8cb12..da931fb857 100644 --- a/projects/PyZpc/interop/vec_nodes.cpp +++ b/projects/PyZpc/interop/vec_nodes.cpp @@ -59,6 +59,78 @@ ZENDEFNODE(NumericToSmallVec, { {"PyZFX"}, }); +struct SmallVecToNumeric : INode { + virtual void apply() override { + const auto &smallVec = get_input("ZSSmallVec")->value; + auto ret = std::make_shared(); + std::visit( + [&ret](auto const &vec) { + using vec_t = RM_CVREF_T(vec); + using VT = zs::conditional_t, float, vec_t>; + if constexpr (zs::is_scalar_v) { + ret->set((VT)vec); + } else { + constexpr auto dim = vec_t::dim; + using VT = zs::conditional_t, + float, typename vec_t::value_type>; + if constexpr (dim == 1) { + constexpr auto dimI = vec_t::template range_t<0>::value; + if constexpr (dimI == 1) { + ret->set((VT)vec(0)); + } else { + zeno::vec tmp; + for (int d = 0; d < dimI; ++d) + tmp[d] = vec(d); + ret->set(tmp); + } + } else if constexpr (dim == 2) { + constexpr auto dimI = vec_t::template range_t<0>::value; + constexpr auto dimJ = vec_t::template range_t<1>::value; + if constexpr (dimI == 1) { + if constexpr (dimJ <= 4) { + if constexpr (dimJ == 1) { + ret->set((VT)vec(0, 0)); + } else { + zeno::vec tmp; + for (int d = 0; d < dimJ; ++d) + tmp[d] = vec(0, d); + ret->set(tmp); + } + } else { + static_assert(zs::always_false, "..."); + } + } else if constexpr (dimJ == 1) { + if constexpr (dimI <= 4) { + if constexpr (dimI == 1) { + ret->set((VT)vec(0, 0)); + } else { + zeno::vec tmp; + for (int d = 0; d < dimI; ++d) + tmp[d] = vec(d, 0); + ret->set(tmp); + } + } else { + static_assert(zs::always_false, "..."); + } + } else { + throw std::runtime_error(fmt::format( + "cannot convert a small vec of shape ({}, {}) to zeno NumericValue", dimI, dimJ)); + } + } + } + }, + smallVec); + set_output("numeric", std::move(ret)); + } +}; + +ZENDEFNODE(SmallVecToNumeric, { + {"ZSSmallVec"}, + {"numeric"}, + {}, + {"PyZFX"}, + }); + struct PrintSmallVec : INode { void apply() override { const auto &smallVec = get_input("ZSSmallVec")->value;