From 592f86da91fa2f41391d263cc97be1e88d980610 Mon Sep 17 00:00:00 2001 From: Kate F Date: Mon, 16 Sep 2024 09:17:59 +0100 Subject: [PATCH] llvm "wasm assembly" codegen. Thanks to Damian Gryski for figuring out what we need to generate here, and to Chris Fallin for suggesting it. --- src/libfsm/print/wasm.c | 390 ++++++++++++++++++++++++++++++---------- 1 file changed, 295 insertions(+), 95 deletions(-) diff --git a/src/libfsm/print/wasm.c b/src/libfsm/print/wasm.c index ae84e4dea..6952621d4 100644 --- a/src/libfsm/print/wasm.c +++ b/src/libfsm/print/wasm.c @@ -5,6 +5,7 @@ */ #include +#include #include #include #include @@ -42,12 +43,108 @@ enum wasm_dialect { }; static void -transition(FILE *f, unsigned index, unsigned to, +print_comment(FILE *f, enum wasm_dialect dialect, + const struct fsm_options *opt, + const char *indent, + const char *fmt, ...) +{ + const char *s; + va_list ap; + + assert(f != NULL); + assert(opt != NULL); + assert(fmt != NULL); + + if (!opt->comments) { + return; + } + + switch (dialect) { + case DIALECT_S: s = "//"; break; + case DIALECT_WAT: s = ";;"; break; + } + + fprintf(f, "%s%s ", indent, s); + + va_start(ap, fmt); + (void) vfprintf(f, fmt, ap); + va_end(ap); +} + +static int +print_const(FILE *f, enum wasm_dialect dialect, unsigned u, const char *indent) +{ + (void) dialect; + + assert(f != NULL); + assert(u <= INT32_MAX); + assert(indent != NULL); + + return fprintf(f, "%si32.const %u\n", indent, u); +} + +static void +print_endpoint(FILE *f, enum wasm_dialect dialect, + const struct fsm_options *opt, unsigned char c, + const char *indent) +{ + assert(f != NULL); + assert(opt != NULL); + assert(indent != NULL); + + fprintf(f, "%si32.const %u", indent, (unsigned char) c); + + if (opt->comments) { + print_comment(f, dialect, opt, " ", ""); + fprintf(f, "\'"); + json_escputc(f, opt, c); + fprintf(f, "\'"); + } + + fprintf(f, "\n"); +} + +static void +print_table(FILE *f, enum wasm_dialect dialect, size_t n, const char *indent) +{ + size_t i; + + assert(f != NULL); + assert(indent != NULL); + + switch (dialect) { + case DIALECT_S: + fprintf(f, "%sbr_table {", indent); + for (i = 0; i < n; i++) { + fprintf(f, " %zu%s", i, i + 1 < n ? "," : ""); + } + fprintf(f, " }\n"); + break; + + case DIALECT_WAT: + fprintf(f, "%sbr_table", indent); + for (i = 0; i < n; i++) { + fprintf(f, " %zu", i); + } + fprintf(f, "\n"); + break; + } +} + +static void +print_transition(FILE *f, enum wasm_dialect dialect, + const struct fsm_options *opt, + unsigned index, unsigned to, const char *indent) { + assert(f != NULL); + assert(indent != NULL); + if (to == ERROR_STATE) { - fprintf(f, " i32.const 0\n"); - fprintf(f, " return\n"); + fprintf(f, "%si32.const %u", indent, 0); + print_comment(f, dialect, opt, " ", "error"); + fprintf(f, "\n"); + fprintf(f, "%sreturn\n", indent); return; } @@ -57,34 +154,90 @@ transition(FILE *f, unsigned index, unsigned to, } // set next state - fprintf(f, "%si32.const %u\n", indent, to); + fprintf(f, "%si32.const %u", indent, to); + print_comment(f, dialect, opt, " ", "S%u", to); + fprintf(f, "\n"); fprintf(f, "%slocal.set %u\n", indent, LOCAL_STATE); } -static void -print_endpoint(FILE *f, const struct fsm_options *opt, unsigned char c) +static int +print_if(FILE *f, enum wasm_dialect dialect, const char *type, const char *indent) { assert(f != NULL); - assert(opt != NULL); + assert(indent != NULL); - fprintf(f, " i32.const %u", (unsigned char) c); + if (type == NULL) { + return fprintf(f, "%sif\n", indent); + } - if (opt->comments) { - fprintf(f, " ;; \'"); - json_escputc(f, opt, c); - fprintf(f, "\'"); + switch (dialect) { + case DIALECT_S: + return fprintf(f, "%sif %s\n", indent, type); + + case DIALECT_WAT: + return fprintf(f, "%sif (result %s)\n", indent, type); } - fprintf(f, "\n"); + assert(!"unreached"); + abort(); +} + +static int +print_end(FILE *f, enum wasm_dialect dialect, const char *construct, const char *indent) +{ + assert(f != NULL); + assert(construct != NULL); + assert(indent != NULL); + + switch (dialect) { + case DIALECT_S: + return fprintf(f, "%send_%s\n", indent, construct); + + case DIALECT_WAT: + return fprintf(f, "%send\n", indent); + } + + assert(!"unreached"); + abort(); +} + +static int +print_endif(FILE *f, enum wasm_dialect dialect, const char *indent) +{ + assert(f != NULL); + assert(indent != NULL); + + return print_end(f, dialect, "if", indent); +} + +static int +print_endloop(FILE *f, enum wasm_dialect dialect, const char *indent) +{ + assert(f != NULL); + assert(indent != NULL); + + return print_end(f, dialect, "loop", indent); +} + +static int +print_endblock(FILE *f, enum wasm_dialect dialect, const char *indent) +{ + assert(f != NULL); + assert(indent != NULL); + + return print_end(f, dialect, "block", indent); } static void -print_range(FILE *f, const struct fsm_options *opt, - const struct ir_range *range) +print_range(FILE *f, enum wasm_dialect dialect, + const struct fsm_options *opt, + const struct ir_range *range, + const char *indent) { assert(f != NULL); assert(opt != NULL); assert(range != NULL); + assert(indent != NULL); // TODO: could identify the same ranges but upper and lowercase, // and take advantage of ascii's single-bit difference. @@ -93,24 +246,25 @@ print_range(FILE *f, const struct fsm_options *opt, // leaves a boolean on the stack if (range->end == range->start) { - fprintf(f, " local.get %u\n", LOCAL_CHAR); // get current input byte - print_endpoint(f, opt, range->start); - fprintf(f, " i32.eq\n"); + fprintf(f, "%slocal.get %u\n", indent, LOCAL_CHAR); // get current input byte + print_endpoint(f, dialect, opt, range->start, indent); + fprintf(f, "%si32.eq\n", indent); } else { - fprintf(f, " local.get %u\n", LOCAL_CHAR); // get current input byte - print_endpoint(f, opt, range->start); - fprintf(f, " i32.ge_u\n"); + fprintf(f, "%slocal.get %u\n", indent, LOCAL_CHAR); // get current input byte + print_endpoint(f, dialect, opt, range->start, indent); + fprintf(f, "%si32.ge_u\n", indent); - fprintf(f, " local.get %u\n", LOCAL_CHAR); // get current input byte - print_endpoint(f, opt, range->end); - fprintf(f, " i32.le_u\n"); + fprintf(f, "%slocal.get %u\n", indent, LOCAL_CHAR); // get current input byte + print_endpoint(f, dialect, opt, range->end, indent); + fprintf(f, "%si32.le_u\n", indent); - fprintf(f, " i32.and\n"); + fprintf(f, "%si32.and\n", indent); } } static void -print_ranges(FILE *f, const struct fsm_options *opt, +print_ranges(FILE *f, enum wasm_dialect dialect, + const struct fsm_options *opt, const struct ir_range *ranges, size_t n, bool complete) { @@ -123,7 +277,7 @@ print_ranges(FILE *f, const struct fsm_options *opt, /* A single range already leaves its own bool on the stack */ if (n == 1) { - print_range(f, opt, &ranges[0]); + print_range(f, dialect, opt, &ranges[0], " "); return; } @@ -140,9 +294,11 @@ print_ranges(FILE *f, const struct fsm_options *opt, * perhaps with IR support. */ for (k = 0; k < n; k++) { - print_range(f, opt, &ranges[k]); - fprintf(f, " if (result i32)\n"); - fprintf(f, " i32.const 1 ;; match \n"); + print_range(f, dialect, opt, &ranges[k], " "); + print_if(f, dialect, "i32", " "); + fprintf(f, " i32.const 1"); + print_comment(f, dialect, opt, " ", "match"); + fprintf(f, "\n"); // XXX: we don't want to return here, just leave a bool on the stack for (result i32) // fprintf(f, " return\n"); @@ -153,17 +309,20 @@ print_ranges(FILE *f, const struct fsm_options *opt, if (!complete) { fprintf(f, " else\n"); - fprintf(f, " i32.const 0 ;; no match \n"); + fprintf(f, " i32.const 0"); + print_comment(f, dialect, opt, " ", "no match"); + fprintf(f, "\n"); fprintf(f, " return\n"); } for (k = 0; k < n; k++) { - fprintf(f, " end\n"); + print_endif(f, dialect, " "); } } static void -print_groups(FILE *f, const struct fsm_options *opt, +print_groups(FILE *f, enum wasm_dialect dialect, + const struct fsm_options *opt, const struct ir_group *groups, size_t n, unsigned index, bool complete, unsigned default_to) @@ -176,21 +335,21 @@ print_groups(FILE *f, const struct fsm_options *opt, assert(n > 0); /* - * Here leave nothing on the stack, but transition() as side effect. + * Here leave nothing on the stack, but print_transition() as side effect. * * We prefer this as a side effect rather than accumulating a * destination on the stack, because we can skip effects for where * the destination state is unchanged from the current index. */ // TODO: explain we need another if/else chain for short-circuit evaulation of the groups -// this one doesn't return a value, because each group transition()s as a side effect -// we can't avoid the if/else chain for a single group, beause we still need to convert bool to transition() per group +// this one doesn't return a value, because each group print_transition()s as a side effect +// we can't avoid the if/else chain for a single group, beause we still need to convert bool to print_transition() per group for (j = 0; j < n; j++) { - print_ranges(f, opt, groups[j].ranges, groups[j].n, complete); + print_ranges(f, dialect, opt, groups[j].ranges, groups[j].n, complete); - fprintf(f, " if\n"); - transition(f, index, groups[j].to, " "); + print_if(f, dialect, NULL, " "); + print_transition(f, dialect, opt, index, groups[j].to, " "); if (j + 1 < n) { fprintf(f, " else\n"); @@ -199,16 +358,16 @@ print_groups(FILE *f, const struct fsm_options *opt, if (!complete && default_to != index) { fprintf(f, " else\n"); - transition(f, index, default_to, " "); + print_transition(f, dialect, opt, index, default_to, " "); } for (j = 0; j < n; j++) { - fprintf(f, " end\n"); + print_endif(f, dialect, " "); } } static int -print_state(FILE *f, +print_state(FILE *f, enum wasm_dialect dialect, const struct fsm_options *opt, const struct fsm_hooks *hooks, const struct ir_state *cs, @@ -230,10 +389,10 @@ print_state(FILE *f, } } - fprintf(f, " end\n"); + print_endblock(f, dialect, " "); fprintf(f, "\n"); - fprintf(f, " ;; S%u", index); + print_comment(f, dialect, opt, " ", "S%u", index); if (cs->example != NULL) { fprintf(f, " \""); escputs(f, opt, json_escputc, cs->example); @@ -243,47 +402,54 @@ print_state(FILE *f, switch (cs->strategy) { case IR_NONE: - fprintf(f, " ;; IR_NONE\n"); - transition(f, index, ERROR_STATE, " "); + print_comment(f, dialect, opt, " ", "IR_NONE"); + fprintf(f, "\n"); + print_transition(f, dialect, opt, index, ERROR_STATE, " "); return 0; case IR_SAME: - fprintf(f, " ;; IR_SAME\n"); - transition(f, index, cs->u.same.to, " "); + print_comment(f, dialect, opt, " ", "IR_SAME"); + fprintf(f, "\n"); + print_transition(f, dialect, opt, index, cs->u.same.to, " "); break; case IR_COMPLETE: - fprintf(f, " ;; IR_COMPLETE\n"); - print_groups(f, opt, cs->u.complete.groups, cs->u.complete.n, index, true, ERROR_STATE); + print_comment(f, dialect, opt, " ", "IR_COMPLETE"); + fprintf(f, "\n"); + print_groups(f, dialect, opt, cs->u.complete.groups, cs->u.complete.n, index, true, ERROR_STATE); break; case IR_PARTIAL: - fprintf(f, " ;; IR_PARTIAL\n"); - print_groups(f, opt, cs->u.partial.groups, cs->u.partial.n, index, false, ERROR_STATE); + print_comment(f, dialect, opt, " ", "IR_PARTIAL"); + fprintf(f, "\n"); + print_groups(f, dialect, opt, cs->u.partial.groups, cs->u.partial.n, index, false, ERROR_STATE); fprintf(f, "\n"); break; case IR_DOMINANT: - fprintf(f, " ;; IR_DOMINANT\n"); - print_groups(f, opt, cs->u.dominant.groups, cs->u.dominant.n, index, false, cs->u.dominant.mode); + print_comment(f, dialect, opt, " ", "IR_DOMINANT"); + fprintf(f, "\n"); + print_groups(f, dialect, opt, cs->u.dominant.groups, cs->u.dominant.n, index, false, cs->u.dominant.mode); break; case IR_ERROR: - fprintf(f, " ;; IR_ERROR\n"); - print_ranges(f, opt, cs->u.error.error.ranges, cs->u.error.error.n, false); - fprintf(f, " if\n"); - transition(f, index, ERROR_STATE, " "); + print_comment(f, dialect, opt, " ", "IR_ERROR"); + fprintf(f, "\n"); + print_ranges(f, dialect, opt, cs->u.error.error.ranges, cs->u.error.error.n, false); + print_if(f, dialect, NULL, " "); + print_transition(f, dialect, opt, index, ERROR_STATE, " "); if (cs->u.error.n > 0) { fprintf(f, " else\n"); - print_groups(f, opt, cs->u.error.groups, cs->u.error.n, index, true, cs->u.error.mode); + print_groups(f, dialect, opt, cs->u.error.groups, cs->u.error.n, index, true, cs->u.error.mode); } - fprintf(f, " end\n"); + print_endif(f, dialect, " "); break; case IR_TABLE: - fprintf(f, " ;; IR_TABLE\n"); + print_comment(f, dialect, opt, " ", "IR_TABLE"); + fprintf(f, "\n"); fprintf(f, " local.get %u\n", LOCAL_CHAR); // get current input byte fprintf(f, " drop\n"); // TODO: do something with it ... // TODO: would emit br_table here @@ -294,18 +460,21 @@ print_state(FILE *f, ; } - fprintf(f, " br %u ;; continue the loop, %u block%s up\n", delta - 1, delta - 1, &"s"[delta - 1 == 1]); + fprintf(f, " br %u", delta - 1); + print_comment(f, dialect, opt, " ", "continue the loop, %u block%s up", + delta - 1, &"s"[delta - 1 == 1]); + fprintf(f, "\n"); return 0; } static int fsm_print_wasm(FILE *f, + enum wasm_dialect dialect, const struct fsm_options *opt, const struct fsm_hooks *hooks, const struct ret_list *retlist, - const struct ir *ir, - enum wasm_dialect dialect) + const struct ir *ir) { const char *prefix; size_t i; @@ -325,25 +494,42 @@ fsm_print_wasm(FILE *f, } if (!opt->fragment) { - fprintf(f, "(module\n"); - fprintf(f, " (export \"%smatch\" (func 0))", prefix); - fprintf(f, " (memory 1 1)\n"); // input + switch (dialect) { + case DIALECT_S: + fprintf(f, ".global %smatch\n", prefix); + fprintf(f, ".hidden %smatch\n", prefix); + fprintf(f, ".type %smatch,@function\n", prefix); + fprintf(f, "%smatch:\n", prefix); + fprintf(f, ".functype %smatch (i32) -> (i32)\n", prefix); + fprintf(f, ".local i32, i32\n"); +// fprintf(f, "// (memory 1 1)\n"); // TODO(dgryski): I guess we don't need this line? + break; + + case DIALECT_WAT: + fprintf(f, "(module\n"); + fprintf(f, " (export \"%smatch\" (func 0))", prefix); + fprintf(f, " (memory 1 1)\n"); // input + break; + } } // TODO: export to component model, use opt.package_prefix // TODO: various IO APIs // TODO: endids - fprintf(f, " (func (param i32) (result i32) (local i32 i32)\n"); -// fprintf(f, " ;; s is in LOCAL_STR (the parameter) and we'll keep p there too\n"); -// fprintf(f, " ;; we'll cache *p in LOCAL_CHAR\n"); -// fprintf(f, "\n"); + if (dialect == DIALECT_WAT) { + fprintf(f, " (func (param i32) (result i32) (local i32 i32)\n"); +// fprintf(f, " ;; s is in LOCAL_STR (the parameter) and we'll keep p there too\n"); +// fprintf(f, " ;; we'll cache *p in LOCAL_CHAR\n"); +// fprintf(f, "\n"); + } // the current state will be in LOCAL_STATE // locals are implicitly initialized to 0 if (ir->start != 0) { - fprintf(f, " ;; start S%u\n", ir->start); - fprintf(f, " i32.const %u\n", ir->start); + print_comment(f, dialect, opt, " ", "start S%u", ir->start); + fprintf(f, "\n"); + print_const(f, dialect, ir->start, " "); fprintf(f, " local.set %u\n", LOCAL_STATE); fprintf(f, "\n"); } @@ -353,20 +539,23 @@ fsm_print_wasm(FILE *f, fprintf(f, " loop\n"); // begin the outer loop fprintf(f, "\n"); - fprintf(f, " ;; fetch *p\n"); + print_comment(f, dialect, opt, " ", "fetch *p"); + fprintf(f, "\n"); fprintf(f, " local.get %u\n", LOCAL_STR); // get address of next byte fprintf(f, " i32.load8_u 0\n"); // load byte at that address fprintf(f, " local.tee %u\n", LOCAL_CHAR); // save the current input byte and keep it on the stack fprintf(f, "\n"); - fprintf(f, " ;; *p != '\\0'\n"); + print_comment(f, dialect, opt, " ", "*p != '\\0'"); + fprintf(f, "\n"); fprintf(f, " i32.eqz\n"); // test if the byte is zero fprintf(f, " br_if 1\n"); // exit the outer block if so fprintf(f, "\n"); - fprintf(f, " ;; p++\n"); + print_comment(f, dialect, opt, " ", "p++"); + fprintf(f, "\n"); fprintf(f, " local.get %u\n", LOCAL_STR); - fprintf(f, " i32.const 1\n"); + print_const(f, dialect, 1, " "); fprintf(f, " i32.add\n"); fprintf(f, " local.set %u\n", LOCAL_STR); fprintf(f, "\n"); @@ -375,14 +564,12 @@ fsm_print_wasm(FILE *f, // fprintf(f, " ;; we need a block for each state: we'll start with a jump-table that\n"); // fprintf(f, " ;; branches out of the block which ends before the code we want to run\n"); for (i = 0; i < ir->n; i++) { - fprintf(f, " block ;; S%zu\n", i); + fprintf(f, " block"); + print_comment(f, dialect, opt, " ", "S%zu", i); + fprintf(f, "\n"); } fprintf(f, " local.get %u\n", LOCAL_STATE); - fprintf(f, " br_table"); - for (i = 0; i < ir->n; i++) { - fprintf(f, " %zu", i); - } - fprintf(f, "\n"); + print_table(f, dialect, ir->n, " "); for (i = 0; i < ir->n; i++) { if (i == ERROR_STATE) { @@ -390,13 +577,13 @@ fsm_print_wasm(FILE *f, return -1; } - if (-1 == print_state(f, opt, hooks, &ir->states[i], i, ir->n - i)) { + if (-1 == print_state(f, dialect, opt, hooks, &ir->states[i], i, ir->n - i)) { return -1; } } - fprintf(f, " end\n"); // end of loop - fprintf(f, " end\n"); // end of outer block + print_endloop(f, dialect, " "); + print_endblock(f, dialect, " "); fprintf(f, "\n"); // TODO: use retlist @@ -440,21 +627,34 @@ fsm_print_wasm(FILE *f, } fprintf(f, " local.get %u\n", LOCAL_STATE); - fprintf(f, " i32.const %zu ;; S%zu\n", i, i); + fprintf(f, " i32.const %zu", i); + print_comment(f, dialect, opt, " ", "S%zu", i); + fprintf(f, "\n"); fprintf(f, " i32.eq\n"); - fprintf(f, " if\n"); - fprintf(f, " i32.const %zu\n", i); + print_if(f, dialect, NULL, " "); + print_const(f, dialect, i, " "); fprintf(f, " return\n"); - fprintf(f, " end\n"); + print_endif(f, dialect, " "); fprintf(f, "\n"); } - fprintf(f, " i32.const 0\n"); + print_const(f, dialect, 0, " "); fprintf(f, " return\n"); - fprintf(f, " )\n"); + + switch (dialect) { + case DIALECT_S: + fprintf(f, " end_function\n"); + break; + + case DIALECT_WAT: + fprintf(f, " )\n"); + break; + } if (!opt->fragment) { - fprintf(f, ")\n"); + if (dialect == DIALECT_WAT) { + fprintf(f, ")\n"); + } } return 0; @@ -467,7 +667,7 @@ fsm_print_wasm_s(FILE *f, const struct ret_list *retlist, const struct ir *ir) { - return fsm_print_wasm(f, opt, hooks, retlist, ir, DIALECT_S); + return fsm_print_wasm(f, DIALECT_S, opt, hooks, retlist, ir); } int @@ -477,6 +677,6 @@ fsm_print_wat(FILE *f, const struct ret_list *retlist, const struct ir *ir) { - return fsm_print_wasm(f, opt, hooks, retlist, ir, DIALECT_WAT); + return fsm_print_wasm(f, DIALECT_WAT, opt, hooks, retlist, ir); }