Skip to content

Commit

Permalink
[cases] add eval.ntt*
Browse files Browse the repository at this point in the history
  • Loading branch information
SharzyL committed Nov 20, 2024
1 parent 9c51fd9 commit d136389
Show file tree
Hide file tree
Showing 10 changed files with 591 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nix/t1/mill-modules.nix
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ let
./../../common.sc
];
};
millDepsHash = "sha256-XvGLNLOC7OEwfC7SB5zBdB64VjROBkwgIcHx+9FHmSs=";
millDepsHash = "sha256-W/76pfggzH4tsKdjY1JNOqp4IW7DBtgN8ibsFKAYa0g=";
nativeBuildInputs = [ dependencies.setupHook ];
};

Expand Down
36 changes: 36 additions & 0 deletions tests/eval/_ntt/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{ linkerScript
, makeBuilder
, t1main
}:

let
builder = makeBuilder { casePrefix = "eval"; };
build_ntt = caseName /* must be consistent with attr name */ : main_src: kernel_src:
builder {
caseName = caseName;

src = ./.;

passthru.featuresRequired = { };

buildPhase = ''
runHook preBuild
$CC -T${linkerScript} \
${main_src} ${kernel_src} \
${t1main} \
-o $pname.elf
runHook postBuild
'';

meta.description = "test case 'ntt'";
};

in {
ntt_128 = build_ntt "ntt_128" ./ntt.c ./ntt_128_main.c;
ntt_256 = build_ntt "ntt_256" ./ntt.c ./ntt_256_main.c;
ntt_512 = build_ntt "ntt_512" ./ntt.c ./ntt_512_main.c;
ntt_1024 = build_ntt "ntt_1024" ./ntt.c ./ntt_1024_main.c;
ntt_mem_1024 = build_ntt "ntt_mem_1024" ./ntt_mem.c ./ntt_1024_main.c;
}
24 changes: 24 additions & 0 deletions tests/eval/_ntt/gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import random

def main():
vlen = 4096
l = 10
n = 1 << l
assert n <= vlen // 4
p = 12289 # p is prime and n | p - 1
g = 11 # primitive root of p
assert (p - 1) % n == 0
w = (g ** ((p - 1) // n)) % p # now w^n == 1 mod p by Fermat's little theorem
print(w)

twindle_list = []
for _ in range(l):
twindle_list.append(w)
w = (w * w) % p
print(twindle_list)

a = [random.randrange(p) for _ in range(n)]
print(a)

if __name__ == '__main__':
main()
82 changes: 82 additions & 0 deletions tests/eval/_ntt/ntt.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include <assert.h>
#include <stdio.h>

// array is of length n=2^l, p is a prime number
// roots is of length l, where g = roots[0] satisfies that
// g^(2^l) == 1 mod p and g^(2^(l-1)) == -1 mod p
// roots[i] = g^(2^i) (hence roots[l - 1] = -1)
//
// 32bit * n <= VLEN * 8 => n <= VLEN / 4
void ntt(const int *array, int l, const int *twindle, int p, int *dst) {
// prepare an array of permutation indices
assert(l <= 16);

int n = 1 << l;
int g = twindle[0];

// registers:
// v8-15: array
// v16-24: loaded elements (until vrgather)
// v4-7: permutation index (until vrgather)
// v16-24: coefficients
int vlenb;
asm("csrr %0, vlenb" : "=r"(vlenb));
int elements_in_vreg = vlenb * 2;
assert(elements_in_vreg >= n);

asm("vsetvli zero, %0, e16, m4, tu, mu\n"
"vid.v v4\n"
:
: "r"(n));

// prepare the permutation list
for (int k = 0; 2 * k <= l; k++) {
asm("vand.vx v8, v4, %0\n"
"vsub.vv v4, v4, v8\n"
"vsll.vx v8, v8, %1\n" // get the k-th digit and shift left

"vand.vx v12, v4, %2\n"
"vsub.vv v4, v4, v12\n"
"vsrl.vx v12, v12, %1\n" // get the (l-k)-th digit and shift right

"vor.vv v4, v4, v8\n"
"vor.vv v4, v4, v12\n"
:
: "r"(1 << k), "r"(l - 1 - 2 * k), "r"(1 << (l - k)));
}

asm("vsetvli zero, %0, e32, m8, tu, mu\n"
"vle32.v v16, 0(%1)\n"
"vrgatherei16.vv v8, v16, v4\n"

// set v16 to all 1
"vxor.vv v16, v16, v16\n"
"vadd.vi v16, v16, 1\n"
:
: "r"(n), "r"(array));

for (int k = 0; k < l; k++) {
asm( // prepare coefficients in v16-23
"vid.v v24\n" // v24-31[i] = i
"vand.vx v24, v24, %1\n" // v24-31[i] = i & (1 << k)
"vmsne.vi v0, v24, 0\n" // vm0[i] = i & (1 << k) != 0
"vmul.vx v16, v16, %2, v0.t\n" // v16-23[i] = w^(???)

// prepare shifted elements in v24-31
"vslideup.vx v24, v8, %3\n" // shift the first 2^(l-k) elements to tail
"vsetvli zero, %3, e32, m8, tu, mu\n" // last n - 2^(l-k) elements
"vslidedown.vx v24, v8, %4\n"

// mul and add
"vsetvli zero, %0, e32, m8, tu, mu\n"
"vmul.vv v24, v24, v16\n"
"vrem.vx v24, v24, %5\n"
"vadd.vv v8, v8, v24\n" // TODO: will it overflow?
:
: "r"(n), /* %1 */ "r"(1 << k), /* %2 */ "r"(twindle[l - 1 - k]),
/* %3 */ "r"(n - (1 << (l - k))),
/* %4 */ "r"(1 << (l - k)), /* %5 */ "r"(p));
}
asm("vse32.v v8, 0(%0)\n" : : "r"(dst));
}

131 changes: 131 additions & 0 deletions tests/eval/_ntt/ntt_1024_main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// requires VLEN >= 4096

#include <stdio.h>

void ntt(const int *array, int l, const int *twindle, int p, int *dst);

void test() {
const int l = 10;
const int n = 1024;
const int arr[1024] = {
9997, 6362, 7134, 11711, 5849, 9491, 5972, 4164, 5894, 11069,
7697, 8319, 2077, 12086, 10239, 5394, 4898, 1370, 1205, 2997,
5274, 4625, 11983, 1789, 3645, 7666, 12128, 10883, 7376, 8883,
2321, 1889, 2026, 8059, 2741, 865, 1785, 9955, 2395, 9330,
11465, 7383, 9649, 11285, 3647, 578, 1158, 9936, 12019, 11114,
7894, 4832, 10148, 10363, 11388, 9122, 10758, 2642, 4171, 10586,
1194, 5280, 3055, 9220, 10577, 9046, 1284, 7915, 10213, 6902,
3777, 9896, 429, 7730, 7429, 8666, 10887, 11255, 2437, 7782,
1327, 7010, 4009, 1038, 9466, 5352, 1473, 10067, 11753, 2019,
8472, 7665, 2679, 5070, 2248, 3044, 10301, 10671, 2092, 1069,
9032, 9131, 11715, 6662, 3423, 10027, 5436, 4259, 999, 3316,
11164, 5597, 6578, 800, 8242, 6952, 2288, 1481, 6770, 11948,
8938, 10813, 11107, 1362, 4510, 9388, 8840, 10557, 6206, 7808,
7131, 1394, 2604, 1509, 689, 5222, 8867, 9934, 7165, 6099,
3229, 1263, 4414, 12212, 4963, 9236, 9040, 6062, 11163, 8169,
4575, 6097, 3006, 1, 1384, 12039, 5445, 11355, 12197, 9182,
10085, 9295, 8890, 10651, 1540, 9061, 10222, 2524, 2213, 6974,
2066, 7348, 7444, 173, 7529, 3884, 3531, 4312, 640, 5352,
5880, 3985, 781, 10165, 1106, 8114, 6043, 8202, 10617, 3060,
11173, 11521, 6933, 9540, 11782, 2284, 6462, 3740, 2581, 126,
508, 12165, 4956, 8045, 9379, 5250, 8148, 6539, 4891, 11252,
5041, 9969, 8524, 9892, 4058, 10580, 10025, 9748, 8829, 4438,
468, 4773, 1657, 1348, 10055, 7192, 9556, 5919, 5690, 6153,
6270, 4938, 6206, 1003, 596, 11173, 9858, 4825, 7940, 794,
7477, 10146, 7203, 4729, 5741, 4603, 1806, 7034, 8772, 10435,
10777, 1359, 630, 11059, 8005, 225, 10355, 9226, 4449, 11236,
680, 8615, 6828, 5502, 10082, 5491, 4346, 7831, 5429, 1253,
6662, 9415, 584, 9362, 8452, 1937, 3271, 6852, 6573, 7706,
1229, 8535, 3786, 6441, 7230, 533, 5778, 6436, 11728, 7896,
785, 7591, 9061, 6149, 10403, 9079, 10837, 9776, 7850, 7870,
5008, 5319, 541, 315, 9973, 5055, 7111, 8399, 614, 10495,
9441, 10946, 449, 6965, 7980, 11475, 9321, 2256, 8998, 4321,
11269, 4744, 5021, 11981, 7947, 7695, 4000, 1140, 2895, 3419,
159, 5370, 10899, 3288, 12007, 8894, 7923, 7366, 11534, 5214,
10461, 11199, 10965, 3739, 5507, 8882, 10725, 9649, 1144, 9153,
5573, 878, 11115, 5677, 5970, 7221, 8614, 4703, 9394, 11660,
8423, 6621, 11112, 10945, 527, 5019, 5396, 10049, 6770, 3406,
2967, 3890, 2441, 4682, 6026, 617, 7316, 2627, 4456, 8925,
2388, 11354, 4554, 10543, 2610, 10688, 1150, 2556, 4278, 431,
9260, 3545, 12215, 631, 4407, 8145, 1403, 8523, 1982, 12073,
950, 7671, 31, 1299, 9003, 11690, 5637, 6761, 5235, 5722,
11858, 2210, 7870, 11608, 8884, 8550, 4776, 4998, 4270, 8850,
12111, 240, 5674, 3845, 5057, 1608, 48, 2760, 8612, 278,
5633, 9505, 3730, 1971, 8637, 8659, 894, 8594, 4221, 6783,
5664, 9506, 2811, 11058, 4475, 2912, 2289, 2136, 7899, 6065,
5259, 2230, 6793, 4280, 3140, 1721, 8333, 11216, 5383, 7139,
10711, 1017, 2001, 10911, 1750, 162, 11775, 10575, 1646, 8322,
175, 10156, 3635, 4893, 2207, 3234, 4380, 1900, 5493, 3082,
10058, 9948, 10752, 7044, 10073, 11210, 8362, 9268, 8694, 1438,
761, 10180, 6570, 6349, 9028, 10495, 4756, 9332, 8348, 4995,
6933, 4351, 111, 1610, 7410, 960, 11972, 2853, 3551, 1423,
9073, 7328, 7803, 7591, 3547, 964, 7327, 7357, 3352, 9415,
7393, 5739, 11960, 4303, 2250, 4026, 9362, 2004, 853, 10393,
4433, 3021, 7803, 2610, 3780, 8299, 1970, 11031, 10118, 308,
3432, 11166, 9976, 569, 1344, 7369, 12097, 1005, 2415, 7435,
2685, 5458, 10746, 392, 426, 1015, 9258, 1151, 4957, 4200,
12077, 2777, 308, 717, 12162, 7328, 2534, 4327, 10539, 11256,
7448, 10860, 7970, 11475, 6069, 4387, 11635, 7366, 2936, 5476,
8097, 2867, 3190, 7533, 5373, 10352, 8159, 5735, 10998, 3075,
10214, 10094, 11536, 2967, 4624, 11742, 9299, 5344, 9317, 8656,
4692, 12008, 4161, 9114, 2469, 251, 11478, 9766, 843, 6217,
8053, 11029, 9887, 5541, 10365, 6291, 10649, 8440, 172, 9521,
116, 12205, 2770, 8357, 8172, 1320, 4, 2834, 3823, 2879,
10188, 4974, 380, 4279, 10235, 5379, 5379, 11037, 9767, 12116,
4150, 7059, 3138, 7590, 5572, 1361, 11572, 3025, 2734, 1012,
3974, 10605, 2533, 6360, 4466, 680, 270, 6194, 8800, 10708,
6327, 5218, 7130, 3073, 5815, 3950, 11849, 3707, 3192, 1406,
676, 975, 2649, 4904, 161, 792, 10023, 4604, 7491, 1174,
747, 12139, 8595, 4933, 3610, 11754, 2648, 909, 9984, 10440,
3929, 8443, 7723, 4698, 1266, 7234, 3598, 2380, 5972, 11194,
9470, 840, 7368, 1626, 5808, 1883, 3314, 6771, 3564, 3146,
743, 10912, 8204, 7195, 5580, 1376, 6366, 6529, 4247, 5104,
5745, 4231, 8300, 7618, 6933, 1241, 277, 551, 10811, 2163,
10481, 11841, 10709, 9664, 10019, 10521, 3400, 4179, 4589, 1961,
6740, 2785, 10196, 8943, 3621, 1180, 8317, 8350, 6758, 3720,
4157, 8131, 4658, 8954, 7026, 9860, 3108, 1006, 9807, 632,
9359, 5535, 8837, 6506, 4205, 1582, 4644, 3885, 5106, 3772,
7830, 4472, 4361, 8529, 9463, 825, 9438, 11990, 4998, 5703,
11138, 5835, 1858, 2308, 1526, 6541, 4857, 585, 8344, 8893,
6536, 1324, 4263, 265, 6381, 8780, 4783, 12098, 10832, 10986,
7327, 7156, 4435, 2430, 1162, 5473, 1602, 1219, 5435, 1868,
8655, 1693, 531, 1889, 7801, 5060, 114, 8715, 10198, 5578,
11574, 10608, 4704, 2476, 4014, 2888, 11601, 7989, 9154, 463,
1206, 2159, 4238, 5734, 7393, 8704, 10369, 308, 7805, 9498,
8644, 11031, 6876, 9446, 7302, 5492, 343, 12078, 11143, 674,
1223, 5279, 470, 4091, 6788, 120, 8981, 9126, 3119, 1562,
10144, 7379, 11688, 1969, 2332, 5613, 2181, 456, 6469, 2622,
11073, 8755, 6536, 375, 3053, 11435, 5193, 4215, 4596, 5145,
8969, 9431, 6894, 6009, 5261, 277, 2507, 1547, 4765, 2207,
6527, 10342, 10440, 6321, 5628, 1722, 7693, 3291, 9392, 5906,
5003, 9013, 10003, 3233, 6551, 10508, 3380, 1030, 3868, 11869,
9858, 9338, 12240, 4671, 3832, 1353, 8888, 3898, 11022, 7442,
11936, 6211, 6142, 7656, 7859, 11772, 116, 6966, 7915, 4903,
6023, 4518, 1155, 2172, 5690, 4241, 9428, 3696, 3735, 3467,
495, 6040, 12019, 10346, 8531, 3713, 2431, 4551, 5070, 5932,
8769, 2413, 5942, 2753, 2600, 11963, 11106, 10875, 6799, 3426,
458, 6126, 8785, 1730, 6994, 5757, 8224, 9043, 8939, 9013,
4686, 7680, 1133, 6033, 6376, 8697, 793, 8639, 4831, 3535,
561, 5483, 8341, 10355, 1411, 5853, 5834, 3689, 1943, 10890,
1693, 1302, 5519, 9392, 9549, 3191, 597, 84, 9477, 3948,
2093, 8565, 10618, 1305, 4570, 4275, 9557, 557, 768, 4047,
4215, 2567, 9480, 4248, 10029, 11156, 4477, 12152, 4108, 3109,
2634, 3972, 5921, 373};
const int twindle[10] = {10302, 3400, 8340, 12149, 7311,
5860, 4134, 8246, 1479, 12288};
const int p = 12289;
int dst[1024];
ntt(arr, l, twindle, p, dst);

for (int i = 0; i < n; i++) {
printf("%d", dst[i]);
if ((i + 1) % 8 == 0) {
printf("\n");
} else {
printf(" ");
}
}
}


int main() { test(); }
51 changes: 51 additions & 0 deletions tests/eval/_ntt/ntt_128_main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// requires VLEN >= 512

#include <stdio.h>

void ntt(const int *array, int l, const int *twindle, int p, int *dst);

void test() {
const int l = 8;
const int n = 256;
const int arr[256] = {
9997, 6362, 7134, 11711, 5849, 9491, 5972, 4164, 5894, 11069,
7697, 8319, 2077, 12086, 10239, 5394, 4898, 1370, 1205, 2997,
5274, 4625, 11983, 1789, 3645, 7666, 12128, 10883, 7376, 8883,
2321, 1889, 2026, 8059, 2741, 865, 1785, 9955, 2395, 9330,
11465, 7383, 9649, 11285, 3647, 578, 1158, 9936, 12019, 11114,
7894, 4832, 10148, 10363, 11388, 9122, 10758, 2642, 4171, 10586,
1194, 5280, 3055, 9220, 10577, 9046, 1284, 7915, 10213, 6902,
3777, 9896, 429, 7730, 7429, 8666, 10887, 11255, 2437, 7782,
1327, 7010, 4009, 1038, 9466, 5352, 1473, 10067, 11753, 2019,
8472, 7665, 2679, 5070, 2248, 3044, 10301, 10671, 2092, 1069,
9032, 9131, 11715, 6662, 3423, 10027, 5436, 4259, 999, 3316,
11164, 5597, 6578, 800, 8242, 6952, 2288, 1481, 6770, 11948,
8938, 10813, 11107, 1362, 4510, 9388, 8840, 10557, 6206, 7808,
7131, 1394, 2604, 1509, 689, 5222, 8867, 9934, 7165, 6099,
3229, 1263, 4414, 12212, 4963, 9236, 9040, 6062, 11163, 8169,
4575, 6097, 3006, 1, 1384, 12039, 5445, 11355, 12197, 9182,
10085, 9295, 8890, 10651, 1540, 9061, 10222, 2524, 2213, 6974,
2066, 7348, 7444, 173, 7529, 3884, 3531, 4312, 640, 5352,
5880, 3985, 781, 10165, 1106, 8114, 6043, 8202, 10617, 3060,
11173, 11521, 6933, 9540, 11782, 2284, 6462, 3740, 2581, 126,
508, 12165, 4956, 8045, 9379, 5250, 8148, 6539, 4891, 11252,
5041, 9969, 8524, 9892, 4058, 10580, 10025, 9748, 8829, 4438,
468, 4773, 1657, 1348, 10055, 7192, 9556, 5919, 5690, 6153,
6270, 4938, 6206, 1003, 596, 11173, 9858, 4825, 7940, 794,
7477, 10146, 7203, 4729, 5741, 4603, 1806, 7034, 8772, 10435,
10777, 1359, 630, 11059, 8005, 225};
const int twindle[8] = {10302, 3400, 8340, 12149, 7311,
5860, 4134, 8246};
const int p = 12289;
int dst[256];
ntt(arr, l, twindle, p, dst);

for (int i = 0; i < n; i++) {
printf("%d", dst[i]);
if ((i + 1) % 8 == 0) {
printf("\n");
} else {
printf(" ");
}
}
}
51 changes: 51 additions & 0 deletions tests/eval/_ntt/ntt_256_main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// requires VLEN >= 1024

#include <stdio.h>

void ntt(const int *array, int l, const int *twindle, int p, int *dst);

void test() {
const int l = 8;
const int n = 256;
const int arr[256] = {
9997, 6362, 7134, 11711, 5849, 9491, 5972, 4164, 5894, 11069,
7697, 8319, 2077, 12086, 10239, 5394, 4898, 1370, 1205, 2997,
5274, 4625, 11983, 1789, 3645, 7666, 12128, 10883, 7376, 8883,
2321, 1889, 2026, 8059, 2741, 865, 1785, 9955, 2395, 9330,
11465, 7383, 9649, 11285, 3647, 578, 1158, 9936, 12019, 11114,
7894, 4832, 10148, 10363, 11388, 9122, 10758, 2642, 4171, 10586,
1194, 5280, 3055, 9220, 10577, 9046, 1284, 7915, 10213, 6902,
3777, 9896, 429, 7730, 7429, 8666, 10887, 11255, 2437, 7782,
1327, 7010, 4009, 1038, 9466, 5352, 1473, 10067, 11753, 2019,
8472, 7665, 2679, 5070, 2248, 3044, 10301, 10671, 2092, 1069,
9032, 9131, 11715, 6662, 3423, 10027, 5436, 4259, 999, 3316,
11164, 5597, 6578, 800, 8242, 6952, 2288, 1481, 6770, 11948,
8938, 10813, 11107, 1362, 4510, 9388, 8840, 10557, 6206, 7808,
7131, 1394, 2604, 1509, 689, 5222, 8867, 9934, 7165, 6099,
3229, 1263, 4414, 12212, 4963, 9236, 9040, 6062, 11163, 8169,
4575, 6097, 3006, 1, 1384, 12039, 5445, 11355, 12197, 9182,
10085, 9295, 8890, 10651, 1540, 9061, 10222, 2524, 2213, 6974,
2066, 7348, 7444, 173, 7529, 3884, 3531, 4312, 640, 5352,
5880, 3985, 781, 10165, 1106, 8114, 6043, 8202, 10617, 3060,
11173, 11521, 6933, 9540, 11782, 2284, 6462, 3740, 2581, 126,
508, 12165, 4956, 8045, 9379, 5250, 8148, 6539, 4891, 11252,
5041, 9969, 8524, 9892, 4058, 10580, 10025, 9748, 8829, 4438,
468, 4773, 1657, 1348, 10055, 7192, 9556, 5919, 5690, 6153,
6270, 4938, 6206, 1003, 596, 11173, 9858, 4825, 7940, 794,
7477, 10146, 7203, 4729, 5741, 4603, 1806, 7034, 8772, 10435,
10777, 1359, 630, 11059, 8005, 225};
const int twindle[8] = {10302, 3400, 8340, 12149, 7311,
5860, 4134, 8246};
const int p = 12289;
int dst[256];
ntt(arr, l, twindle, p, dst);

for (int i = 0; i < n; i++) {
printf("%d", dst[i]);
if ((i + 1) % 8 == 0) {
printf("\n");
} else {
printf(" ");
}
}
}
Loading

0 comments on commit d136389

Please sign in to comment.