Skip to content

Commit

Permalink
Implement if for SPIR-V
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Aug 20, 2024
1 parent a72b4fc commit 066bb7f
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 19 deletions.
76 changes: 68 additions & 8 deletions Sources/backends/spirv.c
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ typedef enum spirv_opcode {
SPIRV_OPCODE_EXECUTION_MODE = 16,
SPIRV_OPCODE_CAPABILITY = 17,
SPIRV_OPCODE_TYPE_VOID = 19,
SPIRV_OPCODE_TYPE_BOOL = 20,
SPIRV_OPCODE_TYPE_INT = 21,
SPIRV_OPCODE_TYPE_FLOAT = 22,
SPIRV_OPCODE_TYPE_VECTOR = 23,
Expand All @@ -142,6 +143,9 @@ typedef enum spirv_opcode {
SPIRV_OPCODE_DECORATE = 71,
SPIRV_OPCODE_MEMBER_DECORATE = 72,
SPIRV_OPCODE_COMPOSITE_CONSTRUCT = 80,
SPIRV_OPCODE_F_ORD_LESS_THAN = 184,
SPIRV_OPCODE_SELECTION_MERGE = 247,
SPIRV_OPCODE_BRANCH_CONDITIONAL = 250,
SPIRV_OPCODE_RETURN = 253,
SPIRV_OPCODE_LABEL = 248
} spirv_opcode;
Expand Down Expand Up @@ -210,6 +214,8 @@ typedef enum builtin { BUILTIN_POSITION = 0 } builtin;

typedef enum storage_class { STORAGE_CLASS_INPUT = 1, STORAGE_CLASS_OUTPUT = 3, STORAGE_CLASS_FUNCTION = 7, STORAGE_CLASS_NONE = 9999 } storage_class;

typedef enum selection_control { SELECTION_CONTROL_NONE = 0, SELCTION_CONTROL_FLATTEN = 1, SELECTION_CONTROL_DONT_FLATTEN = 2 } selection_control;

typedef enum function_control { FUNCTION_CONTROL_NONE } function_control;

typedef enum execution_mode { EXECUTION_MODE_ORIGIN_UPPER_LEFT = 7 } execution_mode;
Expand Down Expand Up @@ -361,6 +367,14 @@ static spirv_id write_type_int(instructions_buffer *instructions, uint32_t width
return int_type;
}

static spirv_id write_type_bool(instructions_buffer *instructions) {
spirv_id bool_type = allocate_index();

uint32_t operands[] = {bool_type.id};
write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_TYPE_BOOL, operands);
return bool_type;
}

static spirv_id write_type_struct(instructions_buffer *instructions, spirv_id *types, uint16_t types_size) {
spirv_id struct_type = allocate_index();

Expand Down Expand Up @@ -394,6 +408,7 @@ static spirv_id spirv_uint_type;
static spirv_id spirv_float2_type;
static spirv_id spirv_float3_type;
static spirv_id spirv_float4_type;
static spirv_id spirv_bool_type;

typedef struct complex_type {
type_id type;
Expand Down Expand Up @@ -460,6 +475,7 @@ static void write_base_types(instructions_buffer *constants, type_id vertex_inpu

spirv_uint_type = write_type_int(constants, 32, false);
spirv_int_type = write_type_int(constants, 32, true);
spirv_bool_type = write_type_bool(constants);
}

static void write_types(instructions_buffer *constants, function *main) {
Expand Down Expand Up @@ -550,19 +566,24 @@ static spirv_id write_op_function(instructions_buffer *instructions, spirv_id re
return result;
}

static spirv_id write_label(instructions_buffer *instructions) {
static spirv_id write_op_label(instructions_buffer *instructions) {
spirv_id result = allocate_index();

uint32_t operands[] = {result.id};
write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_LABEL, operands);
return result;
}

static void write_return(instructions_buffer *instructions) {
static void write_op_label_preallocated(instructions_buffer *instructions, spirv_id result) {
uint32_t operands[] = {result.id};
write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_LABEL, operands);
}

static void write_op_return(instructions_buffer *instructions) {
write_simple_instruction(instructions, SPIRV_OPCODE_RETURN);
}

static void write_function_end(instructions_buffer *instructions) {
static void write_op_function_end(instructions_buffer *instructions) {
write_simple_instruction(instructions, SPIRV_OPCODE_FUNCTION_END);
}

Expand Down Expand Up @@ -633,6 +654,26 @@ static spirv_id write_op_composite_construct(instructions_buffer *instructions,
return result;
}

static spirv_id write_op_f_ord_less_than(instructions_buffer *instructions, spirv_id type, spirv_id operand1, spirv_id operand2) {
spirv_id result = allocate_index();

uint32_t operands[] = {type.id, result.id, operand1.id, operand2.id};

write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_F_ORD_LESS_THAN, operands);

return result;
}

static void write_op_selection_merge(instructions_buffer *instructions, spirv_id merge_block, selection_control control) {
uint32_t operands[] = {merge_block.id, (uint32_t)control};
write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_SELECTION_MERGE, operands);
}

static void write_op_branch_conditional(instructions_buffer *instructions, spirv_id condition, spirv_id pass, spirv_id fail) {
uint32_t operands[] = {condition.id, pass.id, fail.id};
write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_BRANCH_CONDITIONAL, operands);
}

static spirv_id write_op_variable(instructions_buffer *instructions, spirv_id result_type, storage_class storage) {
spirv_id result = allocate_index();

Expand Down Expand Up @@ -676,7 +717,7 @@ static size_t input_vars_count = 0;

static void write_function(instructions_buffer *instructions, function *f, spirv_id function_id, shader_stage stage, bool main, type_id input, type_id output) {
write_op_function_preallocated(instructions, void_type, FUNCTION_CONTROL_NONE, void_function_type, function_id);
write_label(instructions);
write_op_label(instructions);

debug_context context = {0};
check(f->block != NULL, context, "Function block missing");
Expand Down Expand Up @@ -885,17 +926,36 @@ static void write_function(instructions_buffer *instructions, function *f, spirv
error(context, "Type unsupported for input in SPIR-V");
}
}
write_return(instructions);
write_op_return(instructions);
}
else if (stage == SHADER_STAGE_FRAGMENT && main) {
spirv_id object =
write_op_load(instructions, convert_type_to_spirv_id(o->op_return.var.type.type), convert_kong_index_to_spirv_id(o->op_return.var.index));
write_op_store(instructions, output_var, object);
write_return(instructions);
write_op_return(instructions);
}
ends_with_return = true;
break;
}
case OPCODE_LESS: {
spirv_id result = write_op_f_ord_less_than(instructions, spirv_bool_type, convert_kong_index_to_spirv_id(o->op_binary.left.index),
convert_kong_index_to_spirv_id(o->op_binary.right.index));
hmput(index_map, o->op_binary.result.index, result);
break;
}
case OPCODE_IF: {
write_op_selection_merge(instructions, convert_kong_index_to_spirv_id(o->op_if.end_id), SELECTION_CONTROL_NONE);

write_op_branch_conditional(instructions, convert_kong_index_to_spirv_id(o->op_if.condition.index), convert_kong_index_to_spirv_id(o->op_if.start_id),
convert_kong_index_to_spirv_id(o->op_if.end_id));

break;
}
case OPCODE_BLOCK_START:
case OPCODE_BLOCK_END: {
write_op_label_preallocated(instructions, convert_kong_index_to_spirv_id(o->op_block.id));
break;
}
default: {
debug_context context = {0};
error(context, "Opcode not implemented for SPIR-V");
Expand All @@ -911,10 +971,10 @@ static void write_function(instructions_buffer *instructions, function *f, spirv
// TODO
}
else {
write_return(instructions);
write_op_return(instructions);
}
}
write_function_end(instructions);
write_op_function_end(instructions);
}

static void write_functions(instructions_buffer *instructions, function *main, spirv_id entry_point, shader_stage stage, type_id input, type_id output) {
Expand Down
57 changes: 46 additions & 11 deletions Sources/compiler.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,16 @@ variable allocate_variable(type_ref type, variable_kind kind) {
return v;
}

void emit_op(opcodes *code, opcode *o) {
opcode *emit_op(opcodes *code, opcode *o) {
assert(code->size + o->size < OPCODES_SIZE);

uint8_t *location = &code->o[code->size];

memcpy(&code->o[code->size], o, o->size);

code->size += o->size;

return (opcode *)location;
}

#define OP_SIZE(op, opmember) offsetof(opcode, opmember) + sizeof(o.opmember)
Expand Down Expand Up @@ -519,7 +525,12 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
}
}

void emit_statement(opcodes *code, block *parent, statement *statement) {
typedef struct block_ids {
uint64_t start;
uint64_t end;
} block_ids;

static block_ids emit_statement(opcodes *code, block *parent, statement *statement) {
switch (statement->kind) {
case STATEMENT_EXPRESSION:
emit_expression(code, parent, statement->expression);
Expand Down Expand Up @@ -555,13 +566,16 @@ void emit_statement(opcodes *code, block *parent, statement *statement) {

o.op_if.condition = initial_condition;

emit_op(code, &o);
opcode *written_opcode = emit_op(code, &o);

previous_conditions[previous_conditions_size].condition = initial_condition;
previous_conditions_size += 1;
}

emit_statement(code, parent, statement->iffy.if_block);
block_ids ids = emit_statement(code, parent, statement->iffy.if_block);

written_opcode->op_if.start_id = ids.start;
written_opcode->op_if.end_id = ids.end;
}

for (uint16_t i = 0; i < statement->iffy.else_size; ++i) {
variable current_condition;
Expand Down Expand Up @@ -598,7 +612,7 @@ void emit_statement(opcodes *code, block *parent, statement *statement) {

opcode o;
o.type = OPCODE_IF;
o.size = OP_SIZE(o, op_if) + sizeof(variable) * i;
o.size = OP_SIZE(o, op_if);

if (statement->iffy.else_tests[i] != NULL) {
variable v = emit_expression(code, parent, statement->iffy.else_tests[i]);
Expand Down Expand Up @@ -627,9 +641,14 @@ void emit_statement(opcodes *code, block *parent, statement *statement) {
o.op_if.condition = summed_condition;
}

emit_op(code, &o);
{
opcode *written_opcode = emit_op(code, &o);

block_ids ids = emit_statement(code, parent, statement->iffy.else_blocks[i]);

emit_statement(code, parent, statement->iffy.else_blocks[i]);
written_opcode->op_if.start_id = ids.start;
written_opcode->op_if.end_id = ids.end;
}
}

break;
Expand Down Expand Up @@ -702,10 +721,17 @@ void emit_statement(opcodes *code, block *parent, statement *statement) {
statement->block.vars.v[i].variable_id = var.index;
}

uint64_t start_block_id = next_variable_id;
++next_variable_id;

uint64_t end_block_id = next_variable_id;
++next_variable_id;

{
opcode o;
o.type = OPCODE_BLOCK_START;
o.size = OP_SIZE(o, op_nothing);
o.op_block.id = start_block_id;
o.size = OP_SIZE(o, op_block);
emit_op(code, &o);
}

Expand All @@ -716,11 +742,15 @@ void emit_statement(opcodes *code, block *parent, statement *statement) {
{
opcode o;
o.type = OPCODE_BLOCK_END;
o.size = OP_SIZE(o, op_nothing);
o.op_block.id = end_block_id;
o.size = OP_SIZE(o, op_block);
emit_op(code, &o);
}

break;
block_ids ids;
ids.start = start_block_id;
ids.end = end_block_id;
return ids;
}
case STATEMENT_LOCAL_VARIABLE: {
opcode o;
Expand Down Expand Up @@ -754,6 +784,11 @@ void emit_statement(opcodes *code, block *parent, statement *statement) {
break;
}
}

block_ids ids;
ids.start = 0;
ids.end = 0;
return ids;
}

void convert_globals(void) {
Expand Down
5 changes: 5 additions & 0 deletions Sources/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,15 @@ typedef struct opcode {
} op_binary;
struct {
variable condition;
uint64_t start_id;
uint64_t end_id;
} op_if;
struct {
variable condition;
} op_while;
struct {
uint64_t id;
} op_block;
struct {
uint8_t nothing;
} op_nothing;
Expand Down

0 comments on commit 066bb7f

Please sign in to comment.