Skip to content

Commit

Permalink
json
Browse files Browse the repository at this point in the history
  • Loading branch information
hejunchao committed Aug 29, 2023
1 parent cd64295 commit b511b9f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
23 changes: 12 additions & 11 deletions tests/kernels/test_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,19 @@ class GatherTest : public KernelTest,
.expect("create tensor failed");
init_tensor(input);

int64_t indices_array[] = {0, 0, 1, 1};
indices = hrt::create(dt_int64, {2, 2},
int64_t indices_array[] = {0, 0, -1, -1};
indices = hrt::create(dt_int64, {4},
{reinterpret_cast<gsl::byte *>(indices_array),
sizeof(indices_array)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

batchDims_value = value;
int64_t batchDims_array[1] = {value};
batchDims_value = value >= 0
? (size_t)value >= shape.size() ? -1 : value
: -(size_t)value > shape.size() ? -1
: value;

int64_t batchDims_array[1] = {batchDims_value};
batchDims = hrt::create(dt_int64, dims_t{1},
{reinterpret_cast<gsl::byte *>(batchDims_array),
sizeof(batchDims_array)},
Expand All @@ -68,13 +72,10 @@ INSTANTIATE_TEST_SUITE_P(
dt_int8, dt_int16, dt_uint8, dt_uint16,
dt_uint32, dt_float16, dt_float64,
dt_bfloat16, dt_boolean),
testing::Values(dims_t{
2,
2} /*, dims_t{3, 5},
dims_t{2, 3, 1}, dims_t{5, 7, 5},
dims_t{5, 4, 3, 2}, dims_t{5, 5, 7, 7},
dims_t{2, 3, 3, 5}*/),
testing::Values(-1, 0, 1)));
testing::Values(dims_t{2, 3, 5, 7}, dims_t{2, 2},
dims_t{2, 3, 1}, dims_t{5, 5, 7, 7},
dims_t{11}),
testing::Values(-1, 0, 1, -2, -3, 2, 3, -4)));

TEST_P(GatherTest, gather) {
auto input_ort = runtime_tensor_2_ort_tensor(input);
Expand Down
31 changes: 18 additions & 13 deletions tests/kernels/test_get_item.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@
#include <nncase/runtime/stackvm/opcode.h>
#include <ortki/operators.h>

#define TEST_CASE_NAME "test_get_item"

using namespace nncase;
using namespace nncase::runtime;
using namespace ortki;

class GetItemTest
: public KernelTest,
public ::testing::TestWithParam<std::tuple<nncase::typecode_t, dims_t>> {
public ::testing::TestWithParam<std::tuple<int>> {
public:
void SetUp() override {
auto &&[typecode, l_shape] = GetParam();
READY_SUBCASE()

auto l_shape = GetShapeArray("lhs_shape");
auto typecode = GetDataType("lhs_type");

input =
hrt::create(typecode, l_shape, host_runtime_tensor::pool_cpu_only)
Expand All @@ -46,8 +51,7 @@ class GetItemTest
};

INSTANTIATE_TEST_SUITE_P(get_item, GetItemTest,
testing::Combine(testing::Values(dt_float32),
testing::Values(dims_t{1})));
testing::Combine(testing::Range(0, MAX_CASE_NUM)));

TEST_P(GetItemTest, get_item) {

Expand All @@ -62,19 +66,11 @@ TEST_P(GetItemTest, get_item) {
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

int64_t shape_ort[] = {1};
auto shape = hrt::create(dt_int64, {1},
{reinterpret_cast<gsl::byte *>(shape_ort),
sizeof(shape_ort)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

auto get_item_output =
kernels::stackvm::get_item(input.impl(), index.impl())
.expect("get_item failed");

auto output = kernels::stackvm::reshape(get_item_output, shape.impl())
.expect("get_item failed");
auto output = get_item_output;
runtime_tensor actual(output.as<tensor>().expect("as tensor failed"));

bool result = is_same_tensor(expected, actual) ||
Expand All @@ -92,6 +88,15 @@ TEST_P(GetItemTest, get_item) {
}

int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_shape, j)
FOR_LOOP(lhs_type, i)
SPLIT_ELEMENT(lhs_shape, j)
SPLIT_ELEMENT(lhs_type, i)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
4 changes: 4 additions & 0 deletions tests/kernels/test_get_item.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"lhs_shape":[[1]],
"lhs_type":["dt_float32"]
}

0 comments on commit b511b9f

Please sign in to comment.