Skip to content

Commit

Permalink
Major speed improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
garth-wells committed Nov 15, 2024
1 parent 1868dd3 commit eec2489
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 69 deletions.
152 changes: 86 additions & 66 deletions cpp/basix/dof-transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <span>
#include <tuple>

#include <iostream>

using namespace basix;

namespace stdex
Expand All @@ -39,8 +41,7 @@ int find_first_subentity(cell::type cell_type, cell::type entity_type)
{
const int edim = cell::topological_dimension(entity_type);
std::vector<cell::type> entities = cell::subentity_types(cell_type)[edim];
if (auto it = std::ranges::find(entities, entity_type);
it != entities.end())
if (auto it = std::ranges::find(entities, entity_type); it != entities.end())
{
return std::distance(entities.begin(), it);
}
Expand Down Expand Up @@ -93,9 +94,7 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
{
mapinfo_t<T> mapinfo;
auto& data = mapinfo.try_emplace(cell::type::interval).first->second;
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[1], pt[0], 0.0};
};
auto map = [](auto pt) -> std::array<T, 3> { return {pt[1], pt[0], 0.0}; };
stdex::mdarray<T, stdex::extents<std::size_t, 2, 2>> J(
stdex::extents<std::size_t, 2, 2>{}, {0., 1., 1., 0.});

Expand All @@ -109,9 +108,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
{
mapinfo_t<T> mapinfo;
auto& data = mapinfo.try_emplace(cell::type::interval).first->second;
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[0], pt[1], 0};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {1 - pt[0], pt[1], 0}; };
stdex::mdarray<T, stdex::extents<std::size_t, 2, 2>> J(
stdex::extents<std::size_t, 2, 2>{}, {-1., 0., 0., 1.});

Expand All @@ -126,9 +124,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
mapinfo_t<T> mapinfo;
{
auto& data = mapinfo.try_emplace(cell::type::interval).first->second;
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[0], pt[2], pt[1]};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {pt[0], pt[2], pt[1]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{1., 0., 0., 0., 0., 1., 0., 1., 0.});
Expand All @@ -143,9 +140,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
{
auto& data = mapinfo.try_emplace(cell::type::triangle).first->second;
{
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[2], pt[0], pt[1]};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {pt[2], pt[0], pt[1]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 1., 0., 0., 0., 1., 1., 0., 0.});
Expand All @@ -157,9 +153,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
data.push_back(std::tuple(map, J, detJ, K));
}
{
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[0], pt[2], pt[1]};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {pt[0], pt[2], pt[1]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{1., 0., 0., 0., 0., 1., 0., 1., 0.});
Expand All @@ -179,9 +174,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
mapinfo_t<T> mapinfo;
{
auto& data = mapinfo.try_emplace(cell::type::interval).first->second;
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[0], pt[1], pt[2]};
};
auto map = [](auto pt) -> std::array<T, 3>
{ return {1 - pt[0], pt[1], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{-1., 0., 0., 0., 1., 0., 0., 0., 1.});
Expand All @@ -196,9 +190,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
{
auto& data = mapinfo.try_emplace(cell::type::quadrilateral).first->second;
{
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[1], pt[0], pt[2]};
};
auto map = [](auto pt) -> std::array<T, 3>
{ return {1 - pt[1], pt[0], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 1., 0., -1., 0., 0., 0., 0., 1.});
Expand All @@ -210,9 +203,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
data.push_back(std::tuple(map, J, detJ, K));
}
{
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[1], pt[0], pt[2]};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {pt[1], pt[0], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 1., 0., 1., 0., 0., 0., 0., 1.});
Expand All @@ -231,9 +223,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
mapinfo_t<T> mapinfo;
{
auto& data = mapinfo.try_emplace(cell::type::interval).first->second;
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[0], pt[1], pt[2]};
};
auto map = [](auto pt) -> std::array<T, 3>
{ return {1 - pt[0], pt[1], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{-1., 0., 0., 0., 1., 0., 0., 0., 1.});
Expand All @@ -246,9 +237,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
{
auto& data = mapinfo.try_emplace(cell::type::triangle).first->second;
{
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[1] - pt[0], pt[0], pt[2]};
};
auto map = [](auto pt) -> std::array<T, 3>
{ return {1 - pt[1] - pt[0], pt[0], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 1., 0., -1., -1., 0., 0., 0., 1.});
Expand All @@ -259,9 +249,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
data.push_back(std::tuple(map, J, detJ, K));
}
{
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[1], pt[0], pt[2]};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {pt[1], pt[0], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 1., 0., 1., 0., 0., 0., 0., 1.});
Expand All @@ -275,9 +264,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
{
auto& data = mapinfo.try_emplace(cell::type::quadrilateral).first->second;
{
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[2], pt[1], pt[0]};
};
auto map = [](auto pt) -> std::array<T, 3>
{ return {1 - pt[2], pt[1], pt[0]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 0., 1., 0., 1., 0., -1., 0., 0.});
Expand All @@ -288,9 +276,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
data.push_back(std::tuple(map, J, detJ, K));
}
{ // scope
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[2], pt[1], pt[0]};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {pt[2], pt[1], pt[0]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 0., 1., 0., 1., 0., 1., 0., 0.});
Expand All @@ -309,9 +296,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
mapinfo_t<T> mapinfo;
{
auto& data = mapinfo.try_emplace(cell::type::interval).first->second;
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[0], pt[1], pt[2]};
};
auto map = [](auto pt) -> std::array<T, 3>
{ return {1 - pt[0], pt[1], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{-1., 0., 0., 0., 1., 0., 0., 0., 1.});
Expand All @@ -325,9 +311,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
{
auto& data = mapinfo.try_emplace(cell::type::quadrilateral).first->second;
{
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[1], pt[0], pt[2]};
};
auto map = [](auto pt) -> std::array<T, 3>
{ return {1 - pt[1], pt[0], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 1., 0., -1., 0., 0., 0., 0., 1.});
Expand All @@ -338,9 +323,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
data.push_back(std::tuple(map, J, detJ, K));
}
{
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[1], pt[0], pt[2]};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {pt[1], pt[0], pt[2]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 1., 0., 1., 0., 0., 0., 0., 1.});
Expand All @@ -355,9 +339,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
{
auto& data = mapinfo.try_emplace(cell::type::triangle).first->second;
{
auto map = [](auto pt) -> std::array<T, 3> {
return {1 - pt[2] - pt[0], pt[1], pt[0]};
};
auto map = [](auto pt) -> std::array<T, 3>
{ return {1 - pt[2] - pt[0], pt[1], pt[0]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 0., 1., 0., 1., 0., -1., 0., -1.});
Expand All @@ -368,9 +351,8 @@ mapinfo_t<T> get_mapinfo(cell::type cell_type)
data.push_back(std::tuple(map, J, detJ, K));
}
{
auto map = [](auto pt) -> std::array<T, 3> {
return {pt[2], pt[1], pt[0]};
};
auto map
= [](auto pt) -> std::array<T, 3> { return {pt[2], pt[1], pt[0]}; };
stdex::mdarray<T, stdex::extents<std::size_t, 3, 3>> J(
stdex::extents<std::size_t, 3, 3>{},
{0., 0., 1., 0., 1., 0., 1., 0., 0.});
Expand Down Expand Up @@ -426,7 +408,7 @@ std::pair<std::vector<T>, std::array<std::size_t, 2>> compute_transformation(
mdarray_t<T, 2> mapped_pts(pts.extents());
for (std::size_t p = 0; p < mapped_pts.extent(0); ++p)
{
auto mp = map_point(
std::array<T, 3> mp = map_point(
std::span(pts.data_handle() + p * pts.extent(1), pts.extent(1)));
for (std::size_t k = 0; k < mapped_pts.extent(1); ++k)
mapped_pts(p, k) = mp[k];
Expand All @@ -439,18 +421,56 @@ std::pair<std::vector<T>, std::array<std::size_t, 2>> compute_transformation(
mdspan_t<const T, 2> polyset_vals(polyset_vals_b.data(), polyset_shape[1],
polyset_shape[2]);

mdarray_t<T, 3> tabulated_data(npts, total_ndofs, vs);
std::vector<T> tabulated_data_b(npts * total_ndofs * vs);
mdspan_t<T, 3> tabulated_data(tabulated_data_b.data(), npts, total_ndofs, vs);
// mdarray_t<T, 3> tabulated_data(npts, total_ndofs, vs);

// std::vector<T> result_b(polyset_vals.extent(1) * coeffs.extent(0));
// mdspan_t<T, 2> result(result_b.data(), polyset_vals.extent(1),
// coeffs.extent(0));

std::vector<T> resultT_b(polyset_vals.extent(1) * coeffs.extent(0));
mdspan_t<T, 2> resultT(resultT_b.data(), coeffs.extent(0),
polyset_vals.extent(1));
// mdarray_t<T, 2> result(polyset_vals.extent(1), coeffs.extent(0));

// std::cout << "3: transform: " << vs << ", " << coeffs.extent(0) << ", "
// << polyset_vals.extent(1) << ", " << polyset_vals.extent(0)
// << std::endl;

std::vector<T> coeffs_b(coeffs.extent(0) * polyset_vals.extent(0));
mdspan_t<T, 2> _coeffs(coeffs_b.data(), coeffs.extent(0),
polyset_vals.extent(0));
for (std::size_t j = 0; j < vs; ++j)
{
mdarray_t<T, 2> result(polyset_vals.extent(1), coeffs.extent(0));
// std::fill(result_b.begin(), result_b.end(), 0);

for (std::size_t k0 = 0; k0 < coeffs.extent(0); ++k0)
for (std::size_t k1 = 0; k1 < polyset_vals.extent(1); ++k1)
for (std::size_t k2 = 0; k2 < polyset_vals.extent(0); ++k2)
result(k1, k0) += coeffs(k0, k2 + psize * j) * polyset_vals(k2, k1);
for (std::size_t k2 = 0; k2 < polyset_vals.extent(0); ++k2)
_coeffs(k0, k2) = coeffs(k0, k2 + psize * j);

// R^T = coeffs * polyset_vals
// for (std::size_t k0 = 0; k0 < coeffs.extent(0); ++k0) // big
// for (std::size_t k1 = 0; k1 < polyset_vals.extent(1); ++k1)
// for (std::size_t k2 = 0; k2 < polyset_vals.extent(0); ++k2) // big
// result(k1, k0) += _coeffs(k0, k2) * polyset_vals(k2, k1);
// result(k1, k0) += coeffs(k0, k2 + psize * j) * polyset_vals(k2, k1);

// r^t: coeffs.extent(0) x polyset_vals.extent(1) [k0, k1]
// c: coeffs.extent(1) x polyset_vals.extent(0) [k0, k2]
// p: polyset_vals.extent(0) x polyset_vals.extent(1) [k2, k1]

math::dot(_coeffs, polyset_vals, resultT);

// std::cout << "4: transform" << std::endl;

for (std::size_t k0 = 0; k0 < resultT.extent(1); ++k0)
for (std::size_t k1 = 0; k1 < resultT.extent(0); ++k1)
tabulated_data(k0, k1, j) = resultT(k1, k0);

for (std::size_t k0 = 0; k0 < result.extent(0); ++k0)
for (std::size_t k1 = 0; k1 < result.extent(1); ++k1)
tabulated_data(k0, k1, j) = result(k0, k1);
// for (std::size_t k0 = 0; k0 < result.extent(0); ++k0)
// for (std::size_t k1 = 0; k1 < result.extent(1); ++k1)
// tabulated_data(k0, k1, j) = result(k0, k1);
}

// push forward
Expand All @@ -460,7 +480,7 @@ std::pair<std::vector<T>, std::array<std::size_t, 2>> compute_transformation(
for (std::size_t i = 0; i < npts; ++i)
{
mdspan_t<const T, 2> tab(
tabulated_data.data()
tabulated_data_b.data()
+ i * tabulated_data.extent(1) * tabulated_data.extent(2),
tabulated_data.extent(1), tabulated_data.extent(2));

Expand Down
1 change: 1 addition & 0 deletions cpp/basix/finite-element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <concepts>
#include <limits>
#include <numeric>

#define str_macro(X) #X
#define str(X) str_macro(X)

Expand Down
18 changes: 15 additions & 3 deletions test/test_lagrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,11 +951,23 @@ def test_dof_transformations_tetrahedron(degree):

@pytest.mark.parametrize(
"celltype",
[basix.CellType.interval, basix.CellType.triangle, basix.CellType.tetrahedron],
[
basix.CellType.interval,
basix.CellType.triangle,
basix.CellType.tetrahedron,
basix.CellType.quadrilateral,
basix.CellType.hexahedron,
],
)
@pytest.mark.parametrize(
"dtype",
[
np.float32,
np.float64,
],
)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lagrange_transform(celltype, dtype):
for p in range(1, 15):
for p in range(1, 12):
e = basix.create_element(
basix.ElementFamily.P, celltype, p, basix.LagrangeVariant.gll_warped, dtype=dtype
)
Expand Down

0 comments on commit eec2489

Please sign in to comment.