From 2a4f422a7ee87af5480eaad0a66c6bcdb125dab5 Mon Sep 17 00:00:00 2001 From: SharzyL Date: Thu, 21 Nov 2024 16:18:24 +0000 Subject: [PATCH] [cases] add eval.mmm_mem (to be fixed) --- tests/eval/_mmm_mem/default.nix | 32 ++++ tests/eval/_mmm_mem/mmm_4096_vl4096.S | 254 ++++++++++++++++++++++++++ tests/eval/_mmm_mem/mmm_main.c | 24 +++ tests/eval/default.nix | 3 +- 4 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 tests/eval/_mmm_mem/default.nix create mode 100644 tests/eval/_mmm_mem/mmm_4096_vl4096.S create mode 100644 tests/eval/_mmm_mem/mmm_main.c diff --git a/tests/eval/_mmm_mem/default.nix b/tests/eval/_mmm_mem/default.nix new file mode 100644 index 000000000..55b94e77d --- /dev/null +++ b/tests/eval/_mmm_mem/default.nix @@ -0,0 +1,32 @@ +{ linkerScript +, makeBuilder +, t1main +}: + +let + builder = makeBuilder { casePrefix = "eval"; }; + build_ntt = caseName /* must be consistent with attr name */ : len: kernel_src: + builder { + caseName = caseName; + + src = ./.; + + passthru.featuresRequired = { }; + + buildPhase = '' + runHook preBuild + + $CC -T${linkerScript} -DLEN=${toString len} \ + ${./mmm_main.c} ${kernel_src} \ + ${t1main} \ + -o $pname.elf + + runHook postBuild + ''; + + meta.description = "test case 'ntt'"; + }; + +in { + mmm_mem_4096_vl4096 = build_ntt "mmm_mem_4096_vl4096" 4096 ./mmm_4096_vl4096.S; +} diff --git a/tests/eval/_mmm_mem/mmm_4096_vl4096.S b/tests/eval/_mmm_mem/mmm_4096_vl4096.S new file mode 100644 index 000000000..ece8214f8 --- /dev/null +++ b/tests/eval/_mmm_mem/mmm_4096_vl4096.S @@ -0,0 +1,254 @@ +.text +.balign 16 +.globl mmm +.type mmm,@function +# assume VLEN >= 4096, BN = 4096, SEW = 16 * 2 = 32 +# we only support LMUL = 1 for now +# P, A, B, AB should have 384 elements +mmm: + # quite SIMD + li t0, 128 # in case way > 31 + vsetvli zero, t0, e32, m1, ta, ma + # stride + li t1, 12 + # start loop of niter + 1 times + li t4,0 +1: + # AB = B_i*A + AB + # !!!!!! important: lw here assumes SEW = 32 + # T0 is used in vmacc, do not use for temp now! + lw t0, 0(a2) + addi a2, a2, 4 # advance B by a SEW + + # carry for ABV_0 + vmv.v.i v30,0 + # loop variable + li t5,0 + + # --- + # macc (V=a1, VV=v10, VVN=10, ngroupreg=3) + # --- + + # load one group of values from arg + # offset of one group + # !!! important: assume nreg = 8 and sew = 32 + # log(8) + log(32/8) = 5 + slli t2,t5,5 + add t3,t2,a1 + vlsseg3e32.v v10, (t3), t1 + add t3,t2,a0 + vlsseg3e32.v v20, (t3), t1 + vmacc.vx v20, t0, v10 + vmacc.vx v21, t0, v11 + vmacc.vx v22, t0, v12 + # store one group of AB + vssseg3e32.v v20, (t3), t1 + + # --- + # propagate_niter + # --- + + # start loop of niter + 1 times + # use T2 as outer loop index + li t2,0 + 9: + # mask + # set TV2 for every propagate() + # set TV2 every time (see slide1up below) + li t0,65535 + vmv.v.x v31,t0 + + # carry for ABV_0 + vmv.v.i v30,0 + + # loop variable + li t5,0 + + # load last group of values from arg + # offset of last group + # !!! important: assume nreg = 8 and sew = 32 + # log(8) + log(32/8) = 5 + # LOOP2 is now ngroup - 1 + slli t3,t5,5 + add t3,t3,a0 + vlsseg3e32.v v20, (t3), t1 + + # --- + # propagate (j=0, ngroupreg=3) + # --- + + vadd.vv v20, v20, v30 + # save carry in TV + vsrl.vi v30, v20, 16 + # mod 2 ** 16 + vand.vv v20, v20, v31 + vadd.vv v21, v21, v30 + + # --- + # propagate (j=1, ngroupreg=3) + # --- + + # save carry in TV + vsrl.vi v30, v21, 16 + # mod 2 ** 16 + vand.vv v21, v21, v31 + vadd.vv v22, v22, v30 + + # --- + # propagate (j=2, ngroupreg=3) + # --- + + # save carry in TV + vsrl.vi v30, v22, 16 + # mod 2 ** 16 + vand.vv v22, v22, v31 + # store last group of AB + vssseg3e32.v v20, (t3), t1 + + # update carry of AB_{ntotalreg - 1} to AB_0 + vlse32.v v20, (a0), t1 + vslide1up.vx v31, v30, zero + vadd.vv v20, v20, v31 + vsse32.v v20, (a0), t1 + addi t2,t2,1 + li t0,128 + bne t2,t0,9b + # !!!!!! important: lw here assumes SEW = 32 + # T0 is used in vmacc, do not use for temp now! + lw t0, 0(a0) + mul t0, t0, a4 + # mod 2 ** 16 + # !!!! important: here we assume SEW = 32 and XLEN = 64 + sll t0, t0, 16 + srl t0, t0, 16 + + # loop variable + li t5,0 + + # --- + # macc (V=a3, VV=v0, VVN=0, ngroupreg=3) + # --- + + # load one group of values from arg + # offset of one group + # !!! important: assume nreg = 8 and sew = 32 + # log(8) + log(32/8) = 5 + slli t2,t5,5 + add t3,t2,a3 + vlsseg3e32.v v0, (t3), t1 + add t3,t2,a0 + vlsseg3e32.v v20, (t3), t1 + vmacc.vx v20, t0, v0 + vmacc.vx v21, t0, v1 + vmacc.vx v22, t0, v2 + # store one group of AB + vssseg3e32.v v20, (t3), t1 + + # --- + # propagate_niter + # --- + + # start loop of niter + 1 times + # use T2 as outer loop index + li t2,0 + 9: + # mask + # set TV2 for every propagate() + # set TV2 every time (see slide1up below) + li t0,65535 + vmv.v.x v31,t0 + + # carry for ABV_0 + vmv.v.i v30,0 + + # loop variable + li t5,0 + + # load last group of values from arg + # offset of last group + # !!! important: assume nreg = 8 and sew = 32 + # log(8) + log(32/8) = 5 + # LOOP2 is now ngroup - 1 + slli t3,t5,5 + add t3,t3,a0 + vlsseg3e32.v v20, (t3), t1 + + # --- + # propagate (j=0, ngroupreg=3) + # --- + + vadd.vv v20, v20, v30 + # save carry in TV + vsrl.vi v30, v20, 16 + # mod 2 ** 16 + vand.vv v20, v20, v31 + vadd.vv v21, v21, v30 + + # --- + # propagate (j=1, ngroupreg=3) + # --- + + # save carry in TV + vsrl.vi v30, v21, 16 + # mod 2 ** 16 + vand.vv v21, v21, v31 + vadd.vv v22, v22, v30 + + # --- + # propagate (j=2, ngroupreg=3) + # --- + + # save carry in TV + vsrl.vi v30, v22, 16 + # mod 2 ** 16 + vand.vv v22, v22, v31 + # store last group of AB + vssseg3e32.v v20, (t3), t1 + + # update carry of AB_{ntotalreg - 1} to AB_0 + vlse32.v v20, (a0), t1 + vslide1up.vx v31, v30, zero + vadd.vv v20, v20, v31 + vsse32.v v20, (a0), t1 + addi t2,t2,1 + li t0,128 + bne t2,t0,9b + + # update carry of AB_2 to AB_0 + # since we need to substract AB_0 + vlse32.v v20, (a0), t1 + # AB / word + vslide1down.vx v30, v20, zero + # do not need vsse now + # just store it in TV for move + + # ----- + # move + # ----- + + # move AB_1 to AB_0, AB_2 to AB_1, ... , AB_0 (in TV now) to AB_2 + # loop variable + li t5,0 + # load last group of values from arg + # offset of last group + # !!! important: assume nreg = 8 and sew = 32 + # log(8) + log(32/8) = 5 + # LOOP2 is now ngroup - 1 + slli t2,t5,5 + # then offset by 1 element + addi t2,t2,4 + add t3,t2,a0 + vlsseg2e32.v v20, (t3), t1 + # move AB_0 to AB_2 + vmv.v.v v22, v30 + + # back to original offset + addi t3,t3,-4 + vssseg3e32.v v20, (t3), t1 + + addi t4,t4,1 + li t0,257 + + bne t4,t0,1b + + ret diff --git a/tests/eval/_mmm_mem/mmm_main.c b/tests/eval/_mmm_mem/mmm_main.c new file mode 100644 index 000000000..f7078e833 --- /dev/null +++ b/tests/eval/_mmm_mem/mmm_main.c @@ -0,0 +1,24 @@ +#include +#include +#include + +#ifndef LEN + // define then error, to make clangd happy without configuration + #define LEN 1024 + #error "LEN not defined" +#endif + +void mmm(uint32_t* r, const uint32_t* a, const uint32_t* b, const uint32_t* p, const uint32_t mu); + +void test() { + int words = (LEN) / 16 + 4; + uint32_t *r = (uint32_t *) malloc(words * sizeof(uint32_t)); + uint32_t *a = (uint32_t *) malloc(words * sizeof(uint32_t)); + uint32_t *b = (uint32_t *) malloc(words * sizeof(uint32_t)); + uint32_t *p = (uint32_t *) malloc(words * sizeof(uint32_t)); + uint32_t mu = 0xca1b; + mmm(r, a, b, p, mu); + for (int i = 0; i < words; i++) { + printf("%04X ", r[i]); + } +} diff --git a/tests/eval/default.nix b/tests/eval/default.nix index ce8851422..20fe9b58e 100644 --- a/tests/eval/default.nix +++ b/tests/eval/default.nix @@ -34,5 +34,6 @@ let autoCases = findAndBuild ./. build; nttCases = callPackage ./_ntt { }; + mmmCases = callPackage ./_mmm_mem { }; in - autoCases // nttCases + autoCases // nttCases // mmmCases