Skip to content

Commit

Permalink
json
Browse files Browse the repository at this point in the history
  • Loading branch information
hejunchao committed Aug 23, 2023
1 parent 90144fe commit d6ec003
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
57 changes: 39 additions & 18 deletions tests/kernels/test_prelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@
#include <nncase/runtime/stackvm/opcode.h>
#include <ortki/operators.h>

#define TEST_CASE_NAME "test_prelu"

using namespace nncase;
using namespace nncase::runtime;
using namespace ortki;
using slope_t = itlib::small_vector<float, 4>;

class PreluTest : public KernelTest,
public ::testing::TestWithParam<
std::tuple<nncase::typecode_t, dims_t, slope_t>> {
public ::testing::TestWithParam<std::tuple<int>> {
public:
void SetUp() override {
auto &&[typecode, l_shape, slope_value] = GetParam();
READY_SUBCASE()

auto l_shape = GetShapeArray("lhs_shape");
auto typecode = GetDataType("lhs_type");
auto slope_value = GetSlopeArray("slope");

input =
hrt::create(typecode, l_shape, host_runtime_tensor::pool_cpu_only)
Expand All @@ -49,26 +54,30 @@ class PreluTest : public KernelTest,

void TearDown() override {}

slope_t GetSlopeArray(const char *key) {
assert(_document[key].IsArray());
Value &array = _document[key];
size_t arraySize = array.Size();
slope_t cArray(arraySize);
for (rapidjson::SizeType i = 0; i < arraySize; i++) {
if (array[i].IsFloat()) {
cArray[i] = array[i].GetFloat();
} else {
std::cout << "Invalid JSON format. Expected unsigned float "
"values in the array."
<< std::endl;
}
}
return cArray;
}

protected:
runtime_tensor input;
slope_t slope;
};

INSTANTIATE_TEST_SUITE_P(
Prelu, PreluTest,
testing::Combine(
testing::Values(dt_float32),
testing::Values(dims_t{1, 3, 16, 16}, dims_t{1}, dims_t{8, 8},
dims_t{1, 4, 16}, dims_t{1, 3, 24, 24}),
testing::Values(
slope_t{0.2f}, slope_t{0.1f}, slope_t{0.3f},
slope_t{0.2f, 0.1f, 0.3f, 0.2f, 0.1f, 0.3f, 0.2f, 0.1f,
0.3f, 0.2f, 0.1f, 0.3f, 0.2f, 0.1f, 0.3f, 0.2f,
0.1f, 0.3f, 0.2f, 0.1f, 0.3f, 0.2f, 0.1f, 0.3f},
slope_t{0.1f, 0.2f, 0.2f, 0.4f, 0.2f, 0.2f, 0.3f, 0.8f},
slope_t{0.1f, 0.2f, 0.2f, 0.2f, 0.1f, 0.2f, 0.2f, 0.4f, 0.1f, 0.2f,
0.2f, 0.8f, 0.2f, 0.12f, 0.2f, 0.21f},
slope_t{0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.1f, 0.2f})));
INSTANTIATE_TEST_SUITE_P(Prelu, PreluTest,
testing::Combine(testing::Range(0, MAX_CASE_NUM)));

TEST_P(PreluTest, Prelu) {
auto l_ort = runtime_tensor_2_ort_tensor(input);
Expand Down Expand Up @@ -113,6 +122,18 @@ TEST_P(PreluTest, Prelu) {
}

int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_shape, i)
FOR_LOOP(lhs_type, j)
FOR_LOOP(slope, k)
SPLIT_ELEMENT(lhs_shape, i)
SPLIT_ELEMENT(lhs_type, j)
SPLIT_ELEMENT(slope, k)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
5 changes: 5 additions & 0 deletions tests/kernels/test_prelu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"lhs_shape":[[1, 3, 16, 16], [1, 3, 16], [8, 8], [1, 4, 16], [1], [1, 3, 24, 24]],
"lhs_type":["dt_float32"],
"slope":[[0.2], [0.1], [0.3], [0.2, 0.1, 0.3, 0.2, 0.1, 0.3, 0.2, 0.1, 0.3, 0.2, 0.1, 0.3, 0.2, 0.1, 0.3, 0.2, 0.1, 0.3, 0.2, 0.1, 0.3, 0.2, 0.1, 0.3], [0.1, 0.2, 0.2, 0.4, 0.2, 0.2, 0.3, 0.8], [0.1, 0.2, 0.2, 0.2, 0.1, 0.2, 0.2, 0.4, 0.1, 0.2, 0.2, 0.8, 0.2, 0.12, 0.2, 0.21], [0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.1, 0.2]]
}

0 comments on commit d6ec003

Please sign in to comment.