diff --git a/Sources/backends/cstyle.c b/Sources/backends/cstyle.c index 6b7d347..f6c621a 100644 --- a/Sources/backends/cstyle.c +++ b/Sources/backends/cstyle.c @@ -157,6 +157,12 @@ void cstyle_write_opcode(char *code, size_t *offset, opcode *o, type_string_func o->op_binary.result.index, o->op_binary.left.index, o->op_binary.right.index); break; } + case OPCODE_MOD: { + indent(code, offset, *indentation); + *offset += sprintf(&code[*offset], "%s _%" PRIu64 " = _%" PRIu64 " %% _%" PRIu64 ";\n", type_string(o->op_binary.result.type.type), + o->op_binary.result.index, o->op_binary.left.index, o->op_binary.right.index); + break; + } case OPCODE_EQUALS: { indent(code, offset, *indentation); *offset += sprintf(&code[*offset], "%s _%" PRIu64 " = _%" PRIu64 " == _%" PRIu64 ";\n", type_string(o->op_binary.result.type.type), @@ -205,6 +211,12 @@ void cstyle_write_opcode(char *code, size_t *offset, opcode *o, type_string_func o->op_binary.result.index, o->op_binary.left.index, o->op_binary.right.index); break; } + case OPCODE_XOR: { + indent(code, offset, *indentation); + *offset += sprintf(&code[*offset], "%s _%" PRIu64 " = _%" PRIu64 " ^ _%" PRIu64 ";\n", type_string(o->op_binary.result.type.type), + o->op_binary.result.index, o->op_binary.left.index, o->op_binary.right.index); + break; + } case OPCODE_IF: { indent(code, offset, *indentation); *offset += sprintf(&code[*offset], "if (_%" PRIu64 ")\n", o->op_if.condition.index); diff --git a/Sources/backends/hlsl.c b/Sources/backends/hlsl.c index 86b5ae0..61b126f 100644 --- a/Sources/backends/hlsl.c +++ b/Sources/backends/hlsl.c @@ -418,6 +418,18 @@ static void write_root_signature(char *hlsl, size_t *offset) { *offset += sprintf(&hlsl[*offset], "\")]\n"); } +static type_id payload_types[256]; +static size_t payload_types_count = 0; + +static bool is_payload_type(type_id t) { + for (size_t payload_index = 0; payload_index < payload_types_count; ++payload_index) { + if (payload_types[payload_index] == t) { + return true; + } + } + return false; +} + static void write_functions(char *hlsl, size_t *offset, shader_stage stage, function *main, function **rayshaders, size_t rayshaders_count) { function *functions[256]; size_t functions_size = 0; @@ -435,6 +447,83 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func find_referenced_functions(rayshaders[rayshader_index], functions, &functions_size); } + // find payloads + for (size_t i = 0; i < functions_size; ++i) { + function *f = functions[i]; + + uint8_t *data = f->code.o; + size_t size = f->code.size; + + size_t index = 0; + while (index < size) { + opcode *o = (opcode *)&data[index]; + switch (o->type) { + case OPCODE_CALL: { + if (o->op_call.func == add_name("trace_ray")) { + debug_context context = {0}; + check(o->op_call.parameters_size == 3, context, "trace_ray requires three parameters"); + + type_id payload_type = o->op_call.parameters[2].type.type; + + bool found = false; + for (size_t payload_index = 0; payload_index < payload_types_count; ++payload_index) { + if (payload_types[payload_index] == payload_type) { + found = true; + break; + } + } + + if (!found) { + payload_types[payload_types_count] = payload_type; + payload_types_count += 1; + } + } + } + } + index += o->size; + } + } + + // function declarations + for (size_t i = 0; i < functions_size; ++i) { + function *f = functions[i]; + + if (f != main && !is_raygen_shader(f) && !is_raymiss_shader(f) && !is_rayclosesthit_shader(f) && !is_rayintersection_shader(f) && + !is_rayanyhit_shader(f)) { + + uint64_t parameter_ids[256] = {0}; + for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) { + for (size_t i = 0; i < f->block->block.vars.size; ++i) { + if (f->parameter_names[parameter_index] == f->block->block.vars.v[i].name) { + parameter_ids[parameter_index] = f->block->block.vars.v[i].variable_id; + break; + } + } + } + + *offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name)); + for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) { + char *payload_prefix = ""; + if (is_payload_type(f->parameter_types[parameter_index].type)) { + payload_prefix = "inout "; + } + + if (parameter_index == 0) { + + *offset += sprintf(&hlsl[*offset], "%s%s _%" PRIu64, payload_prefix, type_string(f->parameter_types[parameter_index].type), + parameter_ids[parameter_index]); + } + else { + *offset += sprintf(&hlsl[*offset], ", %s%s _%" PRIu64, payload_prefix, type_string(f->parameter_types[parameter_index].type), + parameter_ids[parameter_index]); + } + } + *offset += sprintf(&hlsl[*offset], ");\n"); + } + } + + *offset += sprintf(&hlsl[*offset], "\n"); + for (size_t i = 0; i < functions_size; ++i) { function *f = functions[i]; @@ -706,11 +795,18 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func else { *offset += sprintf(&hlsl[*offset], "%s %s(", type_string(f->return_type.type), get_name(f->name)); for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) { + char *payload_prefix = ""; + if (is_payload_type(f->parameter_types[parameter_index].type)) { + payload_prefix = "inout "; + } + if (parameter_index == 0) { - *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]); + *offset += sprintf(&hlsl[*offset], "%s%s _%" PRIu64, payload_prefix, type_string(f->parameter_types[parameter_index].type), + parameter_ids[parameter_index]); } else { - *offset += sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]); + *offset += sprintf(&hlsl[*offset], ", %s%s _%" PRIu64, payload_prefix, type_string(f->parameter_types[parameter_index].type), + parameter_ids[parameter_index]); } } *offset += sprintf(&hlsl[*offset], ") {\n"); @@ -823,19 +919,40 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter"); *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = _kong_group_index;\n", type_string(o->op_call.var.type.type), o->op_call.var.index); } + else if (o->op_call.func == add_name("instance_id")) { + check(o->op_call.parameters_size == 0, context, "instance_id can not have a parameter"); + *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = InstanceID();\n", type_string(o->op_call.var.type.type), o->op_call.var.index); + } else if (o->op_call.func == add_name("world_ray_direction")) { - check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter"); + check(o->op_call.parameters_size == 0, context, "world_ray_direction can not have a parameter"); *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = WorldRayDirection();\n", type_string(o->op_call.var.type.type), o->op_call.var.index); } + else if (o->op_call.func == add_name("world_ray_origin")) { + check(o->op_call.parameters_size == 0, context, "world_ray_origin can not have a parameter"); + *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = WorldRayOrigin();\n", type_string(o->op_call.var.type.type), o->op_call.var.index); + } + else if (o->op_call.func == add_name("ray_length")) { + check(o->op_call.parameters_size == 0, context, "ray_length can not have a parameter"); + *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = RayTCurrent();\n", type_string(o->op_call.var.type.type), o->op_call.var.index); + } else if (o->op_call.func == add_name("ray_index")) { - check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter"); + check(o->op_call.parameters_size == 0, context, "ray_index can not have a parameter"); *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = DispatchRaysIndex();\n", type_string(o->op_call.var.type.type), o->op_call.var.index); } else if (o->op_call.func == add_name("ray_dimensions")) { - check(o->op_call.parameters_size == 0, context, "group_index can not have a parameter"); + check(o->op_call.parameters_size == 0, context, "ray_dimensions can not have a parameter"); *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = DispatchRaysDimensions();\n", type_string(o->op_call.var.type.type), o->op_call.var.index); } + else if (o->op_call.func == add_name("object_to_world3x3")) { + check(o->op_call.parameters_size == 0, context, "object_to_world3x3 can not have a parameter"); + *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = (float3x3)ObjectToWorld4x3();\n", type_string(o->op_call.var.type.type), + o->op_call.var.index); + } + else if (o->op_call.func == add_name("primitive_index")) { + check(o->op_call.parameters_size == 0, context, "primitive_index can not have a parameter"); + *offset += sprintf(&hlsl[*offset], "%s _%" PRIu64 " = PrimitiveIndex();\n", type_string(o->op_call.var.type.type), o->op_call.var.index); + } else if (o->op_call.func == add_name("trace_ray")) { check(o->op_call.parameters_size == 3, context, "trace_ray requires three parameters"); *offset += sprintf(&hlsl[*offset], "TraceRay(_%" PRIu64 ", RAY_FLAG_NONE, 0xFF, 0, 0, 0, _%" PRIu64 ", _%" PRIu64 ");\n", diff --git a/Sources/compiler.c b/Sources/compiler.c index 7d2f6df..03ae19a 100644 --- a/Sources/compiler.c +++ b/Sources/compiler.c @@ -122,7 +122,8 @@ variable emit_expression(opcodes *code, block *parent, expression *e) { case OPERATOR_LESS: case OPERATOR_LESS_EQUAL: case OPERATOR_AND: - case OPERATOR_OR: { + case OPERATOR_OR: + case OPERATOR_XOR: { variable right_var = emit_expression(code, parent, right); variable left_var = emit_expression(code, parent, left); type_ref t; @@ -156,6 +157,9 @@ variable emit_expression(opcodes *code, block *parent, expression *e) { case OPERATOR_OR: o.type = OPCODE_OR; break; + case OPERATOR_XOR: + o.type = OPCODE_XOR; + break; default: { debug_context context = {0}; error(context, "Unexpected operator"); @@ -172,7 +176,8 @@ variable emit_expression(opcodes *code, block *parent, expression *e) { case OPERATOR_MINUS: case OPERATOR_PLUS: case OPERATOR_DIVIDE: - case OPERATOR_MULTIPLY: { + case OPERATOR_MULTIPLY: + case OPERATOR_MOD: { variable right_var = emit_expression(code, parent, right); variable left_var = emit_expression(code, parent, left); variable result_var = allocate_variable(e->type, VARIABLE_LOCAL); @@ -191,6 +196,9 @@ variable emit_expression(opcodes *code, block *parent, expression *e) { case OPERATOR_MULTIPLY: o.type = OPCODE_MULTIPLY; break; + case OPERATOR_MOD: + o.type = OPCODE_MOD; + break; default: { debug_context context = {0}; error(context, "Unexpected operator"); @@ -208,10 +216,6 @@ variable emit_expression(opcodes *code, block *parent, expression *e) { debug_context context = {0}; error(context, "! is not a binary operator"); } - case OPERATOR_MOD: { - debug_context context = {0}; - error(context, "not implemented"); - } case OPERATOR_ASSIGN: case OPERATOR_MINUS_ASSIGN: case OPERATOR_PLUS_ASSIGN: @@ -402,6 +406,8 @@ variable emit_expression(opcodes *code, block *parent, expression *e) { } case OPERATOR_OR: error(context, "not implemented"); + case OPERATOR_XOR: + error(context, "not implemented"); case OPERATOR_AND: error(context, "not implemented"); case OPERATOR_MOD: diff --git a/Sources/compiler.h b/Sources/compiler.h index ee9c58a..1624796 100644 --- a/Sources/compiler.h +++ b/Sources/compiler.h @@ -36,6 +36,7 @@ typedef struct opcode { OPCODE_CALL, OPCODE_MULTIPLY, OPCODE_DIVIDE, + OPCODE_MOD, OPCODE_ADD, OPCODE_SUB, OPCODE_EQUALS, @@ -46,6 +47,7 @@ typedef struct opcode { OPCODE_LESS_EQUAL, OPCODE_AND, OPCODE_OR, + OPCODE_XOR, OPCODE_IF, OPCODE_WHILE_START, OPCODE_WHILE_CONDITION, diff --git a/Sources/functions.c b/Sources/functions.c index d4e759f..9fbf38f 100644 --- a/Sources/functions.c +++ b/Sources/functions.c @@ -35,6 +35,15 @@ static void add_func_float3_float_float_float(char *name) { f->block = NULL; } +static void add_func_float(char *name) { + function_id func = add_function(add_name(name)); + function *f = get_function(func); + init_type_ref(&f->return_type, add_name("float")); + f->return_type.type = find_type_by_ref(&f->return_type); + f->parameters_size = 0; + f->block = NULL; +} + static void add_func_float3(char *name) { function_id func = add_function(add_name(name)); function *f = get_function(func); @@ -44,6 +53,24 @@ static void add_func_float3(char *name) { f->block = NULL; } +static void add_func_float3x3(char *name) { + function_id func = add_function(add_name(name)); + function *f = get_function(func); + init_type_ref(&f->return_type, add_name("float3x3")); + f->return_type.type = find_type_by_ref(&f->return_type); + f->parameters_size = 0; + f->block = NULL; +} + +static void add_func_uint(char *name) { + function_id func = add_function(add_name(name)); + function *f = get_function(func); + init_type_ref(&f->return_type, add_name("uint")); + f->return_type.type = find_type_by_ref(&f->return_type); + f->parameters_size = 0; + f->block = NULL; +} + static void add_func_uint3(char *name) { function_id func = add_function(add_name(name)); function *f = get_function(func); @@ -89,6 +116,24 @@ static void add_func_float3_float3(char *name) { f->block = NULL; } +static void add_func_float3_float3_float3(char *name) { + function_id func = add_function(add_name(name)); + function *f = get_function(func); + init_type_ref(&f->return_type, add_name("float3")); + f->return_type.type = find_type_by_ref(&f->return_type); + + f->parameter_names[0] = add_name("a"); + init_type_ref(&f->parameter_types[0], add_name("float3")); + f->parameter_types[0].type = find_type_by_ref(&f->parameter_types[0]); + + f->parameter_names[1] = add_name("b"); + init_type_ref(&f->parameter_types[1], add_name("float3")); + f->parameter_types[1].type = find_type_by_ref(&f->parameter_types[1]); + + f->parameters_size = 2; + f->block = NULL; +} + static void add_func_void_uint_uint(char *name) { function_id func = add_function(add_name(name)); function *f = get_function(func); @@ -531,9 +576,12 @@ void functions_init(void) { add_func_uint3("group_thread_id"); add_func_uint3("dispatch_thread_id"); add_func_int("group_index"); + add_func_int("instance_id"); add_func_float3_float_float_float("lerp"); + add_func_float3("world_ray_origin"); add_func_float3("world_ray_direction"); + add_func_float("ray_length"); add_func_float3_float3("normalize"); add_func_float_float("saturate"); add_func_float_float("sin"); @@ -541,6 +589,13 @@ void functions_init(void) { add_func_float_float2("length"); add_func_uint3("ray_index"); add_func_float3("ray_dimensions"); + add_func_float_float("frac"); + add_func_float3x3("object_to_world3x3"); + add_func_float3_float3_float3("reflect"); + add_func_uint("primitive_index"); + add_func_float3_float3("abs"); + add_func_float3_float3_float3("dot"); + add_func_float3_float3("saturate"); add_func_void_uint_uint("set_mesh_output_counts"); diff --git a/Sources/kong.c b/Sources/kong.c index 0d4ce93..29137d7 100644 --- a/Sources/kong.c +++ b/Sources/kong.c @@ -293,7 +293,8 @@ void resolve_types_in_expression(statement *parent, expression *e) { case OPERATOR_LESS: case OPERATOR_LESS_EQUAL: case OPERATOR_OR: - case OPERATOR_AND: { + case OPERATOR_AND: + case OPERATOR_XOR: { e->type.type = bool_id; break; } @@ -371,6 +372,7 @@ void resolve_types_in_expression(statement *parent, expression *e) { case OPERATOR_DIVIDE: case OPERATOR_MULTIPLY: case OPERATOR_OR: + case OPERATOR_XOR: case OPERATOR_AND: case OPERATOR_MOD: case OPERATOR_ASSIGN: diff --git a/Sources/parser.c b/Sources/parser.c index c70bd08..e39295e 100644 --- a/Sources/parser.c +++ b/Sources/parser.c @@ -527,7 +527,7 @@ static expression *parse_logical(state_t *state) { while (!done) { if (current(state).kind == TOKEN_OPERATOR) { operatorr op = current(state).op; - if (op == OPERATOR_OR || op == OPERATOR_AND) { + if (op == OPERATOR_OR || op == OPERATOR_AND || op == OPERATOR_XOR) { advance_state(state); expression *right = parse_equality(state); expression *expression = expression_allocate(); diff --git a/Sources/tokenizer.c b/Sources/tokenizer.c index 4e6a673..ac58aeb 100644 --- a/Sources/tokenizer.c +++ b/Sources/tokenizer.c @@ -17,7 +17,8 @@ static bool is_num(char ch, char chch) { } static bool is_op(char ch) { - return ch == '&' || ch == '|' || ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '=' || ch == '!' || ch == '<' || ch == '>' || ch == '%'; + return ch == '&' || ch == '|' || ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '=' || ch == '!' || ch == '<' || ch == '>' || ch == '%' || + ch == '^'; } static bool is_whitespace(char ch) { @@ -454,6 +455,11 @@ tokens tokenize(const char *filename, const char *source) { token.op = OPERATOR_OR; tokens_add(&tokens, token); } + else if (tokenizer_buffer_equals(&buffer, "^")) { + token token = token_create(TOKEN_OPERATOR, &state); + token.op = OPERATOR_XOR; + tokens_add(&tokens, token); + } else if (tokenizer_buffer_equals(&buffer, "&&")) { token token = token_create(TOKEN_OPERATOR, &state); token.op = OPERATOR_AND; diff --git a/Sources/tokenizer.h b/Sources/tokenizer.h index fef1d1d..f9f8a99 100644 --- a/Sources/tokenizer.h +++ b/Sources/tokenizer.h @@ -18,6 +18,7 @@ typedef enum operatorr { OPERATOR_MULTIPLY, OPERATOR_NOT, OPERATOR_OR, + OPERATOR_XOR, OPERATOR_AND, OPERATOR_MOD, OPERATOR_ASSIGN,