Skip to content

Commit

Permalink
Optimize NTT ops. (#1221)
Browse files Browse the repository at this point in the history
* Optimize log/rsqrt/cosh for rvv with ulp.

* add rsqrt precison version for x86

* change reduce test order for icache

* add some strategy for performance test

* remove the ugly code just now

* Add ulp ctest case for acos and optimize acos for riscv64.

* to avoid timeouts

* reduce repeat num for test

* change test mode for reduce

* add roofline for x86 reduce

* remove typo in test

* Add ulp test case for asin and optimize it with rvv.

* adjust mamul benchmark shape

* revise the value for pi as ulp

* update for rsqrt as ulp

* Optimize cos for rvv and update roofline.

* add acos for ulp x86

* update opt for asin

* add cos for x86 about ulp

* update roofline for matmul

* Optimize sin for rvv and update roofline.

* change to std sqrt for x86 as rsqrt ulp

* change the target for ulp test

* Update ctest cases for unary.

- Add acosh, but there is something wrong, debug it later.
- Add sqrt, x86 use _mm256_sqrt_ps.
- Modify ulp ref from ortki to c/c++ math library.
- cos/sin ulp test for x86/riscv64 only.

* change ortki sin to std::sin for ulp

* add clamp for ntt support

* add max min scalar vector version

* Add ctest and rvv roofline for reduce.

* change usless file

* add opt for max and min

* Add all ctest for reduce.

* Add reduce sum/max/min and optimize mean.

* add opt for pack m&n

* Modify ctest and add benchmark test for clamp.

* add special template for binary

* Apply code-format changes

* Add clamp into primitive_ops to optimize rvv.

* add unary special version template

* Apply code-format changes

* add clamp roofline for x86

* add ntt profiler func

* change style for Info

* add more info for profiler

* support markdown style

* update roofline info for reduce@x86

* add unroll attribute for gcc&clang

* change reality for compiler

* support more kernels

* unroll loop byhand

* add unroll num for x86 and riscv

* unrool loop

* update unroll to support 2 inputs

* change reality for loop unrool

* add support for unary

* Add gcc 14.2.0 and vlen config support for rvv.

* open unrool function for x86

* update roofline for x86

* update unary roofline considering ldst

* update roofline for x86

* revise bug for 2dvector

* Apply code-format changes

* revise change for special template

* Apply code-format changes

* Update rvv roofline for binary/unary/matmul.

* update roofline for x86 as fma

* adjust unroll num  as special situation

* revise tpyo

* better performance for x86

* Modify floor_mod and fix matmul outer_product for rvv.

* Apply code-format changes

* adjust x86 unroll by case

* more readable code

* Update to riscv64 gcc 14.2.0

* Use latest rvv impl of exp in stackvm.

* update roofline info for x86 pack k

* revise for compiler opt problem

* add volatile for matmul output

* Fix performance regression of both binary and unary for rvv.

* Optimize x86 inner_product

* change for pack k benchmark

* [ntt.x86] Remove unroll for outer_product & mma

* Apply code-format changes

* fallback to check if there is wrong

* [ntt.x86] Reorder mma from m,k to k,m

* recover reduce

* try ubuntu 22.04

* opt for pack K matmul

* opt for pack MK and K

* Apply code-format changes

* remove useless code

* Revert "try ubuntu 22.04"

This reverts commit ee71469.

* add tmate session.

* some opt for matmul

* Revert "add tmate session."

This reverts commit 801b8e0.

* Try to disable loop unroll to fix reduce abort.

* Apply code-format changes

* Remove redundant fp16 code.

---------

Co-authored-by: guodongliang <[email protected]>
Co-authored-by: uranus0515 <[email protected]>
Co-authored-by: zhangyang2057 <[email protected]>
Co-authored-by: sunnycase <[email protected]>
Co-authored-by: sunnycase <[email protected]>
  • Loading branch information
6 people authored Sep 7, 2024
1 parent 3eef84e commit b72ba28
Show file tree
Hide file tree
Showing 53 changed files with 4,690 additions and 2,383 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/runtime-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
run: |
conan install . --build=missing -s build_type=${{matrix.config.buildType}} -pr:a=toolchains/${{matrix.config.name}}.profile.jinja -o "&:runtime=True" -o "&:python=True" -o "&:tests=True"
cmake --preset conan-release
- name: Build & Install
run: |
cmake --build build/${{matrix.config.buildType}} --config ${{matrix.config.buildType}}
Expand Down Expand Up @@ -86,7 +86,7 @@ jobs:
strategy:
matrix:
config:
- { name: riscv64-unknown-linux, toolchain: riscv64-unknown-linux, toolchain_env: RISCV_ROOT_PATH, toolchain_file: riscv64-unknown-linux-gnu-12.0.1, qemu: qemu-riscv64, loader_args: '-cpu;rv64,v=true,Zfh=true,vlen=128,elen=64,vext_spec=v1.0;-L', cmakeArgs: '', buildType: Release }
- { name: riscv64-unknown-linux, toolchain: riscv64-unknown-linux, toolchain_env: RISCV_ROOT_PATH, toolchain_file: riscv64-unknown-linux_gnu_14.2.0, qemu: qemu-riscv64, loader_args: '-cpu;rv64,v=true,Zfh=true,vlen=128,elen=64,vext_spec=v1.0;-L', cmakeArgs: '', buildType: Release }

steps:
- uses: actions/checkout@v3
Expand All @@ -100,8 +100,9 @@ jobs:
- name: Install toolchain and QEMU
shell: bash
run: |
wget https://dav.sunnycase.moe/d/ci/nncase/${{matrix.config.toolchain_file}}.tar.xz -O toolchain.tar.xz
sudo tar xf toolchain.tar.xz -C $GITHUB_WORKSPACE
wget https://dav.sunnycase.moe/d/ci/nncase/${{matrix.config.toolchain_file}}.tar.zst -O toolchain.tar.zst
sudo apt install zstd
sudo tar -I zstd -xf toolchain.tar.zst -C $GITHUB_WORKSPACE
echo "${{matrix.config.toolchain_env}}=$GITHUB_WORKSPACE/${{matrix.config.toolchain_file}}" >> $GITHUB_ENV
wget https://dav.sunnycase.moe/d/ci/nncase/${{matrix.config.qemu}}.tgz -O qemu.tgz
Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ if(ENABLE_DUMP_MEM)
add_definitions(-DDUMP_MEM)
endif()

set(NTT_UNROOL_NUM 1 CACHE STRING "Set the unroll number for loop unrolling")
add_compile_definitions(NTT_UNROOL_NUM=${NTT_UNROOL_NUM})

# Workaroud for riscv toolchain auto vectorization bugs in O3
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES
"(riscv32)|(riscv64)")
Expand All @@ -87,6 +90,9 @@ else ()
set(CMAKE_INSTALL_RPATH "$ORIGIN")
endif()

set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

if (MSVC)
add_definitions(/D_SILENCE_ALL_CXX17_DEPRECATION_WARNINGS /D_CRT_SECURE_NO_WARNINGS /DNOMINMAX)
add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX -Wno-unused-function -Wno-unused-command-line-argument)
Expand Down
85 changes: 83 additions & 2 deletions src/Native/include/nncase/ntt/arch/riscv64/arch_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,94 @@
* limitations under the License.
*/
#pragma once
#include "../../native_vector.h"

#ifdef __riscv_vector
#include <riscv_vector.h>

#ifndef __riscv_v_fixed_vlen
#error "-mrvv-vector-bits=zvl must be specified in toolchain compiler option."
#endif

#ifndef NTT_VLEN
#define NTT_VLEN __riscv_v_min_vlen
#define NTT_VLEN __riscv_v_fixed_vlen
#endif

#ifndef NTT_VL
#define NTT_VL(sew, lmul) ((NTT_VLEN) / (sew) * (lmul))
#endif

#define NTT_VL(vlen, sew, lmul) ((vlen) / (sew) * (lmul))
// rvv fixed type
#define REGISTER_RVV_FIXED_TYPE_WITH_LMUL(lmul) \
typedef vint8m##lmul##_t fixed_vint8m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vuint8m##lmul##_t fixed_vuint8m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vint16m##lmul##_t fixed_vint16m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vuint16m##lmul##_t fixed_vuint16m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vint32m##lmul##_t fixed_vint32m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vuint32m##lmul##_t fixed_vuint32m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vint64m##lmul##_t fixed_vint64m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vuint64m##lmul##_t fixed_vuint64m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vfloat32m##lmul##_t fixed_vfloat32m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN))); \
typedef vfloat64m##lmul##_t fixed_vfloat64m##lmul##_t \
__attribute__((riscv_rvv_vector_bits(lmul * NTT_VLEN)));

REGISTER_RVV_FIXED_TYPE_WITH_LMUL(1)
REGISTER_RVV_FIXED_TYPE_WITH_LMUL(2)
REGISTER_RVV_FIXED_TYPE_WITH_LMUL(4)
REGISTER_RVV_FIXED_TYPE_WITH_LMUL(8)

// rvv native vector
#define NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL(lmul) \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT( \
int8_t, fixed_vint8m##lmul##_t, lmul *NTT_VLEN / 8 / sizeof(int8_t)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(uint8_t, fixed_vuint8m##lmul##_t, \
lmul *NTT_VLEN / 8 / \
sizeof(uint8_t)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(int16_t, fixed_vint16m##lmul##_t, \
lmul *NTT_VLEN / 8 / \
sizeof(int16_t)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(uint16_t, fixed_vuint16m##lmul##_t, \
lmul *NTT_VLEN / 8 / \
sizeof(uint16_t)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(int32_t, fixed_vint32m##lmul##_t, \
lmul *NTT_VLEN / 8 / \
sizeof(int32_t)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(uint32_t, fixed_vuint32m##lmul##_t, \
lmul *NTT_VLEN / 8 / \
sizeof(uint32_t)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(int64_t, fixed_vint64m##lmul##_t, \
lmul *NTT_VLEN / 8 / \
sizeof(int64_t)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(uint64_t, fixed_vuint64m##lmul##_t, \
lmul *NTT_VLEN / 8 / \
sizeof(uint64_t)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(float, fixed_vfloat32m##lmul##_t, \
lmul *NTT_VLEN / 8 / sizeof(float)) \
NTT_END_DEFINE_NATIVE_VECTOR() \
NTT_BEGIN_DEFINE_NATIVE_VECTOR_DEFAULT(double, fixed_vfloat64m##lmul##_t, \
lmul *NTT_VLEN / 8 / \
sizeof(double)) \
NTT_END_DEFINE_NATIVE_VECTOR()

NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL(1)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL(2)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL(4)
NTT_DEFINE_NATIVE_VECTOR_WITH_LMUL(8)
#endif
Loading

0 comments on commit b72ba28

Please sign in to comment.