Skip to content

Commit

Permalink
Merge pull request #156 from dop-amin/simplify-naive-ex
Browse files Browse the repository at this point in the history
Simplify Naive Examples
  • Loading branch information
mkannwischer authored Jan 13, 2025
2 parents 09a9d3b + 5b47a4b commit 9ff2713
Show file tree
Hide file tree
Showing 11 changed files with 2,651 additions and 2,327 deletions.
27 changes: 10 additions & 17 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=Non
def core(self, slothy):
slothy.config.constraints.stalls_first_attempt = 16

slothy.config.unsafe_address_offset_fixup = False
slothy.config.unsafe_address_offset_fixup = True


slothy.config.variable_size = True
Expand All @@ -1616,12 +1616,12 @@ def core(self, slothy):
slothy.config.sw_pipelining.optimize_postamble = True
slothy.config.sw_pipelining.allow_pre = True

slothy.optimize_loop("layer123_loop")
slothy.optimize_loop("layer123_loop", forced_loop_type=Arch_Armv7M.BranchLoop)
slothy.optimize_loop("layer456_first_loop")
slothy.optimize_loop("layer456_loop")

slothy.config.inputs_are_outputs = True
slothy.optimize_loop("layer78_loop")
slothy.optimize_loop("layer78_loop", forced_loop_type=Arch_Armv7M.BranchLoop)

class pointwise_montgomery_dilithium(Example):
def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None):
Expand Down Expand Up @@ -1814,6 +1814,7 @@ def core(self, slothy):
slothy.config.constraints.stalls_first_attempt = 32

r = slothy.config.reserved_regs
r.add("r1")
r = r.union(f"s{i}" for i in range(31)) # reserve FPR
slothy.config.reserved_regs = r

Expand All @@ -1825,13 +1826,12 @@ def core(self, slothy):
slothy.config.variable_size = True
slothy.config.split_heuristic = True
slothy.config.timeout = 360 # Not more than 2min per step
slothy.config.split_heuristic_factor = 1
slothy.config.visualize_expected_performance = False
slothy.config.split_heuristic_factor = 4
slothy.config.split_heuristic_factor = 5
slothy.config.split_heuristic_stepsize = 0.15
slothy.optimize_loop("layer1234_loop")
slothy.optimize_loop("layer1234_loop", forced_loop_type=Arch_Armv7M.BranchLoop)
slothy.config.split_heuristic_optimize_seam = 6
slothy.optimize_loop("layer1234_loop")
slothy.optimize_loop("layer1234_loop", forced_loop_type=Arch_Armv7M.BranchLoop)

slothy.config.outputs = ["r14"]

Expand Down Expand Up @@ -2179,12 +2179,11 @@ def core(self, slothy):
slothy.config.variable_size = True

r = slothy.config.reserved_regs
r.add("r14")
slothy.config.reserved_regs = r

slothy.config.sw_pipelining.enabled = True
slothy.config.constraints.stalls_first_attempt = 16
slothy.optimize_loop("1")
slothy.optimize_loop("1", forced_loop_type=Arch_Armv7M.BranchLoop)

class basemul_acc_32_16_kyber(Example):
def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None):
Expand Down Expand Up @@ -2278,14 +2277,10 @@ def core(self, slothy):
slothy.config.inputs_are_outputs = True
slothy.config.variable_size = True

r = slothy.config.reserved_regs
r.add("r14")
slothy.config.reserved_regs = r

slothy.config.unsafe_address_offset_fixup = False
slothy.config.sw_pipelining.enabled = True
slothy.config.constraints.stalls_first_attempt = 16
slothy.optimize_loop("1")
slothy.optimize_loop("1", forced_loop_type=Arch_Armv7M.BranchLoop)

class add_kyber(Example):
def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None):
Expand Down Expand Up @@ -2484,16 +2479,14 @@ def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=Non
def core(self, slothy):
slothy.config.inputs_are_outputs = True
slothy.config.variable_size = True
slothy.config.outputs = ["r14"]
slothy.config.unsafe_address_offset_fixup = False
r = slothy.config.reserved_regs
r.add("r14")
r = r.union(f"s{i}" for i in range(32)) # reserve FPR
slothy.config.reserved_regs = r

slothy.config.sw_pipelining.enabled = True
slothy.config.constraints.stalls_first_attempt = 16
slothy.optimize_loop("1")
slothy.optimize_loop("1", forced_loop_type=Arch_Armv7M.BranchLoop)

class matacc_kyber(Example):
def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None):
Expand Down
24 changes: 12 additions & 12 deletions examples/naive/armv7m/basemul_acc_32_32_kyber.s
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,31 @@ basemul_asm_acc_opt_32_32:

movw loop, #64
1:
ldr poly0, [aptr], #8
ldr poly1, [bptr], #8
ldr poly0, [aptr], #4
ldr poly1, [bptr], #4
ldr.w res0, [rptr_tmp]
ldr tmp2, [aprimeptr], #8
ldr tmp2, [aprimeptr], #4
ldr.w res1, [rptr_tmp, #4]

// (poly0_t * zeta) * poly1_t + poly0_b * poly0_t + res
smlad tmp2, tmp2, poly1, res0
str tmp2, [rptr_tmp], #16
str tmp2, [rptr_tmp], #4

// poly1_t * poly0_b + poly1_b * poly0_t + res
smladx tmp, poly0, poly1, res1
str tmp, [rptr_tmp, #-12]
str tmp, [rptr_tmp], #4

ldr poly0, [aptr, #-4]
ldr poly1, [bptr, #-4]
ldr res0, [rptr_tmp, #-8]
ldr tmp2, [aprimeptr, #-4]
ldr res1, [rptr_tmp, #-4]
ldr poly0, [aptr], #4
ldr poly1, [bptr], #4
ldr.w res0, [rptr_tmp]
ldr tmp2, [aprimeptr], #4
ldr.w res1, [rptr_tmp, #4]

smlad tmp2, tmp2, poly1, res0
str tmp2, [rptr_tmp, #-8]
str tmp2, [rptr_tmp], #4

smladx tmp, poly0, poly1, res1
str tmp, [rptr_tmp, #-4]
str tmp, [rptr_tmp], #4

subs.w loop, loop, #1
bne.w 1b
Expand Down
37 changes: 11 additions & 26 deletions examples/naive/armv7m/frombytes_mul_acc_32_16_kyber.s
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

.macro doublebasemul_frombytes_asm_acc_32_16 rptr_tmp, rptr, bptr, zeta, poly0, poly1, poly3, res0, tmp, q, qa, qinv
ldr \poly0, [\bptr], #8
ldr \res0, [\rptr_tmp], #16 // @slothy:core=True
ldr \res0, [\rptr_tmp], #16 // @slothy:core=True // @slothy:before=cmp

smulwt \tmp, \zeta, \poly1
smlabt \tmp, \tmp, \q, \qa
Expand Down Expand Up @@ -72,7 +72,7 @@ frombytes_mul_asm_acc_32_16:
push {r4-r11, r14}

rptr .req r0
bptr .req r1
bptr .req r3
aptr .req r2
zetaptr .req r3
t0 .req r4
Expand All @@ -85,43 +85,28 @@ frombytes_mul_asm_acc_32_16:
qinv .req r11
zeta .req r12
ctr .req r14
rptr_tmp .req r3
rptr_tmp .req r1

movw qa, #26632
movt q, #3329
### qinv=0x6ba8f301
movw qinv, #62209
movt qinv, #27560

vmov s2, zetaptr
vmov s1, r1
ldr.w rptr_tmp, [sp, #9*4] // load rptr_tmp from stack
vmov s1, rptr_tmp

add ctr, rptr_tmp, #64*4*4
1:
ldr.w zeta, [zetaptr], #4
deserialize aptr, tmp, tmp2, tmp3, t0, t1
vmov tmp, s2
ldr zeta, [tmp], #4
vmov s2, tmp
vmov s2, zetaptr
vmov bptr, s1
doublebasemul_frombytes_asm_acc_32_16 rptr_tmp, rptr, bptr, zeta, tmp3, t0, t1, tmp, tmp2, q, qa, qinv
cmp.w rptr_tmp, ctr
vmov s1, bptr // @slothy:core=True
cmp.w rptr_tmp, ctr // @slothy:id=cmp
vmov zetaptr, s2
bne.w 1b

// Original code
// ldr.w tmp, [sp, #9*4] // load rptr_tmp from stack
// vmov s1, tmp
// vmov s2, zetaptr
// add ctr, tmp, #64*4*4
// 1:
// vmov zetaptr, s2
// ldr.w zeta, [zetaptr], #4
// deserialize aptr, tmp, tmp2, tmp3, t0, t1
// vmov s2, zetaptr
// vmov rptr_tmp, s1
// doublebasemul_frombytes_asm_acc_32_16 rptr_tmp, rptr, bptr, zeta, tmp3, t0, t1, tmp, tmp2, q, qa, qinv
// vmov s1, rptr_tmp
// cmp.w rptr_tmp, ctr
// bne.w 1b

pop {r4-r11, pc}

.size frombytes_mul_asm_acc_32_16, .-frombytes_mul_asm_acc_32_16
6 changes: 3 additions & 3 deletions examples/naive/armv7m/frombytes_mul_acc_kyber.s
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
// r[1] in upper half of tmp2
pkhtb \tmp, \tmp2, \tmp, asr #16
uadd16 \res0, \res0, \tmp
str \res0, [\rptr], #8 // @slothy:core=True
str \res0, [\rptr], #8 // @slothy:core=True // @slothy:before=cmp

neg \zeta, \zeta

Expand Down Expand Up @@ -101,13 +101,13 @@ frombytes_mul_asm_acc:
movt qinv, #27560

add ctr, rptr, #64*4*2
vmov s0, ctr
1:
ldr.w zeta, [zetaptr], #4
deserialize aptr, tmp, tmp2, tmp3, t0, t1
vmov s0, ctr
doublebasemul_frombytes_asm_acc rptr, bptr, zeta, tmp3, t0, t1, ctr, tmp, tmp2, q, qa, qinv
vmov ctr, s0
cmp.w rptr, ctr
cmp.w rptr, ctr // @slothy:id=cmp
bne.w 1b

pop {r4-r11, pc}
Expand Down
68 changes: 34 additions & 34 deletions examples/naive/armv7m/intt_dilithium_123_456_78.s
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ pqcrystals_dilithium_invntt_tomont:
str.w pol5, [ptr_p, #5*distance/4]
str.w pol6, [ptr_p, #6*distance/4]
str.w pol7, [ptr_p, #7*distance/4]
str.w pol0, [ptr_p], #strincr
str.w pol0, [ptr_p], #strincr // @slothy:before=cmp
vmov temp_l, s9
cmp.w ptr_p, temp_l
cmp.w ptr_p, temp_l // @slothy:id=cmp
bne.w layer123_loop

sub ptr_p, #32*strincr
Expand All @@ -248,21 +248,21 @@ pqcrystals_dilithium_invntt_tomont:
ldr.w pol3, [ptr_p, #7*distance2/4]
_3_layer_inv_butterfly_light_fast_first pol0, pol1, pol2, pol3, pol4, pol5, pol6, pol7, s2, s3, s4, s5, s6, s7, s8, zeta, qinv, q, temp_h, temp_l

ldr.w pol0, [ptr_p], #128
ldr pol1, [ptr_p, #1*distance2/4-128]
ldr pol2, [ptr_p, #2*distance2/4-128]
ldr pol3, [ptr_p, #3*distance2/4-128]
ldr.w pol0, [ptr_p]
ldr pol1, [ptr_p, #1*distance2/4]
ldr pol2, [ptr_p, #2*distance2/4]
ldr pol3, [ptr_p, #3*distance2/4]
_3_layer_inv_butterfly_light_fast_second pol0, pol1, pol2, pol3, pol4, pol5, pol6, pol7, s2, s3, s4, s5, s6, s7, s8, zeta, qinv, q, temp_h, temp_l

str pol1, [ptr_p, #1*distance2/4-128]
str pol2, [ptr_p, #2*distance2/4-128]
str pol3, [ptr_p, #3*distance2/4-128]
str.w pol5, [ptr_p, #5*distance2/4-128]
str.w pol6, [ptr_p, #6*distance2/4-128]
str.w pol7, [ptr_p, #7*distance2/4-128]
str pol0, [ptr_p, #-128]
str.w pol4, [ptr_p], #128
//add.w ptr_p, #strincr2
str pol1, [ptr_p, #1*distance2/4]
str pol2, [ptr_p, #2*distance2/4]
str pol3, [ptr_p, #3*distance2/4]
str.w pol4, [ptr_p, #4*distance2/4]
str.w pol5, [ptr_p, #5*distance2/4]
str.w pol6, [ptr_p, #6*distance2/4]
str.w pol7, [ptr_p, #7*distance2/4]
str pol0, [ptr_p]
add.w ptr_p, ptr_p, #strincr2

vmov temp_l, s10
cmp.w ptr_p, temp_l
Expand All @@ -281,26 +281,26 @@ pqcrystals_dilithium_invntt_tomont:
vldm ptr_zeta!, {s2-s8}
vmov s0, ptr_zeta
layer456_loop:
ldr.w pol0, [ptr_p], #128
ldr pol1, [ptr_p, #1*distance2/4-128]
ldr pol2, [ptr_p, #2*distance2/4-128]
ldr pol3, [ptr_p, #3*distance2/4-128]
ldr.w pol4, [ptr_p, #4*distance2/4-128]
ldr.w pol5, [ptr_p, #5*distance2/4-128]
ldr.w pol6, [ptr_p, #6*distance2/4-128]
ldr.w pol7, [ptr_p, #7*distance2/4-128]
ldr.w pol0, [ptr_p]
ldr pol1, [ptr_p, #1*distance2/4]
ldr pol2, [ptr_p, #2*distance2/4]
ldr pol3, [ptr_p, #3*distance2/4]
ldr.w pol4, [ptr_p, #4*distance2/4]
ldr.w pol5, [ptr_p, #5*distance2/4]
ldr.w pol6, [ptr_p, #6*distance2/4]
ldr.w pol7, [ptr_p, #7*distance2/4]

_3_layer_inv_CT_32 pol0, pol1, pol2, pol3, pol4, pol5, pol6, pol7, s2, s3, s4, s5, s6, s7, s8, zeta, qinv, q, temp_h, temp_l

str pol1, [ptr_p, #1*distance2/4-128]
str pol2, [ptr_p, #2*distance2/4-128]
str pol3, [ptr_p, #3*distance2/4-128]
str.w pol5, [ptr_p, #5*distance2/4-128]
str.w pol6, [ptr_p, #6*distance2/4-128]
str.w pol7, [ptr_p, #7*distance2/4-128]
str pol0, [ptr_p, #-128]
str.w pol4, [ptr_p], #128
//add.w ptr_p, #strincr2
str pol1, [ptr_p, #1*distance2/4]
str pol2, [ptr_p, #2*distance2/4]
str pol3, [ptr_p, #3*distance2/4]
str.w pol4, [ptr_p, #4*distance2/4]
str.w pol5, [ptr_p, #5*distance2/4]
str.w pol6, [ptr_p, #6*distance2/4]
str.w pol7, [ptr_p, #7*distance2/4]
str pol0, [ptr_p]
add.w ptr_p, ptr_p, #strincr2

vmov temp_l, s10
cmp.w ptr_p, temp_l
Expand Down Expand Up @@ -342,10 +342,10 @@ pqcrystals_dilithium_invntt_tomont:
str.w pol1, [ptr_p, #256]
str.w pol2, [ptr_p, #512]
str.w pol3, [ptr_p, #768]
str pol0, [ptr_p], #strincr3 // @slothy:core
str pol0, [ptr_p], #strincr3 // @slothy:core // @slothy:before=cmp

vmov cntr, s9
cmp.w ptr_p, cntr
cmp.w ptr_p, cntr // @slothy:id=cmp
bne.w layer78_loop

//restore registers
Expand Down
6 changes: 2 additions & 4 deletions examples/naive/armv7m/ntt_769_dilithium.s
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ small_ntt_asm_769:
// s24: tmp
// s25: twiddle_ptr
vmov s24, tmp
vmov s25, twiddle_ptr
layer1234_loop:
// load a1, a3, ..., a15
vmov s23, poly
Expand Down Expand Up @@ -251,10 +250,10 @@ small_ntt_asm_769:
uadd16 tmp, poly0, poly1
usub16 twiddle1, poly0, poly1
str.w twiddle1, [poly, #offset]
str.w tmp, [poly], #4 // @slothy:core
str.w tmp, [poly], #4 // @slothy:core // @slothy:before=cmp

vmov tmp, s24
cmp.w poly, tmp
cmp.w poly, tmp // @slothy:id=cmp
bne.w layer1234_loop

sub.w poly, #8*strincr
Expand All @@ -266,7 +265,6 @@ small_ntt_asm_769:

add.w tmp, poly, #strincr2*16
vmov s13, tmp
vmov twiddle_ptr, s25
layer567_loop:
vmov s23, poly
load poly, poly0, poly1, poly2, poly3, #0, #distance2/4, #2*distance2/4, #3*distance2/4
Expand Down
Loading

0 comments on commit 9ff2713

Please sign in to comment.