Skip to content

Commit

Permalink
Merge branch 'master' into kmp5/feature/cp-bcd
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Aug 17, 2023
2 parents dbd76aa + 33d2218 commit 4008a5e
Showing 1 changed file with 31 additions and 32 deletions.
63 changes: 31 additions & 32 deletions unittest/ztensor_cp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TEST_CASE("ZCP") {
using btas::CP_RALS;

// double epsilon = fmax(1e-10, std::numeric_limits<double>::epsilon());
double epsilon = 4e-5;
double epsilon = 1e-5;

ztensor Z3(3, 2, 4);
std::ifstream inp3(__dirname + "/z-mat3D.txt");
Expand Down Expand Up @@ -72,98 +72,97 @@ TEST_CASE("ZCP") {
std::complex<double> norm3 = sqrt(dot(Z3, Z3));
std::complex<double> norm32 = sqrt(dot(Z33, Z33));

zconv_class conv(1e-4);

zconv_class conv(1e-3);

// ALS tests
{
SECTION("ALS MODE = 3, Finite error") {
CP_ALS<ztensor, zconv_class> A1(Z3);
conv.set_norm(norm3.real());
double diff = A1.compute_error(conv, 1e-9, 1, 50, false, 0, 1e4, false, true);
double diff = A1.compute_error(conv, 1e-6, 1, 11,false,0,100);
CHECK(std::abs(diff) <= epsilon);
}
SECTION("ALS MODE = 3, Finite rank") {
CP_ALS<ztensor, zconv_class> A1(Z3);
conv.set_norm(norm3.real());
double diff = A1.compute_rank(99, conv);
double diff = A1.compute_rank(11, conv, 1, false, 0, 100);
CHECK(std::abs(diff) <= epsilon);
}
#if BTAS_ENABLE_TUCKER_CP_UT
SECTION("ALS MODE = 3, Tucker + CP") {
auto d = Z3;
btas::TUCKER_CP_ALS<ztensor, zconv_class> A1(d, 1e-3);
conv.set_norm(norm3.real());
double diff = A1.compute_rank(25, conv, 1, false, 0, 1e4, false, false, true);
double diff = A1.compute_rank(6, conv, 1, false, 0, 100);
CHECK(std::abs(diff) <= epsilon);
conv.verbose(false);
}
#endif

SECTION("ALS MODE = 4, Finite error") {
CP_ALS<ztensor, zconv_class> A1(Z4);
conv.set_norm(norm4.real());
double diff = A1.compute_error(conv, 1e-9, 1, 120, false, 0, 1e4, false, true);
double diff = A1.compute_error(conv, 1e-2, 1, 100, true, 57);
CHECK(std::abs(diff) <= epsilon);
}
SECTION("ALS MODE = 4, Finite rank") {
CP_ALS<ztensor, zconv_class> A1(Z4);
conv.set_norm(norm4.real());
double diff = A1.compute_rank(120, conv);
double diff = A1.compute_rank(57, conv, 1, true, 57);
CHECK(std::abs(diff) <= epsilon);
}
#if BTAS_ENABLE_TUCKER_CP_UT
SECTION("ALS MODE = 4, Tucker + CP") {
auto d = Z4;
btas::TUCKER_CP_ALS<ztensor, zconv_class> A1(d, 1e-3);
conv.set_norm(norm4.real());
double diff = A1.compute_rank(120, conv, 1, false, 0, 1e4, false, false, true);
double diff = A1.compute_rank(58, conv, 1, true, 58);
CHECK(std::abs(diff) <= epsilon);
}
#endif
}
// RALS TESTS
{
SECTION("RALS MODE = 3, Finite rank"){
SECTION("RALS MODE = 3, Finite rank") {
CP_RALS<ztensor, zconv_class> A1(Z3);
conv.set_norm(norm3.real());
double diff =A1.compute_rank(20, conv, 1, false, 0, 100, false, false, true);
double diff = A1.compute_rank(12, conv);
CHECK(std::abs(diff) <= epsilon);
}
SECTION("RALS MODE = 3, Finite error"){
SECTION("RALS MODE = 3, Finite error") {
CP_RALS<ztensor, zconv_class> A1(Z3);
conv.set_norm(norm3.real());
double diff = A1.compute_error(conv, 1e-9, 1, 30, false, 0, 1e4, false, true);
double diff = A1.compute_error(conv, 1e-2, 1, 13, true, 12);
CHECK(std::abs(diff) <= epsilon);
}
#if BTAS_ENABLE_TUCKER_CP_UT
SECTION("RALS MODE = 3, Tucker + CP"){
SECTION("RALS MODE = 3, Tucker + CP") {
auto d = Z3;
btas::TUCKER_CP_RALS<ztensor, zconv_class > A1(d, 1e-3);
btas::TUCKER_CP_RALS<ztensor, zconv_class> A1(d, 1e-3);
conv.set_norm(norm3.real());
double diff = A1.compute_rank(35, conv, 1, false, 0, 100, false, false, true);
double diff = A1.compute_rank(13, conv, 1, true, 13);
CHECK(std::abs(diff) <= epsilon);
}
#endif
SECTION("RALS MODE = 4, Finite rank"){
CP_RALS<ztensor, zconv_class> A1(Z4);
conv.set_norm(norm4.real());
double diff = A1.compute_rank(120, conv);
CHECK(std::abs(diff) <= epsilon);
}
SECTION("RALS MODE = 4, Finite error"){
SECTION("RALS MODE = 4, Finite rank") {
CP_RALS<ztensor, zconv_class> A1(Z4);
conv.set_norm(norm4.real());
double diff = A1.compute_error(conv, 1e-9, 1, 120);
double diff = A1.compute_rank(65, conv, 1, true, 65);
CHECK(std::abs(diff) <= epsilon);
}
SECTION("RALS MODE = 4, Finite error"){
CP_RALS<ztensor, zconv_class> A1(Z4);
conv.set_norm(norm4.real());
double diff = A1.compute_error(conv, 1e-2, 1, 67, true, 65);
CHECK(std::abs(diff) <= epsilon);
}
#if BTAS_ENABLE_TUCKER_CP_UT
SECTION("RALS MODE = 4, Tucker + CP"){
auto d = Z4;
btas::TUCKER_CP_RALS<ztensor, zconv_class > A1(d, 1e-3);
conv.set_norm(norm4.real());
double diff = A1.compute_rank(120, conv, 1, false, 0, 100, false, false, true);
CHECK(std::abs(diff) <= epsilon);
}
SECTION("RALS MODE = 4, Tucker + CP"){
auto d = Z4;
btas::TUCKER_CP_RALS<ztensor, zconv_class > A1(d, 1e-3);
conv.set_norm(norm4.real());
double diff = A1.compute_rank(67, conv, 1, true, 67);
CHECK(std::abs(diff) <= epsilon);
}
#endif
}
}
Expand Down

0 comments on commit 4008a5e

Please sign in to comment.