From 0a179c69249bf3e3e8529a7951e00b68c6bc7a13 Mon Sep 17 00:00:00 2001 From: liushiyun Date: Fri, 22 Nov 2024 09:58:17 +0800 Subject: [PATCH] NTT expand --- .../benchmark_test/benchmark_ntt_expand.cpp | 12 ++++---- ntt/test/ctest/test_ntt_expand.cpp | 30 ++++--------------- run_docker.sh | 11 ------- 3 files changed, 11 insertions(+), 42 deletions(-) delete mode 100644 run_docker.sh diff --git a/ntt/test/benchmark_test/benchmark_ntt_expand.cpp b/ntt/test/benchmark_test/benchmark_ntt_expand.cpp index 773370b3de..dfb413a49e 100644 --- a/ntt/test/benchmark_test/benchmark_ntt_expand.cpp +++ b/ntt/test/benchmark_test/benchmark_ntt_expand.cpp @@ -13,7 +13,7 @@ namespace nncase::ntt { } template -void benchmark_ntt_expand_NoPack(T init_low, T init_high) { +void benchmark_ntt_expand_nopack(T init_low, T init_high) { std::string pack_mode = "NoPack"; constexpr size_t warmup_size = 10; #if __riscv @@ -50,7 +50,7 @@ void benchmark_ntt_expand_NoPack(T init_low, T init_high) { } template -void benchmark_ntt_expand_NoPack1(T init_low, T init_high) { +void benchmark_ntt_expand_nopack1(T init_low, T init_high) { std::string pack_mode = "NoPack"; constexpr size_t warmup_size = 10; #if __riscv @@ -87,7 +87,7 @@ void benchmark_ntt_expand_NoPack1(T init_low, T init_high) { } template -void benchmark_ntt_expand_2D_pack(T init_low, T init_high) { +void benchmark_ntt_expand_pack(T init_low, T init_high) { std::string pack_mode = "Pack"; constexpr size_t warmup_size = 10; #if __riscv @@ -137,18 +137,18 @@ int main(int argc, char *argv[]) { constexpr size_t M1 = 1; constexpr size_t P1 = 2; - benchmark_ntt_expand_NoPack(-10.f, 10.f); + benchmark_ntt_expand_nopack(-10.f, 10.f); constexpr size_t M2 = 1024; constexpr size_t N2 = 1; constexpr size_t P2 = 2048; - benchmark_ntt_expand_NoPack1(-10.f, 10.f); + benchmark_ntt_expand_nopack1(-10.f, 10.f); constexpr size_t M3 = 32; constexpr size_t N3 = 1; constexpr size_t P3 = 2; constexpr size_t VLEN3 = 4; - benchmark_ntt_expand_2D_pack(0.0f, 1.0f); + benchmark_ntt_expand_pack(-10.f, 10.f); return 0; } \ No newline at end of file diff --git a/ntt/test/ctest/test_ntt_expand.cpp b/ntt/test/ctest/test_ntt_expand.cpp index 1ca2ee0544..74ba54ead5 100644 --- a/ntt/test/ctest/test_ntt_expand.cpp +++ b/ntt/test/ctest/test_ntt_expand.cpp @@ -45,14 +45,8 @@ TEST(ExpandTestFloat, NoPack) { auto ort_input = NttTest::ntt2ort(*ntt_input); int64_t target_shape[] = {M, K}; int64_t shape_size = 2; - - // 创建一维的 shape_tensor 形状 int64_t shape[] = {shape_size}; - - // 创建 shape_tensor auto shape_tensor = make_tensor(reinterpret_cast(target_shape), DataType_INT64, shape, 1); - - // 调用 Expand 操作 auto ort_output = ortki_Expand(ort_input, shape_tensor); // compare @@ -68,35 +62,27 @@ TEST(ExpandTestFloat, NoPack1) { float min_input = static_cast(-10); float max_input = static_cast(10); - // 定义输入和输出张量类型 + // init using input_tensor_type = ntt::tensor>; using output_tensor_type = ntt::tensor>; - - // 初始化输入张量 std::unique_ptr ntt_input(new input_tensor_type); NttTest::init_tensor(*ntt_input, min_input, max_input); - // 执行 expand 操作 + // ntt std::unique_ptr ntt_output1(new output_tensor_type); ntt::expand(*ntt_input, *ntt_output1); - // 将输入张量转换为 ORT 格式 + // ort auto ort_input = NttTest::ntt2ort(*ntt_input); int64_t target_shape[] = {M, K}; int64_t shape_size = 2; int64_t shape[] = {shape_size}; - - // 创建 shape_tensor auto shape_tensor = make_tensor(reinterpret_cast(target_shape), DataType_INT64, shape, 1); - - // 调用 Expand 操作 auto ort_output = ortki_Expand(ort_input, shape_tensor); - // 将 ORT 输出转换回 NTT 格式 + // compare std::unique_ptr ntt_output2(new output_tensor_type); NttTest::ort2ntt(ort_output, *ntt_output2); - - // 比较结果 EXPECT_TRUE(NttTest::compare_tensor(*ntt_output1, *ntt_output2)); } @@ -122,18 +108,12 @@ TEST(ExpandTestFloat, Pack_M_K) { int64_t target_shape[] = {32, 2}; int64_t shape_size = 2; int64_t shape[] = {shape_size}; - - // 创建 shape_tensor auto shape_tensor = make_tensor(reinterpret_cast(target_shape), DataType_INT64, shape, 1); - - // 调用 Expand 操作 auto ort_output = ortki_Expand(ort_input, shape_tensor); - - // 将 ORT 输出转换回 NTT 格式 std::unique_ptr ntt_output2(new output_tensor_type); NttTest::ort2ntt(ort_output, *ntt_output2); - // 比较结果 + // compare EXPECT_TRUE(NttTest::compare_tensor(*ntt_output1, *ntt_output2)); } diff --git a/run_docker.sh b/run_docker.sh deleted file mode 100644 index 6cd0c1c5ad..0000000000 --- a/run_docker.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -# 设置变量 -IMAGE_NAME="liushiyun/nncase:latest" -MOUNT_POINT="/nncase" -CURRENT_DIR=$(pwd) - -# 启动容器 -docker run -it \ - -v $CURRENT_DIR:$MOUNT_POINT \ - $IMAGE_NAME \ No newline at end of file