Skip to content

Commit

Permalink
Merge pull request #40 from slothy-optimizer/kyber_neon_ntt_no_gpr_stash
Browse files Browse the repository at this point in the history
Remove unnecessary GPR->Stack stashing in Kyber Neon NTT
  • Loading branch information
hanno-becker authored Mar 19, 2024
2 parents e02002f + 9744672 commit 4c727f5
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 236 deletions.
133 changes: 45 additions & 88 deletions paper/clean/neon/ntt_kyber_1234_567.s
Original file line number Diff line number Diff line change
Expand Up @@ -108,36 +108,6 @@
vmlaq \a, t0, consts, 0
.endm

.macro load_roots_123
ldr_vi root0, r_ptr0, 32
ldr_vo root1, r_ptr0, -16
.endm

.macro load_next_roots_45 root0, r_ptr0
ldr_vi \root0, \r_ptr0, 16
.endm

.macro load_next_roots_67 root0, root0_tw, root1, root1_tw, root2, root2_tw, r_ptr1
ldr_vi \root0, \r_ptr1, (6*16)
ldr_vo \root0_tw, \r_ptr1, (-6*16 + 1*16)
ldr_vo \root1, \r_ptr1, (-6*16 + 2*16)
ldr_vo \root1_tw, \r_ptr1, (-6*16 + 3*16)
ldr_vo \root2, \r_ptr1, (-6*16 + 4*16)
ldr_vo \root2_tw, \r_ptr1, (-6*16 + 5*16)
.endm

.macro transpose4 data
trn1 t0.4s, \data\()0.4s, \data\()1.4s
trn2 t1.4s, \data\()0.4s, \data\()1.4s
trn1 t2.4s, \data\()2.4s, \data\()3.4s
trn2 t3.4s, \data\()2.4s, \data\()3.4s

trn2 \data\()2.2d, t0.2d, t2.2d
trn2 \data\()3.2d, t1.2d, t3.2d
trn1 \data\()0.2d, t0.2d, t2.2d
trn1 \data\()1.2d, t1.2d, t3.2d
.endm

.macro save_gprs // @slothy:no-unfold
sub sp, sp, #(16*6)
stp x19, x20, [sp, #16*0]
Expand Down Expand Up @@ -196,29 +166,6 @@
restore_gprs
.endm

.data
.p2align 4
roots:
#include "ntt_kyber_1234_567_twiddles.s"
.text

.global ntt_kyber_1234_567
.global _ntt_kyber_1234_567

.p2align 4
const_addr: .short -3329
.short 20159
.short 0
.short 0
.short 0
.short 0
.short 0
.short 0

ntt_kyber_1234_567:
_ntt_kyber_1234_567:
push_stack

in .req x0
inp .req x1
count .req x2
Expand All @@ -228,20 +175,6 @@ _ntt_kyber_1234_567:

src0 .req x6
src1 .req x7
src2 .req x8
src3 .req x9
src4 .req x10
src5 .req x11
src6 .req x12
src7 .req x13
src8 .req x14
src9 .req x15
src10 .req x16
src11 .req x17
src12 .req x18
src13 .req x19
src14 .req x20
src15 .req x21

qform_v0 .req q0
qform_v1 .req q1
Expand Down Expand Up @@ -336,17 +269,43 @@ _ntt_kyber_1234_567:

consts .req v8

ASM_LOAD(r_ptr0, roots)
.data
.p2align 4
roots:
#include "ntt_kyber_1234_567_twiddles.s"
.text

.global ntt_kyber_1234_567
.global _ntt_kyber_1234_567

.p2align 4
const_addr: .short -3329
.short 20159
.short 0
.short 0
.short 0
.short 0
.short 0
.short 0

ntt_kyber_1234_567:
_ntt_kyber_1234_567:
push_stack

ASM_LOAD(r_ptr0, roots)
ASM_LOAD(r_ptr1, roots_l456)
ASM_LOAD(xtmp, const_addr)
ld1 {consts.8h}, [xtmp]

save STACK0, in

add src0, x0, #32*0
add src8, x0, #32*8
add src1, x0, #32*8

ld1 { root0.8h, root1.8h, root2.8h, root3.8h}, [r_ptr0], #64
ldr_vo root0, r_ptr0, 0
ldr_vo root1, r_ptr0, 16
ldr_vo root2, r_ptr0, 32
ldr_vo root3, r_ptr0, 48

mov count, #2

Expand All @@ -362,14 +321,14 @@ layer1234_start:
ldr_vo data6, src0, 6*32
ldr_vo data7, src0, 7*32

ldr_vo data8, src8, 0
ldr_vo data9, src8, 1*32
ldr_vo data10, src8, 2*32
ldr_vo data11, src8, 3*32
ldr_vo data12, src8, 4*32
ldr_vo data13, src8, 5*32
ldr_vo data14, src8, 6*32
ldr_vo data15, src8, 7*32
ldr_vo data8, src1, 0
ldr_vo data9, src1, 1*32
ldr_vo data10, src1, 2*32
ldr_vo data11, src1, 3*32
ldr_vo data12, src1, 4*32
ldr_vo data13, src1, 5*32
ldr_vo data14, src1, 6*32
ldr_vo data15, src1, 7*32

ct_butterfly data0, data8, root0, 0, 1
ct_butterfly data1, data9, root0, 0, 1
Expand Down Expand Up @@ -416,23 +375,21 @@ layer1234_start:
str_vo data6, src0, -16+6*32
str_vo data7, src0, -16+7*32

str_vi data8, src8, 16
str_vo data9, src8, -16+1*32
str_vo data10, src8, -16+2*32
str_vo data11, src8, -16+3*32
str_vo data12, src8, -16+4*32
str_vo data13, src8, -16+5*32
str_vo data14, src8, -16+6*32
str_vo data15, src8, -16+7*32
str_vi data8, src1, 16
str_vo data9, src1, -16+1*32
str_vo data10, src1, -16+2*32
str_vo data11, src1, -16+3*32
str_vo data12, src1, -16+4*32
str_vo data13, src1, -16+5*32
str_vo data14, src1, -16+6*32
str_vo data15, src1, -16+7*32

subs count, count, #1
cbnz count, layer1234_start

restore inp, STACK0
mov count, #4

ASM_LOAD(r_ptr1, roots_l456)

add src0, inp, #256*0
add src1, inp, #256*1

Expand Down
3 changes: 2 additions & 1 deletion paper/clean/neon/ntt_kyber_1234_567_twiddles.s
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ roots_l0123:
.short -1583
.short -15582

.p2align 4
roots_l456:
.short 296
.short 296
Expand Down Expand Up @@ -478,4 +479,4 @@ roots_l456:
.short 6309
.short 6309
.short -11566
.short -11566
.short -11566
94 changes: 30 additions & 64 deletions paper/clean/neon/ntt_kyber_123_4567.s
Original file line number Diff line number Diff line change
Expand Up @@ -139,27 +139,6 @@
trn2 \data_out\()3.4s, \data_in\()2.4s, \data_in\()3.4s
.endm

.macro save_gprs // @slothy:no-unfold
sub sp, sp, #(16*6)
stp x19, x20, [sp, #16*0]
stp x19, x20, [sp, #16*0]
stp x21, x22, [sp, #16*1]
stp x23, x24, [sp, #16*2]
stp x25, x26, [sp, #16*3]
stp x27, x28, [sp, #16*4]
str x29, [sp, #16*5]
.endm

.macro restore_gprs // @slothy:no-unfold
ldp x19, x20, [sp, #16*0]
ldp x21, x22, [sp, #16*1]
ldp x23, x24, [sp, #16*2]
ldp x25, x26, [sp, #16*3]
ldp x27, x28, [sp, #16*4]
ldr x29, [sp, #16*5]
add sp, sp, #(16*6)
.endm

.macro save_vregs // @slothy:no-unfold
sub sp, sp, #(16*4)
stp d8, d9, [sp, #16*0]
Expand All @@ -176,51 +155,16 @@
add sp, sp, #(16*4)
.endm

#define STACK_SIZE 16
#define STACK0 0

.macro restore a, loc // @slothy:no-unfold
ldr \a, [sp, #\loc\()]
.endm
.macro save loc, a // @slothy:no-unfold
str \a, [sp, #\loc\()]
.endm
.macro push_stack // @slothy:no-unfold
save_gprs
save_vregs
sub sp, sp, #STACK_SIZE
.endm

.macro pop_stack // @slothy:no-unfold
add sp, sp, #STACK_SIZE
restore_vregs
restore_gprs
.endm

.data
.p2align 4
roots:
#include "ntt_kyber_123_45_67_twiddles.s"
.text

.global ntt_kyber_123_4567
.global _ntt_kyber_123_4567

.p2align 4
const_addr: .short 3329
.short 20159
.short 0
.short 0
.short 0
.short 0
.short 0
.short 0
ntt_kyber_123_4567:
_ntt_kyber_123_4567:
push_stack

in .req x0
inp .req x1
in_orig .req x1
count .req x2
r_ptr0 .req x3
r_ptr1 .req x4
Expand Down Expand Up @@ -318,13 +262,35 @@ _ntt_kyber_123_4567:
t2 .req v27
t3 .req v28

.data
.p2align 4
roots:
#include "ntt_kyber_123_45_67_twiddles.s"
.text

.global ntt_kyber_123_4567
.global _ntt_kyber_123_4567

.p2align 4
const_addr: .short 3329
.short 20159
.short 0
.short 0
.short 0
.short 0
.short 0
.short 0
ntt_kyber_123_4567:
_ntt_kyber_123_4567:
push_stack

ASM_LOAD(r_ptr0, roots)
ASM_LOAD(r_ptr1, roots_l56)

ASM_LOAD(xtmp, const_addr)
ld1 {consts.8h}, [xtmp]

save STACK0, in
mov in_orig, in
mov count, #4

load_roots_123
Expand Down Expand Up @@ -368,15 +334,15 @@ layer123_start:
subs count, count, #1
cbnz count, layer123_start

restore inp, STACK0
mov in, in_orig
mov count, #8

.p2align 2
layer4567_start:
ldr_vo data0, inp, (16*0)
ldr_vo data1, inp, (16*1)
ldr_vo data2, inp, (16*2)
ldr_vo data3, inp, (16*3)
ldr_vo data0, in, (16*0)
ldr_vo data1, in, (16*1)
ldr_vo data2, in, (16*2)
ldr_vo data3, in, (16*3)

load_next_roots_45

Expand All @@ -397,7 +363,7 @@ layer4567_start:
barrett_reduce data1
barrett_reduce data2
barrett_reduce data3
st4 {data0.4S, data1.4S, data2.4S, data3.4S}, [inp], #64
st4 {data0.4S, data1.4S, data2.4S, data3.4S}, [in], #64

subs count, count, #1
cbnz count, layer4567_start
Expand Down
Loading

0 comments on commit 4c727f5

Please sign in to comment.