Skip to content

Commit

Permalink
Integer determinant calculation (#80)
Browse files Browse the repository at this point in the history
* Add integer determinant overload

* Add typetraits and make template more readable

* Add new swap function and change naming.

* Added determinant specilization via custom type traits

* Removed unnecessary include

* Update impl/NotSoBasicLinearAlgebra.h

Co-authored-by: tomstewart89 <[email protected]>

---------

Co-authored-by: Nils Mueller <[email protected]>
Co-authored-by: tomstewart89 <[email protected]>
  • Loading branch information
3 people authored Jun 10, 2024
1 parent 0080e10 commit 55b01ed
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 14 deletions.
1 change: 1 addition & 0 deletions BasicLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,6 @@ using DownCast = MatrixBase<DerivedType, DerivedType::Rows, DerivedType::Cols, t

} // namespace BLA

#include "impl/Types.h"
#include "impl/BasicLinearAlgebra.h"
#include "impl/NotSoBasicLinearAlgebra.h"
88 changes: 74 additions & 14 deletions impl/NotSoBasicLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

namespace BLA
{
template <typename T>
inline void bla_swap(T &a, T &b)
template <typename ParentType, typename Dtype>
void Swap(MatrixBase<ParentType, ParentType::Rows, ParentType::Cols, Dtype> &A,
MatrixBase<ParentType, ParentType::Rows, ParentType::Cols, Dtype> &B)
{
T tmp = a;
a = b;
b = tmp;
Dtype tmp;
for (int i = 0; i < ParentType::Rows; i++)
{
for (int j = 0; j < ParentType::Cols; j++)
{
tmp = A(i, j);
A(i, j) = B(i, j);
B(i, j) = tmp;
}
}
}

template <typename ParentTypeA, typename ParentTypeB, int Cols>
Expand Down Expand Up @@ -135,14 +143,19 @@ LUDecomposition<ParentType> LUDecompose(MatrixBase<ParentType, Dim, Dim, typenam

if (j != argmax)
{
for (int k = 0; k < Dim; ++k)
{
bla_swap(A(argmax, k), A(j, k));
}
auto row_argmax = A.Row(argmax);
auto row_j = A.Row(j);
Swap(row_argmax, row_j);

decomp.parity = -decomp.parity;

bla_swap(idx[j], idx[argmax]);
// swap indices
{
auto tmp = idx[j];
idx[j] = idx[argmax];
idx[argmax] = tmp;
}

row_scale[argmax] = row_scale[j];
}

Expand Down Expand Up @@ -320,14 +333,16 @@ Matrix<Dim, Dim, typename ParentType::DType> Inverse(
return out;
}

template <typename ParentType, int Dim>
typename ParentType::DType Determinant(const MatrixBase<ParentType, Dim, Dim, typename ParentType::DType> &A)
// LU-Decomposition only works for floating point numbers. Use Bareiss algorithm for (signed) integer types.
template <typename ParentType, typename Dtype, int Dim>
typename Types::enable_if<Types::is_floating_point<Dtype>::value, Dtype>::type
DeterminantLUDecomposition(const MatrixBase<ParentType, Dim, Dim, Dtype> &A)
{
Matrix<Dim, Dim, typename ParentType::DType> A_copy = A;
Matrix<Dim, Dim, Dtype> A_copy = A;

auto decomp = LUDecompose(A_copy);

typename ParentType::DType det = decomp.parity;
Dtype det = decomp.parity;

for (int i = 0; i < Dim; ++i)
{
Expand All @@ -337,6 +352,51 @@ typename ParentType::DType Determinant(const MatrixBase<ParentType, Dim, Dim, ty
return det;
}

// Bareiss algorithm works for all (signed) types, but for floating-point numbers LU-Decomposition is faster.
template <typename ParentType, typename Dtype, int Dim>
typename Types::enable_if<Types::is_signed<Dtype>::value, Dtype>::type
DeterminantBareissAlgorithm(const MatrixBase<ParentType, Dim, Dim, Dtype> &A)
{
Matrix<Dim, Dim, Dtype> A_copy = A;

int sign = 1;
Dtype prev = 1;

for (int i = 0; i < Dim; i++)
{
if (A_copy(i, i) == 0)
{
int j = i + 1;
for (; j < Dim; j++)
{
if (A_copy(j, i) != 0) break;
}
if (j == Dim) return 0;
auto row_i = A_copy.Row(i);
auto row_j = A_copy.Row(j);
Swap(row_i, row_j);
sign = - sign;
}
for (int j = i + 1; j < Dim; j++)
{
for (int k = i + 1; k < Dim; k++)
{
A_copy(j, k) = (A_copy(j, k) * A_copy(i, i) - A_copy(j, i) * A_copy(i, k)) / prev;
}
}
prev = A_copy(i, i);
}
return sign * A_copy(Dim - 1, Dim - 1);
}

template <typename ParentType, typename Dtype, int Dim>
typename Types::enable_if<Types::is_floating_point<Dtype>::value, Dtype>::type
Determinant(const MatrixBase<ParentType, Dim, Dim, Dtype> &A) { return DeterminantLUDecomposition(A); }

template <typename ParentType, typename Dtype, int Dim>
typename Types::enable_if<Types::is_signed_integer<Dtype>::value, Dtype>::type
Determinant(const MatrixBase<ParentType, Dim, Dim, Dtype> &A) { return DeterminantBareissAlgorithm(A); }

template <typename DerivedType>
typename DerivedType::DType Norm(const DownCast<DerivedType> &A)
{
Expand Down
46 changes: 46 additions & 0 deletions impl/Types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

namespace BLA
{
// This namespace exists because the header "typetraits" is not implemented in every Arduino environment.
namespace Types
{
template<class T, class U> struct is_same { static constexpr bool value = false; };
template<class T> struct is_same<T, T> { static constexpr bool value = true; };

template<class T> struct remove_const { typedef T type; };
template<class T> struct remove_const<const T> { typedef T type; };

template<class T>
struct is_floating_point
{
static constexpr bool value =
is_same<float, typename remove_const<T>::type>::value ||
is_same<double, typename remove_const<T>::type>::value ||
is_same<long double, typename remove_const<T>::type>::value;
};

template<class T>
struct is_signed_integer
{
static constexpr bool value =
is_same<signed char, typename remove_const<T>::type>::value ||
is_same<signed short, typename remove_const<T>::type>::value ||
is_same<signed int, typename remove_const<T>::type>::value ||
is_same<signed long, typename remove_const<T>::type>::value ||
is_same<signed long long, typename remove_const<T>::type>::value;
};

template<class T>
struct is_signed
{
static constexpr bool value =
is_floating_point<T>::value ||
is_signed_integer<T>::value;
};

template<bool, typename T = void> struct enable_if {};
template<typename T> struct enable_if<true, T> { typedef T type; };
}
} // namespace BLA

9 changes: 9 additions & 0 deletions test/test_linear_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ TEST(Arithmetic, Determinant)
BLA::Matrix<3, 3> singular = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0};

EXPECT_FLOAT_EQ(Determinant(singular), 0.0);

BLA::Matrix<4, 4, int16_t> C = {8, 5, 5, 8, 3, 1, 3, 2, 1, 1, 3, 0, 3, 3, 5, 9};
int16_t det_C_expected = -140;

EXPECT_EQ(Determinant(C), det_C_expected);

BLA::Matrix<3, 3, int> singular_int = {1, 0, 0, 1, 0, 0, 1, 0, 0};

EXPECT_EQ(Determinant(singular_int), 0);
}

template <typename SparseMatA, typename SparseMatB, int OutTableSize = 100>
Expand Down

0 comments on commit 55b01ed

Please sign in to comment.