Skip to content

Commit

Permalink
NTT expand
Browse files Browse the repository at this point in the history
  • Loading branch information
liushiyun committed Nov 22, 2024
1 parent 638878e commit 0a179c6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 42 deletions.
12 changes: 6 additions & 6 deletions ntt/test/benchmark_test/benchmark_ntt_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace nncase::ntt {
}

template <typename T, size_t M, size_t P>
void benchmark_ntt_expand_NoPack(T init_low, T init_high) {
void benchmark_ntt_expand_nopack(T init_low, T init_high) {
std::string pack_mode = "NoPack";
constexpr size_t warmup_size = 10;
#if __riscv
Expand Down Expand Up @@ -50,7 +50,7 @@ void benchmark_ntt_expand_NoPack(T init_low, T init_high) {
}

template <typename T, size_t M, size_t N, size_t P>
void benchmark_ntt_expand_NoPack1(T init_low, T init_high) {
void benchmark_ntt_expand_nopack1(T init_low, T init_high) {
std::string pack_mode = "NoPack";
constexpr size_t warmup_size = 10;
#if __riscv
Expand Down Expand Up @@ -87,7 +87,7 @@ void benchmark_ntt_expand_NoPack1(T init_low, T init_high) {
}

template <typename T, size_t M, size_t N, size_t P, size_t VLEN>
void benchmark_ntt_expand_2D_pack(T init_low, T init_high) {
void benchmark_ntt_expand_pack(T init_low, T init_high) {
std::string pack_mode = "Pack";
constexpr size_t warmup_size = 10;
#if __riscv
Expand Down Expand Up @@ -137,18 +137,18 @@ int main(int argc, char *argv[]) {

constexpr size_t M1 = 1;
constexpr size_t P1 = 2;
benchmark_ntt_expand_NoPack<float, M1, P1>(-10.f, 10.f);
benchmark_ntt_expand_nopack<float, M1, P1>(-10.f, 10.f);

constexpr size_t M2 = 1024;
constexpr size_t N2 = 1;
constexpr size_t P2 = 2048;
benchmark_ntt_expand_NoPack1<float, M2, N2, P2>(-10.f, 10.f);
benchmark_ntt_expand_nopack1<float, M2, N2, P2>(-10.f, 10.f);

constexpr size_t M3 = 32;
constexpr size_t N3 = 1;
constexpr size_t P3 = 2;
constexpr size_t VLEN3 = 4;
benchmark_ntt_expand_2D_pack<float, M3, N3, P3, VLEN3>(0.0f, 1.0f);
benchmark_ntt_expand_pack<float, M3, N3, P3, VLEN3>(-10.f, 10.f);

return 0;
}
30 changes: 5 additions & 25 deletions ntt/test/ctest/test_ntt_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,8 @@ TEST(ExpandTestFloat, NoPack) {
auto ort_input = NttTest::ntt2ort(*ntt_input);
int64_t target_shape[] = {M, K};
int64_t shape_size = 2;

// 创建一维的 shape_tensor 形状
int64_t shape[] = {shape_size};

// 创建 shape_tensor
auto shape_tensor = make_tensor(reinterpret_cast<void*>(target_shape), DataType_INT64, shape, 1);

// 调用 Expand 操作
auto ort_output = ortki_Expand(ort_input, shape_tensor);

// compare
Expand All @@ -68,35 +62,27 @@ TEST(ExpandTestFloat, NoPack1) {
float min_input = static_cast<float>(-10);
float max_input = static_cast<float>(10);

// 定义输入和输出张量类型
// init
using input_tensor_type = ntt::tensor<float, ntt::fixed_shape<M>>;
using output_tensor_type = ntt::tensor<float, ntt::fixed_shape<M, K>>;

// 初始化输入张量
std::unique_ptr<input_tensor_type> ntt_input(new input_tensor_type);
NttTest::init_tensor(*ntt_input, min_input, max_input);

// 执行 expand 操作
// ntt
std::unique_ptr<output_tensor_type> ntt_output1(new output_tensor_type);
ntt::expand(*ntt_input, *ntt_output1);

// 将输入张量转换为 ORT 格式
// ort
auto ort_input = NttTest::ntt2ort(*ntt_input);
int64_t target_shape[] = {M, K};
int64_t shape_size = 2;
int64_t shape[] = {shape_size};

// 创建 shape_tensor
auto shape_tensor = make_tensor(reinterpret_cast<void*>(target_shape), DataType_INT64, shape, 1);

// 调用 Expand 操作
auto ort_output = ortki_Expand(ort_input, shape_tensor);

// 将 ORT 输出转换回 NTT 格式
// compare
std::unique_ptr<output_tensor_type> ntt_output2(new output_tensor_type);
NttTest::ort2ntt(ort_output, *ntt_output2);

// 比较结果
EXPECT_TRUE(NttTest::compare_tensor(*ntt_output1, *ntt_output2));
}

Expand All @@ -122,18 +108,12 @@ TEST(ExpandTestFloat, Pack_M_K) {
int64_t target_shape[] = {32, 2};
int64_t shape_size = 2;
int64_t shape[] = {shape_size};

// 创建 shape_tensor
auto shape_tensor = make_tensor(reinterpret_cast<void*>(target_shape), DataType_INT64, shape, 1);

// 调用 Expand 操作
auto ort_output = ortki_Expand(ort_input, shape_tensor);

// 将 ORT 输出转换回 NTT 格式
std::unique_ptr<output_tensor_type> ntt_output2(new output_tensor_type);
NttTest::ort2ntt(ort_output, *ntt_output2);

// 比较结果
// compare
EXPECT_TRUE(NttTest::compare_tensor(*ntt_output1, *ntt_output2));
}

Expand Down
11 changes: 0 additions & 11 deletions run_docker.sh

This file was deleted.

0 comments on commit 0a179c6

Please sign in to comment.