Skip to content

Commit

Permalink
Merge pull request #33 from slothy-optimizer/tutorial
Browse files Browse the repository at this point in the history
Thanks @mkannwischer !
  • Loading branch information
hanno-becker authored Mar 20, 2024
2 parents 4c727f5 + 3ced776 commit 2ce4fa5
Show file tree
Hide file tree
Showing 26 changed files with 7,866 additions and 9 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/test_basic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ jobs:
- name: Run examples
run: |
python3 example.py --dry-run
tutorial:
if: ${{ github.event.label.name == 'needs-ci' ||
github.event.pull_request.user.login == 'hanno-becker' ||
github.event.pull_request.user.login == 'dop-amin' ||
github.event.pull_request.user.login == 'mkannwischer'
}}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install python dependencies
run: |
python -m pip install -r requirements.txt
- name: Run tutorial
run: |
(cd tutorial && ./tutorial_all.sh)
examples_basic:
if: ${{ github.event.label.name == 'needs-ci' ||
github.event.pull_request.user.login == 'hanno-becker' ||
Expand Down
63 changes: 63 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,62 @@ def core(self, slothy):
slothy.config.sw_pipelining.halving_heuristic_periodic = True
slothy.optimize_loop("layer345_loop")

class AArch64Example0(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "aarch64_simple0"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.variable_size=True
slothy.config.constraints.stalls_first_attempt=32
slothy.optimize()

class AArch64Example1(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "aarch64_simple0_macros"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.variable_size=True
slothy.config.constraints.stalls_first_attempt=32
slothy.optimize(start="start", end="end")


class AArch64Example2(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "aarch64_simple0_loop"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.variable_size=True
slothy.config.constraints.stalls_first_attempt=32
slothy.config.sw_pipelining.enabled = True
slothy.config.sw_pipelining.optimize_preamble = False
slothy.config.sw_pipelining.optimize_postamble = False
slothy.optimize_loop("start")



class ntt_kyber_123_4567(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55, timeout=None):
Expand Down Expand Up @@ -1197,6 +1253,13 @@ def main():
Example2(),
Example3(),

AArch64Example0(),
AArch64Example0(target=Target_CortexA72),
AArch64Example1(),
AArch64Example1(target=Target_CortexA72),
AArch64Example2(),
AArch64Example2(target=Target_CortexA72),

CRT(),

ntt_n256_l6_s32("bar"),
Expand Down
24 changes: 24 additions & 0 deletions examples/naive/aarch64/aarch64_simple0.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
ldr q0, [x1, #0]
ldr q1, [x2, #0]

ldr q8, [x0]
ldr q9, [x0, #1*16]
ldr q10, [x0, #2*16]
ldr q11, [x0, #3*16]

mul v24.8h, v9.8h, v0.h[0]
sqrdmulh v9.8h, v9.8h, v0.h[1]
mls v24.8h, v9.8h, v1.h[0]
sub v9.8h, v8.8h, v24.8h
add v8.8h, v8.8h, v24.8h

mul v24.8h, v11.8h, v0.h[0]
sqrdmulh v11.8h, v11.8h, v0.h[1]
mls v24.8h, v11.8h, v1.h[0]
sub v11.8h, v10.8h, v24.8h
add v10.8h, v10.8h, v24.8h

str q8, [x0], #4*16
str q9, [x0, #-3*16]
str q10, [x0, #-2*16]
str q11, [x0, #-1*16]
55 changes: 55 additions & 0 deletions examples/naive/aarch64/aarch64_simple0_loop.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
qdata0 .req q8
qdata1 .req q9
qdata2 .req q10
qdata3 .req q11

qtwiddle .req q0
qmodulus .req q1

data0 .req v8
data1 .req v9
data2 .req v10
data3 .req v11

twiddle .req v0
modulus .req v1

tmp .req v12

data_ptr .req x0
twiddle_ptr .req x1
modulus_ptr .req x2

.macro barmul out, in, twiddle, modulus
mul \out.8h, \in.8h, \twiddle.h[0]
sqrdmulh \in.8h, \in.8h, \twiddle.h[1]
mls \out.8h, \in.8h, \modulus.h[0]
.endm

.macro butterfly data0, data1, tmp, twiddle, modulus
barmul \tmp, \data1, \twiddle, \modulus
sub \data1.8h, \data0.8h, \tmp.8h
add \data0.8h, \data0.8h, \tmp.8h
.endm

count .req x2
ldr qtwiddle, [twiddle_ptr, #0]
ldr qmodulus, [modulus_ptr, #0]
mov count, #16
start:

ldr qdata0, [data_ptr, #0*16]
ldr qdata1, [data_ptr, #1*16]
ldr qdata2, [data_ptr, #2*16]
ldr qdata3, [data_ptr, #3*16]

butterfly data0, data1, tmp, twiddle, modulus
butterfly data2, data3, tmp, twiddle, modulus

str qdata0, [data_ptr], #4*16
str qdata1, [data_ptr, #-3*16]
str qdata2, [data_ptr, #-2*16]
str qdata3, [data_ptr, #-1*16]

subs count, count, #1
cbnz count, start
55 changes: 55 additions & 0 deletions examples/naive/aarch64/aarch64_simple0_macros.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
qdata0 .req q8
qdata1 .req q9
qdata2 .req q10
qdata3 .req q11

qtwiddle .req q0
qmodulus .req q1

data0 .req v8
data1 .req v9
data2 .req v10
data3 .req v11

twiddle .req v0
modulus .req v1

tmp .req v12

data_ptr .req x0
twiddle_ptr .req x1
modulus_ptr .req x2

.macro barmul out, in, twiddle, modulus
mul \out.8h, \in.8h, \twiddle.h[0]
sqrdmulh \in.8h, \in.8h, \twiddle.h[1]
mls \out.8h, \in.8h, \modulus.h[0]
.endm

.macro butterfly data0, data1, tmp, twiddle, modulus
barmul \tmp, \data1, \twiddle, \modulus
sub \data1.8h, \data0.8h, \tmp.8h
add \data0.8h, \data0.8h, \tmp.8h
.endm

count .req x2

start:

ldr qtwiddle, [twiddle_ptr, #0]
ldr qmodulus, [modulus_ptr, #0]

ldr qdata0, [data_ptr, #0*16]
ldr qdata1, [data_ptr, #1*16]
ldr qdata2, [data_ptr, #2*16]
ldr qdata3, [data_ptr, #3*16]

butterfly data0, data1, tmp, twiddle, modulus
butterfly data2, data3, tmp, twiddle, modulus

str qdata0, [data_ptr], #4*16
str qdata1, [data_ptr, #-3*16]
str qdata2, [data_ptr, #-2*16]
str qdata3, [data_ptr, #-1*16]

end:
125 changes: 125 additions & 0 deletions examples/opt/aarch64/aarch64_simple0_loop_opt_a55.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
qdata0 .req q8
qdata1 .req q9
qdata2 .req q10
qdata3 .req q11

qtwiddle .req q0
qmodulus .req q1

data0 .req v8
data1 .req v9
data2 .req v10
data3 .req v11

twiddle .req v0
modulus .req v1

tmp .req v12

data_ptr .req x0
twiddle_ptr .req x1
modulus_ptr .req x2

.macro barmul out, in, twiddle, modulus
mul \out.8h, \in.8h, \twiddle.h[0]
sqrdmulh \in.8h, \in.8h, \twiddle.h[1]
mls \out.8h, \in.8h, \modulus.h[0]
.endm

.macro butterfly data0, data1, tmp, twiddle, modulus
barmul \tmp, \data1, \twiddle, \modulus
sub \data1.8h, \data0.8h, \tmp.8h
add \data0.8h, \data0.8h, \tmp.8h
.endm

count .req x2
ldr qtwiddle, [twiddle_ptr, #0]
ldr qmodulus, [modulus_ptr, #0]
mov count, #16
ldr q3, [x0, #16]
sqrdmulh v7.8H, v3.8H, v0.H[1]
sub count, count, #1
start:
mul v3.8H, v3.8H, v0.H[0] // ....*.............
// gap // ..................
ldr q19, [x0, #48] // ...*..............
// gap // ..................
// gap // ..................
// gap // ..................
ldr q15, [x0, #0] // *.................
// gap // ..................
// gap // ..................
// gap // ..................
mls v3.8H, v7.8H, v1.H[0] // ......*...........
// gap // ..................
mul v13.8H, v19.8H, v0.H[0] // .........*........
// gap // ..................
sqrdmulh v19.8H, v19.8H, v0.H[1] // ..........*.......
// gap // ..................
ldr q7, [x0, #32] // ..*...............
// gap // ..................
// gap // ..................
// gap // ..................
sub v17.8H, v15.8H, v3.8H // .......*..........
// gap // ..................
add v10.8H, v15.8H, v3.8H // ........*.........
// gap // ..................
mls v13.8H, v19.8H, v1.H[0] // ...........*......
// gap // ..................
str q17, [x0, #16] // ...............*..
// gap // ..................
ldr q3, [x0, #80] // .e................
// gap // ..................
// gap // ..................
// gap // ..................
add v15.8H, v7.8H, v13.8H // .............*....
// gap // ..................
str q10, [x0], #4*16 // ..............*...
// gap // ..................
sub v13.8H, v7.8H, v13.8H // ............*.....
// gap // ..................
str q15, [x0, #-32] // ................*.
// gap // ..................
sqrdmulh v7.8H, v3.8H, v0.H[1] // .....e............
// gap // ..................
str q13, [x0, #-16] // .................*
// gap // ..................

// original source code
// ldr q8, [x0, #0*16] // .......|.*...............
// ldr q9, [x0, #1*16] // e......|..........e......
// ldr q10, [x0, #2*16] // .......|.....*...........
// ldr q11, [x0, #3*16] // .......|*................
// mul v12.8h, v9.8h, v0.h[0] // .......*.................
// sqrdmulh v9.8h, v9.8h, v0.h[1] // .....e.|...............e.
// mls v12.8h, v9.8h, v1.h[0] // .......|..*..............
// sub v9.8h, v8.8h, v12.8h // .......|......*..........
// add v8.8h, v8.8h, v12.8h // .......|.......*.........
// mul v12.8h, v11.8h, v0.h[0] // .......|...*.............
// sqrdmulh v11.8h, v11.8h, v0.h[1] // .......|....*............
// mls v12.8h, v11.8h, v1.h[0] // .......|........*........
// sub v11.8h, v10.8h, v12.8h // ...*...|.............*...
// add v10.8h, v10.8h, v12.8h // .*.....|...........*.....
// str q8, [x0], #4*16 // ..*....|............*....
// str q9, [x0, #-3*16] // .......|.........*.......
// str q10, [x0, #-2*16] // ....*..|..............*..
// str q11, [x0, #-1*16] // ......*|................*

sub count, count, #1
cbnz count, start
mul v3.8H, v3.8H, v0.H[0]
ldr q19, [x0, #48]
ldr q15, [x0, #0]
mls v3.8H, v7.8H, v1.H[0]
mul v13.8H, v19.8H, v0.H[0]
sqrdmulh v19.8H, v19.8H, v0.H[1]
ldr q7, [x0, #32]
sub v17.8H, v15.8H, v3.8H
add v10.8H, v15.8H, v3.8H
mls v13.8H, v19.8H, v1.H[0]
str q17, [x0, #16]
add v15.8H, v7.8H, v13.8H
str q10, [x0], #4*16
sub v13.8H, v7.8H, v13.8H
str q15, [x0, #-32]
str q13, [x0, #-16]
Loading

0 comments on commit 2ce4fa5

Please sign in to comment.