Skip to content

Commit

Permalink
enhance primitiveunfuse
Browse files Browse the repository at this point in the history
  • Loading branch information
littlemine committed Dec 11, 2023
1 parent 95478c8 commit 49e1ea5
Showing 1 changed file with 186 additions and 129 deletions.
315 changes: 186 additions & 129 deletions projects/CUDA/utils/Primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1041,162 +1041,219 @@ ZENDEFNODE(PrimitiveFuse, {
{"zs_geom"},
});

/// @note duplicate vertices shared by multiple groups
struct PrimitiveUnfuse : INode {
void apply() override {
using namespace zs;
auto prim = get_input<PrimitiveObject>("prim");
auto tag = get_input2<std::string>("partition_tag");
static std::shared_ptr<PrimitiveObject> unfuse_primitive(std::shared_ptr<PrimitiveObject> prim, std::string tag) {
using namespace zs;
constexpr auto space = execspace_e::openmp;
auto pol = omp_exec();

constexpr auto space = execspace_e::openmp;
auto pol = omp_exec();
auto &verts = prim->verts;
const auto &pos = verts.values;

auto &verts = prim->verts;
const auto &pos = verts.values;
auto &tris = prim->tris;
const auto &triIds = tris.values;
const bool hasTris = tris.size() > 0;

auto &tris = prim->tris;
const auto &triIds = tris.values;
const bool hasTris = tris.size() > 0;
auto &polys = prim->polys;
const auto &loops = prim->loops;
const bool hasLoops = polys.size() > 1;

auto &polys = prim->polys;
const auto &loops = prim->loops;
const bool hasLoops = polys.size() > 1;

if ((hasTris ^ hasLoops) == 0)
throw std::runtime_error("The input mesh must either own active triangle topology or loop topology.");
if ((hasTris ^ hasLoops) == 0)
throw std::runtime_error("The input mesh must either own active triangle topology or loop topology.");

std::vector<std::set<int>> groupsPerVertex(pos.size());
if (hasTris) {
const auto &triGroups = tris.attr<int>(tag);
std::vector<Mutex> mtxs(pos.size());
pol(zip(triIds, triGroups), [&mtxs, &groupsPerVertex](auto tri, int groupNo) {
for (int d = 0; d != 3; ++d) {
int vi = tri[d];
auto &mtx = mtxs[vi];
auto &group = groupsPerVertex[vi];
{
mtxs[vi].lock();
group.insert(groupNo);
mtxs[vi].unlock();
}
};
});
} else {
const auto &polyGroups = polys.attr<int>(tag);
std::vector<Mutex> mtxs(pos.size());
pol(zip(polys.values, polyGroups), [&mtxs, &groupsPerVertex, &loops](zeno::vec2i poly, int groupNo) {
auto st = poly[0];
auto ed = st + poly[1];
for (; st != ed; ++st) {
int vi = loops.values[st];

auto &mtx = mtxs[vi];
auto &group = groupsPerVertex[vi];
{
mtxs[vi].lock();
group.insert(groupNo);
mtxs[vi].unlock();
}
std::vector<std::set<int>> groupsPerVertex(pos.size());
if (hasTris) {
const auto &triGroups = tris.attr<int>(tag);
std::vector<Mutex> mtxs(pos.size());
pol(zip(triIds, triGroups), [&mtxs, &groupsPerVertex](auto tri, int groupNo) {
for (int d = 0; d != 3; ++d) {
int vi = tri[d];
auto &mtx = mtxs[vi];
auto &group = groupsPerVertex[vi];
{
mtxs[vi].lock();
group.insert(groupNo);
mtxs[vi].unlock();
}
});
};
});
} else {
const auto &polyGroups = polys.attr<int>(tag);
std::vector<Mutex> mtxs(pos.size());
pol(zip(polys.values, polyGroups), [&mtxs, &groupsPerVertex, &loops](zeno::vec2i poly, int groupNo) {
auto st = poly[0];
auto ed = st + poly[1];
for (; st != ed; ++st) {
int vi = loops.values[st];

auto &mtx = mtxs[vi];
auto &group = groupsPerVertex[vi];
{
mtxs[vi].lock();
group.insert(groupNo);
mtxs[vi].unlock();
}
}
});
}

std::vector<int> numGroupsPerVertex(pos.size() + 1), ptrs(pos.size() + 1);
pol(zip(numGroupsPerVertex, groupsPerVertex), [](int &num, const std::set<int> &g) { num = g.size(); });
exclusive_scan(pol, std::begin(numGroupsPerVertex), std::end(numGroupsPerVertex), std::begin(ptrs));

auto numEntries = ptrs.back();
std::vector<int> inds(numEntries);
pol(enumerate(groupsPerVertex), [&inds, &ptrs](int vi, const std::set<int> &groups) {
auto st = ptrs[vi], ed = ptrs[vi + 1];
for (auto groupNo : groups) {
inds[st++] = groupNo; // the first group does not need to change
}
});

std::vector<int> numGroupsPerVertex(pos.size() + 1), ptrs(pos.size() + 1);
pol(zip(numGroupsPerVertex, groupsPerVertex), [](int &num, const std::set<int> &g) { num = g.size(); });
exclusive_scan(pol, std::begin(numGroupsPerVertex), std::end(numGroupsPerVertex), std::begin(ptrs));
auto resPrim = std::make_shared<PrimitiveObject>();

auto numEntries = ptrs.back();
std::vector<int> inds(numEntries);
pol(enumerate(groupsPerVertex), [&inds, &ptrs](int vi, const std::set<int> &groups) {
auto st = ptrs[vi], ed = ptrs[vi + 1];
for (auto groupNo : groups) {
inds[st++] = groupNo; // the first group does not need to change
}
resPrim->verts.resize(numEntries);
auto &resVerts = resPrim->verts;
auto &resPos = resVerts.values;
verts.foreach_attr<AttrAcceptAll>([&](const auto &key, const auto &arr) {
using T = std::decay_t<decltype(arr[0])>;
resVerts.add_attr<T>(key);
});
pol(range(pos.size()), [&ptrs, &verts, &resVerts, &resPos](int vi) {
auto st = ptrs[vi], ed = ptrs[vi + 1];
for (int j = st; j != ed; ++j)
resPos[j] = verts.values[vi];
resVerts.foreach_attr<AttrAcceptAll>([&](const auto &key, auto &dst) {
using T = std::decay_t<decltype(dst[0])>;
const auto &src = verts.attr<T>(key);
for (int j = st; j != ed; ++j)
dst[j] = src[vi];
});
});

auto resPrim = std::make_shared<PrimitiveObject>();
if (hasTris) {
const auto &triGroups = tris.attr<int>(tag);

resPrim->verts.resize(numEntries);
auto &resVerts = resPrim->verts;
auto &resPos = resVerts.values;
verts.foreach_attr<AttrAcceptAll>([&](const auto &key, const auto &arr) {
resPrim->tris.resize(tris.size());
auto &resTris = resPrim->tris;
auto &resTriIds = resTris.values;
tris.foreach_attr<AttrAcceptAll>([&](const auto &key, const auto &arr) {
using T = std::decay_t<decltype(arr[0])>;
resVerts.add_attr<T>(key);
resTris.add_attr<T>(key) = arr;
});
pol(range(pos.size()), [&ptrs, &verts, &resVerts, &resPos](int vi) {
auto st = ptrs[vi], ed = ptrs[vi + 1];
for (int j = st; j != ed; ++j)
resPos[j] = verts.values[vi];
resVerts.foreach_attr<AttrAcceptAll>([&](const auto &key, auto &dst) {
using T = std::decay_t<decltype(dst[0])>;
const auto &src = verts.attr<T>(key);
for (int j = st; j != ed; ++j)
dst[j] = src[vi];
});
pol(range(tris.size()), [&](int f) {
auto groupNo = triGroups[f];
for (int d = 0; d != 3; ++d) {
int vi = tris[f][d];
int st = ptrs[vi], ed = ptrs[vi + 1];
for (; st != ed; ++st) {
if (groupNo == inds[st])
break;
}
resTriIds[f][d] = st;
}
});
} else {
const auto &polyGroups = polys.attr<int>(tag);
bool uvExist = prim->uvs.size() > 0 && loops.has_attr("uvs");

if (hasTris) {
const auto &triGroups = tris.attr<int>(tag);
resPrim->polys.resize(polys.size());
resPrim->loops.resize(loops.size());
auto &resPolys = resPrim->polys;
auto &resLoops = resPrim->loops;

resPrim->tris.resize(tris.size());
auto &resTris = resPrim->tris;
auto &resTriIds = resTris.values;
tris.foreach_attr<AttrAcceptAll>([&](const auto &key, const auto &arr) {
using T = std::decay_t<decltype(arr[0])>;
resTris.add_attr<T>(key) = arr;
});
pol(range(tris.size()), [&](int f) {
auto groupNo = triGroups[f];
for (int d = 0; d != 3; ++d) {
int vi = tris[f][d];
int st = ptrs[vi], ed = ptrs[vi + 1];
for (; st != ed; ++st) {
if (groupNo == inds[st])
break;
}
resTriIds[f][d] = st;
resPolys.values = polys.values;
polys.foreach_attr<AttrAcceptAll>([&](const auto &key, const auto &arr) {
using T = std::decay_t<decltype(arr[0])>;
resPolys.add_attr<T>(key) = arr;
});

loops.foreach_attr<AttrAcceptAll>([&](const auto &key, const auto &arr) {
using T = std::decay_t<decltype(arr[0])>;
resLoops.add_attr<T>(key) = arr;
});
// resLoops.values = loops.values;
pol(zip(polys.values, polyGroups), [&](zeno::vec2i poly, int groupNo) {
auto st = poly[0];
auto ed = st + poly[1];
for (; st != ed; ++st) {
int vi = loops.values[st];

int l = ptrs[vi], r = ptrs[vi + 1];
for (; l != r; ++l) {
if (groupNo == inds[l])
break;
}
});
} else {
const auto &polyGroups = polys.attr<int>(tag);
bool uvExist = prim->uvs.size() > 0 && loops.has_attr("uvs");
resLoops.values[st] = l;
}
});
}
return resPrim;
}

resPrim->polys.resize(polys.size());
resPrim->loops.resize(loops.size());
auto &resPolys = resPrim->polys;
auto &resLoops = resPrim->loops;
/// @note duplicate vertices shared by multiple groups
struct PrimitiveUnfuse : INode {
void apply() override {
auto prim = get_input<PrimitiveObject>("prim");
auto tag = get_input2<std::string>("partition_tag");

resPolys.values = polys.values;
polys.foreach_attr<AttrAcceptAll>([&](const auto &key, const auto &arr) {
using T = std::decay_t<decltype(arr[0])>;
resPolys.add_attr<T>(key) = arr;
});
auto resPrim = unfuse_primitive(prim, tag);

loops.foreach_attr<AttrAcceptAll>([&](const auto &key, const auto &arr) {
using T = std::decay_t<decltype(arr[0])>;
resLoops.add_attr<T>(key) = arr;
});
// resLoops.values = loops.values;
pol(zip(polys.values, polyGroups), [&](zeno::vec2i poly, int groupNo) {
auto st = poly[0];
auto ed = st + poly[1];
for (; st != ed; ++st) {
int vi = loops.values[st];
bool toList = get_input2<bool>("to_list");

int l = ptrs[vi], r = ptrs[vi + 1];
for (; l != r; ++l) {
if (groupNo == inds[l])
break;
if (toList) {
constexpr auto space = zs::execspace_e::openmp;
auto pol = zs::omp_exec();

auto &verts = resPrim->verts;
const auto &pos = verts.values;

auto &tris = resPrim->tris;
const bool hasTris = tris.size() > 0;

auto &polys = resPrim->polys;
const bool hasLoops = polys.size() > 1;

auto &vertGroups = resPrim->verts.add_attr<int>(tag);

if (hasTris) {
const auto &triIds = tris.values;
const auto &triGroups = tris.attr<int>(tag);

pol(zs::zip(triIds, triGroups), [&vertGroups](auto tri, int groupNo) {
for (int d = 0; d != 3; ++d) {
int vi = tri[d];
vertGroups[vi] = groupNo;
}
resLoops.values[st] = l;
}
});
});
} else {
const auto &loops = resPrim->loops.values;
const auto &polyGroups = polys.attr<int>(tag);

pol(zs::zip(polys.values, polyGroups), [&vertGroups, &loops](zeno::vec2i poly, int groupNo) {
auto st = poly[0];
auto ed = st + poly[1];
for (; st != ed; ++st) {
int vi = loops[st];
vertGroups[vi] = groupNo;
}
});
}

auto primList = primUnmergeVerts(resPrim.get(), tag);
auto listPrim = std::make_shared<ListObject>();
for (auto &primPtr : primList) {
listPrim->arr.push_back(std::move(primPtr));
}
set_output("partitioned_prim", std::move(listPrim));
} else {
set_output("partitioned_prim", std::move(resPrim));
}
set_output("partitioned_prim", std::move(resPrim));
}
};
ZENDEFNODE(PrimitiveUnfuse, {
{{"PrimitiveObject", "prim"}, {"string", "partition_tag", "triangle_index"}},
{{"PrimitiveObject", "prim"},
{"string", "partition_tag", "triangle_index"},
{"bool", "to_list", "false"}},
{
{"PrimitiveObject", "partitioned_prim"},
},
Expand Down

0 comments on commit 49e1ea5

Please sign in to comment.