From 066bb7fc992a06c71ae41861020bb2bb213a9aec Mon Sep 17 00:00:00 2001 From: Robert Konrad Date: Tue, 20 Aug 2024 19:16:36 +0200 Subject: [PATCH] Implement if for SPIR-V --- Sources/backends/spirv.c | 76 +++++++++++++++++++++++++++++++++++----- Sources/compiler.c | 57 ++++++++++++++++++++++++------ Sources/compiler.h | 5 +++ 3 files changed, 119 insertions(+), 19 deletions(-) diff --git a/Sources/backends/spirv.c b/Sources/backends/spirv.c index e622541..7a388bb 100644 --- a/Sources/backends/spirv.c +++ b/Sources/backends/spirv.c @@ -126,6 +126,7 @@ typedef enum spirv_opcode { SPIRV_OPCODE_EXECUTION_MODE = 16, SPIRV_OPCODE_CAPABILITY = 17, SPIRV_OPCODE_TYPE_VOID = 19, + SPIRV_OPCODE_TYPE_BOOL = 20, SPIRV_OPCODE_TYPE_INT = 21, SPIRV_OPCODE_TYPE_FLOAT = 22, SPIRV_OPCODE_TYPE_VECTOR = 23, @@ -142,6 +143,9 @@ typedef enum spirv_opcode { SPIRV_OPCODE_DECORATE = 71, SPIRV_OPCODE_MEMBER_DECORATE = 72, SPIRV_OPCODE_COMPOSITE_CONSTRUCT = 80, + SPIRV_OPCODE_F_ORD_LESS_THAN = 184, + SPIRV_OPCODE_SELECTION_MERGE = 247, + SPIRV_OPCODE_BRANCH_CONDITIONAL = 250, SPIRV_OPCODE_RETURN = 253, SPIRV_OPCODE_LABEL = 248 } spirv_opcode; @@ -210,6 +214,8 @@ typedef enum builtin { BUILTIN_POSITION = 0 } builtin; typedef enum storage_class { STORAGE_CLASS_INPUT = 1, STORAGE_CLASS_OUTPUT = 3, STORAGE_CLASS_FUNCTION = 7, STORAGE_CLASS_NONE = 9999 } storage_class; +typedef enum selection_control { SELECTION_CONTROL_NONE = 0, SELCTION_CONTROL_FLATTEN = 1, SELECTION_CONTROL_DONT_FLATTEN = 2 } selection_control; + typedef enum function_control { FUNCTION_CONTROL_NONE } function_control; typedef enum execution_mode { EXECUTION_MODE_ORIGIN_UPPER_LEFT = 7 } execution_mode; @@ -361,6 +367,14 @@ static spirv_id write_type_int(instructions_buffer *instructions, uint32_t width return int_type; } +static spirv_id write_type_bool(instructions_buffer *instructions) { + spirv_id bool_type = allocate_index(); + + uint32_t operands[] = {bool_type.id}; + write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_TYPE_BOOL, operands); + return bool_type; +} + static spirv_id write_type_struct(instructions_buffer *instructions, spirv_id *types, uint16_t types_size) { spirv_id struct_type = allocate_index(); @@ -394,6 +408,7 @@ static spirv_id spirv_uint_type; static spirv_id spirv_float2_type; static spirv_id spirv_float3_type; static spirv_id spirv_float4_type; +static spirv_id spirv_bool_type; typedef struct complex_type { type_id type; @@ -460,6 +475,7 @@ static void write_base_types(instructions_buffer *constants, type_id vertex_inpu spirv_uint_type = write_type_int(constants, 32, false); spirv_int_type = write_type_int(constants, 32, true); + spirv_bool_type = write_type_bool(constants); } static void write_types(instructions_buffer *constants, function *main) { @@ -550,7 +566,7 @@ static spirv_id write_op_function(instructions_buffer *instructions, spirv_id re return result; } -static spirv_id write_label(instructions_buffer *instructions) { +static spirv_id write_op_label(instructions_buffer *instructions) { spirv_id result = allocate_index(); uint32_t operands[] = {result.id}; @@ -558,11 +574,16 @@ static spirv_id write_label(instructions_buffer *instructions) { return result; } -static void write_return(instructions_buffer *instructions) { +static void write_op_label_preallocated(instructions_buffer *instructions, spirv_id result) { + uint32_t operands[] = {result.id}; + write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_LABEL, operands); +} + +static void write_op_return(instructions_buffer *instructions) { write_simple_instruction(instructions, SPIRV_OPCODE_RETURN); } -static void write_function_end(instructions_buffer *instructions) { +static void write_op_function_end(instructions_buffer *instructions) { write_simple_instruction(instructions, SPIRV_OPCODE_FUNCTION_END); } @@ -633,6 +654,26 @@ static spirv_id write_op_composite_construct(instructions_buffer *instructions, return result; } +static spirv_id write_op_f_ord_less_than(instructions_buffer *instructions, spirv_id type, spirv_id operand1, spirv_id operand2) { + spirv_id result = allocate_index(); + + uint32_t operands[] = {type.id, result.id, operand1.id, operand2.id}; + + write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_F_ORD_LESS_THAN, operands); + + return result; +} + +static void write_op_selection_merge(instructions_buffer *instructions, spirv_id merge_block, selection_control control) { + uint32_t operands[] = {merge_block.id, (uint32_t)control}; + write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_SELECTION_MERGE, operands); +} + +static void write_op_branch_conditional(instructions_buffer *instructions, spirv_id condition, spirv_id pass, spirv_id fail) { + uint32_t operands[] = {condition.id, pass.id, fail.id}; + write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_BRANCH_CONDITIONAL, operands); +} + static spirv_id write_op_variable(instructions_buffer *instructions, spirv_id result_type, storage_class storage) { spirv_id result = allocate_index(); @@ -676,7 +717,7 @@ static size_t input_vars_count = 0; static void write_function(instructions_buffer *instructions, function *f, spirv_id function_id, shader_stage stage, bool main, type_id input, type_id output) { write_op_function_preallocated(instructions, void_type, FUNCTION_CONTROL_NONE, void_function_type, function_id); - write_label(instructions); + write_op_label(instructions); debug_context context = {0}; check(f->block != NULL, context, "Function block missing"); @@ -885,17 +926,36 @@ static void write_function(instructions_buffer *instructions, function *f, spirv error(context, "Type unsupported for input in SPIR-V"); } } - write_return(instructions); + write_op_return(instructions); } else if (stage == SHADER_STAGE_FRAGMENT && main) { spirv_id object = write_op_load(instructions, convert_type_to_spirv_id(o->op_return.var.type.type), convert_kong_index_to_spirv_id(o->op_return.var.index)); write_op_store(instructions, output_var, object); - write_return(instructions); + write_op_return(instructions); } ends_with_return = true; break; } + case OPCODE_LESS: { + spirv_id result = write_op_f_ord_less_than(instructions, spirv_bool_type, convert_kong_index_to_spirv_id(o->op_binary.left.index), + convert_kong_index_to_spirv_id(o->op_binary.right.index)); + hmput(index_map, o->op_binary.result.index, result); + break; + } + case OPCODE_IF: { + write_op_selection_merge(instructions, convert_kong_index_to_spirv_id(o->op_if.end_id), SELECTION_CONTROL_NONE); + + write_op_branch_conditional(instructions, convert_kong_index_to_spirv_id(o->op_if.condition.index), convert_kong_index_to_spirv_id(o->op_if.start_id), + convert_kong_index_to_spirv_id(o->op_if.end_id)); + + break; + } + case OPCODE_BLOCK_START: + case OPCODE_BLOCK_END: { + write_op_label_preallocated(instructions, convert_kong_index_to_spirv_id(o->op_block.id)); + break; + } default: { debug_context context = {0}; error(context, "Opcode not implemented for SPIR-V"); @@ -911,10 +971,10 @@ static void write_function(instructions_buffer *instructions, function *f, spirv // TODO } else { - write_return(instructions); + write_op_return(instructions); } } - write_function_end(instructions); + write_op_function_end(instructions); } static void write_functions(instructions_buffer *instructions, function *main, spirv_id entry_point, shader_stage stage, type_id input, type_id output) { diff --git a/Sources/compiler.c b/Sources/compiler.c index a6e2c49..772cd0a 100644 --- a/Sources/compiler.c +++ b/Sources/compiler.c @@ -93,10 +93,16 @@ variable allocate_variable(type_ref type, variable_kind kind) { return v; } -void emit_op(opcodes *code, opcode *o) { +opcode *emit_op(opcodes *code, opcode *o) { assert(code->size + o->size < OPCODES_SIZE); + + uint8_t *location = &code->o[code->size]; + memcpy(&code->o[code->size], o, o->size); + code->size += o->size; + + return (opcode *)location; } #define OP_SIZE(op, opmember) offsetof(opcode, opmember) + sizeof(o.opmember) @@ -519,7 +525,12 @@ variable emit_expression(opcodes *code, block *parent, expression *e) { } } -void emit_statement(opcodes *code, block *parent, statement *statement) { +typedef struct block_ids { + uint64_t start; + uint64_t end; +} block_ids; + +static block_ids emit_statement(opcodes *code, block *parent, statement *statement) { switch (statement->kind) { case STATEMENT_EXPRESSION: emit_expression(code, parent, statement->expression); @@ -555,13 +566,16 @@ void emit_statement(opcodes *code, block *parent, statement *statement) { o.op_if.condition = initial_condition; - emit_op(code, &o); + opcode *written_opcode = emit_op(code, &o); previous_conditions[previous_conditions_size].condition = initial_condition; previous_conditions_size += 1; - } - emit_statement(code, parent, statement->iffy.if_block); + block_ids ids = emit_statement(code, parent, statement->iffy.if_block); + + written_opcode->op_if.start_id = ids.start; + written_opcode->op_if.end_id = ids.end; + } for (uint16_t i = 0; i < statement->iffy.else_size; ++i) { variable current_condition; @@ -598,7 +612,7 @@ void emit_statement(opcodes *code, block *parent, statement *statement) { opcode o; o.type = OPCODE_IF; - o.size = OP_SIZE(o, op_if) + sizeof(variable) * i; + o.size = OP_SIZE(o, op_if); if (statement->iffy.else_tests[i] != NULL) { variable v = emit_expression(code, parent, statement->iffy.else_tests[i]); @@ -627,9 +641,14 @@ void emit_statement(opcodes *code, block *parent, statement *statement) { o.op_if.condition = summed_condition; } - emit_op(code, &o); + { + opcode *written_opcode = emit_op(code, &o); + + block_ids ids = emit_statement(code, parent, statement->iffy.else_blocks[i]); - emit_statement(code, parent, statement->iffy.else_blocks[i]); + written_opcode->op_if.start_id = ids.start; + written_opcode->op_if.end_id = ids.end; + } } break; @@ -702,10 +721,17 @@ void emit_statement(opcodes *code, block *parent, statement *statement) { statement->block.vars.v[i].variable_id = var.index; } + uint64_t start_block_id = next_variable_id; + ++next_variable_id; + + uint64_t end_block_id = next_variable_id; + ++next_variable_id; + { opcode o; o.type = OPCODE_BLOCK_START; - o.size = OP_SIZE(o, op_nothing); + o.op_block.id = start_block_id; + o.size = OP_SIZE(o, op_block); emit_op(code, &o); } @@ -716,11 +742,15 @@ void emit_statement(opcodes *code, block *parent, statement *statement) { { opcode o; o.type = OPCODE_BLOCK_END; - o.size = OP_SIZE(o, op_nothing); + o.op_block.id = end_block_id; + o.size = OP_SIZE(o, op_block); emit_op(code, &o); } - break; + block_ids ids; + ids.start = start_block_id; + ids.end = end_block_id; + return ids; } case STATEMENT_LOCAL_VARIABLE: { opcode o; @@ -754,6 +784,11 @@ void emit_statement(opcodes *code, block *parent, statement *statement) { break; } } + + block_ids ids; + ids.start = 0; + ids.end = 0; + return ids; } void convert_globals(void) { diff --git a/Sources/compiler.h b/Sources/compiler.h index d894009..0b75c90 100644 --- a/Sources/compiler.h +++ b/Sources/compiler.h @@ -101,10 +101,15 @@ typedef struct opcode { } op_binary; struct { variable condition; + uint64_t start_id; + uint64_t end_id; } op_if; struct { variable condition; } op_while; + struct { + uint64_t id; + } op_block; struct { uint8_t nothing; } op_nothing;