Skip to content

Commit

Permalink
ADD: update
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Jul 14, 2024
1 parent d5bd9f4 commit 9980c95
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 102 deletions.
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ set(WRAP_SPECS_FILE "htif_wrap.specs")
set(SPECS_FILE "htif_nano.specs")
set(LIBGLOSS_DIR "$ENV{RISCV}/riscv64-unknown-elf/lib/")

set(MARCH "rv64gcv_zfh_zvfh")
set(MARCH "rv64gcv_zfh_zvfh_zvfhmin")
set(MABI "lp64d")
set(MCMODEL "medany")

Expand All @@ -60,6 +60,8 @@ target_compile_options(target-riscv INTERFACE -march=${MARCH} -mabi=${MABI} -mcm
target_compile_options(target-riscv INTERFACE -Wl,-Map=output.map -specs=${SPECS_FILE} -specs=${WRAP_SPECS_FILE})
target_compile_options(target-riscv INTERFACE -T ${LINKER_SCRIPT})

target_compile_definitions(target-riscv INTERFACE FLT16_MAX=65504.0f)

target_link_options(target-riscv INTERFACE -static)
target_link_options(target-riscv INTERFACE -march=${MARCH} -mabi=${MABI} -mcmodel=${MCMODEL})
target_link_options(target-riscv INTERFACE -Wl,-Map=output.map -specs=${SPECS_FILE} -specs=${WRAP_SPECS_FILE})
Expand Down
22 changes: 11 additions & 11 deletions nn/impl/rvv/abs.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ void NN__abs_i32(size_t n, int32_t *y, size_t incy, int32_t *x, size_t incx) {
}
}

// void NN__abs_f16(size_t n, float16_t *y, size_t incy, float16_t *x, size_t incx) {
// while (n > 0) {
// size_t vl = __riscv_vsetvl_e16m1(n);
// vfloat16m1_t vec_x = __riscv_vlse16_v_f16m1(x, sizeof(float16_t) * incx, vl);
// vfloat16m1_t vec_y = __riscv_vfabs_v_f16m1(vec_x, vl);
// __riscv_vse16_v_f16m1(y, sizeof(float16_t) * incy, vec_y, vl);
// x += vl;
// y += vl;
// n -= vl;
// }
// }
void NN__abs_f16(size_t n, float16_t *y, size_t incy, float16_t *x, size_t incx) {
while (n > 0) {
size_t vl = __riscv_vsetvl_e16m1(n);
vfloat16m1_t vec_x = __riscv_vlse16_v_f16m1(x, sizeof(float16_t) * incx, vl);
vfloat16m1_t vec_y = __riscv_vfabs_v_f16m1(vec_x, vl);
__riscv_vsse16_v_f16m1(y, sizeof(float16_t) * incy, vec_y, vl);
x += vl;
y += vl;
n -= vl;
}
}

void NN__abs_f32(size_t n, float *y, size_t incy, float *x, size_t incx) {
while (n > 0) {
Expand Down
24 changes: 5 additions & 19 deletions nn/impl/rvv/add.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,11 @@ void NN__add_i8(size_t n, int8_t *z, size_t incz, int8_t *x, size_t incx, int8_t

void NN__add_f16(size_t n, float16_t *z, size_t incz, float16_t *x, size_t incx, float16_t *y, size_t incy) {
while (n > 0) {
size_t vl;

printf("hi\n");

// size_t vl = __riscv_vsetvl_e16m1(n);
asm volatile("vsetvli %0, %1, e16, m1, ta, ma" : "=r"(vl) : "r"(n));

// vfloat16m1_t vec_x = __riscv_vlse16_v_f16m1(x, sizeof(float16_t) * incx, vl);
asm volatile("vlse16.v v24, (%0), %1" : : "r"(x), "r"(sizeof(float16_t) * incx));

// vfloat16m1_t vec_y = __riscv_vlse16_v_f16m1(y, sizeof(float16_t) * incy, vl);
asm volatile("vlse16.v v25, (%0), %1" : : "r"(y), "r"(sizeof(float16_t) * incy));

// // vfloat16m1_t vec_z = __riscv_vfadd_vv_f16m1(vec_x, vec_y, vl);
asm volatile("vfadd.vv v24, v24, v25");

// __riscv_vsse16_v_f16m1(z, sizeof(float16_t) * incz, vec_z, vl);
asm volatile("vsse16.v v24, (%0), %1" : : "r"(z), "r"(sizeof(float16_t) * incz));

size_t vl = __riscv_vsetvl_e16m1(n);
vfloat16m1_t vec_x = __riscv_vlse16_v_f16m1(x, sizeof(float16_t) * incx, vl);
vfloat16m1_t vec_y = __riscv_vlse16_v_f16m1(y, sizeof(float16_t) * incy, vl);
vfloat16m1_t vec_z = __riscv_vfadd_vv_f16m1(vec_x, vec_y, vl);
__riscv_vsse16_v_f16m1(z, sizeof(float16_t) * incz, vec_z, vl);
x += vl;
y += vl;
z += vl;
Expand Down
21 changes: 5 additions & 16 deletions nn/impl/rvv/maximum1.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,11 @@

void NN__maximum1_f16(size_t n, float16_t *y, size_t incy, float16_t *x, size_t incx, float16_t scalar) {
while (n > 0) {
size_t vl;
// size_t vl = __riscv_vsetvl_e16m1(n);
asm volatile("vsetvli %0, %1, e16, m1, ta, ma" : "=r"(vl) : "r"(n));

// vfloat16m1_t vec_x = __riscv_vlse16_v_f16m1(x, sizeof(float16_t) * incx, vl);
asm volatile("vlse16.v v26, (%0), %1" : : "r"(x), "r"(sizeof(float16_t) * incx));

// vfloat16m1_t vec_s = __riscv_vfmv_v_f_f16m1(scalar, vl);
asm volatile("vmv.v.x v25, %0" : : "r"(scalar));

// vfloat16m1_t vec_y = __riscv_vfmax_vv_f16m1(vec_x, vec_s, vl);
asm volatile("vfmax.vv v25, v26, v25");

// __riscv_vsse16_v_f16m1(y, sizeof(float16_t) * incy, vec_y, vl);
asm volatile("vsse16.v v25, (%0), %1" : : "r"(y), "r"(sizeof(float16_t) * incy));

size_t vl = __riscv_vsetvl_e16m1(n);
vfloat16m1_t vec_x = __riscv_vlse16_v_f16m1(x, sizeof(float16_t) * incx, vl);
vfloat16m1_t vec_s = __riscv_vfmv_v_f_f16m1(scalar, vl);
vfloat16m1_t vec_y = __riscv_vfmax_vv_f16m1(vec_x, vec_s, vl);
__riscv_vsse16_v_f16m1(y, sizeof(float16_t) * incy, vec_y, vl);
x += vl;
y += vl;
n -= vl;
Expand Down
21 changes: 5 additions & 16 deletions nn/impl/rvv/minimum1.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,11 @@

void NN__minimum1_f16(size_t n, float16_t *y, size_t incy, float16_t *x, size_t incx, float16_t scalar) {
while (n > 0) {
size_t vl;
// size_t vl = __riscv_vsetvl_e16m1(n);
asm volatile("vsetvli %0, %1, e16, m1, ta, ma" : "=r"(vl) : "r"(n));

// vfloat16m1_t vec_x = __riscv_vlse16_v_f16m1(x, sizeof(float16_t) * incx, vl);
asm volatile("vlse16.v v26, (%0), %1" : : "r"(x), "r"(sizeof(float16_t) * incx));

// vfloat16m1_t vec_s = __riscv_vfmv_v_f_f16m1(scalar, vl);
asm volatile("vmv.v.x v25, %0" : : "r"(scalar));

// vfloat16m1_t vec_y = __riscv_vfmin_vv_f16m1(vec_x, vec_s, vl);
asm volatile("vfmin.vv v25, v26, v25");

// __riscv_vsse16_v_f16m1(y, sizeof(float16_t) * incy, vec_y, vl);
asm volatile("vsse16.v v25, (%0), %1" : : "r"(y), "r"(sizeof(float16_t) * incy));

size_t vl = __riscv_vsetvl_e16m1(n);
vfloat16m1_t vec_x = __riscv_vlse16_v_f16m1(x, sizeof(float16_t) * incx, vl);
vfloat16m1_t vec_s = __riscv_vfmv_v_f_f16m1(scalar, vl);
vfloat16m1_t vec_y = __riscv_vfmin_vv_f16m1(vec_x, vec_s, vl);
__riscv_vsse16_v_f16m1(y, sizeof(float16_t) * incy, vec_y, vl);
x += vl;
y += vl;
n -= vl;
Expand Down
11 changes: 10 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ elseif (RISCV)
endif()


target_link_libraries(tests PUBLIC nn)
# include_directories(
# ../nn
# ../nn/functional
# ../nn/impl)

find_library(LIB_TO_INCLUDE nn ./)

target_link_libraries(tests PUBLIC ${LIB_TO_INCLUDE})

# target_link_libraries(tests PUBLIC nn)
target_link_libraries(tests PUBLIC m)

6 changes: 3 additions & 3 deletions tests/src/generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def functional_rms_norm(x, w, eps):


test_pattern = [
# ("abs", lambda a: torch.abs(a), [("a", rand((7, 7))), ]),
# ("add", lambda a, b: a + b, [("a", rand((6, 7))), ("b", rand((6, 7))) ]),
# ("add", lambda a, b: a + b, [("a", rand((6, 7))), ("b", rand((1, 7))) ]),
("abs", lambda a: torch.abs(a), [("a", rand((7, 7))), ]),
("add", lambda a, b: a + b, [("a", rand((6, 7))), ("b", rand((6, 7))) ]),
("add", lambda a, b: a + b, [("a", rand((6, 7))), ("b", rand((1, 7))) ]),
# ("add", lambda a, b: a + b, [("a", rand((6, 7))), ("b", rand((6, 1))) ]),
# ("add", lambda a, b: a + b, [("a", rand((6, 7))), ("b", rand((7, ))) ]),
# ("add_inplace", lambda a, b: a + b, [("actual", torch.zeros((7, 7))), ("b", rand((7, 7))) ]),
Expand Down
Loading

0 comments on commit 9980c95

Please sign in to comment.