diff --git a/.gitignore b/.gitignore index 1761ea1..17b9b3d 100644 --- a/.gitignore +++ b/.gitignore @@ -61,7 +61,7 @@ CMakeCache.txt CMakeFiles CMakeScripts Testing -Makefile +# Makefile cmake_install.cmake install_manifest.txt compile_commands.json @@ -85,3 +85,4 @@ _deps *tmp* KariLang +bin/ diff --git a/CMakeLists.txt b/CMakeLists.txt index fb67e78..8d63233 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,21 +1,79 @@ cmake_minimum_required(VERSION 3.20) -project(KariLang C) +project(KariLang) -set(CMAKE_C_STANDARD 11) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +add_definitions(-Wall) find_package(BISON 3.0 REQUIRED) -bison_target(PARSER "src/parser.yy" "src/parser.tab.c" DEFINES_FILE "src/parser.tab.h" COMPILE_FLAGS "-Wall -Wcounterexamples") +bison_target(PARSER "src/Parser.yy" "${CMAKE_CURRENT_BINARY_DIR}/Parser.tab.cc" DEFINES_FILE "${CMAKE_CURRENT_BINARY_DIR}/Parser.tab.hh" COMPILE_FLAGS "-Wall -Wcounterexamples") find_package(FLEX 2.0 REQUIRED) -flex_target(LEXER "src/lexer.l" "src/lex.yy.c") +flex_target(LEXER "src/Lexer.l" "${CMAKE_CURRENT_BINARY_DIR}/Lex.yy.cc") + +find_package(LLVM 17 REQUIRED CONFIG) + +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +include_directories( + ${CMAKE_CURRENT_BINARY_DIR} + src/ + ${LLVM_INCLUDE_DIRS} + ) + +add_definitions(${LLVM_DEFINITIONS_LIST}) +separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) + +add_executable(KariLang src/Main.cc + src/AST.cc + src/AST.hh + src/Compile.cc + src/Compile.hh + src/JIT.cc + src/JIT.hh + src/Parser.hh + src/PCH.hh + src/Utils.hh + ${BISON_PARSER_OUTPUTS} + ${FLEX_LEXER_OUTPUTS}) + +target_precompile_headers(KariLang PUBLIC src/PCH.hh) + +# getting LLVM libraries to link with +set(KARILANG_LLVM_COMPONENTS core support mcjit orcjit native asmparser asmprinter irreader) + + +# Enable the native target +if ("${LLVM_NATIVE_ARCH}" STREQUAL "AArch64") + set(WITH_TARGET_AARCH64 yes) +endif() + +if ("${LLVM_NATIVE_ARCH}" STREQUAL "X86") + set(WITH_TARGET_X86 yes) +endif() + +if (WITH_TARGET_AARCH64) + if (NOT ("${LLVM_TARGETS_TO_BUILD}" MATCHES "AArch64")) + message(FATAL_ERROR "The selected LLVM library doesn't have support for AArch64 targets") + endif() + + list(APPEND KARILANG_LLVM_COMPONENTS aarch64info aarch64utils aarch64desc aarch64asmparser aarch64codegen aarch64disassembler) + add_definitions("-DHAVE_TARGET_AARCH64=1") +endif() + +if (WITH_TARGET_X86) + if (NOT ("${LLVM_TARGETS_TO_BUILD}" MATCHES "X86")) + message(FATAL_ERROR "The selected LLVM library doesn't have support for X86 targets") + endif() + + list(APPEND KARILANG_LLVM_COMPONENTS x86info x86desc x86codegen x86asmparser x86disassembler) + add_definitions("-DHAVE_TARGET_X86=1") +endif() + +llvm_map_components_to_libnames(llvm_libs ${KARILANG_LLVM_COMPONENTS}) -add_executable(KariLang src/main.c - src/semantics.c - src/interpreter.c - src/parser.tab.c - src/lex.yy.c - src/DS.h - src/common.h - src/cli_interpreter.h) +target_link_libraries(KariLang + "${llvm_libs}") diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..93d149b --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +BUILD_DIR=bin +SRC_DIR=src + +LEXER_FILE = $(SRC_DIR)/Lexer.l +LEXER_OUTPUT = $(BUILD_DIR)/Lexer.yy.cc + +PARSER_FILE = $(SRC_DIR)/Parser.yy +PARSER_OUTPUT = $(BUILD_DIR)/Parser.tab.cc + +SOURCE_FILES = \ + $(SRC_DIR)/Main.cc \ + $(SRC_DIR)/AST.cc \ + $(SRC_DIR)/Compile.cc \ + $(SRC_DIR)/JIT.cc + +HEADER_FILES = \ + $(SRC_DIR)/Utils.hh \ + $(SRC_DIR)/AST.hh \ + $(SRC_DIR)/JIT.hh \ + $(SRC_DIR)/Parser.hh + +INCLUDE_OPTIONS = \ + -I$(SRC_DIR) \ + -I$(BUILD_DIR) + +C = clang +CXX = clang++ +BUILD_OPTIONS = \ + -Wall \ + -g \ + `llvm-config --cxxflags --ldflags --system-libs --libs all` + # -fsanitize=address \ + # -fno-omit-frame-pointer + +$(BUILD_DIR)/KariLang: $(SOURCE_FILES) $(LEXER_OUTPUT) $(PARSER_OUTPUT) $(HEADER_FILES) + $(CXX) $(INCLUDE_OPTIONS) $(SOURCE_FILES) $(LEXER_OUTPUT) $(PARSER_OUTPUT) $(BUILD_OPTIONS) -o $@ + +$(LEXER_OUTPUT): $(LEXER_FILE) $(PARSER_OUTPUT) + flex -o $@ $(LEXER_FILE) + +$(PARSER_OUTPUT): $(PARSER_FILE) + bison -H $^ -o $@ + +.PHONY: clean +clean: + rm bin/* && rm tmp/*.o diff --git a/README.md b/README.md index 011f98a..1223570 100644 --- a/README.md +++ b/README.md @@ -57,22 +57,14 @@ from the [releases](https://github.com/Vipul-Cariappa/KariLang/releases) page. If you want to compile from source: -Go to src directory ```bash -cd src/ +mkdir bin && cd bin +cmake .. -DCMAKE_BUILD_TYPE=Release +make -j4 ``` -Compile the parser -```bash -bison -Wall -Wcounterexamples -H ./parser.y -``` - -Compiler the lexer -```bash -flex ./lexer.l -``` +While development to speed up compilation time use `lld` linker -Compiler the language ```bash -cc -Wall -g ./main.c ./semantics.c ./interpreter.c ./lex.yy.c ./parser.tab.c -o ./KariLang +cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-fuse-ld=lld" ``` diff --git a/runtime/main.c b/runtime/main.c new file mode 100644 index 0000000..dbdf2ad --- /dev/null +++ b/runtime/main.c @@ -0,0 +1,21 @@ +#include +#include + +#define RED "\x1B[31m" +#define RESET "\x1B[0m" + +int ____karilang_main(int); + +int main(int argc, char *argv[]) { + if (argc != 2) { + fprintf(stderr, "KariLang programs take one integer input and outputs " + "an integer value.\n" RED "Error: " RESET + "Integer input missing.\n"); + return 1; + } + + int input = atoi(argv[1]); + printf("Input: %d\nOutput: %d\n", input, ____karilang_main(input)); + + return 0; +} diff --git a/src/AST.cc b/src/AST.cc new file mode 100644 index 0000000..ad24355 --- /dev/null +++ b/src/AST.cc @@ -0,0 +1,484 @@ +#include "AST.hh" + +#define GOOD_SEMANTICS(t) \ + { \ + this->result_type = t; \ + this->semantics_verified = true; \ + this->semantics_correct = true; \ + return true; \ + } + +#define BAD_SEMANTICS_MSG(msg) \ + { \ + std::cerr << "Semantic Error: " << msg << "\n"; \ + this->result_type = TYPE::INT_T; \ + this->semantics_verified = true; \ + this->semantics_correct = false; \ + return false; \ + } + +#define BAD_SEMANTICS() \ + { \ + this->result_type = TYPE::INT_T; \ + this->semantics_verified = true; \ + this->semantics_correct = false; \ + return false; \ + } + +TYPE Expression::deduce_result_type() { + switch (type) { + case INTEGER_EXP: + return TYPE::INT_T; + case BOOLEAN_EXP: + return TYPE::BOOL_T; + case VARIABLE_EXP: + if (globals_ast.find(std::get(value)) != globals_ast.end()) + return globals_ast.at(std::get(value))->type; + return TYPE::INT_T; /* Hopefully sematic verification will catch the + error */ + case UNARY_OP_EXP: + return std::get>(value) + ->deduce_result_type(); + case BINARY_OP_EXP: + return std::get>(value) + ->deduce_result_type(); + case IF_EXP: + return std::get>(value) + ->deduce_result_type(); + case FUNCTION_CALL_EXP: + return std::get>(value) + ->deduce_result_type(); + } +} + +bool Expression::verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) { + switch (type) { + case INTEGER_EXP: + if (result_type == TYPE::INT_T) + GOOD_SEMANTICS(result_type) + BAD_SEMANTICS_MSG("Expected type " << ToString(result_type) + << " but got int") + case BOOLEAN_EXP: + if (result_type == TYPE::BOOL_T) + GOOD_SEMANTICS(result_type) + BAD_SEMANTICS_MSG("Expected type " << ToString(result_type) + << " but got bool") + case VARIABLE_EXP: + if (globals_ast.find(std::get(value)) != + globals_ast.end()) { + if (globals_ast.at(std::get(value))->type == + result_type) + GOOD_SEMANTICS(result_type) + else + BAD_SEMANTICS_MSG( + "Variable type " + << ToString( + globals_ast.at(std::get(value))->type) + << " but " << ToString(result_type) << "expected") + } + if (context.find(std::get(value)) != context.end()) { + if (context.at(std::get(value)) == result_type) + GOOD_SEMANTICS(result_type) + else + BAD_SEMANTICS_MSG( + "Variable type " + << ToString(context.at(std::get(value))) + << " but " << ToString(result_type) << "expected") + } + BAD_SEMANTICS_MSG("Could not find \"" << std::get(value) + << "\" variable") + case UNARY_OP_EXP: + if (std::get>(value)->verify_semantics( + result_type, functions_ast, globals_ast, context)) + GOOD_SEMANTICS(result_type); + BAD_SEMANTICS() + case BINARY_OP_EXP: + if (std::get>(value)->verify_semantics( + result_type, functions_ast, globals_ast, context)) + GOOD_SEMANTICS(result_type); + BAD_SEMANTICS() + case IF_EXP: + if (std::get>(value)->verify_semantics( + result_type, functions_ast, globals_ast, context)) + GOOD_SEMANTICS(result_type); + BAD_SEMANTICS() + case FUNCTION_CALL_EXP: + if (std::get>(value)->verify_semantics( + result_type, functions_ast, globals_ast, context)) + GOOD_SEMANTICS(result_type); + BAD_SEMANTICS() + } +} + +std::variant Expression::interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) { + switch (type) { + case INTEGER_EXP: + return std::get(value); + case BOOLEAN_EXP: + return std::get(value); + case VARIABLE_EXP: + if (globals_ast.find(std::get(value)) != + globals_ast.end()) { + return globals_ast.at(std::get(value)) + ->interpret(functions_ast, globals_ast, context); + } else { + return context.at(std::get(value)); + } + case UNARY_OP_EXP: + return std::get>(value)->interpret( + functions_ast, globals_ast, context); + case BINARY_OP_EXP: + return std::get>(value)->interpret( + functions_ast, globals_ast, context); + case IF_EXP: + return std::get>(value)->interpret( + functions_ast, globals_ast, context); + case FUNCTION_CALL_EXP: + return std::get>(value)->interpret( + functions_ast, globals_ast, context); + } +} + +TYPE UnaryOperator::deduce_result_type() { + switch (op_type) { + case NOT_OP: + return TYPE::BOOL_T; + case NEG_OP: + return TYPE::INT_T; + } +} + +bool UnaryOperator::verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) { + switch (op_type) { + case NOT_OP: + if (result_type == TYPE::BOOL_T) { + if (fst->verify_semantics(TYPE::BOOL_T, functions_ast, globals_ast, + context)) + GOOD_SEMANTICS(result_type) + BAD_SEMANTICS() + } + BAD_SEMANTICS_MSG("Expected " << ToString(result_type) + << " but got bool (unary not operator)") + case NEG_OP: + if (result_type == TYPE::INT_T) { + if (fst->verify_semantics(TYPE::INT_T, functions_ast, globals_ast, + context)) + GOOD_SEMANTICS(result_type) + BAD_SEMANTICS() + } + BAD_SEMANTICS_MSG("Expected " + << ToString(result_type) + << " but got int (unary negation operator)") + } +} + +std::variant UnaryOperator::interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) { + switch (op_type) { + case NOT_OP: + return std::get( + fst->interpret(functions_ast, globals_ast, context)) + ? false + : true; + case NEG_OP: + return -( + std::get(fst->interpret(functions_ast, globals_ast, context))); + } +} + +TYPE BinaryOperator::deduce_result_type() { + switch (op_type) { + case ADD_OP: + case MUL_OP: + case DIV_OP: + case MOD_OP: + return TYPE::INT_T; + case AND_OP: + case OR_OP: + case EQS_OP: + case NEQ_OP: + case GT_OP: + case GTE_OP: + case LT_OP: + case LTE_OP: + return TYPE::BOOL_T; + } +} + +bool BinaryOperator::verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) { + switch (op_type) { + case ADD_OP: + case MUL_OP: + case DIV_OP: + case MOD_OP: + if (result_type == TYPE::INT_T) { + if (fst->verify_semantics(TYPE::INT_T, functions_ast, globals_ast, + context) && + snd->verify_semantics(TYPE::INT_T, functions_ast, globals_ast, + context)) + GOOD_SEMANTICS(result_type) + BAD_SEMANTICS() + } + BAD_SEMANTICS_MSG("Expected " + << ToString(result_type) + << " but got int (binary arithmetic operator)") + case AND_OP: + case OR_OP: + if (result_type == TYPE::BOOL_T) { + if (fst->verify_semantics(TYPE::BOOL_T, functions_ast, globals_ast, + context) && + snd->verify_semantics(TYPE::BOOL_T, functions_ast, globals_ast, + context)) + GOOD_SEMANTICS(result_type) + BAD_SEMANTICS() + } + BAD_SEMANTICS_MSG("Expected " + << ToString(result_type) + << " but got bool (binary logical operator)") + case EQS_OP: + case NEQ_OP: + case GT_OP: + case GTE_OP: + case LT_OP: + case LTE_OP: + if (result_type == TYPE::BOOL_T) { + if (fst->verify_semantics(TYPE::INT_T, functions_ast, globals_ast, + context) && + snd->verify_semantics(TYPE::INT_T, functions_ast, globals_ast, + context)) + GOOD_SEMANTICS(result_type) + BAD_SEMANTICS() + } + BAD_SEMANTICS_MSG("Expected " + << ToString(result_type) + << " but got bool (binary comparision operator)") + } +} + +std::variant BinaryOperator::interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) { + switch (op_type) { + case ADD_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) + + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case MUL_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) * + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case DIV_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) / + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case MOD_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) % + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case AND_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) && + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case OR_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) || + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case EQS_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) == + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case NEQ_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) != + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case GT_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) > + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case GTE_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) >= + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case LT_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) < + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + case LTE_OP: + return (std::get( + fst->interpret(functions_ast, globals_ast, context))) <= + (std::get( + snd->interpret(functions_ast, globals_ast, context))); + } +} + +TYPE IfOperator::deduce_result_type() { return yes->deduce_result_type(); } + +bool IfOperator::verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) { + if (!cond->verify_semantics(TYPE::BOOL_T, functions_ast, globals_ast, + context)) + BAD_SEMANTICS_MSG("Condition does not yield a boolean value"); + if (yes->verify_semantics(result_type, functions_ast, globals_ast, + context) && + no->verify_semantics(result_type, functions_ast, globals_ast, context)) + GOOD_SEMANTICS(result_type) + BAD_SEMANTICS_MSG( + "Conditional branches does not yield the expected type of " + << ToString(result_type)); +} + +std::variant IfOperator::interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) { + return std::get(cond->interpret(functions_ast, globals_ast, context)) + ? yes->interpret(functions_ast, globals_ast, context) + : no->interpret(functions_ast, globals_ast, context); +} + +TYPE FunctionCall::deduce_result_type() { + if (functions_ast.find(function_name) != functions_ast.end()) { + std::unique_ptr &func = functions_ast.at(function_name); + return func->return_type; + } + return TYPE::INT_T; /* Hopefully sematic verification will catch the error + */ +} + +bool FunctionCall::verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) { + if (functions_ast.find(function_name) != functions_ast.end()) { + std::unique_ptr &func = functions_ast.at(function_name); + if (func->args_type.size() != args.size()) + BAD_SEMANTICS_MSG("Expected " + << func->args_type.size() + << " number of arguments but supplied " + << args.size() << " arguments") + + if (func->return_type != result_type) + BAD_SEMANTICS_MSG("Expected return type to be " + << ToString(result_type) + << " but actual return type is " + << ToString(func->return_type)) + + for (size_t i = 0; i < args.size(); i++) { + if (!args[i]->verify_semantics(func->args_type[i], functions_ast, + globals_ast, context)) + BAD_SEMANTICS_MSG("\t in function arguments"); + } + GOOD_SEMANTICS(result_type); + } + BAD_SEMANTICS_MSG("Could not find \"" << function_name << "\" function") +} + +std::variant FunctionCall::interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) { + // TODO: change context's type to std::unique_ptr + // that will automatically change to lazy evaluation of arguments + std::unordered_map> my_context; + std::unique_ptr &func = functions_ast.at(function_name); + for (size_t i = 0; i < func->args_name.size(); i++) { + my_context.insert( + {func->args_name.at(i), + args.at(i)->interpret(functions_ast, globals_ast, context)}); + } + return func->interpret(functions_ast, globals_ast, my_context); +} + +bool ValueDef::verify_semantics( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast) { + std::unordered_map r; + if (expression->verify_semantics(type, functions_ast, globals_ast, r)) { + semantics_verified = true; + semantics_correct = true; + return true; + } + semantics_verified = true; + semantics_correct = false; + return false; +} + +std::variant ValueDef::interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) { + return expression->interpret(functions_ast, globals_ast, context); +} + +bool FunctionDef::verify_semantics( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast) { + // creating a context + std::unordered_map context; + for (size_t i = 0; i < args_name.size(); i++) { + context.insert({args_name[i], args_type[i]}); + } + if (expression->verify_semantics(return_type, functions_ast, globals_ast, + context)) { + semantics_verified = true; + semantics_correct = true; + return true; + } + semantics_verified = true; + semantics_correct = false; + return false; +} + +std::variant FunctionDef::interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) { + return expression->interpret(functions_ast, globals_ast, context); +} diff --git a/src/AST.hh b/src/AST.hh new file mode 100644 index 0000000..f273fd2 --- /dev/null +++ b/src/AST.hh @@ -0,0 +1,504 @@ +#pragma once + +#include "PCH.hh" +#include "Utils.hh" + +enum EXPRESSION_TYPE { + INTEGER_EXP, + BOOLEAN_EXP, + VARIABLE_EXP, + UNARY_OP_EXP, + BINARY_OP_EXP, + IF_EXP, + FUNCTION_CALL_EXP, +}; + +enum TYPE { + BOOL_T, + INT_T, +}; + +inline std::string ToString(TYPE type) { + switch (type) { + case BOOL_T: + return "bool"; + case INT_T: + return "int"; + } +} + +enum UNARY_OPERATOR { + NOT_OP, + NEG_OP, +}; + +enum BINARY_OPERATOR { + ADD_OP, + MUL_OP, + DIV_OP, + MOD_OP, + AND_OP, + OR_OP, + EQS_OP, + NEQ_OP, + GT_OP, + GTE_OP, + LT_OP, + LTE_OP, +}; + +inline std::string ToString(BINARY_OPERATOR op) { + switch (op) { + case ADD_OP: + return "+"; + case MUL_OP: + return "*"; + case DIV_OP: + return "/"; + case MOD_OP: + return "%"; + case AND_OP: + return "&&"; + case OR_OP: + return "||"; + case EQS_OP: + return "=="; + case NEQ_OP: + return "!="; + case GT_OP: + return ">"; + break; + case GTE_OP: + return ">="; + case LT_OP: + return "<"; + case LTE_OP: + return "<="; + } +} + +class ValueDef; +class FunctionDef; + +class BaseExpression; +class UnaryOperator; +class BinaryOperator; +class IfOperator; +class Expression; + +class ValueDef { + public: + TYPE type; + std::string name; + std::unique_ptr expression; + + bool semantics_verified = false; + bool semantics_correct = false; + + inline ValueDef(TYPE type, std::string name, + std::unique_ptr expression) + : type(type), name(name), expression(std::move(expression)) {} + + inline virtual ~ValueDef() = default; + ValueDef &operator=(ValueDef &&other) = default; + + inline friend std::ostream &operator<<(std::ostream &os, + ValueDef const &m) { + + return os << "valdef " << m.name << ": " << ToString(m.type) << " = " + << m.expression << ";"; + } + + inline static std::unique_ptr + from(TYPE type, std::string name, std::unique_ptr expression) { + std::unique_ptr result( + new ValueDef(type, name, std::move(expression))); + return result; + } + + bool verify_semantics( + std::unordered_map> + &functions_ast, + std::unordered_map> + &globals_ast); + + std::variant interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context); + + llvm::Value *generate_llvm_ir(); +}; + +class FunctionDef { + public: + TYPE return_type; + std::string name; + std::vector args_name; + std::vector args_type; + std::unique_ptr expression; + + bool semantics_verified = false; + bool semantics_correct = false; + + inline FunctionDef() {} + + inline virtual ~FunctionDef() = default; + FunctionDef &operator=(FunctionDef &&other) = default; + + inline friend std::ostream &operator<<(std::ostream &os, + FunctionDef const &m) { + os << "funcdef " << m.name; + for (size_t i = 0; i < m.args_name.size(); i++) { + os << " " << m.args_name[i] << ": " << m.args_type[i]; + } + os << " -> " << ToString(m.return_type) << " =\n\t" << m.expression + << ";"; + return os; + } + + inline void add_argument(TYPE arg_type, std::string arg_name) { + this->args_name.push_back(arg_name); + this->args_type.push_back(arg_type); + } + + inline void set_info(std::string name, TYPE return_type, + std::unique_ptr expression) { + this->name = name; + this->return_type = return_type; + this->expression = std::move(expression); + } + + inline static std::unique_ptr from() { + std::unique_ptr result(new FunctionDef()); + return result; + } + + bool verify_semantics( + std::unordered_map> + &functions_ast, + std::unordered_map> + &globals_ast); + + std::variant interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context); + + llvm::Value *generate_llvm_ir(); +}; + +class BaseExpression { + public: + TYPE result_type; // type of the computed result + bool semantics_verified = false; + bool semantics_correct = false; + + inline virtual ~BaseExpression() = default; + BaseExpression &operator=(BaseExpression &&other) = default; + + virtual bool verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) = 0; + + virtual std::variant interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) = 0; + + virtual llvm::Value *generate_llvm_ir() = 0; + virtual TYPE deduce_result_type() = 0; +}; + +class UnaryOperator : public BaseExpression { + public: + std::unique_ptr fst; + UNARY_OPERATOR op_type; + + inline UnaryOperator(std::unique_ptr fst, + UNARY_OPERATOR op_type) + : fst(std::move(fst)), op_type(op_type) {} + + inline virtual ~UnaryOperator() = default; + UnaryOperator &operator=(UnaryOperator &&other) = default; + + inline friend std::ostream &operator<<(std::ostream &os, + UnaryOperator const &m) { + return os << (m.op_type == NOT_OP ? "!(" : "-(") << m.fst << ")"; + } + + virtual bool verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) override; + virtual std::variant interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) + override; + virtual llvm::Value *generate_llvm_ir() override; + virtual TYPE deduce_result_type() override; +}; + +class BinaryOperator : public BaseExpression { + public: + std::unique_ptr fst; + std::unique_ptr snd; + BINARY_OPERATOR op_type; + + inline BinaryOperator(std::unique_ptr fst, + std::unique_ptr snd, + BINARY_OPERATOR op_type) + : fst(std::move(fst)), snd(std::move(snd)), op_type(op_type) {} + + inline virtual ~BinaryOperator() = default; + BinaryOperator &operator=(BinaryOperator &&other) = default; + + inline friend std::ostream &operator<<(std::ostream &os, + BinaryOperator const &m) { + + return os << "(" << m.fst << ") " << ToString(m.op_type) << " (" + << m.snd << ")"; + } + + virtual bool verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) override; + virtual std::variant interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) + override; + virtual llvm::Value *generate_llvm_ir() override; + virtual TYPE deduce_result_type() override; +}; + +class IfOperator : public BaseExpression { + public: + std::unique_ptr cond; + std::unique_ptr yes; + std::unique_ptr no; + + inline IfOperator(std::unique_ptr cond, + std::unique_ptr yes, + std::unique_ptr no) + : cond(std::move(cond)), yes(std::move(yes)), no(std::move(no)) {} + + inline virtual ~IfOperator() = default; + IfOperator &operator=(IfOperator &&other) = default; + + inline friend std::ostream &operator<<(std::ostream &os, + IfOperator const &m) { + return os << "if (" << m.cond << ") then " << m.yes << " else " << m.no; + } + + virtual bool verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) override; + virtual std::variant interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) + override; + virtual llvm::Value *generate_llvm_ir() override; + virtual TYPE deduce_result_type() override; +}; + +class FunctionCall : public BaseExpression { + public: + std::string function_name; + std::vector> args; + + inline FunctionCall() {} + + inline FunctionCall(FunctionCall &&fc) { + this->function_name = fc.function_name; + this->args = std::move(fc.args); + } + + inline virtual ~FunctionCall() = default; + FunctionCall &operator=(FunctionCall &&other) = default; + + inline friend std::ostream &operator<<(std::ostream &os, + FunctionCall const &m) { + os << m.function_name; + for (const std::unique_ptr &i : m.args) { + os << " (" << i << ")"; + } + return os; + } + + inline void set_function_name(std::string name) { + this->function_name = name; + } + + inline void add_argument(std::unique_ptr arg) { + this->args.push_back(std::move(arg)); + } + + virtual bool verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) override; + virtual std::variant interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) + override; + virtual llvm::Value *generate_llvm_ir() override; + virtual TYPE deduce_result_type() override; +}; + +class Expression : public BaseExpression { + public: + std::variant, + std::unique_ptr, std::unique_ptr, + std::unique_ptr> + value; + EXPRESSION_TYPE type; + + inline Expression(int value) : value(value), type(INTEGER_EXP) {} + inline Expression(bool value) : value(value), type(BOOLEAN_EXP) {} + inline Expression(std::string value) : value(value), type(VARIABLE_EXP) {} + inline Expression(std::unique_ptr value) + : value(std::move(value)), type(UNARY_OP_EXP) {} + inline Expression(std::unique_ptr value) + : value(std::move(value)), type(BINARY_OP_EXP) {} + inline Expression(std::unique_ptr value) + : value(std::move(value)), type(IF_EXP) {} + inline Expression(FunctionCall value) { + this->type = FUNCTION_CALL_EXP; + std::unique_ptr f(new FunctionCall(std::move(value))); + this->value = std::move(f); + } + + inline virtual ~Expression() = default; + Expression &operator=(Expression &&other) = default; + + inline friend std::ostream &operator<<(std::ostream &os, + Expression const &m) { + switch (m.type) { + case INTEGER_EXP: + return os << std::get(m.value); + break; + case BOOLEAN_EXP: + return os << std::get(m.value); + break; + case VARIABLE_EXP: + return os << std::get(m.value); + break; + case UNARY_OP_EXP: + return os << std::get>(m.value); + break; + case BINARY_OP_EXP: + return os << std::get>(m.value); + break; + case IF_EXP: + return os << std::get>(m.value); + break; + case FUNCTION_CALL_EXP: + return os << std::get>(m.value); + break; + } + } + + inline static std::unique_ptr from(int value) { + std::unique_ptr result(new Expression(value)); + return result; + } + + inline static std::unique_ptr from(bool value) { + std::unique_ptr result(new Expression(value)); + return result; + } + + inline static std::unique_ptr from(std::string value) { + std::unique_ptr result(new Expression(value)); + return result; + } + + inline static std::unique_ptr from(FunctionCall fc) { + std::unique_ptr result(new Expression(std::move(fc))); + return result; + } + + inline static std::unique_ptr + from(std::unique_ptr fst, UNARY_OPERATOR type) { + std::unique_ptr unary_op( + new UnaryOperator(std::move(fst), type)); + std::unique_ptr result(new Expression(std::move(unary_op))); + return result; + } + + inline static std::unique_ptr + from(std::unique_ptr lhs, std::unique_ptr rhs, + BINARY_OPERATOR type) { + std::unique_ptr binary_op( + new BinaryOperator(std::move(lhs), std::move(rhs), type)); + std::unique_ptr result( + new Expression(std::move(binary_op))); + return result; + } + + inline static std::unique_ptr + from(std::unique_ptr cond, std::unique_ptr yes, + std::unique_ptr no) { + std::unique_ptr if_op( + new IfOperator(std::move(cond), std::move(yes), std::move(no))); + std::unique_ptr result(new Expression(std::move(if_op))); + return result; + } + + virtual bool verify_semantics( + TYPE result_type, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map &context) override; + virtual std::variant interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + std::unordered_map> &context) + override; + virtual llvm::Value *generate_llvm_ir() override; + virtual TYPE deduce_result_type() override; +}; + +inline std::ostream & +print(std::ostream &os, + std::unordered_map> &ast) { + for (auto &i : ast) { + os << i.second << "\n"; + } + return os; +} + +inline std::ostream & +print(std::ostream &os, + std::unordered_map> &ast) { + for (auto &i : ast) { + os << i.second << "\n"; + } + return os; +} diff --git a/src/Compile.cc b/src/Compile.cc new file mode 100644 index 0000000..858c017 --- /dev/null +++ b/src/Compile.cc @@ -0,0 +1,357 @@ +#include "Compile.hh" +#include "AST.hh" +#include "PCH.hh" +#include "Utils.hh" + +std::unique_ptr TheContext; +std::unique_ptr> Builder; +std::unique_ptr TheModule; +std::map NamedValues; +std::unique_ptr TheFPM; + +llvm::Value *LogErrorV(const char *str) { + std::cerr << str; + return nullptr; +} + +llvm::Value *Expression::generate_llvm_ir() { + switch (type) { + case INTEGER_EXP: + return llvm::ConstantInt::get( + *TheContext, llvm::APInt(32, std::get(value), true)); + case BOOLEAN_EXP: + return llvm::ConstantInt::get( + *TheContext, llvm::APInt(1, std::get(value), false)); + case VARIABLE_EXP: { + if (NamedValues[std::get(value)]) + return NamedValues[std::get(value)]; + + // all variables (except arguments for now) are no args function which + // return the datatype as the variable + llvm::Function *CalleeF = + TheModule->getFunction(std::get(value)); + + // ???: below is required for JIT + auto f = globals_ast.find(std::get(value)); + if ((!CalleeF) && (f != globals_ast.end())) { + CalleeF = FunctionPrototype::generate_llvm_ir(f->second->name, {}, + {}, f->second->type); + } + + std::vector ArgsV; + + return Builder->CreateCall(CalleeF, ArgsV, "variable_funccall"); + } + case UNARY_OP_EXP: + return std::get>(value) + ->generate_llvm_ir(); + case BINARY_OP_EXP: + return std::get>(value) + ->generate_llvm_ir(); + case IF_EXP: + return std::get>(value)->generate_llvm_ir(); + case FUNCTION_CALL_EXP: + return std::get>(value) + ->generate_llvm_ir(); + } +} + +llvm::Value *UnaryOperator::generate_llvm_ir() { + llvm::Value *L = fst->generate_llvm_ir(); + if (!L) + return nullptr; + switch (op_type) { + case NEG_OP: { + llvm::Value *R = + llvm::ConstantInt::get(*TheContext, llvm::APInt(32, 0, true)); + if (!R) + return nullptr; + return Builder->CreateSub(R, L, "neg_op"); + } + case NOT_OP: + return Builder->CreateNeg(L, "not_op"); + } +} + +llvm::Value *BinaryOperator::generate_llvm_ir() { + llvm::Value *L = fst->generate_llvm_ir(); + llvm::Value *R = snd->generate_llvm_ir(); + if ((!L) || (!R)) + return nullptr; + + switch (op_type) { + case ADD_OP: + return Builder->CreateAdd(L, R, "add_op"); + case MUL_OP: + return Builder->CreateMul(L, R, "mul_op"); + case DIV_OP: + return Builder->CreateSDiv(L, R, "div_op"); + case MOD_OP: + return Builder->CreateSRem(L, R, "mod_op"); + case AND_OP: + return Builder->CreateAnd(L, R, "and_op"); + case OR_OP: + return Builder->CreateOr(L, R, "or_op"); + case EQS_OP:; + return Builder->CreateICmpEQ(L, R, "eqs_op"); + case NEQ_OP: + return Builder->CreateICmpNE(L, R, "neq_op"); + case GT_OP: + return Builder->CreateICmpSGT(L, R, "gt_op"); + case GTE_OP: + return Builder->CreateICmpSGE(L, R, "gte_op"); + case LT_OP: + return Builder->CreateICmpSLT(L, R, "lt_op"); + case LTE_OP: + return Builder->CreateICmpSLE(L, R, "lte_op"); + } +} + +llvm::Value *IfOperator::generate_llvm_ir() { + llvm::Value *Cond = cond->generate_llvm_ir(); + if (!Cond) + return nullptr; + + llvm::Function *TheFunction = Builder->GetInsertBlock()->getParent(); + + // create a temporary variable to store the result of if operation + llvm::IRBuilder<> TmpB(&TheFunction->getEntryBlock(), + TheFunction->getEntryBlock().begin()); + llvm::AllocaInst *if_op_result; + switch (result_type) { + case BOOL_T: + if_op_result = Builder->CreateAlloca(llvm::Type::getInt1Ty(*TheContext), + nullptr, "if_op_result"); + break; + case INT_T: + if_op_result = TmpB.CreateAlloca(llvm::Type::getInt32Ty(*TheContext), + nullptr, "if_op_result_ptr"); + break; + } + + // Create blocks for the then and else cases + llvm::BasicBlock *ThenBB = + llvm::BasicBlock::Create(*TheContext, "if_then", TheFunction); + llvm::BasicBlock *ElseBB = llvm::BasicBlock::Create(*TheContext, "if_else"); + llvm::BasicBlock *MergeBB = + llvm::BasicBlock::Create(*TheContext, "if_continued"); + + Builder->CreateCondBr(Cond, ThenBB, ElseBB); + + // Emit then value + Builder->SetInsertPoint(ThenBB); + if (!Builder->CreateStore(yes->generate_llvm_ir(), if_op_result)) + return nullptr; + Builder->CreateBr(MergeBB); + + ThenBB = Builder->GetInsertBlock(); + TheFunction->insert(TheFunction->end(), ElseBB); + + // Emit else value + Builder->SetInsertPoint(ElseBB); + if (!Builder->CreateStore(no->generate_llvm_ir(), if_op_result)) + return nullptr; + Builder->CreateBr(MergeBB); + + // Emit merge value + ElseBB = Builder->GetInsertBlock(); + TheFunction->insert(TheFunction->end(), MergeBB); + Builder->SetInsertPoint(MergeBB); + + return Builder->CreateLoad(llvm::Type::getInt32Ty(*TheContext), + if_op_result, "if_op_result"); +} + +llvm::Value *FunctionCall::generate_llvm_ir() { + llvm::Function *CalleeF = TheModule->getFunction(function_name); + // ???: below is required for JIT + auto f = functions_ast.find(function_name); + if ((!CalleeF) && (f != functions_ast.end())) { + CalleeF = FunctionPrototype::generate_llvm_ir( + f->second->name, f->second->args_name, f->second->args_type, + f->second->return_type); + } + + std::vector ArgsV; + for (size_t i = 0; i < args.size(); i++) { + ArgsV.push_back(args.at(i)->generate_llvm_ir()); + if (!ArgsV.back()) + return nullptr; + } + return Builder->CreateCall(CalleeF, ArgsV, "funccall"); +} + +llvm::Value *ValueDef::generate_llvm_ir() { + std::vector ArgV; + llvm::FunctionType *FT; + switch (type) { + case TYPE::INT_T: + FT = llvm::FunctionType::get(llvm::Type::getInt32Ty(*TheContext), ArgV, + false); + break; + case TYPE::BOOL_T: + FT = llvm::FunctionType::get(llvm::Type::getInt1Ty(*TheContext), ArgV, + false); + break; + } + llvm::Function *TheFunction = llvm::Function::Create( + FT, llvm::Function::ExternalLinkage, name, TheModule.get()); + + // Create a new basic block to start insertion into. + llvm::BasicBlock *BB = llvm::BasicBlock::Create( + *TheContext, std::string(name) + "_entry", TheFunction); + Builder->SetInsertPoint(BB); + + if (llvm::Value *RetVal = expression->generate_llvm_ir()) { + // Finish off the function. + Builder->CreateRet(RetVal); + + // TheFunction->print(llvm::errs(), nullptr); + + // Validate the generated code, checking for consistency. + assert(!verifyFunction(*TheFunction)); + + TheFPM->run(*TheFunction); + return TheFunction; + } + + // Error reading body, remove function. + TheFunction->eraseFromParent(); + return nullptr; +} + +llvm::Value *FunctionDef::generate_llvm_ir() { + // First, check for an existing function + llvm::Function *TheFunction; + if (name == "main") + TheFunction = TheModule->getFunction("____karilang_main"); + else + TheFunction = TheModule->getFunction(name); + + if (!TheFunction) + return (llvm::Function *)LogErrorV( + "Function prototype not found. " + "Prototype should be defined before body is defined."); + if (!TheFunction->empty()) + return (llvm::Function *)LogErrorV("Function cannot be redefined"); + + // Create a new basic block to start insertion into. + llvm::BasicBlock *BB = llvm::BasicBlock::Create( + *TheContext, std::string(name) + "_entry", TheFunction); + Builder->SetInsertPoint(BB); + + // Record the function arguments in the NamedValues map. + NamedValues.clear(); + for (auto &Arg : TheFunction->args()) + NamedValues[std::string(Arg.getName())] = &Arg; + + if (llvm::Value *RetVal = expression->generate_llvm_ir()) { + // Finish off the function. + Builder->CreateRet(RetVal); + + // TheFunction->print(llvm::errs(), nullptr); + + // Validate the generated code, checking for consistency. + assert(!verifyFunction(*TheFunction)); + + TheFPM->run(*TheFunction); + return TheFunction; + } + + // Error reading body, remove function. + TheFunction->eraseFromParent(); + return nullptr; +} + +int Compile(const std::string filename, + const std::unordered_map> + &functions_ast, + const std::unordered_map> + &globals_ast) { + // Open a new context and module. + TheContext = std::make_unique(); + TheModule = std::make_unique("LLVM Compiler", *TheContext); + + // Create a new pass manager attached to it. + TheFPM = + std::make_unique(TheModule.get()); + + // Do simple "peephole" optimizations and bit-twiddling optzns. + TheFPM->add(llvm::createInstructionCombiningPass()); + // Re-associate expressions. + TheFPM->add(llvm::createReassociatePass()); + // Eliminate Common SubExpressions. + TheFPM->add(llvm::createGVNPass()); + // Simplify the control flow graph (deleting unreachable blocks, etc). + TheFPM->add(llvm::createCFGSimplificationPass()); + // tail call optimization + TheFPM->add(llvm::createTailCallEliminationPass()); + + TheFPM->doInitialization(); + + // Create a new builder for the module. + Builder = std::make_unique>(*TheContext); + + // Generate LLVM IR + for (auto &i : functions_ast) + // First create prototype for all the functions + FunctionPrototype::generate_llvm_ir(i.second->name, i.second->args_name, + i.second->args_type, + i.second->return_type); + for (auto &i : functions_ast) + i.second->generate_llvm_ir(); + for (auto &i : globals_ast) + i.second->generate_llvm_ir(); + + // Print generated IR + // TheModule->print(llvm::errs(), nullptr); + + // Getting target type + std::string TargetTriple = llvm::sys::getDefaultTargetTriple(); + // std::cout << TargetTriple << std::endl; + + // Initializing target + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmParser(); + llvm::InitializeNativeTargetAsmPrinter(); + + std::string Error; + const llvm::Target *Target = + llvm::TargetRegistry::lookupTarget(TargetTriple, Error); + if (!Target) { + llvm::errs() << Error; + return 1; + } + + // Configure other target related types + std::string CPU = "generic"; + std::string Features = ""; + llvm::TargetOptions opt; + llvm::TargetMachine *TargetMachine = Target->createTargetMachine( + TargetTriple, CPU, Features, opt, llvm::Reloc::PIC_); + TheModule->setDataLayout(TargetMachine->createDataLayout()); + TheModule->setTargetTriple(TargetTriple); + + // Compile + std::filesystem::path output_file(filename); + output_file.replace_extension("o"); + + std::error_code EC; + llvm::raw_fd_ostream dest(output_file.c_str(), EC, llvm::sys::fs::OF_None); + if (EC) { + llvm::errs() << "Could not open file: " << EC.message(); + return 1; + } + + llvm::legacy::PassManager pass; + llvm::CodeGenFileType FileType = llvm::CGFT_ObjectFile; + if (TargetMachine->addPassesToEmitFile(pass, dest, nullptr, FileType)) { + llvm::errs() << "TargetMachine can't emit a file of this type "; + return 1; + } + + pass.run(*TheModule); + dest.flush(); + + return 0; +} diff --git a/src/Compile.hh b/src/Compile.hh new file mode 100644 index 0000000..1f74332 --- /dev/null +++ b/src/Compile.hh @@ -0,0 +1,63 @@ +#include "AST.hh" +#include "PCH.hh" + +extern std::unique_ptr TheContext; +extern std::unique_ptr> Builder; +extern std::unique_ptr TheModule; +extern std::map NamedValues; +extern std::unique_ptr TheFPM; + +int Compile(const std::string filename, + const std::unordered_map> + &functions_ast, + const std::unordered_map> + &globals_ast); + +namespace FunctionPrototype { +inline llvm::Function * +generate_llvm_ir(const std::string &name, + const std::vector &args_name, + const std::vector &args_type, TYPE return_type) { + // setup arguments type + std::vector ArgV; + for (size_t i = 0; i < args_name.size(); i++) { + switch (args_type.at(i)) { + case TYPE::INT_T: + ArgV.push_back(llvm::Type::getInt32Ty(*TheContext)); + break; + case TYPE::BOOL_T: + ArgV.push_back(llvm::Type::getInt1Ty(*TheContext)); + break; + } + } + + // setup return type + llvm::FunctionType *FT; + switch (return_type) { + case TYPE::INT_T: + FT = llvm::FunctionType::get(llvm::Type::getInt32Ty(*TheContext), ArgV, + false); + break; + case TYPE::BOOL_T: + FT = llvm::FunctionType::get(llvm::Type::getInt1Ty(*TheContext), ArgV, + false); + break; + } + + // create the function prototype + llvm::Function *F; + if (name == "main") + F = llvm::Function::Create(FT, llvm::Function::ExternalLinkage, + "____karilang_main", TheModule.get()); + else + F = llvm::Function::Create(FT, llvm::Function::ExternalLinkage, name, + TheModule.get()); + + // setup argument names + unsigned Idx = 0; + for (auto &Arg : F->args()) + Arg.setName(args_name.at(Idx++)); + + return F; +} +} // namespace FunctionPrototype diff --git a/src/DS.h b/src/DS.h deleted file mode 100644 index 8c9234a..0000000 --- a/src/DS.h +++ /dev/null @@ -1,502 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#define DS_ARRAY_DEC(name, TYPE) \ - typedef struct _##name##_array_t name##_array_t; \ - name##_array_t *name##_array_new(); \ - size_t name##_array_size(const name##_array_t *arr); \ - bool name##_array_append(name##_array_t *arr, TYPE val); \ - TYPE name##_array_pop(name##_array_t *arr); \ - bool name##_array_setat(name##_array_t *arr, TYPE val, \ - const size_t index); \ - TYPE name##_array_getat(const name##_array_t *arr, const size_t index); \ - TYPE *name##_array_get_ptr_at(const name##_array_t *arr, \ - const size_t index); \ - bool name##_array_clear(name##_array_t *arr); - -#define DS_LIST_DEC(name, TYPE) \ - typedef struct _##name##_list_t name##_list_t; \ - name##_list_t *name##_list_new(); \ - size_t name##_list_size(const name##_list_t *li); \ - bool name##_list_append(name##_list_t *li, TYPE val); \ - TYPE name##_list_pop(name##_list_t *li); \ - bool name##_list_setat(name##_list_t *li, TYPE val, const size_t index); \ - TYPE name##_list_getat(const name##_list_t *li, const size_t index); \ - TYPE *name##_list_get_ptr_at(const name##_list_t *li, const size_t index); \ - bool name##_list_clear(name##_list_t *li); - -#define DS_TABLE_DEC(name, TYPE) \ - typedef struct _##name##_table_t name##_table_t; \ - name##_table_t *name##_table_new(size_t size); \ - bool name##_table_insert(name##_table_t *tb, const char *key, TYPE value); \ - TYPE name##_table_get(name##_table_t *tb, const char *key); \ - TYPE *name##_table_get_ptr(name##_table_t *tb, const char *key); \ - size_t name##_table_size(name##_table_t *tb); \ - bool name##_table_delete(name##_table_t *tb, const char *key); \ - TYPE *name##_table_iter_next(name##_table_t *tb, char **key); \ - void name##_table_iter(name##_table_t *tb); \ - bool name##_table_clear(name##_table_t *tb); - -#define DS_ARRAY_DEF(name, TYPE, delFunc) \ - struct _##name##_array_t { \ - TYPE *array; \ - size_t size; \ - size_t capacity; \ - }; \ - typedef struct _##name##_array_t name##_array_t; \ - \ - name##_array_t *name##_array_new() { \ - name##_array_t *arr = malloc(sizeof(name##_array_t)); \ - if (!arr) { \ - errno = ENOMEM; \ - return NULL; \ - } \ - \ - arr->array = calloc(4, sizeof(TYPE)); \ - if (!arr->array) { \ - free(arr); \ - errno = ENOMEM; \ - return NULL; \ - } \ - \ - arr->capacity = 4; \ - arr->size = 0; \ - return arr; \ - } \ - \ - size_t name##_array_size(const name##_array_t *arr) { return arr->size; } \ - \ - bool name##_array_append(name##_array_t *arr, TYPE val) { \ - if (arr->size < arr->capacity) { \ - arr->array[arr->size] = val; \ - arr->size++; \ - return true; \ - } else { \ - arr->array = \ - realloc(arr->array, (arr->capacity * 2 * sizeof(TYPE))); \ - if (arr->array) { \ - arr->capacity *= 2; \ - arr->array[arr->size] = val; \ - arr->size++; \ - return true; \ - } \ - errno = ENOMEM; \ - return false; \ - } \ - } \ - \ - TYPE name##_array_pop(name##_array_t *arr) { \ - if (arr->size == 0) { \ - errno = EINVAL; \ - return (TYPE){0}; \ - } \ - \ - TYPE last_value = arr->array[--(arr->size)]; \ - \ - if (arr->size * 2 < arr->capacity) { \ - arr->array = \ - realloc(arr->array, (arr->capacity / 2) * sizeof(TYPE)); \ - arr->capacity /= 2; \ - } \ - return last_value; \ - } \ - \ - bool name##_array_setat(name##_array_t *arr, TYPE val, \ - const size_t index) { \ - if (index < arr->size) { \ - if (delFunc) { \ - delFunc(arr->array + index); \ - } \ - arr->array[index] = val; \ - return true; \ - } \ - errno = EINVAL; \ - return false; \ - } \ - \ - TYPE name##_array_getat(const name##_array_t *arr, const size_t index) { \ - if (index < arr->size) { \ - return arr->array[index]; \ - } \ - errno = EINVAL; \ - return (TYPE){0}; \ - } \ - \ - TYPE *name##_array_get_ptr_at(const name##_array_t *arr, \ - const size_t index) { \ - if (index < arr->size) { \ - return arr->array + index; \ - } \ - errno = EINVAL; \ - return NULL; \ - } \ - \ - bool name##_array_clear(name##_array_t *arr) { \ - if (delFunc) { \ - for (size_t i = 0; i < arr->size; i++) { \ - delFunc(arr->array + i); \ - } \ - } \ - free(arr->array); \ - free(arr); \ - return true; \ - } - -#define DS_ARRAY_FOREACH(arr, name) \ - typeof(*(arr->array)) name = arr->array[0]; \ - for (size_t _i = 0; _i++ < arr->size; name = arr->array[_i]) - -#define DS_ARRAY_PTR_FOREACH(arr, name) \ - typeof(*(arr->array)) *name = arr->array + 0; \ - for (size_t _i = 0; _i++ < arr->size; name = arr->array + _i) - -#define DS_LIST_DEF(name, TYPE, delFunc) \ - typedef struct _##name##_list_node name##_list_node; \ - struct _##name##_list_node { \ - TYPE element; \ - name##_list_node *next; \ - name##_list_node *previous; \ - }; \ - \ - typedef struct _##name##_list_t name##_list_t; \ - struct _##name##_list_t { \ - name##_list_node *first; \ - name##_list_node *last; \ - size_t size; \ - }; \ - \ - name##_list_t *name##_list_new() { \ - name##_list_t *li = malloc(sizeof(name##_list_t)); \ - if (!li) { \ - errno = ENOMEM; \ - return NULL; \ - } \ - \ - li->first = NULL; \ - li->last = NULL; \ - li->size = 0; \ - return li; \ - } \ - \ - size_t name##_list_size(const name##_list_t *li) { return li->size; } \ - \ - bool name##_list_append(name##_list_t *li, TYPE val) { \ - name##_list_node *new = malloc(sizeof(name##_list_node)); \ - if (!new) { \ - errno = ENOMEM; \ - return false; \ - } \ - \ - new->element = val; \ - new->next = NULL; \ - \ - if (li->size > 0) { \ - name##_list_node *last = li->last; \ - assert(last->next == NULL); \ - last->next = new; \ - new->previous = last; \ - } else { \ - li->first = new; \ - new->previous = NULL; \ - } \ - \ - li->last = new; \ - li->size++; \ - return true; \ - } \ - \ - TYPE name##_list_pop(name##_list_t *li) { \ - name##_list_node *last = li->last; \ - last->previous->next = NULL; \ - li->last = last->previous; \ - li->size--; \ - \ - TYPE element = last->element; \ - free(last); \ - return element; \ - } \ - \ - bool name##_list_setat(name##_list_t *li, TYPE val, const size_t index) { \ - if (index >= li->size) { \ - errno = EINVAL; \ - return false; \ - } \ - \ - name##_list_node *node = li->first; \ - for (size_t i = 0; i < index; i++) { \ - node = node->next; \ - } \ - \ - if (delFunc) { \ - delFunc(&(node->element)); \ - } \ - \ - node->element = val; \ - return true; \ - } \ - \ - TYPE name##_list_getat(const name##_list_t *li, const size_t index) { \ - if (index >= li->size) { \ - errno = EINVAL; \ - return (TYPE){0}; \ - } \ - \ - name##_list_node *node = li->first; \ - for (size_t i = 0; i < index; i++) { \ - node = node->next; \ - } \ - return node->element; \ - } \ - \ - TYPE *name##_list_get_ptr_at(const name##_list_t *li, \ - const size_t index) { \ - if (index >= li->size) { \ - errno = EINVAL; \ - return NULL; \ - } \ - \ - name##_list_node *node = li->first; \ - for (size_t i = 0; i < index; i++) { \ - node = node->next; \ - } \ - return &(node->element); \ - } \ - \ - bool name##_list_clear(name##_list_t *li) { \ - name##_list_node *node = li->first; \ - for (size_t i = 0; i < li->size; i++) { \ - name##_list_node *next = node->next; \ - if (delFunc) { \ - delFunc(&(node->element)); \ - } \ - free(node); \ - node = next; \ - } \ - free(li); \ - return true; \ - } - -#define DS_LIST_FOREACH(li, name) \ - typeof(li->first) name##_node = li->first; \ - typeof(li->first->element) name = name##_node->element; \ - for (size_t(name##_i) = 0, _i = 0; ++(name##_i) <= li->size; \ - name##_node = name##_node->next, \ - name = name##_node ? name##_node->element \ - : (typeof(li->first->element)){0}, \ - _i++) - -#define DS_LIST_PTR_FOREACH(li, name) \ - typeof(li->first) name##_node = li->first; \ - typeof(li->first->element) *name = &(name##_node->element); \ - for (size_t(name##_i) = 0, _i = 0; ++(name##_i) <= li->size; \ - name##_node = name##_node->next, \ - name = name##_node ? &(name##_node->element) : NULL, _i++) - -#define IMPLEMENT_HASH_FUNCTION \ - size_t hash_function(const char *str) { \ - /* TODO: not calculate length here */ \ - size_t hash = 0xcbf29ce484222325; \ - for (size_t i = 0; i < strlen(str); i++) { \ - hash *= 0x100000001b3; \ - hash ^= str[i]; \ - } \ - return hash; \ - } - -#define DS_TABLE_DEF(name, TYPE, delFunc) \ - typedef struct __##name##_hash_table_list_node \ - _##name##_hash_table_list_node; \ - \ - struct __##name##_hash_table_list_node { \ - struct { \ - const char *key; \ - TYPE value; \ - } item_pair; \ - struct __##name##_hash_table_list_node *next; \ - }; \ - \ - struct _##name##_table_t { \ - size_t count; \ - size_t array_length; \ - size_t current_iter_array_index; /* Used for iteration */ \ - _##name##_hash_table_list_node \ - *current_iter_list_node; /* Used for iteration */ \ - _##name##_hash_table_list_node *table_list; \ - }; \ - \ - name##_table_t *name##_table_new(size_t size) { \ - name##_table_t *tb = calloc(1, sizeof(name##_table_t)); \ - if (!tb) { \ - errno = ENOMEM; \ - return NULL; \ - } \ - \ - tb->table_list = calloc(size, sizeof(_##name##_hash_table_list_node)); \ - if (!tb->table_list) { \ - free(tb); \ - errno = ENOMEM; \ - return NULL; \ - } \ - \ - tb->array_length = size; \ - return tb; \ - } \ - \ - void name##_table_iter(name##_table_t *tb) { \ - tb->current_iter_array_index = -1; \ - tb->current_iter_list_node = NULL; \ - } \ - \ - TYPE *name##_table_iter_next(name##_table_t *tb, char **key) { \ - if ((tb->current_iter_list_node) && \ - (tb->current_iter_list_node->next)) { \ - tb->current_iter_list_node = tb->current_iter_list_node->next; \ - *key = (char *)tb->current_iter_list_node->item_pair.key; \ - return &(tb->current_iter_list_node->item_pair.value); \ - } \ - \ - while (++tb->current_iter_array_index < tb->array_length) { \ - if (tb->table_list[tb->current_iter_array_index].item_pair.key) { \ - tb->current_iter_list_node = \ - tb->table_list + tb->current_iter_array_index; \ - *key = (char *)tb->table_list[tb->current_iter_array_index] \ - .item_pair.key; \ - return &(tb->table_list[tb->current_iter_array_index] \ - .item_pair.value); \ - } \ - } \ - return NULL; \ - } \ - \ - bool name##_table_insert(name##_table_t *tb, const char *key, \ - TYPE value) { \ - size_t hash = hash_function(key) % tb->array_length; \ - \ - _##name##_hash_table_list_node *item_list_node = \ - tb->table_list + hash; \ - \ - if (!(item_list_node->item_pair.key)) { \ - item_list_node->item_pair.key = key; \ - item_list_node->item_pair.value = value; \ - tb->count++; \ - return true; \ - } \ - \ - TYPE *val = name##_table_get_ptr(tb, key); \ - if (val) { \ - errno = EINVAL; \ - return false; \ - } \ - \ - errno = 0; \ - \ - while (item_list_node->next) { \ - item_list_node = item_list_node->next; \ - } \ - \ - item_list_node->next = \ - calloc(1, sizeof(_##name##_hash_table_list_node)); \ - if (!item_list_node->next) { \ - errno = ENOMEM; \ - return false; \ - } \ - \ - item_list_node->next->item_pair.key = key; \ - item_list_node->next->item_pair.value = value; \ - tb->count++; \ - return true; \ - } \ - \ - TYPE name##_table_get(name##_table_t *tb, const char *key) { \ - size_t hash = hash_function(key) % tb->array_length; \ - \ - _##name##_hash_table_list_node *item_list_node = \ - tb->table_list + hash; \ - \ - if (!(item_list_node->item_pair.key)) { \ - errno = EINVAL; \ - return (TYPE){0}; \ - } \ - \ - do { \ - if (!strcmp(key, item_list_node->item_pair.key)) { \ - return item_list_node->item_pair.value; \ - } \ - item_list_node = item_list_node->next; \ - } while (item_list_node); \ - \ - errno = EINVAL; \ - return (TYPE){0}; \ - } \ - \ - TYPE *name##_table_get_ptr(name##_table_t *tb, const char *key) { \ - size_t hash = hash_function(key) % tb->array_length; \ - \ - _##name##_hash_table_list_node *item_list_node = \ - tb->table_list + hash; \ - \ - if (!(item_list_node->item_pair.key)) { \ - errno = EINVAL; \ - return NULL; \ - } \ - \ - do { \ - if (!strcmp(key, item_list_node->item_pair.key)) { \ - return &(item_list_node->item_pair.value); \ - } \ - item_list_node = item_list_node->next; \ - } while (item_list_node); \ - \ - errno = EINVAL; \ - return NULL; \ - } \ - \ - size_t name##_table_size(name##_table_t *tb) { return tb->count; } \ - \ - bool name##_table_delete(name##_table_t *tb, const char *key) { \ - size_t hash = hash_function(key) % tb->array_length; \ - \ - _##name##_hash_table_list_node *item_list_node = \ - tb->table_list + hash; \ - \ - if (!(item_list_node->item_pair.key)) { \ - errno = EINVAL; \ - return false; \ - } \ - \ - bool is_first = true; \ - \ - do { \ - if (!strcmp(key, item_list_node->item_pair.key)) { \ - delFunc(item_list_node->item_pair.value); \ - \ - void *to_free = item_list_node->next; \ - if (item_list_node->next) { \ - *item_list_node = *item_list_node->next; \ - if (!is_first) { \ - free(to_free); \ - } \ - } else { \ - *item_list_node = (_##name##_hash_table_list_node){0}; \ - } \ - tb->count--; \ - return true; \ - } \ - \ - is_first = false; \ - item_list_node = item_list_node->next; \ - } while (item_list_node); \ - \ - errno = EINVAL; \ - return false; \ - } \ - \ - bool name##_table_clear(name##_table_t *tb) { \ - /* TODO: Implement */ \ - return true; \ - } diff --git a/src/JIT.cc b/src/JIT.cc new file mode 100644 index 0000000..5a8ca21 --- /dev/null +++ b/src/JIT.cc @@ -0,0 +1,103 @@ +#include "JIT.hh" +#include "AST.hh" +#include "Compile.hh" +#include "PCH.hh" + +llvm::ExitOnError ExitOnErr; +std::unique_ptr TheJIT; + +namespace jit { +int jit() { + // Open a new context and module. + TheContext = std::make_unique(); + TheModule = std::make_unique("LLVM JIT", *TheContext); + TheModule->setDataLayout(TheJIT->getDataLayout()); + + // Create a new builder for the module. + Builder = std::make_unique>(*TheContext); + + // Create a new pass manager attached to it. + TheFPM = + std::make_unique(TheModule.get()); + + // Do simple "peephole" optimizations and bit-twiddling optzns. + TheFPM->add(llvm::createInstructionCombiningPass()); + // Re-associate expressions. + TheFPM->add(llvm::createReassociatePass()); + // Eliminate Common SubExpressions. + TheFPM->add(llvm::createGVNPass()); + // Simplify the control flow graph (deleting unreachable blocks, etc). + TheFPM->add(llvm::createCFGSimplificationPass()); + // tail call optimization + TheFPM->add(llvm::createTailCallEliminationPass()); + + TheFPM->doInitialization(); + return 0; +} + +void JIT_Expression(std::unique_ptr exp) { + // Creating prototype + // setup arguments type + std::vector ArgV; + + // setup return type + llvm::FunctionType *FT; + switch (exp->result_type) { + case TYPE::BOOL_T: + FT = llvm::FunctionType::get(llvm::Type::getInt1Ty(*TheContext), ArgV, + false); + break; + case TYPE::INT_T: + default: + FT = llvm::FunctionType::get(llvm::Type::getInt32Ty(*TheContext), ArgV, + false); + break; + } + // creating function + llvm::Function *F = llvm::Function::Create( + FT, llvm::Function::ExternalLinkage, + "___anonymous_expression_evaluator_function", TheModule.get()); + + llvm::BasicBlock *BB = llvm::BasicBlock::Create(*TheContext, "_entry", F); + Builder->SetInsertPoint(BB); + + llvm::Value *RetVal = exp->generate_llvm_ir(); + // Finish off the function. + Builder->CreateRet(RetVal); + // F->print(llvm::errs(), nullptr); + // Validate the generated code, checking for consistency. + assert(!verifyFunction(*F)); + + // Create a ResourceTracker to track JIT'd memory allocated to our + // anonymous expression -- that way we can free it after executing. + llvm::orc::ResourceTrackerSP RT = + TheJIT->getMainJITDylib().createResourceTracker(); + + llvm::orc::ThreadSafeModule TSM = llvm::orc::ThreadSafeModule( + std::move(TheModule), std::move(TheContext)); + ExitOnErr(TheJIT->addModule(std::move(TSM), RT)); + jit(); + + // Search the JIT for the __anon_expr symbol. + llvm::orc::ExecutorSymbolDef ExprSymbol = + ExitOnErr(TheJIT->lookup("___anonymous_expression_evaluator_function")); + + // Get the symbol's address and cast it to the right type (takes no + // arguments, returns a double) so we can call it as a native function. + switch (exp->result_type) { + case INT_T: { + int (*FP)() = ExprSymbol.getAddress().toPtr(); + std::cout << FP() << "\n"; + break; + } + case BOOL_T: { + bool (*FP)() = ExprSymbol.getAddress().toPtr(); + std::cout << (FP() ? "true" : "false") << "\n"; + break; + } + } + + // Delete the anonymous expression module from the JIT. + ExitOnErr(RT->remove()); +} +} // namespace jit diff --git a/src/JIT.hh b/src/JIT.hh new file mode 100644 index 0000000..971f15f --- /dev/null +++ b/src/JIT.hh @@ -0,0 +1,84 @@ +#pragma once + +#include "AST.hh" +#include "PCH.hh" + +class KariLangJIT { + private: + std::unique_ptr ES; + + llvm::DataLayout DL; + llvm::orc::MangleAndInterner Mangle; + + llvm::orc::RTDyldObjectLinkingLayer ObjectLayer; + llvm::orc::IRCompileLayer CompileLayer; + + llvm::orc::JITDylib &MainJD; + + public: + KariLangJIT(std::unique_ptr ES, + llvm::orc::JITTargetMachineBuilder JTMB, llvm::DataLayout DL) + : ES(std::move(ES)), DL(std::move(DL)), Mangle(*this->ES, this->DL), + ObjectLayer( + *this->ES, + []() { return std::make_unique(); }), + CompileLayer(*this->ES, ObjectLayer, + std::make_unique( + std::move(JTMB))), + MainJD(this->ES->createBareJITDylib("
")) { + MainJD.addGenerator(cantFail( + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + DL.getGlobalPrefix()))); + if (JTMB.getTargetTriple().isOSBinFormatCOFF()) { + ObjectLayer.setOverrideObjectFlagsWithResponsibilityFlags(true); + ObjectLayer.setAutoClaimResponsibilityForObjectSymbols(true); + } + } + + ~KariLangJIT() { + if (auto Err = ES->endSession()) + ES->reportError(std::move(Err)); + } + + static llvm::Expected> Create() { + auto EPC = llvm::orc::SelfExecutorProcessControl::Create(); + if (!EPC) + return EPC.takeError(); + + auto ES = + std::make_unique(std::move(*EPC)); + + llvm::orc::JITTargetMachineBuilder JTMB( + ES->getExecutorProcessControl().getTargetTriple()); + + auto DL = JTMB.getDefaultDataLayoutForTarget(); + if (!DL) + return DL.takeError(); + + return std::make_unique(std::move(ES), std::move(JTMB), + std::move(*DL)); + } + + const llvm::DataLayout &getDataLayout() const { return DL; } + + llvm::orc::JITDylib &getMainJITDylib() { return MainJD; } + + llvm::Error addModule(llvm::orc::ThreadSafeModule TSM, + llvm::orc::ResourceTrackerSP RT = nullptr) { + if (!RT) + RT = MainJD.getDefaultResourceTracker(); + return CompileLayer.add(RT, std::move(TSM)); + } + + llvm::Expected lookup(llvm::StringRef Name) { + return ES->lookup({&MainJD}, Mangle(Name.str())); + } +}; + +extern std::unique_ptr TheJIT; +extern llvm::ExitOnError ExitOnErr; + +namespace jit { +void JIT_Expression(std::unique_ptr exp); +int jit(); +} // namespace jit diff --git a/src/Lexer.l b/src/Lexer.l new file mode 100644 index 0000000..cee22be --- /dev/null +++ b/src/Lexer.l @@ -0,0 +1,51 @@ +%option noyywrap + +%{ + #include "Parser.tab.hh" + yy::parser::location_type loc; + #define YY_USER_ACTION \ + do { \ + loc.columns(yyleng); \ + } while (false); +%} + +%% +"exit" { exit(0); } +"valdef" { return yy::parser::make_KW_VALDEF(loc); } +"funcdef" { return yy::parser::make_KW_FUNCDEF(loc); } +"bool" { return yy::parser::make_KW_BOOL(loc); } +"int" { return yy::parser::make_KW_INT(loc); } +"true" { return yy::parser::make_KW_TRUE(loc); } +"false" { return yy::parser::make_KW_FALSE(loc); } +"if" { return yy::parser::make_KW_IF(loc); } +"then" { return yy::parser::make_KW_THEN(loc); } +"else" { return yy::parser::make_KW_ELSE(loc); } +";" { return yy::parser::make_STATEMENT_END(loc); } +":" { return yy::parser::make_TYPE_OF(loc); } +"," { return yy::parser::make_COMMA(loc); } +"+" { return yy::parser::make_PLUS(loc); } +"-" { return yy::parser::make_MINUS(loc); } +"*" { return yy::parser::make_MULTIPLY(loc); } +"/" { return yy::parser::make_DIVIDE(loc); } +"%" { return yy::parser::make_MODULO(loc); } +"&&" { return yy::parser::make_AND(loc); } +"||" { return yy::parser::make_OR(loc); } +"!" { return yy::parser::make_NOT(loc); } +"==" { return yy::parser::make_EQUALS(loc); } +"!=" { return yy::parser::make_NOT_EQUALS(loc); } +">" { return yy::parser::make_GREATER(loc); } +"<" { return yy::parser::make_LESSER(loc); } +">=" { return yy::parser::make_GREATER_EQUALS(loc); } +"<=" { return yy::parser::make_LESSER_EQUALS(loc); } +"->" { return yy::parser::make_RETURN(loc); } +"(" { return yy::parser::make_OPEN_BRACKETS(loc); } +")" { return yy::parser::make_CLOSE_BRACKETS(loc); } +"=" { return yy::parser::make_ASSIGN(loc); } +[0-9]* { return yy::parser::make_INTEGER(std::stoi(yytext), loc); } +[a-zA-Z_][0-9a-zA-Z_]* { return yy::parser::make_IDENTIFIER(yytext, loc); } +[ \t]+ { loc.step(); } +[\n] { loc.lines(yyleng); loc.step(); } +\/\/.+ { loc.step(); } +. { std::cerr << "unexpected character: " << yytext << '\n'; num_errors += 1; } +<> { return yy::parser::make_EOF(loc); } +%% diff --git a/src/Main.cc b/src/Main.cc new file mode 100644 index 0000000..9e00446 --- /dev/null +++ b/src/Main.cc @@ -0,0 +1,100 @@ +#include "AST.hh" +#include "Compile.hh" +#include "JIT.hh" + +std::unordered_map> functions_ast; +std::unordered_map> globals_ast; + +int parse( + std::string filename, LANGUAGE_ACTION_TYPE comp_flags, + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast); + +int interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + int input); + +int main(int argc, char *argv[]) { + int status; + + if ((argc == 3) && (std::string(argv[2]) == "-c")) { + status = parse(argv[1], LANGUAGE_ACTION_TYPE::COMPILE_FILE, + functions_ast, globals_ast); + } else if ((argc == 2) && ((std::string(argv[1]) == "-jit") || + (std::string(argv[1]) == "--jit"))) { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + TheJIT = ExitOnErr(KariLangJIT::Create()); + jit::jit(); + std::cout << ">>> "; + return parse("", LANGUAGE_ACTION_TYPE::INTERACTIVE_COMPILE, + functions_ast, globals_ast); + } else if (argc == 1) { + std::cout << ">>> "; + return parse("", LANGUAGE_ACTION_TYPE::INTERACTIVE_INTERPRET, + functions_ast, globals_ast); + } else { + std::cerr << "File and input required to execute the program\n"; + std::cerr << "Or pass \"-c\" flag to compile\n"; + return 1; + } + + if (status) { + std::cerr << "Error while parsing. Exiting\n"; + return 1; + } + + // verify semantics + for (auto &i : functions_ast) { + if (!i.second->verify_semantics(functions_ast, globals_ast)) { + std::cerr << "Invalid semantics for function " << i.first << "\n\t" + << i.second << "\n"; + return 1; + } + } + for (auto &i : globals_ast) { + if (!i.second->verify_semantics(functions_ast, globals_ast)) { + std::cerr << "Invalid semantics for variable " << i.first << "\n\t" + << i.second << "\n"; + return 1; + } + } + + if (std::string(argv[2]) == "-c") + return Compile(argv[1], functions_ast, globals_ast); + return interpret(functions_ast, globals_ast, std::stoi(argv[2])); +} + +int interpret( + std::unordered_map> + &functions_ast, + std::unordered_map> &globals_ast, + int input) { + // interpret + if (functions_ast.find("main") != functions_ast.end()) { + std::unique_ptr &func_main = functions_ast.at("main"); + + if ((func_main->args_name.size() != 1) && + (func_main->args_type.at(0) != INT_T) && + (func_main->return_type != INT_T)) { + std::cerr << "main function can take only 1 int type argument and " + "should return int\n"; + return 1; + } + + std::unordered_map> context; + context.insert({func_main->args_name.at(0), input}); + int res = std::get( + func_main->interpret(functions_ast, globals_ast, context)); + std::cout << "Input: " << input << "\nOutput: " << res << "\n"; + } else { + std::cerr << "Error: Could not find main function\n"; + return 1; + } + + return 0; +} diff --git a/src/PCH.hh b/src/PCH.hh new file mode 100644 index 0000000..a714703 --- /dev/null +++ b/src/PCH.hh @@ -0,0 +1,47 @@ +#pragma once + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/TargetParser/Host.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/src/Parser.hh b/src/Parser.hh new file mode 100644 index 0000000..e3eea03 --- /dev/null +++ b/src/Parser.hh @@ -0,0 +1,117 @@ +#include "AST.hh" +#include "Compile.hh" +#include "JIT.hh" +#include "PCH.hh" +#include "Utils.hh" + +#define PROMPT "\n>>> " + +inline void handle_expressions(LANGUAGE_ACTION_TYPE flags, + std::unique_ptr exp) { + + if (flags & KARILANG_INTERACTIVE) { + std::unordered_map semantics_context; + TYPE result_type = exp->deduce_result_type(); + if (!exp->verify_semantics(result_type, functions_ast, globals_ast, + semantics_context)) { + std::cerr << "Invalid semantics for the given expression\n" + << PROMPT; + return; + } + } else { + std::cerr << "Cannot have expressions at top level\n"; + return; + } + + if (flags == INTERACTIVE_INTERPRET) { + std::unordered_map> + interpret_context; + + std::variant res = + exp->interpret(functions_ast, globals_ast, interpret_context); + switch (exp->result_type) { + case INT_T: + std::cout << std::get(res); + break; + case BOOL_T: + std::cout << (std::get(res) ? "true" : "false"); + break; + } + + } else { + jit::JIT_Expression(std::move(exp)); + } + + if (flags & KARILANG_INTERACTIVE) + std::cout << PROMPT; +} + +inline void handle_valdef(LANGUAGE_ACTION_TYPE flags, + std::unique_ptr _val) { + const std::string &name = _val->name; + + if (globals_ast.find(name) != globals_ast.end()) { + std::cerr << "Cannot rewrite variable. Consider changing " + "variable name\n"; + if (flags & KARILANG_INTERACTIVE) + std::cout << PROMPT; + } + + globals_ast.insert({_val->name, std::move(_val)}); + std::unique_ptr &val = globals_ast.at(name); + + // verify semantics if interactive + if (flags & KARILANG_INTERACTIVE) { + if (!val->verify_semantics(functions_ast, globals_ast)) { + std::cerr << "Invalid semantics for the given variable\n" << PROMPT; + globals_ast.erase(name); + return; + } + } + + if (flags == INTERACTIVE_COMPILE) { + val->generate_llvm_ir(); + ExitOnErr(TheJIT->addModule(llvm::orc::ThreadSafeModule( + std::move(TheModule), std::move(TheContext)))); + jit::jit(); + } + + if (flags & KARILANG_INTERACTIVE) + std::cout << val << "\n" << PROMPT; +} + +inline void handle_funcdef(LANGUAGE_ACTION_TYPE flags, + std::unique_ptr _func) { + const std::string &name = _func->name; + + if (functions_ast.find(name) != functions_ast.end()) { + std::cerr << "Cannot rewrite variable. Consider changing " + "function name\n"; + if (flags & KARILANG_INTERACTIVE) + std::cout << PROMPT; + } + + functions_ast.insert({_func->name, std::move(_func)}); + std::unique_ptr &func = functions_ast.at(name); + + // verify semantics if interactive + if (flags & KARILANG_INTERACTIVE) { + if (!func->verify_semantics(functions_ast, globals_ast)) { + std::cerr << "Invalid semantics for the given function\n" << PROMPT; + functions_ast.erase(name); + return; + } + } + + if (flags == INTERACTIVE_COMPILE) { + FunctionPrototype::generate_llvm_ir(func->name, func->args_name, + func->args_type, func->return_type); + func->generate_llvm_ir(); + ExitOnErr(TheJIT->addModule(llvm::orc::ThreadSafeModule( + std::move(TheModule), std::move(TheContext)))); + jit::jit(); + } + + if (flags & KARILANG_INTERACTIVE) + std::cout << func << "\n" << PROMPT; +} diff --git a/src/Parser.yy b/src/Parser.yy new file mode 100644 index 0000000..0c422da --- /dev/null +++ b/src/Parser.yy @@ -0,0 +1,167 @@ +%language "C++" +%require "3.2" + + +%define api.value.type variant +%define api.token.constructor +%define api.token.prefix {TOK_} +%define parse.error verbose +%define parse.trace + +%locations + +%param { int& num_errors } +%param { LANGUAGE_ACTION_TYPE comp_flags } +%param { std::unordered_map> &functions_ast } +%param { std::unordered_map> &globals_ast } + +%code provides { + #define YY_DECL \ + yy::parser::symbol_type yylex( \ + int& num_errors, \ + LANGUAGE_ACTION_TYPE comp_flags, \ + std::unordered_map> &functions_ast, \ + std::unordered_map> &globals_ast \ + ) + + YY_DECL; +} + +%code requires { + #include "Parser.hh" +} + +%code { + #include + extern FILE *yyin; +} + +%printer { yyo << $$; } ; +%printer { yyo << $$; } ; +%printer { yyo << $$; } >; + +%token KW_VALDEF +%token KW_FUNCDEF +%token KW_BOOL +%token KW_INT +%token KW_TRUE +%token KW_FALSE +%token KW_IF +%token KW_THEN +%token KW_ELSE +%token COMMA +%token STATEMENT_END +%token TYPE_OF +%token ASSIGN +%token PLUS +%token MINUS +%token MULTIPLY +%token DIVIDE +%token MODULO +%token OPEN_BRACKETS +%token CLOSE_BRACKETS +%token AND +%token OR +%token NOT +%token EQUALS +%token NOT_EQUALS +%token GREATER +%token GREATER_EQUALS +%token LESSER +%token LESSER_EQUALS +%token RETURN +%token INTEGER +%token IDENTIFIER + +%token EOF 0 "end-of-file" + +%type > basic_expression; +%type > expression; +%type function_call_arguments; +%type > value_definition; +%type > function_definition; +%type > function_definition_arguments; + +%precedence KW_ELSE +%left AND OR +%left EQUALS NOT_EQUALS GREATER GREATER_EQUALS LESSER LESSER_EQUALS +%left PLUS +%left MULTIPLY DIVIDE +%left MODULO +%precedence MINUS +%precedence NOT + +%% +input: %empty + | input expression STATEMENT_END { handle_expressions(comp_flags, std::move($2)); } + | input value_definition { handle_valdef(comp_flags, std::move($2)); } + | input function_definition { handle_funcdef(comp_flags, std::move($2)); } + | input error STATEMENT_END { + if (comp_flags & KARILANG_INTERACTIVE) + std::cout << PROMPT; + else + exit(1); + }; + +function_definition: KW_FUNCDEF IDENTIFIER function_definition_arguments RETURN KW_BOOL ASSIGN expression STATEMENT_END { ($3)->set_info($2, BOOL_T, std::move($7)); $$ = std::move($3); } + | KW_FUNCDEF IDENTIFIER function_definition_arguments RETURN KW_INT ASSIGN expression STATEMENT_END { ($3)->set_info($2, INT_T, std::move($7)); $$ = std::move($3); }; + +function_definition_arguments: IDENTIFIER TYPE_OF KW_BOOL { $$ = FunctionDef::from(); ($$)->add_argument(BOOL_T, $1); } + | IDENTIFIER TYPE_OF KW_INT { $$ = FunctionDef::from(); ($$)->add_argument(INT_T, $1); } + | IDENTIFIER TYPE_OF KW_BOOL function_definition_arguments { ($4)->add_argument(BOOL_T, $1); $$ = std::move($4); } + | IDENTIFIER TYPE_OF KW_INT function_definition_arguments { ($4)->add_argument(INT_T, $1); $$ = std::move($4); } + | OPEN_BRACKETS function_definition_arguments CLOSE_BRACKETS { $$ = std::move($2); }; + +value_definition: KW_VALDEF IDENTIFIER TYPE_OF KW_BOOL ASSIGN expression STATEMENT_END { $$ = ValueDef::from(BOOL_T, $2, std::move($6)); } + | KW_VALDEF IDENTIFIER TYPE_OF KW_INT ASSIGN expression STATEMENT_END { $$ = ValueDef::from(INT_T, $2, std::move($6)); }; + +basic_expression: IDENTIFIER { $$ = Expression::from($1); } + | INTEGER { $$ = Expression::from($1); } + | KW_TRUE { $$ = Expression::from(true); } + | KW_FALSE { $$ = Expression::from(false); } + | OPEN_BRACKETS expression CLOSE_BRACKETS { $$ = std::move($2); }; + +expression: basic_expression { $$ = std::move($1); } + | AND basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), AND_OP); } + | OR basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), OR_OP); } + | NOT basic_expression { $$ = Expression::from(std::move($2), NOT_OP); } + | PLUS basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), ADD_OP); } + | MULTIPLY basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), MUL_OP); } + | DIVIDE basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), DIV_OP); } + | MODULO basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), MOD_OP); } + | MINUS basic_expression { $$ = Expression::from(std::move($2), NEG_OP); } + | EQUALS basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), EQS_OP); } + | NOT_EQUALS basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), NEQ_OP); } + | GREATER basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), GT_OP); } + | GREATER_EQUALS basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), GTE_OP); } + | LESSER basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), LT_OP); } + | LESSER_EQUALS basic_expression basic_expression { $$ = Expression::from(std::move($2), std::move($3), LTE_OP); } + | KW_IF expression KW_THEN expression KW_ELSE expression { $$ = Expression::from(std::move($2), std::move($4), std::move($6)); } + | IDENTIFIER function_call_arguments { ($2).set_function_name($1); $$ = Expression::from(std::move($2)); }; + +function_call_arguments: basic_expression { $$ = FunctionCall(); ($$).add_argument(std::move($1)); } + | basic_expression function_call_arguments { ($2).add_argument(std::move($1)); $$ = std::move($2); }; +%% + +void yy::parser::error(const location_type& loc, const std::string& s) { + std::cerr << loc << ": " << s << '\n'; +} + +int parse( + std::string filename, LANGUAGE_ACTION_TYPE comp_flags, + std::unordered_map> &functions_ast, + std::unordered_map> &globals_ast) { + if (filename != "") { + FILE *file = fopen(filename.c_str(), "r"); + if (file == NULL) { + std::cerr << "Could not open file " << filename << "\n"; + return 1; + } + yyin = file; + } + + auto num_errors = 0; + yy::parser parser(num_errors, comp_flags, functions_ast, globals_ast); + auto status = parser.parse(); + return status; +} diff --git a/src/Utils.hh b/src/Utils.hh new file mode 100644 index 0000000..fa3dd32 --- /dev/null +++ b/src/Utils.hh @@ -0,0 +1,25 @@ +#pragma once + +#include "PCH.hh" + +template +std::ostream &operator<<(std::ostream &os, const std::unique_ptr &x) { + return os << *x; +} + +class FunctionDef; +class ValueDef; + +extern std::unordered_map> + functions_ast; +extern std::unordered_map> globals_ast; + +#define KARILANG_INTERACTIVE 0b1 +#define KARILANG_COMPILED 0b10 + +enum LANGUAGE_ACTION_TYPE { + INTERPRET_FILE = 0, + INTERACTIVE_INTERPRET = KARILANG_INTERACTIVE, + COMPILE_FILE = KARILANG_COMPILED, + INTERACTIVE_COMPILE = KARILANG_INTERACTIVE ^ KARILANG_COMPILED, +}; diff --git a/src/cli_interpreter.h b/src/cli_interpreter.h deleted file mode 100644 index 0bd93c3..0000000 --- a/src/cli_interpreter.h +++ /dev/null @@ -1,95 +0,0 @@ -#include "common.h" -#include - -typedef struct { - size_t arglen; - Argument *args; -} Context; - -ExpressionResult evaluate_expression(Expression *exp, Context *cxt); -bool verify_expression_type(Expression *exp, Type type, Context *cxt); -bool verify_ast_semantics(AST *tree); -static inline int my_print(FILE *file, const char *msg, ...); - -static inline bool cli_interpret(AST tree) { - if (tree.type == AST_VARIABLE) { - if (ast_table_get_ptr(ast, tree.value.var->name)) { - ast_table_delete(ast, tree.value.var->name); - global_table_delete(globals, tree.value.var->name); - errno = 0; - } - if (!ast_table_insert(ast, tree.value.var->name, tree)) { - my_print(stderr, "AST insertion Error\n"); - errno = 0; - return false; - } - } else if (tree.type == AST_FUNCTION) { - if (ast_table_get_ptr(ast, tree.value.func->funcname)) { - ast_table_delete(ast, tree.value.func->funcname); - // ???: is the below deletion necessary here - global_table_delete(globals, tree.value.func->funcname); - errno = 0; - } - if (!ast_table_insert(ast, tree.value.func->funcname, tree)) { - my_print(stderr, "AST insertion Error\n"); - errno = 0; - return false; - } - } else { - if (verify_expression_type(tree.value.exp, BOOL, NULL)) { - my_print(stdout, evaluate_expression(tree.value.exp, NULL).boolean - ? "true\n" - : "false\n"); - return true; - } - if (verify_expression_type(tree.value.exp, INT, NULL)) { - my_print(stdout, "%d\n", - evaluate_expression(tree.value.exp, NULL).integer); - return true; - } - - my_print(stderr, - "Error While Evaluation Expression\nSemantic Error: %s\n", - semantic_error_msg); - semantic_error_msg[0] = 0; - return false; - } - - if (!verify_ast_semantics(&tree)) { - my_print(stderr, "Semantic Error: %s\n", semantic_error_msg); - semantic_error_msg[0] = 0; - return false; - } - - if (tree.type == AST_VARIABLE) { - assert(global_table_insert( - globals, tree.value.var->name, - (global){.value = - evaluate_expression(tree.value.var->expression, NULL), - .type = tree.value.var->type})); - } - return true; -} - -static inline int my_print(FILE *file, const char *msg, ...) { - va_list args; - va_start(args, msg); - int len = 0; - - if ((file == stdout) && (STDOUT_REDIRECT_STRING)) { - if (STDOUT_REDIRECT_STRING) - len = vsnprintf(STDOUT_REDIRECT_STRING, STDOUT_STRING_LENGTH, msg, - args); - else - len = vprintf(msg, args); - } else { - if (STDERR_REDIRECT_STRING) - len = vsnprintf(STDERR_REDIRECT_STRING, STDERR_STRING_LENGTH, msg, - args); - else - len = vfprintf(stderr, msg, args); - } - - va_end(args); - return len; -} diff --git a/src/common.h b/src/common.h deleted file mode 100644 index e6fde48..0000000 --- a/src/common.h +++ /dev/null @@ -1,475 +0,0 @@ -#pragma once - -#define YYERROR_VERBOSE 1 - -#include "DS.h" -#include -#include -#include - -extern FILE *yyin; -extern int yylex(void); -extern int yyparse(void); -extern void yyerror(char const *s); -extern int yylineno; -extern int column; -extern char *yytext; -extern const char *filename; - -extern bool cli_interpretation_mode; - -extern char syntax_error_msg[]; - -typedef enum { - UNDEFINED, - INTEGER_EXPRESSION, - VARIABLE_EXPRESSION, - BOOLEAN_EXPRESSION, - PLUS_EXPRESSION, - MINUS_EXPRESSION, - MULTIPLY_EXPRESSION, - DIVIDE_EXPRESSION, - MODULO_EXPRESSION, - AND_EXPRESSION, - OR_EXPRESSION, - NOT_EXPRESSION, - EQUALS_EXPRESSION, - NOT_EQUALS_EXPRESSION, - GREATER_EXPRESSION, - GREATER_EQUALS_EXPRESSION, - LESSER_EXPRESSION, - LESSER_EQUALS_EXPRESSION, - IF_EXPRESSION, - FUNCTION_CALL_EXPRESSION, -} ExpressionType; - -typedef enum { - BOOL, - INT, -} Type; - -typedef union _ExpressionValue ExpressionValue; -typedef struct _Expression Expression; -typedef struct _Variable Variable; -typedef struct _Function Function; - -struct _Variable { - Type type; - const char *name; - Expression *expression; -}; - -typedef struct { - const char *name; - Type type; -} Argument; - -struct _Function { - const char *funcname; - Type return_type; - Expression *expression; - size_t arglen; - Argument args[]; -}; - -union _ExpressionValue { - int integer; - const char *variable; - bool boolean; - struct { - Expression *fst; - } unary; - struct { - Expression *fst; - Expression *snd; - } binary; - struct { - Expression *condition; - Expression *yes; - Expression *no; - } if_statement; - struct { - const char *funcname; - size_t arglen; - Expression **args; - } function_call; -}; - -struct _Expression { - // Type result_type; - ExpressionType type; - ExpressionValue value; -}; - -/* AST for Semantic Analysis and Evaluation */ - -typedef enum { - AST_VARIABLE, - AST_FUNCTION, - AST_EXPRESSION, -} AST_TYPE; - -struct _AST { - AST_TYPE type; - bool semantically_correct; - union { - Function *func; - Variable *var; - Expression *exp; - } value; -}; - -typedef struct _AST AST; - -DS_TABLE_DEC(ast, AST); - -extern ast_table_t *ast; - -typedef union { - int integer; - bool boolean; -} ExpressionResult; - -typedef struct { - Type type; - ExpressionResult value; -} global; - -size_t hash_function(const char *str); - -static inline void free_global(global gb) {} - -DS_TABLE_DEC(global, global); - -extern global_table_t *globals; - -extern char semantic_error_msg[]; -bool verify_semantics(); - -extern char runtime_error_msg[]; -bool interpret(int input, int *output); - -#define STDOUT_STRING_LENGTH 500 -#define STDERR_STRING_LENGTH 500 - -extern char *STDOUT_REDIRECT_STRING; -extern char *STDERR_REDIRECT_STRING; - -// FIXME: Check if memory allocations fail - -static inline const char *const Type_to_string(Type type) { - // FIXME: should not be static inlined? - switch (type) { - case BOOL: - return "bool"; - case INT: - return "int"; - default: - return "UNDEFINED TYPE"; - } -} - -static inline Function *make_function() { - return calloc(1, sizeof(Function) + sizeof(Argument)); -} - -static inline Function *set_function_name(Function *func, - const char *funcname) { - func->funcname = funcname; - return func; -} - -static inline Function *set_function_return_value(Function *func, Type type, - Expression *exp) { - func->return_type = type; - func->expression = exp; - return func; -} - -static inline Function *add_function_argument(Function *func, - const char *argname, Type type) { - if (func->arglen == 0) { - func->args[0] = (Argument){.type = type, .name = argname}; - func->arglen = 1; - return func; - } - func = - realloc(func, sizeof(Function) + sizeof(Argument) * (func->arglen + 1)); - func->args[func->arglen] = (Argument){.type = type, .name = argname}; - func->arglen += 1; - return func; -} - -static inline Expression *make_function_call_expression() { - Expression *result = malloc(sizeof(Expression)); - *result = (Expression){.type = FUNCTION_CALL_EXPRESSION, - .value.function_call.args = - calloc(1, sizeof(Expression *))}; - return result; -} - -static inline Expression * -set_function_call_name_expression(Expression *func, const char *funcname) { - assert(func->type == FUNCTION_CALL_EXPRESSION); - - func->value.function_call.funcname = funcname; - return func; -} - -static inline Expression * -add_function_call_argument_expression(Expression *func, Expression *exp) { - assert(func->type == FUNCTION_CALL_EXPRESSION); - -#define FUNC func->value.function_call - if (FUNC.arglen == 0) { - FUNC.args[0] = exp; - FUNC.arglen = 1; - return func; - } - FUNC.args = realloc(FUNC.args, sizeof(Expression *) * (FUNC.arglen + 1)); - FUNC.args[FUNC.arglen] = exp; - FUNC.arglen += 1; -#undef FUNC - return func; -} - -static inline Variable *make_variable(const char *varname, Type type, - Expression *exp) { - Variable *result = malloc(sizeof(Variable)); - *result = (Variable){.type = type, .name = varname, .expression = exp}; - return result; -} - -static inline Expression *make_integer_expression(int n) { - Expression *result = malloc(sizeof(Expression)); - *result = (Expression){.type = INTEGER_EXPRESSION, .value.integer = n}; - return result; -} - -static inline Expression *make_variable_expression(const char *varname) { - Expression *result = malloc(sizeof(Expression)); - *result = - (Expression){.type = VARIABLE_EXPRESSION, .value.variable = varname}; - return result; -} - -static inline Expression *make_boolean_expression(bool b) { - Expression *result = malloc(sizeof(Expression)); - *result = (Expression){.type = BOOLEAN_EXPRESSION, .value.boolean = b}; - return result; -} - -static inline Expression * -make_binary_expression(Expression *fst, Expression *snd, ExpressionType type) { - Expression *result = malloc(sizeof(Expression)); - *result = (Expression){ - .type = type, .value.binary.fst = fst, .value.binary.snd = snd}; - return result; -} - -static inline Expression *make_unary_expression(Expression *fst, - ExpressionType type) { - Expression *result = malloc(sizeof(Expression)); - *result = (Expression){.type = type, .value.unary.fst = fst}; - return result; -} - -static inline Expression *make_if_expression(Expression *condition, - Expression *yes, Expression *no) { - Expression *result = malloc(sizeof(Expression)); - *result = (Expression){.type = IF_EXPRESSION, - .value.if_statement.condition = condition, - .value.if_statement.yes = yes, - .value.if_statement.no = no}; - return result; -} - -static inline void print_expression(Expression *exp) { - // FIXME: should not be static inlined? - ExpressionType type = exp->type; - ExpressionValue value = exp->value; - - switch (type) { - case INTEGER_EXPRESSION: - printf("%d", value.integer); - break; - case VARIABLE_EXPRESSION: - printf("VariableName: %s", value.variable); - break; - case BOOLEAN_EXPRESSION: - printf("%s", value.boolean ? "true" : "false"); - break; - case PLUS_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" + "); - print_expression(value.binary.snd); - printf(")"); - break; - case MINUS_EXPRESSION: - printf("(-"); - print_expression(value.unary.fst); - printf(")"); - break; - case MULTIPLY_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" * "); - print_expression(value.binary.snd); - printf(")"); - break; - case DIVIDE_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" / "); - print_expression(value.binary.snd); - printf(")"); - break; - case MODULO_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" %% "); - print_expression(value.binary.snd); - printf(")"); - break; - case AND_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" && "); - print_expression(value.binary.snd); - printf(")"); - break; - case OR_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" || "); - print_expression(value.binary.snd); - printf(")"); - break; - case NOT_EXPRESSION: - printf("(!"); - print_expression(value.unary.fst); - printf(")"); - break; - case EQUALS_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" == "); - print_expression(value.binary.snd); - printf(")"); - break; - case NOT_EQUALS_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" != "); - print_expression(value.binary.snd); - printf(")"); - break; - case GREATER_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" > "); - print_expression(value.binary.snd); - printf(")"); - break; - case GREATER_EQUALS_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" >= "); - print_expression(value.binary.snd); - printf(")"); - break; - case LESSER_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" < "); - print_expression(value.binary.snd); - printf(")"); - break; - case LESSER_EQUALS_EXPRESSION: - printf("("); - print_expression(value.binary.fst); - printf(" <= "); - print_expression(value.binary.snd); - printf(")"); - break; - case IF_EXPRESSION: - printf("Conditions: "); - print_expression(value.if_statement.condition); - printf("\n\tIfTrue: "); - print_expression(value.if_statement.yes); - printf("\n\tIfFalse: "); - print_expression(value.if_statement.no); - break; - case FUNCTION_CALL_EXPRESSION: - printf("FunctionCallTo: %s", value.function_call.funcname); - for (int i = value.function_call.arglen - 1; i >= 0; i--) { - printf("\n\tArg: "); - print_expression(value.function_call.args[i]); - } - break; - default: - printf("Found Undefined Expression Type"); - } -} - -static inline void print_variable(Variable *var) { - // FIXME: should not be static inlined? - printf("VariableName: %s, VariableType: %s, Expression: ", var->name, - Type_to_string(var->type)); - print_expression(var->expression); -} - -static inline void print_arguments(Argument *args, size_t arglen) { - for (int i = arglen - 1; i >= 0; i--) { - printf("\tArgumentName: %s, ArgumentType: %s\n", args[i].name, - Type_to_string(args[i].type)); - } -} - -static inline void print_function(Function *func) { - // FIXME: should not be static inlined? - printf("FunctionName: %s, FunctionReturnType: %s\n", func->funcname, - Type_to_string(func->return_type)); - print_arguments(func->args, func->arglen); - printf("\t\tExpression: "); - print_expression(func->expression); -} - -static inline void print_ast_table(ast_table_t *ast) { - // FIXME: should not be static inlined? - char *key; - ast_table_iter(ast); - AST *tree = ast_table_iter_next(ast, &key); - for (size_t i = 0; i++ < ast_table_size(ast); - tree = ast_table_iter_next(ast, &key)) { - switch (tree->type) { - case AST_FUNCTION: - print_function(tree->value.func); - break; - case AST_VARIABLE: - print_variable(tree->value.var); - break; - case AST_EXPRESSION: - // TODO: print expression - break; - } - printf("\n"); - } -} - -static inline void clear_expression(Expression *exp) { - // TODO: implement -} - -static inline void clear_variable(Variable *var) { - // TODO: implement -} - -static inline void clear_function(Function *func) { - // TODO: implement -} - -static inline void clear_ast(AST tree) { - // TODO: implement -} diff --git a/src/interpreter.c b/src/interpreter.c deleted file mode 100644 index cb2c340..0000000 --- a/src/interpreter.c +++ /dev/null @@ -1,227 +0,0 @@ -#include "DS.h" -#include "common.h" -#include -#include -#include -#include - -#define ERROR_MSG_LEN 500 - -char runtime_error_msg[ERROR_MSG_LEN]; - -struct _context { - const char *var_name; - ExpressionResult var_value; -}; - -typedef struct { - size_t len; - struct _context *variable; -} Context; - -DS_TABLE_DEF(global, global, free_global); - -ExpressionResult evaluate_expression(Expression *exp, Context *cxt); -ExpressionResult execute_function_call(Function *func, Expression **args, - Context *cxt); - -global_table_t *globals; - -bool interpret(int input, int *output) { - // initialize global variables table - globals = global_table_new(100); - - Function *main_func = NULL; - char *key; - AST *tree; - ast_table_iter(ast); - - while (NULL != (tree = ast_table_iter_next(ast, &key))) { - switch (tree->type) { - case AST_VARIABLE: - if (global_table_get_ptr(globals, tree->value.var->name)) { - break; - } - errno = 0; - - assert(global_table_insert( - globals, tree->value.var->name, - (global){.value = evaluate_expression( - tree->value.var->expression, NULL), - .type = tree->value.var->type})); - break; - - case AST_FUNCTION: - if (!strcmp(tree->value.func->funcname, "main")) { - main_func = tree->value.func; - if (main_func->return_type != INT) { - snprintf(runtime_error_msg, ERROR_MSG_LEN, "%s", - "'main' function should return an integer"); - return false; - } - if ((main_func->arglen != 1) || - (main_func->args[0].type != INT)) { - snprintf( - runtime_error_msg, ERROR_MSG_LEN, "%s", - "'main' function should have only 1 integer argument"); - return false; - } - } - break; - case AST_EXPRESSION: - snprintf(syntax_error_msg, ERROR_MSG_LEN, "Internal Error"); - // TODO: clean memory - return false; - } - } - - if (!main_func) { - snprintf(runtime_error_msg, ERROR_MSG_LEN, "%s", - "Could not find 'main' function"); - return false; - } - - struct _context args[1] = { - {.var_name = main_func->args[0].name, - .var_value = (ExpressionResult){.integer = input}}}; - Context cxt = {.len = 1, .variable = args}; - - *output = evaluate_expression(main_func->expression, &cxt).integer; - return true; -} - -ExpressionResult evaluate_expression(Expression *exp, Context *cxt) { - switch (exp->type) { - case INTEGER_EXPRESSION: - return (ExpressionResult){.integer = exp->value.integer}; - case BOOLEAN_EXPRESSION: - return (ExpressionResult){.boolean = exp->value.boolean}; - case VARIABLE_EXPRESSION: { - // check the context - if (cxt) { - for (size_t i = 0; i < cxt->len; i++) { - if (!strcmp(cxt->variable[i].var_name, exp->value.variable)) - return cxt->variable[i].var_value; - } - } - - // check the global variables - global *v = global_table_get_ptr(globals, exp->value.variable); - if (v) - return v->value; - errno = 0; - AST *tree = ast_table_get_ptr(ast, exp->value.variable); - if (tree) { - ExpressionResult result_exp = - evaluate_expression(tree->value.var->expression, NULL); - assert(global_table_insert( - globals, tree->value.var->name, - (global){.type = tree->value.var->type, .value = result_exp})); - - return result_exp; - } - - goto error; - } - case PLUS_EXPRESSION: - return (ExpressionResult){ - .integer = evaluate_expression(exp->value.binary.fst, cxt).integer + - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case MINUS_EXPRESSION: - return (ExpressionResult){ - .integer = - -(evaluate_expression(exp->value.unary.fst, cxt).integer)}; - case MULTIPLY_EXPRESSION: - return (ExpressionResult){ - .integer = evaluate_expression(exp->value.binary.fst, cxt).integer * - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case DIVIDE_EXPRESSION: - return (ExpressionResult){ - .integer = evaluate_expression(exp->value.binary.fst, cxt).integer / - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case MODULO_EXPRESSION: - return (ExpressionResult){ - .integer = evaluate_expression(exp->value.binary.fst, cxt).integer % - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case AND_EXPRESSION: - return (ExpressionResult){ - .boolean = - evaluate_expression(exp->value.binary.fst, cxt).boolean && - evaluate_expression(exp->value.binary.snd, cxt).boolean}; - case OR_EXPRESSION: - return (ExpressionResult){ - .boolean = - evaluate_expression(exp->value.binary.fst, cxt).boolean || - evaluate_expression(exp->value.binary.snd, cxt).boolean}; - case NOT_EXPRESSION: - return (ExpressionResult){ - .boolean = - !(evaluate_expression(exp->value.unary.fst, cxt).boolean)}; - case EQUALS_EXPRESSION: - return (ExpressionResult){ - .boolean = - evaluate_expression(exp->value.binary.fst, cxt).integer == - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case NOT_EQUALS_EXPRESSION: - return (ExpressionResult){ - .boolean = - evaluate_expression(exp->value.binary.fst, cxt).integer != - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case GREATER_EXPRESSION: - return (ExpressionResult){ - .boolean = evaluate_expression(exp->value.binary.fst, cxt).integer > - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case GREATER_EQUALS_EXPRESSION: - return (ExpressionResult){ - .boolean = - evaluate_expression(exp->value.binary.fst, cxt).integer >= - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case LESSER_EXPRESSION: - return (ExpressionResult){ - .boolean = evaluate_expression(exp->value.binary.fst, cxt).integer < - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case LESSER_EQUALS_EXPRESSION: - return (ExpressionResult){ - .boolean = - evaluate_expression(exp->value.binary.fst, cxt).integer <= - evaluate_expression(exp->value.binary.snd, cxt).integer}; - case IF_EXPRESSION: - if (evaluate_expression(exp->value.if_statement.condition, cxt).boolean) - return evaluate_expression(exp->value.if_statement.yes, cxt); - return evaluate_expression(exp->value.if_statement.no, cxt); - case FUNCTION_CALL_EXPRESSION: { - Function *f = ast_table_get(ast, exp->value.function_call.funcname) - .value.func; /* ???: direct dereferening the pointer - with checking for NULL. Should be safe - because of the semantic checker */ - return execute_function_call(f, exp->value.function_call.args, cxt); - } - default: - error: - fprintf(stderr, "Error Encounter while interpreting"); - exit(1); - } -} - -ExpressionResult execute_function_call(Function *func, Expression **args, - Context *cxt) { - Context new_context = {.len = func->arglen}; - - new_context.variable = calloc(func->arglen, sizeof(struct _context)); - if (!new_context.variable) { - fprintf(stderr, "Error Encounter while interpreting (Memory Error)"); - exit(1); - } - - for (size_t i = 0; i < new_context.len; i++) { - new_context.variable[i] = - (struct _context){.var_name = func->args[i].name, - .var_value = evaluate_expression(args[i], cxt)}; - } - - ExpressionResult result = - evaluate_expression(func->expression, &new_context); - free(new_context.variable); - - return result; -} diff --git a/src/lexer.l b/src/lexer.l deleted file mode 100644 index 6fa6748..0000000 --- a/src/lexer.l +++ /dev/null @@ -1,57 +0,0 @@ -%{ - #include "parser.tab.h" - #include - #include - - #ifdef _WIN32 - #define YY_NO_UNISTD_H 1 - #endif - - static int next_column = 1; - int column = 1; - - #define HANDLE_COLUMN \ - column = next_column; \ - next_column += strlen(yytext) -%} - -%option noyywrap noinput nounput yylineno - -%% -"valdef" { HANDLE_COLUMN; return KW_VALDEF; } -"funcdef" { HANDLE_COLUMN; return KW_FUNCDEF; } -"bool" { HANDLE_COLUMN; return KW_BOOL; } -"int" { HANDLE_COLUMN; return KW_INT; } -"true" { HANDLE_COLUMN; return KW_TRUE; } -"false" { HANDLE_COLUMN; return KW_FALSE; } -"if" { HANDLE_COLUMN; return KW_IF; } -"then" { HANDLE_COLUMN; return KW_THEN; } -"else" { HANDLE_COLUMN; return KW_ELSE; } -";" { HANDLE_COLUMN; return STATEMENT_END; } -":" { HANDLE_COLUMN; return TYPE_OF; } -"," { HANDLE_COLUMN; return COMMA; } -"+" { HANDLE_COLUMN; return PLUS; } -"-" { HANDLE_COLUMN; return MINUS; } -"*" { HANDLE_COLUMN; return MULTIPLY; } -"/" { HANDLE_COLUMN; return DIVIDE; } -"%" { HANDLE_COLUMN; return MODULO; } -"&&" { HANDLE_COLUMN; return AND; } -"||" { HANDLE_COLUMN; return OR; } -"!" { HANDLE_COLUMN; return NOT; } -"==" { HANDLE_COLUMN; return EQUALS; } -"!=" { HANDLE_COLUMN; return NOT_EQUALS; } -">" { HANDLE_COLUMN; return GREATER; } -"<" { HANDLE_COLUMN; return LESSER; } -">=" { HANDLE_COLUMN; return GREATER_EQUALS; } -"<=" { HANDLE_COLUMN; return LESSER_EQUALS; } -"->" { HANDLE_COLUMN; return RETURN; } -"(" { HANDLE_COLUMN; return OPEN_BRACKETS; } -")" { HANDLE_COLUMN; return CLOSE_BRACKETS; } -"=" { HANDLE_COLUMN; return ASSIGN; } -[0-9]* { HANDLE_COLUMN; yylval.integer = atoi(yytext); return INTEGER; } -[a-zA-Z_][0-9a-zA-Z_]* { HANDLE_COLUMN; yylval.identifier = strdup(yytext); return IDENTIFIER; } -[ \t]+ { HANDLE_COLUMN; } -[\n] { HANDLE_COLUMN; next_column = 1; } -\/\/.+ { ; } -. { HANDLE_COLUMN; /* TODO: handle error */ } -%% diff --git a/src/main.c b/src/main.c deleted file mode 100644 index 3d58a12..0000000 --- a/src/main.c +++ /dev/null @@ -1,135 +0,0 @@ -#include "common.h" -#include -#include - -void *yy_scan_string(const char *); - -char *STDOUT_REDIRECT_STRING; -char *STDERR_REDIRECT_STRING; - -IMPLEMENT_HASH_FUNCTION; -DS_TABLE_DEF(ast, AST, clear_ast); - -ast_table_t *ast; -const char *filename; - -bool cli_interpretation_mode = false; -int interactive_interpretation(); -int file_interpretation(const char *file_name, int input); - -int main(int argc, char *argv[]) { - STDOUT_REDIRECT_STRING = NULL; - STDERR_REDIRECT_STRING = NULL; - - if (argc == 1) { - return interactive_interpretation(); - } - - if (argc != 3) { - fprintf(stderr, "File and input required to execute the program\n"); - return 1; - } - - return file_interpretation(argv[1], atoi(argv[2])); -} - -int interactive_interpretation() { - cli_interpretation_mode = true; - ast = ast_table_new(100); - globals = global_table_new(100); - char new_input_prompt[] = ">>> "; - char continue_input_prompt[] = " "; - - // STDOUT_REDIRECT_STRING = calloc(STDOUT_STRING_LENGTH, 1); - // STDERR_REDIRECT_STRING = calloc(STDERR_STRING_LENGTH, 1); - - static char string[500]; - - char *prompt = new_input_prompt; - int input_length = 0; - - while (true) { - printf("%s", prompt); - if (!fgets(string + input_length, 500, stdin)) { - fprintf(stderr, "Error while getting input\n"); - return 1; - } - if ((!strcmp("exit\n", string)) || (!strcmp("exit;\n", string))) { - return 0; - } - - input_length = 0; - bool get_more_input = false; - for (int i = 0; i < 500; i++) { - if (string[i] == 0) - break; - if (string[i] == ';') - get_more_input = false; - else if ((string[i] != ' ') && (string[i] != '\t') && - (string[i] != '\n')) - get_more_input = true; - - input_length++; - } - - if (get_more_input) { - prompt = continue_input_prompt; - continue; - } else { - prompt = new_input_prompt; - input_length = 0; - } - - yy_scan_string(string); - yyparse(); - - // if (STDOUT_REDIRECT_STRING[0]) { - // fprintf(stdout, ":: %s", STDOUT_REDIRECT_STRING); - // STDOUT_REDIRECT_STRING[0] = 0; - // } - // if (STDERR_REDIRECT_STRING[0]) { - // fprintf(stderr, ":: %s", STDERR_REDIRECT_STRING); - // STDERR_REDIRECT_STRING[0] = 0; - // } - } -} - -int file_interpretation(const char *file_name, int input) { - filename = file_name; - - FILE *file = fopen(filename, "r"); - if (file == NULL) { - fprintf(stderr, "Could not open file \"%s\"\n", filename); - return 0; - } - - /* Initialization of Variables and Functions Table */ - ast = ast_table_new(100); - - /* Parsing */ - yyin = file; - if (yyparse()) { - fclose(file); - fprintf(stderr, "%s\n", syntax_error_msg); - return 1; - } - - fclose(file); - - /* Sematic Analysis */ - if (!verify_semantics()) { - fprintf(stderr, "Semantic Error: %s\n", semantic_error_msg); - return 1; - } - - /* Interpreting */ - int output; - if (!interpret(input, &output)) { - fprintf(stderr, "Runtime Error: %s\n", runtime_error_msg); - return 1; - } - - printf("Input: %d\nOutput: %d\n", input, output); - - return 0; -} diff --git a/src/parser.yy b/src/parser.yy deleted file mode 100644 index 65e9c40..0000000 --- a/src/parser.yy +++ /dev/null @@ -1,137 +0,0 @@ -%{ - #include "common.h" - #include "cli_interpreter.h" - - #define ERROR_MSG_LEN 500 - char syntax_error_msg[ERROR_MSG_LEN]; -%} - -%union { - int integer; - char *identifier; - struct _Expression *expression; - struct _Variable *variable; - struct _Function *function; -} - -%token KW_VALDEF -%token KW_FUNCDEF -%token KW_BOOL -%token KW_INT -%token KW_TRUE -%token KW_FALSE -%token KW_IF -%token KW_THEN -%token KW_ELSE -%token COMMA -%token STATEMENT_END -%token TYPE_OF -%token ASSIGN -%token PLUS -%token MINUS -%token MULTIPLY -%token DIVIDE -%token MODULO -%token OPEN_BRACKETS -%token CLOSE_BRACKETS -%token AND -%token OR -%token NOT -%token EQUALS -%token NOT_EQUALS -%token GREATER -%token GREATER_EQUALS -%token LESSER -%token LESSER_EQUALS -%token RETURN -%token INTEGER -%token IDENTIFIER - -%type function_definition; -%type function_definition_arguments; -%type value_definition; -%type expression; -%type basic_expression; -%type function_call_arguments; - -/* %precedence KW_ELSE -%left AND OR -%left EQUALS NOT_EQUALS GREATER GREATER_EQUALS LESSER LESSER_EQUALS -%left PLUS -%left MULTIPLY DIVIDE -%left MODULO -%precedence MINUS -%precedence NOT */ - -%% -input: %empty - | input expression STATEMENT_END { - if (cli_interpretation_mode) { - cli_interpret((AST){.type = AST_EXPRESSION, .value.exp = $2}); - } - else { - yyerror("Standalone expression are not allowed\n"); - // ???: Don't exit here - return 1; - } - } - | input value_definition { - if (cli_interpretation_mode) { - cli_interpret((AST){.type = AST_VARIABLE, .value.var = $2}); - } - else { - assert(ast_table_insert(ast, ($2)->name, (AST){.type = AST_VARIABLE, .value.var = $2})); - } - } - | input function_definition { - if (cli_interpretation_mode) { - cli_interpret((AST){.type = AST_FUNCTION, .value.func = $2}); - } else { - assert(ast_table_insert(ast, ($2)->funcname, (AST){.type = AST_FUNCTION, .value.func = $2})); - } - }; - -function_definition: KW_FUNCDEF IDENTIFIER function_definition_arguments RETURN KW_BOOL ASSIGN expression STATEMENT_END { $$ = set_function_return_value(set_function_name($3, $2), BOOL, $7); } - | KW_FUNCDEF IDENTIFIER function_definition_arguments RETURN KW_INT ASSIGN expression STATEMENT_END { $$ = set_function_return_value(set_function_name($3, $2), INT, $7);}; - -function_definition_arguments: IDENTIFIER TYPE_OF KW_BOOL { $$ = add_function_argument(make_function(), $1, BOOL); } - | IDENTIFIER TYPE_OF KW_INT { $$ = add_function_argument(make_function(), $1, INT); } - | IDENTIFIER TYPE_OF KW_BOOL function_definition_arguments { $$ = add_function_argument($4, $1, BOOL); } - | IDENTIFIER TYPE_OF KW_INT function_definition_arguments { $$ = add_function_argument($4, $1, INT); } - | OPEN_BRACKETS function_definition_arguments CLOSE_BRACKETS { $$ = $2; }; - -value_definition: KW_VALDEF IDENTIFIER TYPE_OF KW_BOOL ASSIGN expression STATEMENT_END { $$ = make_variable($2, BOOL, $6); } - | KW_VALDEF IDENTIFIER TYPE_OF KW_INT ASSIGN expression STATEMENT_END { $$ = make_variable($2, INT, $6); }; - -basic_expression: IDENTIFIER { $$ = make_variable_expression($1); } - | INTEGER { $$ = make_integer_expression($1); } - | KW_TRUE { $$ = make_boolean_expression(true); } - | KW_FALSE { $$ = make_boolean_expression(false); } - | OPEN_BRACKETS expression CLOSE_BRACKETS { $$ = $2; }; - -expression: basic_expression { $$ = $1; } - | AND basic_expression basic_expression { $$ = make_binary_expression($2, $3, AND_EXPRESSION); } - | OR basic_expression basic_expression { $$ = make_binary_expression($2, $3, OR_EXPRESSION); } - | NOT basic_expression { $$ = make_unary_expression($2, NOT_EXPRESSION); } - | PLUS basic_expression basic_expression { $$ = make_binary_expression($2, $3, PLUS_EXPRESSION); } - | MULTIPLY basic_expression basic_expression { $$ = make_binary_expression($2, $3, MULTIPLY_EXPRESSION); } - | DIVIDE basic_expression basic_expression { $$ = make_binary_expression($2, $3, DIVIDE_EXPRESSION); } - | MODULO basic_expression basic_expression { $$ = make_binary_expression($2, $3, MODULO_EXPRESSION); } - | MINUS basic_expression { $$ = make_unary_expression($2, MINUS_EXPRESSION); } - | EQUALS basic_expression basic_expression { $$ = make_binary_expression($2, $3, EQUALS_EXPRESSION); } - | NOT_EQUALS basic_expression basic_expression { $$ = make_binary_expression($2, $3, NOT_EQUALS_EXPRESSION); } - | GREATER basic_expression basic_expression { $$ = make_binary_expression($2, $3, GREATER_EXPRESSION); } - | GREATER_EQUALS basic_expression basic_expression { $$ = make_binary_expression($2, $3, GREATER_EQUALS_EXPRESSION); } - | LESSER basic_expression basic_expression { $$ = make_binary_expression($2, $3, LESSER_EXPRESSION); } - | LESSER_EQUALS basic_expression basic_expression { $$ = make_binary_expression($2, $3, LESSER_EQUALS_EXPRESSION); } - | KW_IF expression KW_THEN expression KW_ELSE expression { $$ = make_if_expression($2, $4, $6); } - | IDENTIFIER function_call_arguments { $$ = set_function_call_name_expression($2, $1); }; - -function_call_arguments: basic_expression { $$ = add_function_call_argument_expression(make_function_call_expression(), $1); } - | basic_expression function_call_arguments { $$ = add_function_call_argument_expression($2, $1); }; -%% - -void yyerror(char const *str) { - snprintf(syntax_error_msg, ERROR_MSG_LEN, - "ERROR: %s in %s:%d:%d", str, filename, yylineno, column); -} diff --git a/src/semantics.c b/src/semantics.c deleted file mode 100644 index 09ecb4b..0000000 --- a/src/semantics.c +++ /dev/null @@ -1,250 +0,0 @@ -#include "common.h" -#include -#include -#include - -#define ERROR_MSG_LEN 500 - -typedef struct { - size_t arglen; - Argument *args; -} Context; - -bool verify_function_semantics(Function *func); -bool verify_variable_semantics(Variable *var); -bool verify_function_call_arguments(Function *func, Expression **args, - size_t arglen, Context *cxt); -bool verify_expression_type(Expression *exp, Type type, Context *cxt); -bool verify_ast_semantics(AST *tree); - -char semantic_error_msg[ERROR_MSG_LEN] = {0}; - -// TODO: improve error message with line number - -bool verify_semantics() { - char *key; - AST *tree; - ast_table_iter(ast); - - while (NULL != (tree = ast_table_iter_next(ast, &key))) { - if (tree->semantically_correct) - continue; - if (!verify_ast_semantics(tree)) - return false; - } - - return true; -} - -bool verify_ast_semantics(AST *tree) { - if (tree->semantically_correct) - return true; - - switch (tree->type) { - case AST_FUNCTION: - if (!verify_function_semantics(tree->value.func)) { - return false; - } - break; - case AST_VARIABLE: - if (!verify_variable_semantics(tree->value.var)) { - return false; - } - break; - case AST_EXPRESSION: - snprintf(semantic_error_msg, ERROR_MSG_LEN, "Internal Error"); - return false; - } - tree->semantically_correct = true; - return tree; -} - -bool verify_function_call_arguments(Function *func, Expression **args, - size_t arglen, Context *cxt) { - if (func->arglen != arglen) { - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "Parameter count and argument count mismatch"); - return false; - } - - for (size_t i = 0; i < func->arglen; i++) { - if (!verify_expression_type(args[i], func->args[i].type, cxt)) { - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "Expected argument type %s, but got other type", - Type_to_string(func->args[i].type)); - return false; - } - } - return true; -} - -bool verify_expression_type(Expression *exp, Type type, Context *cxt) { - switch (exp->type) { - case INTEGER_EXPRESSION: - if (type == INT) - return true; - goto expected_int_error; - case VARIABLE_EXPRESSION: { - AST *variable_ast = ast_table_get_ptr(ast, exp->value.variable); - if (variable_ast) { - if (variable_ast->type == AST_VARIABLE) { - if (variable_ast->value.var->type == type) { - return true; - } - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "%s has type %s, but expected %s type", - exp->value.variable, - Type_to_string(variable_ast->value.var->type), - Type_to_string(type)); - return false; - } - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "%s is not a variable definition", - exp->value.variable); - return false; - } - - if (!cxt) - goto variable_not_found_error; - - // TODO: try changing the below to something other then linear search - for (size_t i = 0; i < cxt->arglen; i++) { - if (!strcmp(exp->value.variable, cxt->args[i].name)) { - if (cxt->args[i].type == type) { - return true; - } - } - } - - variable_not_found_error: - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "Could not find %s's variable definition", - exp->value.variable); - return false; - } - case BOOLEAN_EXPRESSION: - if (type == BOOL) - return true; - goto expected_bool_error; - case PLUS_EXPRESSION: - case MULTIPLY_EXPRESSION: - case DIVIDE_EXPRESSION: - case MODULO_EXPRESSION: - if (type == INT) { - if (verify_expression_type(exp->value.binary.fst, INT, cxt) && - verify_expression_type(exp->value.binary.snd, INT, cxt)) - return true; - } - goto expected_int_error; - case MINUS_EXPRESSION: - if (type == INT) { - if (verify_expression_type(exp->value.unary.fst, INT, cxt)) - return true; - } - goto expected_int_error; - case AND_EXPRESSION: - case OR_EXPRESSION: - if (type == BOOL) { - if (verify_expression_type(exp->value.binary.fst, BOOL, cxt) && - verify_expression_type(exp->value.binary.snd, BOOL, cxt)) - return true; - } - goto expected_bool_error; - case NOT_EXPRESSION: - if (type == BOOL) { - if (verify_expression_type(exp->value.unary.fst, BOOL, cxt)) - return true; - } - goto expected_bool_error; - case EQUALS_EXPRESSION: - case NOT_EQUALS_EXPRESSION: - case GREATER_EXPRESSION: - case GREATER_EQUALS_EXPRESSION: - case LESSER_EXPRESSION: - case LESSER_EQUALS_EXPRESSION: - if (type == BOOL) { - if (verify_expression_type(exp->value.binary.fst, INT, cxt) && - verify_expression_type(exp->value.binary.snd, INT, cxt)) - return true; - goto expected_int_error; - } - goto expected_bool_error; - case IF_EXPRESSION: - if (verify_expression_type(exp->value.if_statement.condition, BOOL, - cxt)) { - if (type == INT) { - if (verify_expression_type(exp->value.if_statement.yes, INT, - cxt) && - verify_expression_type(exp->value.if_statement.no, INT, - cxt)) - return true; - goto expected_int_error; - } - if (type == BOOL) { - if (verify_expression_type(exp->value.if_statement.yes, BOOL, - cxt) && - verify_expression_type(exp->value.if_statement.no, BOOL, - cxt)) - return true; - goto expected_bool_error; - } - } - goto expected_bool_error; - case FUNCTION_CALL_EXPRESSION: { - AST *func_ast = - ast_table_get_ptr(ast, exp->value.function_call.funcname); - if (!func_ast) { - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "Could not find function %s", - exp->value.function_call.funcname); - return false; - } - if (func_ast->type != AST_FUNCTION) { - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "%s is not a function", - exp->value.variable); - return false; - } - - if (func_ast->value.func->return_type == type) { - if (verify_function_call_arguments( - func_ast->value.func, exp->value.function_call.args, - exp->value.function_call.arglen, cxt)) - return true; - return false; - } - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "Function return type is not the expected type"); - return false; - } - - default: - printf("This should not be printed\n"); - return false; - } - -expected_bool_error: - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "Expected bool type but got other type"); - return false; - -expected_int_error: - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "Expected int type but got other type"); - return false; -} - -bool verify_function_semantics(Function *func) { - Context cxt = {.arglen = func->arglen, .args = func->args}; - return verify_expression_type(func->expression, func->return_type, &cxt); -} - -bool verify_variable_semantics(Variable *var) { - if (verify_expression_type(var->expression, var->type, NULL)) - return true; - - snprintf(semantic_error_msg, ERROR_MSG_LEN, - "Expected %s to be %s type but got other type", var->name, - Type_to_string(var->type)); - return false; -}