Skip to content

Commit

Permalink
fix celu's alpha
Browse files Browse the repository at this point in the history
  • Loading branch information
hejunchao committed Aug 25, 2023
1 parent 7dad301 commit 4b61e92
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 22 deletions.
2 changes: 1 addition & 1 deletion tests/kernels/test_lrn.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"lhs_shape":[[1, 3, 16, 16], [1, 3, 8, 8]],
"lhs_shape":[[1, 3, 16, 16], [1, 3, 8, 8], [1, 3, 24, 24], [1, 3, 4, 4]],
"lhs_type":["dt_float32"]
}
6 changes: 3 additions & 3 deletions tests/kernels/test_reduce_window2D.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"lhs_type":["dt_float32"],
"lhs_shape":[[1, 3, 16, 16]],
"lhs_shape":[[1, 3, 16, 16], [1, 3, 16, 24]],
"dilations":[[1, 1]],
"filter": [[3, 3]],
"stride": [[1, 1]],
"filter": [[3, 3], [9, 9]],
"stride": [[1, 1], [2, 2]],
"onnxPads":[[1, 1, 1, 1], [0, 0, 0, 0]]
}
40 changes: 28 additions & 12 deletions tests/kernels/test_reverse_sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class ReverseSequenceTest : public KernelTest,

auto typecode = GetDataType("lhs_type");
auto l_shape = GetShapeArray("i_shape");
seqLens_array = GetAxesArray("seqLens");
batch_axis = GetNumber("batch_axis");
time_axis = 1;

input =
hrt::create(typecode, l_shape, host_runtime_tensor::pool_cpu_only)
Expand All @@ -47,6 +50,9 @@ class ReverseSequenceTest : public KernelTest,

protected:
runtime_tensor input;
axes_t seqLens_array;
int64_t batch_axis;
int64_t time_axis;
};

INSTANTIATE_TEST_SUITE_P(ReverseSequence, ReverseSequenceTest,
Expand All @@ -56,15 +62,19 @@ TEST_P(ReverseSequenceTest, ReverseSequence) {
auto l_ort = runtime_tensor_2_ort_tensor(input);

// expected
size_t seqLens_size = seqLens_array.size();
int64_t *seqLens_array_ptr =
(int64_t *)malloc(seqLens_size * sizeof(int64_t));
std::copy(seqLens_array.begin(), seqLens_array.end(), seqLens_array_ptr);
size_t size = 0;
int64_t seqLens_array[] = {1, 2, 3, 4};
auto seqLens = hrt::create(dt_int64, {4},
{reinterpret_cast<gsl::byte *>(seqLens_array),
sizeof(seqLens_array)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
auto seqLens =
hrt::create(dt_int64, {seqLens_size},
{reinterpret_cast<gsl::byte *>(seqLens_array_ptr),
seqLens_size * sizeof(int64_t)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
auto output_ort = ortki_ReverseSequence(
l_ort, runtime_tensor_2_ort_tensor(seqLens), 1, 0);
l_ort, runtime_tensor_2_ort_tensor(seqLens), batch_axis, time_axis);
void *ptr_ort = tensor_buffer(output_ort, &size);
dims_t shape(tensor_rank(output_ort));
tensor_shape(output_ort, reinterpret_cast<int64_t *>(shape.data()));
Expand All @@ -74,14 +84,14 @@ TEST_P(ReverseSequenceTest, ReverseSequence) {
.expect("create tensor failed");

// actual
int64_t batch_axis_array[] = {1};
int64_t batch_axis_array[] = {batch_axis};
auto batch_axis =
hrt::create(dt_int64, {1},
{reinterpret_cast<gsl::byte *>(batch_axis_array),
sizeof(batch_axis_array)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
int64_t time_axis_array[] = {0};
int64_t time_axis_array[] = {time_axis};
auto time_axis =
hrt::create(dt_int64, {1},
{reinterpret_cast<gsl::byte *>(time_axis_array),
Expand All @@ -104,19 +114,25 @@ TEST_P(ReverseSequenceTest, ReverseSequence) {
print_runtime_tensor(expected);
}

// compare
// compare
EXPECT_TRUE(result);
}

int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_type, i)
FOR_LOOP(i_shape, j)
SPLIT_ELEMENT(lhs_type, i)
FOR_LOOP(lhs_type, i)
FOR_LOOP(seqLens, k)
FOR_LOOP(batch_axis, l)
SPLIT_ELEMENT(i_shape, j)
SPLIT_ELEMENT(lhs_type, i)
SPLIT_ELEMENT(seqLens, k)
SPLIT_ELEMENT(batch_axis, l)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
6 changes: 4 additions & 2 deletions tests/kernels/test_reverse_sequence.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{
"i_shape":[[4, 4]],
"lhs_type":["dt_float32"]
"i_shape":[[2, 4, 2, 2]],
"lhs_type":["dt_float32"],
"seqLens":[[1, 1], [1, 2], [2, 2], [3, 3]],
"batch_axis":[0]
}
8 changes: 4 additions & 4 deletions tests/kernels/test_slice.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"lhs_type":["dt_int32"],
"input_shape":[[2, 3, 4, 5], [1, 4, 5, 6], [1, 1, 1, 120], [2, 2, 5, 6], [1, 1, 2, 60]],
"value1": [[0, 0, 0, 0]],
"value2": [[1, 1, 1, 5]],
"value3": [[0, 1, 2, 3]],
"value4": [[1, 1, 1, 1]]
"value1": [[0, 0, 0, 0], [1, 1, 1, 1]],
"value2": [[1, 1, 1, 5], [2, 2, 2, 2]],
"value3": [[0, 1, 2, 3], [2, 3, 2, 3]],
"value4": [[1, 1, 1, 1], [1, 2, 3, 4]]
}

0 comments on commit 4b61e92

Please sign in to comment.