Skip to content

Commit

Permalink
legacy check compat., add bad_arg tests (#1218)
Browse files Browse the repository at this point in the history
  • Loading branch information
TorreZuk committed Apr 28, 2022
1 parent ae86376 commit c922737
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 41 deletions.
47 changes: 30 additions & 17 deletions clients/gtest/trtri_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@

namespace
{

enum trtri_test_type
{
TRTRI,
TRTRI_BATCHED,
TRTRI_STRIDED_BATCHED,
};

// By default, this test does not apply to any types.
// The unnamed second parameter is used for enable_if_t below.
template <typename, typename = void>
Expand All @@ -53,23 +61,22 @@ namespace
{
if(!strcmp(arg.function, "trtri"))
testing_trtri<T>(arg);
else if(!strcmp(arg.function, "trtri_bad_arg"))
testing_trtri_bad_arg<T>(arg);
else if(!strcmp(arg.function, "trtri_batched"))
testing_trtri_batched<T>(arg);
else if(!strcmp(arg.function, "trtri_batched_bad_arg"))
testing_trtri_batched_bad_arg<T>(arg);
else if(!strcmp(arg.function, "trtri_strided_batched"))
testing_trtri_strided_batched<T>(arg);
else if(!strcmp(arg.function, "trtri_strided_batched_bad_arg"))
testing_trtri_strided_batched_bad_arg<T>(arg);
else
FAIL() << "Internal error: Test called with unknown function: " << arg.function;
}
};

enum trtri_kind
{
trtri_k,
trtri_batched_k,
trtri_strided_batched_k,
};

template <trtri_kind K>
template <trtri_test_type K>
struct trtri_template : RocBLAS_Test<trtri_template<K>, trtri_testing>
{
// Filter for which types apply to this suite
Expand All @@ -81,12 +88,14 @@ namespace
// Filter for which functions apply to this suite
static bool function_filter(const Arguments& arg)
{
if(K == trtri_k)
return !strcmp(arg.function, "trtri");
else if(K == trtri_batched_k)
return !strcmp(arg.function, "trtri_batched");
if(K == TRTRI)
return !strcmp(arg.function, "trtri") || !strcmp(arg.function, "trtri_bad_arg");
else if(K == TRTRI_BATCHED)
return !strcmp(arg.function, "trtri_batched")
|| !strcmp(arg.function, "trtri_batched_bad_arg");
else
return !strcmp(arg.function, "trtri_strided_batched");
return !strcmp(arg.function, "trtri_strided_batched")
|| !strcmp(arg.function, "trtri_strided_batched_bad_arg");
}

// Google Test name suffix based on parameters
Expand All @@ -95,7 +104,11 @@ namespace
RocBLAS_TestName<trtri_template> name(arg.name);
name << rocblas_datatype2string(arg.a_type) << '_' << (char)std::toupper(arg.uplo)
<< (char)std::toupper(arg.diag) << '_' << arg.N << '_' << arg.lda;
if(K != trtri_k)

if(K == TRTRI_STRIDED_BATCHED)
name << '_' << arg.stride_a;

if(K != TRTRI)
name << '_' << arg.batch_count;

if(arg.fortran)
Expand All @@ -107,23 +120,23 @@ namespace
}
};

using trtri = trtri_template<trtri_k>;
using trtri = trtri_template<TRTRI>;
TEST_P(trtri, blas3_tensile)
{
CATCH_SIGNALS_AND_EXCEPTIONS_AS_FAILURES(
rocblas_simple_dispatch<trtri_testing>(GetParam()));
}
INSTANTIATE_TEST_CATEGORIES(trtri);

using trtri_batched = trtri_template<trtri_batched_k>;
using trtri_batched = trtri_template<TRTRI_BATCHED>;
TEST_P(trtri_batched, blas3_tensile)
{
CATCH_SIGNALS_AND_EXCEPTIONS_AS_FAILURES(
rocblas_simple_dispatch<trtri_testing>(GetParam()));
}
INSTANTIATE_TEST_CATEGORIES(trtri_batched);

using trtri_strided_batched = trtri_template<trtri_strided_batched_k>;
using trtri_strided_batched = trtri_template<TRTRI_STRIDED_BATCHED>;
TEST_P(trtri_strided_batched, blas3_tensile)
{
CATCH_SIGNALS_AND_EXCEPTIONS_AS_FAILURES(
Expand Down
10 changes: 10 additions & 0 deletions clients/gtest/trtri_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ Tests:
matrix_size: *small_matrix_size_range
fortran: [ false, true ]

- name: trtri_bad_arg
category: pre_checkin
function:
- trtri_bad_arg: *single_double_precisions_complex_real
- trtri_batched_bad_arg: *single_double_precisions_complex_real
- trtri_strided_batched_bad_arg: *single_double_precisions_complex_real
uplo: L
diag: N
fortran: [ false, true ]

- name: trtri
category: pre_checkin
function: trtri
Expand Down
53 changes: 53 additions & 0 deletions clients/include/blas3/testing_trtri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,59 @@
#include "unit.hpp"
#include "utility.hpp"

template <typename T>
void testing_trtri_bad_arg(const Arguments& arg)
{
auto rocblas_trtri_fn = arg.fortran ? rocblas_trtri<T, true> : rocblas_trtri<T, false>;

rocblas_local_handle handle{arg};

const rocblas_int N = 100;
const rocblas_int lda = 100;

const rocblas_fill uplo = rocblas_fill_upper;
const rocblas_diagonal diag = rocblas_diagonal_non_unit;

// Allocate device memory
device_matrix<T> dA(N, N, lda);
device_matrix<T> dinvA(N, N, lda);

// Check device memory allocation
CHECK_DEVICE_ALLOCATION(dA.memcheck());
CHECK_DEVICE_ALLOCATION(dinvA.memcheck());

EXPECT_ROCBLAS_STATUS(rocblas_trtri_fn(handle, rocblas_fill_full, diag, N, dA, lda, dinvA, lda),
rocblas_status_invalid_value);

EXPECT_ROCBLAS_STATUS(
rocblas_trtri_fn(handle, uplo, (rocblas_diagonal)rocblas_side_both, N, dA, lda, dinvA, lda),
rocblas_status_invalid_value);

// check for invalid sizes
EXPECT_ROCBLAS_STATUS(rocblas_trtri_fn(handle, uplo, diag, -1, dA, lda, dinvA, lda),
rocblas_status_invalid_size);

EXPECT_ROCBLAS_STATUS(rocblas_trtri_fn(handle, uplo, diag, N, dA, lda - 1, dinvA, lda),
rocblas_status_invalid_size);

EXPECT_ROCBLAS_STATUS(rocblas_trtri_fn(handle, uplo, diag, N, dA, lda, dinvA, lda - 1),
rocblas_status_invalid_size);

// nullptr tests
EXPECT_ROCBLAS_STATUS(rocblas_trtri_fn(nullptr, uplo, diag, N, dA, lda, dinvA, lda),
rocblas_status_invalid_handle);

EXPECT_ROCBLAS_STATUS(rocblas_trtri_fn(handle, uplo, diag, N, nullptr, lda, dinvA, lda),
rocblas_status_invalid_pointer);

EXPECT_ROCBLAS_STATUS(rocblas_trtri_fn(handle, uplo, diag, N, dA, lda, nullptr, lda),
rocblas_status_invalid_pointer);

// quick return: If N==0, then all pointers can be nullptr without error
EXPECT_ROCBLAS_STATUS(rocblas_trtri_fn(handle, uplo, diag, 0, nullptr, lda, nullptr, lda),
rocblas_status_success);
}

template <typename T>
void testing_trtri(const Arguments& arg)
{
Expand Down
117 changes: 117 additions & 0 deletions clients/include/blas3/testing_trtri_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,123 @@
#include "unit.hpp"
#include "utility.hpp"

template <typename T>
void testing_trtri_batched_bad_arg(const Arguments& arg)
{
auto rocblas_trtri_batched_fn
= arg.fortran ? rocblas_trtri_batched<T, true> : rocblas_trtri_batched<T, false>;

rocblas_local_handle handle{arg};

const rocblas_int N = 100;
const rocblas_int lda = 100;
const rocblas_int batch_count = 2;

const rocblas_fill uplo = rocblas_fill_upper;
const rocblas_diagonal diag = rocblas_diagonal_non_unit;

// Allocate device memory
device_batch_matrix<T> dA(N, N, lda, batch_count);
device_batch_matrix<T> dinvA(N, N, lda, batch_count);

// Check device memory allocation
CHECK_DEVICE_ALLOCATION(dA.memcheck());
CHECK_DEVICE_ALLOCATION(dinvA.memcheck());

EXPECT_ROCBLAS_STATUS(rocblas_trtri_batched_fn(handle,
rocblas_fill_full,
diag,
N,
dA.ptr_on_device(),
lda,
dinvA.ptr_on_device(),
lda,
batch_count),
rocblas_status_invalid_value);

EXPECT_ROCBLAS_STATUS(rocblas_trtri_batched_fn(handle,
uplo,
(rocblas_diagonal)rocblas_side_both,
N,
dA.ptr_on_device(),
lda,
dinvA.ptr_on_device(),
lda,
batch_count),
rocblas_status_invalid_value);

// check for invalid sizes
EXPECT_ROCBLAS_STATUS(rocblas_trtri_batched_fn(handle,
uplo,
diag,
-1,
dA.ptr_on_device(),
lda,
dinvA.ptr_on_device(),
lda,
batch_count),
rocblas_status_invalid_size);

EXPECT_ROCBLAS_STATUS(
rocblas_trtri_batched_fn(
handle, uplo, diag, N, dA.ptr_on_device(), lda, dinvA.ptr_on_device(), lda, -1),
rocblas_status_invalid_size);

EXPECT_ROCBLAS_STATUS(rocblas_trtri_batched_fn(handle,
uplo,
diag,
N,
dA.ptr_on_device(),
lda - 1,
dinvA.ptr_on_device(),
lda,
batch_count),
rocblas_status_invalid_size);

EXPECT_ROCBLAS_STATUS(rocblas_trtri_batched_fn(handle,
uplo,
diag,
N,
dA.ptr_on_device(),
lda,
dinvA.ptr_on_device(),
lda - 1,
batch_count),
rocblas_status_invalid_size);

// nullptr tests
EXPECT_ROCBLAS_STATUS(rocblas_trtri_batched_fn(nullptr,
uplo,
diag,
N,
dA.ptr_on_device(),
lda,
dinvA.ptr_on_device(),
lda,
batch_count),
rocblas_status_invalid_handle);

EXPECT_ROCBLAS_STATUS(
rocblas_trtri_batched_fn(
handle, uplo, diag, N, nullptr, lda, dinvA.ptr_on_device(), lda, batch_count),
rocblas_status_invalid_pointer);

EXPECT_ROCBLAS_STATUS(
rocblas_trtri_batched_fn(
handle, uplo, diag, N, dA.ptr_on_device(), lda, nullptr, lda, batch_count),
rocblas_status_invalid_pointer);

// quick return: If N==0, then all pointers can be nullptr without error
EXPECT_ROCBLAS_STATUS(
rocblas_trtri_batched_fn(handle, uplo, diag, 0, nullptr, lda, nullptr, lda, batch_count),
rocblas_status_success);

// quick return: If batch_count==0, then all pointers can be nullptr without error
EXPECT_ROCBLAS_STATUS(
rocblas_trtri_batched_fn(handle, uplo, diag, N, nullptr, lda, nullptr, lda, 0),
rocblas_status_success);
}

template <typename T>
void testing_trtri_batched(const Arguments& arg)
{
Expand Down
Loading

0 comments on commit c922737

Please sign in to comment.