Skip to content

Commit

Permalink
SmallMatrix: Structured binding support (#4189)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang authored Oct 12, 2024
1 parent e64ffef commit 62c2a81
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
19 changes: 19 additions & 0 deletions Src/Base/AMReX_SmallMatrix.H
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <algorithm>
#include <initializer_list>
#include <iostream>
#include <tuple>
#include <type_traits>

namespace amrex {
Expand Down Expand Up @@ -388,6 +389,14 @@ namespace amrex {
return r;
}

template <int N, std::enable_if_t<(N<NRows*NCols),int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
constexpr T const& get () const { return m_mat[N]; }

template <int N, std::enable_if_t<(N<NRows*NCols),int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
constexpr T& get () { return m_mat[N]; }

private:
T m_mat[NRows*NCols];
};
Expand Down Expand Up @@ -447,6 +456,16 @@ namespace amrex {
using SmallRowVector = SmallMatrix<T,1,N,Order::F,StartIndex>;
}

template <class T, int NRows, int NCols, amrex::Order ORDER, int StartIndex>
struct std::tuple_size<amrex::SmallMatrix<T,NRows,NCols,ORDER,StartIndex> >
: std::integral_constant<std::size_t,NRows*NCols> {};

template <std::size_t N, class T, int NRows, int NCols, amrex::Order ORDER, int StartIndex>
struct std::tuple_element<N, amrex::SmallMatrix<T,NRows,NCols,ORDER,StartIndex> >
{
using type = T;
};

#endif

/*
Expand Down
31 changes: 31 additions & 0 deletions Tests/SmallMatrix/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ int main (int argc, char* argv[])
b.setVal(-1);
AMREX_ALWAYS_ASSERT(a.dot(b) == -30);
}
{
SmallVector<int, 3> v{10,20,30};
auto const& [x,y,z] = v;
AMREX_ALWAYS_ASSERT(x == 10 && y == 20 && z == 30);

auto& [a,b,c] = v;
a = 100; b = 200; c = 300;
AMREX_ALWAYS_ASSERT(v[0] == 100 && v[1] == 200 && v[2] == 300);

auto const [i,j,k] = v;
AMREX_ALWAYS_ASSERT(i == 100 && j == 200 && k == 300);

auto [d,e,f] = v;
AMREX_ALWAYS_ASSERT(d == 100 && e == 200 && f == 300);
}

// 1-based indexing
{
Expand Down Expand Up @@ -271,5 +286,21 @@ int main (int argc, char* argv[])
b.setVal(-1);
AMREX_ALWAYS_ASSERT(a.dot(b) == -30);
}
{
SmallVector<int, 3, 1> v{10,20,30};
auto const& [x,y,z] = v;
AMREX_ALWAYS_ASSERT(x == 10 && y == 20 && z == 30);

auto& [a,b,c] = v;
a = 100; b = 200; c = 300;
AMREX_ALWAYS_ASSERT(v[1] == 100 && v[2] == 200 && v[3] == 300);

auto const [i,j,k] = v;
AMREX_ALWAYS_ASSERT(i == 100 && j == 200 && k == 300);

auto [d,e,f] = v;
AMREX_ALWAYS_ASSERT(d == 100 && e == 200 && f == 300);
}

amrex::Finalize();
}

0 comments on commit 62c2a81

Please sign in to comment.