Skip to content

Commit

Permalink
multimesh visitor can work with basic runner
Browse files Browse the repository at this point in the history
  • Loading branch information
mtao committed Oct 13, 2023
1 parent 1f081d3 commit 32fbca0
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 126 deletions.
4 changes: 2 additions & 2 deletions src/wmtk/multimesh/MultiMeshVisitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <wmtk/Mesh.hpp>
#include <wmtk/Primitive.hpp>
#include <wmtk/simplex/Simplex.hpp>
#include <wmtk/utils/as_mesh_variant.hpp>
#include <wmtk/utils/metaprogramming/as_mesh_variant.hpp>
#include <wmtk/utils/mesh_type_from_primitive_type.hpp>


Expand Down Expand Up @@ -44,7 +44,7 @@ class MultiMeshVisitor
// if the user passed in a mesh class lets try re-invoking with a derived type
MultiMeshVisitorExecutor exec(*this, mesh, simplex);
Mesh& root = mesh.get_multi_mesh_root();
auto mesh_root_variant = wmtk::utils::as_mesh_variant(root);
auto mesh_root_variant = wmtk::utils::metaprogramming::as_mesh_variant(root);
const simplex::Simplex root_simplex = mesh.map_to_root(simplex);
std::visit([&](auto& root) { exec.execute(root, root_simplex); }, mesh_root_variant);

Expand Down
49 changes: 49 additions & 0 deletions src/wmtk/multimesh/utils/BasicMeshVariantRunner.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include <wmtk/utils/metaprogramming/as_mesh_variant.hpp>
#include "CachedMeshVariantReturnValues.hpp"
namespace wmtk::multimesh::utils {
template <typename Functor, typename... OtherTypes>
class BasicMeshVariantRunner
{
public:
BasicMeshVariantRunner(Functor&& f)
: func(f)
{}
BasicMeshVariantRunner(Functor&& f, std::tuple<OtherTypes...>)
: func(f)
{}

void run(Mesh& mesh, const OtherTypes&... ts)
{
auto var = wmtk::utils::metaprogramming::as_mesh_variant(mesh);
std::visit(
[&](auto& t) {
auto& v = t.get();
return_data.add(v, func(v, ts...), ts...);
},
var);
}
void run(const Mesh& mesh, const OtherTypes&... ts)
{
auto var = wmtk::utils::metaprogramming::as_const_mesh_variant(mesh);
std::visit(
[&](auto& t) {
const auto& v = t.get();
return_data.add(v, func(v, ts...), ts...);
},
var);
}

CachedMeshVariantReturnValues<Functor, OtherTypes...> return_data;

private:
const Functor& func;
};

template <typename Functor, typename... Ts>
BasicMeshVariantRunner(Functor&& f, std::tuple<Ts...>)
-> BasicMeshVariantRunner<Functor, std::decay_t<Ts>...>;
template <typename Functor>
BasicMeshVariantRunner(Functor&& f) -> BasicMeshVariantRunner<Functor>;
} // namespace wmtk::multimesh::utils
124 changes: 124 additions & 0 deletions src/wmtk/multimesh/utils/CachedMeshVariantReturnValues.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#pragma once
#include <map>

#include <wmtk/utils/metaprogramming/MeshVariantTraits.hpp>
#include <wmtk/utils/metaprogramming/unwrap_ref.hpp>


namespace wmtk::multimesh::utils {


// A helper class for specifying per-type return types from an input functor
// Assumes the argument is the variant type being selected form, all other
// arguments are passed in as const references
template <typename Functor, typename... Ts>
struct ReturnVariantHelper
{
};
template <typename Functor, typename... VTs, typename... Ts>
struct ReturnVariantHelper<Functor, std::variant<VTs...>, Ts...>
{
// For a specific type in the variant, get the return type
template <typename T>
using ReturnType = std::decay_t<std::invoke_result_t<
Functor,
wmtk::utils::metaprogramming::unwrap_ref_decay_t<T>&,
const Ts&...>>;

template <typename T>
using ReturnTypeConst = std::decay_t<std::invoke_result_t<
Functor,
const wmtk::utils::metaprogramming::unwrap_ref_decay_t<T>&,
const Ts&...>>;

// check what happens if we return a const ref or non-const ref
template <bool IsConst, typename T>
using ReturnType_const = std::conditional_t<IsConst, ReturnTypeConst<T>, ReturnType<T>>;

// Get an overall variant for the types
using type = std::variant<ReturnType<VTs>...>;
using const_type = std::variant<ReturnTypeConst<VTs>...>;
template <bool IsConst>
using type_const = std::variant<ReturnType_const<IsConst, VTs>...>;
};

// Interface for reading off the return values from data
template <typename Functor, typename... OtherArgumentTypes>
class CachedMeshVariantReturnValues
{
public:
using MeshVariantTraits = wmtk::utils::metaprogramming::MeshVariantTraits;
using MeshVariantType = MeshVariantTraits::ReferenceVariant;
using ConstMeshVariantType = MeshVariantTraits::ConstReferenceVariant;

using TypeHelper = ReturnVariantHelper<Functor, MeshVariantType, OtherArgumentTypes...>;
using ReturnVariant = typename TypeHelper::type;

// a pointer to an input and some other arguments
using KeyType = std::tuple<const Mesh*, OtherArgumentTypes...>;

auto get_id(const Mesh& input, const OtherArgumentTypes&... ts) const
{
// other applications might use a fancier version of get_id
return KeyType(&input, ts...);
}

// Add new data by giving the MeshType
// MeshType is used to make sure the pair of Mesh/Output is valid and to
// extract an id
template <typename MeshType, typename ReturnType>
void add(const MeshType& input, ReturnType&& return_data, const OtherArgumentTypes&... args)
{
using ReturnType_t = std::decay_t<ReturnType>;
static_assert(
!std::is_same_v<std::decay_t<MeshType>, Mesh>,
"Don't pass in a input, use variant/visitor to get its "
"derived type");
// if the user passed in a input class lets try re-invoking with a
// derived type
auto id = get_id(input, args...);
using ExpectedReturnType = typename TypeHelper::template ReturnType<MeshType>;

static_assert(
std::is_convertible_v<ReturnType_t, ExpectedReturnType>,
"Second argument should be the return value of a Functor "
"(or convertible at "
"least) ");

m_data.emplace(
id,
ReturnVariant(
std::in_place_type_t<ExpectedReturnType>{},
std::forward<ReturnType>(return_data)));
}

// let user get the variant for a specific Mesh derivate
const auto& get_variant(const Mesh& input, const OtherArgumentTypes&... ts) const
{
auto id = get_id(input, ts...);
return m_data.at(id);
}

// get the type specific input
template <typename MeshType>
auto get(const MeshType& input, const OtherArgumentTypes&... ts) const
{
static_assert(
!std::is_same_v<std::decay_t<MeshType>, Mesh>,
"Don't pass in a input, use variant/visitor to get its "
"derived type");
using ExpectedReturnType = typename TypeHelper::template ReturnType<MeshType>;

return std::get<ExpectedReturnType>(get_variant(input, ts...));
}

private:
std::map<KeyType, ReturnVariant> m_data;
};


template <typename Functor>
CachedMeshVariantReturnValues(Functor&& f) -> CachedMeshVariantReturnValues<std::decay_t<Functor>>;


} // namespace wmtk::multimesh::utils
53 changes: 48 additions & 5 deletions tests/test_mesh_variant.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include <catch2/catch_test_macros.hpp>
#include <wmtk/EdgeMesh.hpp>
#include <wmtk/PointMesh.hpp>

#include <wmtk/Types.hpp>
#include <wmtk/multimesh/utils/BasicMeshVariantRunner.hpp>
#include <wmtk/utils/metaprogramming/as_mesh_variant.hpp>
#include "tools/TetMesh_examples.hpp"
#include "tools/TriMesh_examples.hpp"
Expand All @@ -20,8 +23,8 @@ struct DimFunctor
}
int operator()(const EdgeMesh&) const
{
spdlog::info("TriMesh");
return 2;
spdlog::info("EdgeMesh");
return 1;
}
int operator()(const TriMesh&) const
{
Expand All @@ -34,10 +37,40 @@ struct DimFunctor
return 3;
}
template <typename T>
int operator()(std::reference_wrapper<T> ref) {
return (*this)(ref.get());
}
auto operator()(std::reference_wrapper<T> ref)
{
return (*this)(ref.get());
}
template <typename T>
auto operator()(std::reference_wrapper<const T> ref)
{
return (*this)(ref.get());
}
};

struct DimFunctorDiffType
{
// the dimension of the mesh we expect to see
char operator()(const PointMesh&) const
{
spdlog::info("Mesh!");
return 0;
}
int operator()(const EdgeMesh&) const
{
spdlog::info("EdgeMesh");
return 1;
}
long operator()(const TriMesh&) const
{
spdlog::info("TriMesh");
return 2;
}
size_t operator()(const TetMesh&) const
{
spdlog::info("TetMesh");
return 3;
}
};

} // namespace
Expand All @@ -61,4 +94,14 @@ TEST_CASE("test_multi_mesh_print_visitor", "[multimesh][2D]")
CHECK(std::visit(DimFunctor{}, trivar) == 2);
CHECK(std::visit(DimFunctor{}, trimvar) == 2);
CHECK(std::visit(DimFunctor{}, tetvar) == 3);


spdlog::info("Running!");
wmtk::multimesh::utils::BasicMeshVariantRunner runner(DimFunctorDiffType{});
runner.run(mesh);
runner.run(tetmesh);

CHECK(runner.return_data.get(trimesh) == 2);
CHECK(runner.return_data.get(tetmesh) == 3);
}

Loading

0 comments on commit 32fbca0

Please sign in to comment.