diff --git a/.semaphore/semaphore.yml b/.semaphore/semaphore.yml new file mode 100644 index 00000000..de6ea01b --- /dev/null +++ b/.semaphore/semaphore.yml @@ -0,0 +1,20 @@ +version: v1.0 +name: flax +agent: + machine: + type: e1-standard-2 + os_image: ubuntu1804 +blocks: + - name: "main" + task: + jobs: + - name: main + commands: + - checkout + - echo "LLVM 7" + - sudo echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-7 main" | sudo tee -a /etc/apt/sources.list + - sudo echo "deb-src http://apt.llvm.org/bionic/ llvm-toolchain-bionic-7 main" | sudo tee -a /etc/apt/sources.list + - sudo wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - + - sudo apt -y update + - sudo apt-get -o Dpkg::Options::="--force-overwrite" --allow-unauthenticated -y install -y llvm-7 llvm-7-dev libllvm7 libmpfr-dev libmpfr6 + - CXX=g++-8 CC=gcc-8 LLVM_CONFIG=llvm-config-7 make -j2 test diff --git a/.travis.yml b/.travis.yml index b9acd8a2..476530e2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ compiler: clang matrix: include: - os: linux - dist: trusty + dist: xenial sudo: false - os: osx osx_image: xcode9.4 @@ -12,25 +12,25 @@ matrix: addons: apt: sources: - - llvm-toolchain-trusty-6.0 + - llvm-toolchain-xenial-7 - ubuntu-toolchain-r-test packages: - g++-8 - - llvm-6.0 - - llvm-6.0-dev - - libllvm6.0 + - llvm-7 + - llvm-7-dev + - libllvm7 - libmpfr-dev - libmpfr4 homebrew: packages: - - llvm@6 + - llvm@7 update: true script: - if [ "$TRAVIS_OS_NAME" == "osx" ]; - then PATH="/usr/local/opt/llvm@6/bin:$PATH" LLVM_CONFIG=llvm-config make tester; - else CXX=g++-$GCC_VERSION CC=gcc-$GCC_VERSION make ci; + then PATH="/usr/local/opt/llvm@7/bin:$PATH" LLVM_CONFIG=llvm-config make tester; + else CXX=g++-$GCC_VERSION CC=gcc-$GCC_VERSION LLVM_CONFIG=llvm-config-7 make ci; fi notifications: diff --git a/README.md b/README.md index 47748069..184d0f37 100644 --- a/README.md +++ b/README.md @@ -95,11 +95,16 @@ do { ### Building the Flax compiler +#### Dependencies #### +- LLVM 7, mostly due to their obsession with changing the IR interface every damn version (6 does not work, 8 does not work) +- GMP/MPIR +- MPFR + #### macOS / Linux - Flax uses a makefile; most likely some form of GNU-compatible `make` will work. -- LLVM needs to be installed. On macOS, `brew install llvm@6` should work. (note: llvm 7 and up seems to have changed some JIT-related things) +- LLVM needs to be installed. On macOS, `brew install llvm@7` should work. (note: llvm 8 and up seems to have changed some JIT-related things) - For macOS people, simply call `make`. - Linux people, call `make linux`. - A *C++17*-compatible compiler should be used. diff --git a/build/run-test.bat b/build/run-test.bat index 409810ab..08bb2ee5 100644 --- a/build/run-test.bat +++ b/build/run-test.bat @@ -17,6 +17,10 @@ IF /I "%1"=="debugopt" ( ) -ninja -C %buildDir% && cls && %buildDir%\flaxc.exe -Ox -sysroot build\sysroot -run build\%2.flx +ninja -C %buildDir% && cls && %buildDir%\flaxc.exe -Ox -sysroot build\sysroot -run build\%2.flx %3 %4 -ENDLOCAL \ No newline at end of file +IF /I "%1"=="release" ( + copy %buildDir%\flaxc.exe build\sysroot\usr\local\bin\ >NUL +) + +ENDLOCAL diff --git a/build/supertiny.flx b/build/supertiny.flx index afcbbca7..c7578c58 100644 --- a/build/supertiny.flx +++ b/build/supertiny.flx @@ -8,14 +8,61 @@ import libc as _ // import std::io // import std::map +import std::opt + +@raw union ipv4 +{ + _: struct { + _: @raw union { + bytes2: [u8: 4] + raw3: u32 + } + } + + _: struct { + bytes: [u8: 4] + } + + _: struct { + raw2: u32 + } + + raw: u32 +} +struct foo +{ + x: int + y: str +} + @entry fn main() { + do { + var addr: ipv4 + addr.raw3 = 0xff01a8c0; + + printf("%d.%d.%d.%d\n", addr.bytes[0], addr.bytes[1], addr.bytes[2], addr.bytes[3]); + } + + + /* do { + var x: @raw union { + bar: f64 + foo: i64 + } + + x.bar = 3.14159 + + printf("x = %.2lf\n", x.bar) + printf("x = %d\n", x.foo) + } */ } + /* import std::opt @@ -36,6 +83,14 @@ import std:: map /* + !!! bugs (2) !!! + !!! 22/03/19 + { + 1. we currently check whether you are trying to put a refcounted type as a field of a raw union. however! + you can circumvent this check if you nest the field into inner types, or if you use transparent fields!! + } + + ! bugs !!! ! 02/12 { @@ -96,6 +151,17 @@ import std:: map even just enforcing that we get either a string literal of exactly length 1, or an i8 type, at compile-time. sure, we could just abort when passed a slice != length 1, but then why bother being statically typed?? } + + 16. operator overload errors { + something like this: + + failed to resolve candidate for 'operator == (type_1, type_2)' + callsite was here: 'foo == bar' + potential candidates: 'operator == (type_1, type_3)' + + ofc potential candidates should be limited to ones where at least one of the types + are similar. + } */ diff --git a/build/tests/anytest.flx b/build/tests/anytest.flx index 72b7cf0c..b4aacff4 100644 --- a/build/tests/anytest.flx +++ b/build/tests/anytest.flx @@ -11,11 +11,11 @@ var glob: any struct Large { - var a: i64 - var b: i64 - var c: i64 - var d: i64 - var e: i64 + a: i64 + b: i64 + c: i64 + d: i64 + e: i64 } public fn doAnyTest() diff --git a/build/tests/linkedlist.flx b/build/tests/linkedlist.flx index c6aa8308..9e2ac2be 100644 --- a/build/tests/linkedlist.flx +++ b/build/tests/linkedlist.flx @@ -9,10 +9,10 @@ class LinkedList { struct Node { - var prev: &Node - var next: &Node + prev: &Node + next: &Node - var data: T + data: T } var head: &Node diff --git a/build/tests/using.flx b/build/tests/using.flx index 3918e632..ae06b32b 100644 --- a/build/tests/using.flx +++ b/build/tests/using.flx @@ -23,8 +23,10 @@ public fn doUsingTest() } do { - struct xxx + class xxx { + init() { } + var k: T enum Foo: int { diff --git a/flax.vcxproj b/flax.vcxproj index 1d460384..3c2cf07b 100644 --- a/flax.vcxproj +++ b/flax.vcxproj @@ -107,7 +107,7 @@ /ignore:4099 - LLVMCore.lib;LLVMSupport.lib;LLVMTarget.lib;LLVMPasses.lib;LLVMAnalysis.lib;LLVMGlobalISel.lib;LLVMLibDriver.lib;LLVMLinker.lib;LLVMipo.lib;LLVMBinaryFormat.lib;LLVMMC.lib;LLVMMCJIT.lib;LLVMMCParser.lib;LLVMMCDisassembler.lib;LLVMObject.lib;LLVMScalarOpts.lib;LLVMVectorize.lib;LLVMCodegen.lib;LLVMTablegen.lib;LLVMBitReader.lib;LLVMBitWriter.lib;LLVMInstrumentation.lib;LLVMRuntimeDyld.lib;LLVMInstCombine.lib;LLVMInterpreter.lib;LLVMExecutionEngine.lib;LLVMSelectionDAG.lib;LLVMTransformUtils.lib;LLVMDebugInfoCodeView.lib;LLVMDebugInfoDWARF.lib;LLVMDebugInfoMSF.lib;LLVMDebugInfoPDB.lib;LLVMAsmPrinter.lib;LLVMX86AsmPrinter.lib;LLVMProfileData.lib;LLVMX86AsmParser.lib;LLVMX86Info.lib;LLVMX86CodeGen.lib;LLVMX86Utils.lib;LLVMX86Desc.lib;LLVMDlltoolDriver.lib;mpfr.lib;mpir.lib;%(AdditionalDependencies) + LLVMCore.lib;LLVMSupport.lib;LLVMTarget.lib;LLVMPasses.lib;LLVMAnalysis.lib;LLVMGlobalISel.lib;LLVMLibDriver.lib;LLVMLinker.lib;LLVMipo.lib;LLVMBinaryFormat.lib;LLVMMC.lib;LLVMMCJIT.lib;LLVMOrcJIT.lib;LLVMMCParser.lib;LLVMMCDisassembler.lib;LLVMObject.lib;LLVMScalarOpts.lib;LLVMVectorize.lib;LLVMCodegen.lib;LLVMTablegen.lib;LLVMBitReader.lib;LLVMBitWriter.lib;LLVMInstrumentation.lib;LLVMRuntimeDyld.lib;LLVMInstCombine.lib;LLVMInterpreter.lib;LLVMExecutionEngine.lib;LLVMSelectionDAG.lib;LLVMTransformUtils.lib;LLVMDebugInfoCodeView.lib;LLVMDebugInfoDWARF.lib;LLVMDebugInfoMSF.lib;LLVMDebugInfoPDB.lib;LLVMAsmPrinter.lib;LLVMX86AsmPrinter.lib;LLVMProfileData.lib;LLVMX86AsmParser.lib;LLVMX86Info.lib;LLVMX86CodeGen.lib;LLVMX86Utils.lib;LLVMX86Desc.lib;LLVMDlltoolDriver.lib;mpfr.lib;mpir.lib;%(AdditionalDependencies) LIBCMTD;%(IgnoreSpecificDefaultLibraries) true Default @@ -150,7 +150,7 @@ rem exit 0 /ignore:4099 - LLVMCore.lib;LLVMSupport.lib;LLVMTarget.lib;LLVMPasses.lib;LLVMAnalysis.lib;LLVMGlobalISel.lib;LLVMLibDriver.lib;LLVMLinker.lib;LLVMipo.lib;LLVMBinaryFormat.lib;LLVMMC.lib;LLVMMCJIT.lib;LLVMMCParser.lib;LLVMMCDisassembler.lib;LLVMObject.lib;LLVMScalarOpts.lib;LLVMVectorize.lib;LLVMCodegen.lib;LLVMTablegen.lib;LLVMBitReader.lib;LLVMBitWriter.lib;LLVMInstrumentation.lib;LLVMRuntimeDyld.lib;LLVMInstCombine.lib;LLVMInterpreter.lib;LLVMExecutionEngine.lib;LLVMSelectionDAG.lib;LLVMTransformUtils.lib;LLVMDebugInfoCodeView.lib;LLVMDebugInfoDWARF.lib;LLVMDebugInfoMSF.lib;LLVMDebugInfoPDB.lib;LLVMAsmPrinter.lib;LLVMX86AsmPrinter.lib;LLVMProfileData.lib;LLVMX86AsmParser.lib;LLVMX86Info.lib;LLVMX86CodeGen.lib;LLVMX86Utils.lib;LLVMX86Desc.lib;LLVMDlltoolDriver.lib;mpfr.lib;mpir.lib;%(AdditionalDependencies) + LLVMCore.lib;LLVMSupport.lib;LLVMTarget.lib;LLVMPasses.lib;LLVMAnalysis.lib;LLVMGlobalISel.lib;LLVMLibDriver.lib;LLVMLinker.lib;LLVMipo.lib;LLVMBinaryFormat.lib;LLVMMC.lib;LLVMMCJIT.lib;LLVMOrcJIT.lib;LLVMMCParser.lib;LLVMMCDisassembler.lib;LLVMObject.lib;LLVMScalarOpts.lib;LLVMVectorize.lib;LLVMCodegen.lib;LLVMTablegen.lib;LLVMBitReader.lib;LLVMBitWriter.lib;LLVMInstrumentation.lib;LLVMRuntimeDyld.lib;LLVMInstCombine.lib;LLVMInterpreter.lib;LLVMExecutionEngine.lib;LLVMSelectionDAG.lib;LLVMTransformUtils.lib;LLVMDebugInfoCodeView.lib;LLVMDebugInfoDWARF.lib;LLVMDebugInfoMSF.lib;LLVMDebugInfoPDB.lib;LLVMAsmPrinter.lib;LLVMX86AsmPrinter.lib;LLVMProfileData.lib;LLVMX86AsmParser.lib;LLVMX86Info.lib;LLVMX86CodeGen.lib;LLVMX86Utils.lib;LLVMX86Desc.lib;LLVMDlltoolDriver.lib;mpfr.lib;mpir.lib;%(AdditionalDependencies) LIBCMTD;%(IgnoreSpecificDefaultLibraries) false Default @@ -193,7 +193,7 @@ rem exit 0 /ignore:4099 - LLVMCore.lib;LLVMSupport.lib;LLVMTarget.lib;LLVMPasses.lib;LLVMAnalysis.lib;LLVMGlobalISel.lib;LLVMLibDriver.lib;LLVMLinker.lib;LLVMipo.lib;LLVMBinaryFormat.lib;LLVMMC.lib;LLVMMCJIT.lib;LLVMMCParser.lib;LLVMMCDisassembler.lib;LLVMObject.lib;LLVMScalarOpts.lib;LLVMVectorize.lib;LLVMCodegen.lib;LLVMTablegen.lib;LLVMBitReader.lib;LLVMBitWriter.lib;LLVMInstrumentation.lib;LLVMRuntimeDyld.lib;LLVMInstCombine.lib;LLVMInterpreter.lib;LLVMExecutionEngine.lib;LLVMSelectionDAG.lib;LLVMTransformUtils.lib;LLVMDebugInfoCodeView.lib;LLVMDebugInfoDWARF.lib;LLVMDebugInfoMSF.lib;LLVMDebugInfoPDB.lib;LLVMAsmPrinter.lib;LLVMX86AsmPrinter.lib;LLVMProfileData.lib;LLVMX86AsmParser.lib;LLVMX86Info.lib;LLVMX86CodeGen.lib;LLVMX86Utils.lib;LLVMX86Desc.lib;LLVMDlltoolDriver.lib;mpfr.lib;mpir.lib;%(AdditionalDependencies) + LLVMCore.lib;LLVMSupport.lib;LLVMTarget.lib;LLVMPasses.lib;LLVMAnalysis.lib;LLVMGlobalISel.lib;LLVMLibDriver.lib;LLVMLinker.lib;LLVMipo.lib;LLVMBinaryFormat.lib;LLVMMC.lib;LLVMMCJIT.lib;LLVMOrcJIT.lib;LLVMMCParser.lib;LLVMMCDisassembler.lib;LLVMObject.lib;LLVMScalarOpts.lib;LLVMVectorize.lib;LLVMCodegen.lib;LLVMTablegen.lib;LLVMBitReader.lib;LLVMBitWriter.lib;LLVMInstrumentation.lib;LLVMRuntimeDyld.lib;LLVMInstCombine.lib;LLVMInterpreter.lib;LLVMExecutionEngine.lib;LLVMSelectionDAG.lib;LLVMTransformUtils.lib;LLVMDebugInfoCodeView.lib;LLVMDebugInfoDWARF.lib;LLVMDebugInfoMSF.lib;LLVMDebugInfoPDB.lib;LLVMAsmPrinter.lib;LLVMX86AsmPrinter.lib;LLVMProfileData.lib;LLVMX86AsmParser.lib;LLVMX86Info.lib;LLVMX86CodeGen.lib;LLVMX86Utils.lib;LLVMX86Desc.lib;LLVMDlltoolDriver.lib;mpfr.lib;mpir.lib;%(AdditionalDependencies) LIBCMTD;%(IgnoreSpecificDefaultLibraries) false Default @@ -271,4 +271,4 @@ rem exit 0 - \ No newline at end of file + diff --git a/issues.md b/issues.md index 83539ef0..2a128bba 100644 --- a/issues.md +++ b/issues.md @@ -11,9 +11,6 @@ Note: this is just a personal log of outstanding issues, shorter rants/ramblings 3. Optional arguments. -4. Public and private imports (ie. do we re-export our imports (currently the default), or do we keep them to ourselves (the new default) - - 5. String operators diff --git a/makefile b/makefile index 4879c425..1db9db64 100644 --- a/makefile +++ b/makefile @@ -4,7 +4,7 @@ -WARNINGS := -Wno-unused-parameter -Wno-sign-conversion -Wno-padded -Wno-conversion -Wno-shadow -Wno-missing-noreturn -Wno-unused-macros -Wno-switch-enum -Wno-deprecated -Wno-format-nonliteral -Wno-trigraphs -Wno-unused-const-variable +WARNINGS := -Wno-unused-parameter -Wno-sign-conversion -Wno-padded -Wno-conversion -Wno-shadow -Wno-missing-noreturn -Wno-unused-macros -Wno-switch-enum -Wno-deprecated -Wno-format-nonliteral -Wno-trigraphs -Wno-unused-const-variable -Wno-deprecated-declarations CLANGWARNINGS := -Wno-undefined-func-template -Wno-comma -Wno-nullability-completeness -Wno-redundant-move -Wno-nested-anon-types -Wno-gnu-anonymous-struct -Wno-reserved-id-macro -Wno-extra-semi -Wno-gnu-zero-variadic-macro-arguments -Wno-shift-sign-overflow -Wno-exit-time-destructors -Wno-global-constructors -Wno-c++98-compat-pedantic -Wno-documentation-unknown-command -Wno-weak-vtables -Wno-c++98-compat @@ -38,12 +38,11 @@ CXXDEPS := $(CXXSRC:.cpp=.cpp.d) NUMFILES := $$(($(words $(CXXSRC)) + $(words $(CSRC)))) - - +DEFINES := -D__USE_MINGW_ANSI_STDIO=1 SANITISE := -CXXFLAGS += -std=c++1z -O0 -g -c -Wall -frtti -fexceptions -fno-omit-frame-pointer -Wno-old-style-cast $(SANITISE) -CFLAGS += -std=c11 -O0 -g -c -Wall -fno-omit-frame-pointer -Wno-overlength-strings $(SANITISE) +CXXFLAGS += -std=c++1z -O0 -g -c -Wall -frtti -fexceptions -fno-omit-frame-pointer -Wno-old-style-cast $(SANITISE) $(DEFINES) +CFLAGS += -std=c11 -O0 -g -c -Wall -fno-omit-frame-pointer -Wno-overlength-strings $(SANITISE) $(DEFINES) LDFLAGS += $(SANITISE) @@ -64,11 +63,7 @@ TESTSRC := build/tester.flx -include $(CXXDEPS) -.PHONY: copylibs jit compile clean build osx linux ci prep satest tiny osxflags - -prep: - @# echo C++ compiler is: $(CXX) - @mkdir -p $(dir $(OUTPUT)) +.PHONY: copylibs jit compile clean build osx linux ci satest tiny osxflags osxflags: CXXFLAGS += -march=native -fmodules -Weverything -Xclang -fcolor-diagnostics $(SANITISE) $(CLANGWARNINGS) osxflags: CFLAGS += -fmodules -Xclang -fcolor-diagnostics $(SANITISE) $(CLANGWARNINGS) @@ -76,17 +71,17 @@ osxflags: CFLAGS += -fmodules -Xclang -fcolor-diagnostics $(SANITISE) $(CLANGWAR osxflags: -osx: prep jit osxflags +osx: jit osxflags -satest: prep osxflags build +satest: osxflags build @$(OUTPUT) $(FLXFLAGS) -run build/standalone.flx -tester: prep osxflags build +tester: osxflags build @$(OUTPUT) $(FLXFLAGS) -run build/tester.flx -ci: prep test +ci: test -linux: prep jit +linux: jit jit: build @$(OUTPUT) $(FLXFLAGS) -run -o $(SUPERTINYBIN) $(SUPERTINYSRC) @@ -100,7 +95,9 @@ test: build gltest: build @$(OUTPUT) $(FLXFLAGS) -run -framework GLUT -framework OpenGL -lsdl2 -o $(GLTESTBIN) $(GLTESTSRC) -build: $(OUTPUT) copylibs +build1: + +build: build1 $(OUTPUT) copylibs # built build/%.flx: build @@ -117,7 +114,8 @@ copylibs: $(FLXSRC) $(OUTPUT): $(PRECOMP_GCH) $(CXXOBJ) $(COBJ) @printf "# linking\n" - @$(CXX) -o $@ $(CXXOBJ) $(COBJ) $(shell $(LLVM_CONFIG) --cxxflags --ldflags --system-libs --libs core engine native linker bitwriter lto vectorize all-targets object) -lmpfr -lgmp $(LDFLAGS) -lpthread + @mkdir -p $(dir $(OUTPUT)) + @$(CXX) -o $@ $(CXXOBJ) $(COBJ) $(shell $(LLVM_CONFIG) --cxxflags --ldflags --system-libs --libs core engine native linker bitwriter lto vectorize all-targets object orcjit) -lmpfr -lgmp $(LDFLAGS) -lpthread %.cpp.o: %.cpp diff --git a/meson.build b/meson.build index c328135f..36c7e4ea 100644 --- a/meson.build +++ b/meson.build @@ -5,17 +5,15 @@ the_compiler = meson.get_compiler('c') # buildKind = get_option('buildtype') # message('build type is ' + buildKind) -# libKind = 'DebugNoSyms' - if get_option('buildtype') == 'debug' - libKind = 'DebugNoSyms' + libKind = 'Debug' else libKind = 'Release' endif mpir_root_dir = 'D:/Projects/lib/mpir' mpfr_root_dir = 'D:/Projects/lib/mpfr' -llvm_root_dir = 'D:/Projects/lib/llvm' +llvm_root_dir = 'D:/Projects/lib/llvm/7.0.1/' mpir_hdr_dir = mpir_root_dir + '/' + libKind + '/include/' mpfr_hdr_dir = mpfr_root_dir + '/' + libKind + '/include/' @@ -164,6 +162,7 @@ source_files = files([ 'source/fir/Types/ArraySliceType.cpp', 'source/fir/Types/PrimitiveType.cpp', 'source/fir/Types/FunctionType.cpp', + 'source/fir/Types/RawUnionType.cpp', 'source/fir/Types/PointerType.cpp', 'source/fir/Types/SingleTypes.cpp', 'source/fir/Types/StructType.cpp', @@ -206,7 +205,7 @@ mpir_dep = declare_dependency(version: '3.0.0', include_directories: include_dir mpfr_dep = declare_dependency(version: '4.0.0', include_directories: include_directories(mpfr_hdr_dir), dependencies: the_compiler.find_library('mpfr', dirs: mpfr_lib_dir)) -llvm_dep = declare_dependency(version: '6.0.0', include_directories: include_directories(llvm_hdr_dir), +llvm_dep = declare_dependency(version: '7.0.1', include_directories: include_directories(llvm_hdr_dir), dependencies: [ mpfr_dep, mpir_dep, # the_compiler.find_library('LLVM_all', dirs: llvm_lib_dir), the_compiler.find_library('LLVMMC', dirs: llvm_lib_dir), @@ -236,7 +235,7 @@ llvm_dep = declare_dependency(version: '6.0.0', include_directories: include_dir the_compiler.find_library('LLVMX86CodeGen', dirs: llvm_lib_dir), the_compiler.find_library('LLVMProfileData', dirs: llvm_lib_dir), the_compiler.find_library('LLVMInstCombine', dirs: llvm_lib_dir), - the_compiler.find_library('LLVMInterpreter', dirs: llvm_lib_dir), + # the_compiler.find_library('LLVMInterpreter', dirs: llvm_lib_dir), the_compiler.find_library('LLVMRuntimeDyld', dirs: llvm_lib_dir), # the_compiler.find_library('LLVMDebugInfoPDB', dirs: llvm_lib_dir), # the_compiler.find_library('LLVMDebugInfoMSF', dirs: llvm_lib_dir), diff --git a/source/backend/llvm/jit.cpp b/source/backend/llvm/jit.cpp index e0854982..9c060dbf 100644 --- a/source/backend/llvm/jit.cpp +++ b/source/backend/llvm/jit.cpp @@ -7,11 +7,22 @@ namespace backend { LLVMJit::LLVMJit(llvm::TargetMachine* tm) : - objectLayer([]() -> auto { return std::make_shared(); }), - compileLayer(this->objectLayer, llvm::orc::SimpleCompiler(*tm)) - { - this->targetMachine = std::unique_ptr(tm); + targetMachine(tm), + symbolResolver(llvm::orc::createLegacyLookupResolver(this->execSession, [&](const std::string& name) -> llvm::JITSymbol { + if(auto sym = this->compileLayer.findSymbol(name, false)) return sym; + else if(auto err = sym.takeError()) return std::move(err); + if(auto symaddr = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name)) + return llvm::JITSymbol(symaddr, llvm::JITSymbolFlags::Exported); + else + return llvm::JITSymbol(nullptr); + }, [](llvm::Error err) { llvm::cantFail(std::move(err), "lookupFlags failed"); })), + dataLayout(this->targetMachine->createDataLayout()), + objectLayer(this->execSession, [this](llvm::orc::VModuleKey) -> auto { + return llvm::orc::RTDyldObjectLinkingLayer::Resources { + std::make_shared(), this->symbolResolver }; }), + compileLayer(this->objectLayer, llvm::orc::SimpleCompiler(*this->targetMachine.get())) + { llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); } @@ -20,23 +31,12 @@ namespace backend return this->targetMachine.get(); } - LLVMJit::ModuleHandle_t LLVMJit::addModule(std::shared_ptr mod) + LLVMJit::ModuleHandle_t LLVMJit::addModule(std::unique_ptr mod) { - auto resolver = llvm::orc::createLambdaResolver([&](const std::string& name) -> auto { - if(auto sym = this->compileLayer.findSymbol(name, false)) - return sym; - - else - return llvm::JITSymbol(nullptr); - }, [](const std::string& name) -> auto { - if(auto symaddr = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name)) - return llvm::JITSymbol(symaddr, llvm::JITSymbolFlags::Exported); - - else - return llvm::JITSymbol(nullptr); - }); + auto vmod = this->execSession.allocateVModule(); + llvm::cantFail(this->compileLayer.addModule(vmod, std::move(mod))); - return llvm::cantFail(this->compileLayer.addModule(mod, std::move(resolver))); + return vmod; } void LLVMJit::removeModule(LLVMJit::ModuleHandle_t mod) @@ -48,7 +48,7 @@ namespace backend { std::string mangledName; llvm::raw_string_ostream out(mangledName); - llvm::Mangler::getNameWithPrefix(out, name, this->targetMachine->createDataLayout()); + llvm::Mangler::getNameWithPrefix(out, name, this->dataLayout); return this->compileLayer.findSymbol(out.str(), false); } diff --git a/source/backend/llvm/linker.cpp b/source/backend/llvm/linker.cpp index 22b0d75a..854802da 100644 --- a/source/backend/llvm/linker.cpp +++ b/source/backend/llvm/linker.cpp @@ -26,6 +26,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Analysis/Passes.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/TargetSelect.h" @@ -37,6 +38,7 @@ #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/DynamicLibrary.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #ifdef _MSC_VER @@ -103,7 +105,7 @@ namespace backend this->entryFunction = mainModule->getFunction(this->compiledData.module->getEntryFunction()->getName().mangled()); - this->linkedModule = std::shared_ptr(mainModule); + this->linkedModule = std::unique_ptr(mainModule); this->finaliseGlobalConstructors(); // ok, move some shit into here because llvm is fucking retarded @@ -160,7 +162,6 @@ namespace backend fpm.add(llvm::createFlattenCFGPass()); fpm.add(llvm::createScalarizerPass()); fpm.add(llvm::createSinkingPass()); - fpm.add(llvm::createInstructionSimplifierPass()); fpm.add(llvm::createDeadStoreEliminationPass()); fpm.add(llvm::createMemCpyOptPass()); @@ -221,7 +222,7 @@ namespace backend llvm::sys::fs::OpenFlags of = (llvm::sys::fs::OpenFlags) 0; llvm::raw_fd_ostream rso(oname.c_str(), e, of); - llvm::WriteBitcodeToFile(this->linkedModule.get(), rso); + llvm::WriteBitcodeToFile(*this->linkedModule.get(), rso); rso.close(); _printTiming(ts, "writing bitcode file"); @@ -441,7 +442,7 @@ namespace backend { // auto p = prof::Profile(PROFGROUP_LLVM, "llvm_emit_object"); llvm::legacy::PassManager pm = llvm::legacy::PassManager(); - targetMachine->addPassesToEmitFile(pm, *rawStream, llvm::TargetMachine::CodeGenFileType::CGFT_ObjectFile); + targetMachine->addPassesToEmitFile(pm, *rawStream, rawStream, llvm::TargetMachine::CodeGenFileType::CGFT_ObjectFile); pm.run(*this->linkedModule); } @@ -583,11 +584,11 @@ namespace backend if(this->entryFunction) { #if 1 + auto name = this->entryFunction->getName().str(); this->jitInstance = new LLVMJit(this->targetMachine); - this->jitInstance->addModule(this->linkedModule); + this->jitInstance->addModule(std::move(this->linkedModule)); - auto name = this->entryFunction->getName().str(); auto entryaddr = this->jitInstance->getSymbolAddress(name); ret = (int (*)(int, const char**)) entryaddr; diff --git a/source/backend/llvm/translator.cpp b/source/backend/llvm/translator.cpp index ba757cda..30835ddb 100644 --- a/source/backend/llvm/translator.cpp +++ b/source/backend/llvm/translator.cpp @@ -309,6 +309,27 @@ namespace backend return createdTypes[ut->getTypeName()]; } + else if(type->isRawUnionType()) + { + auto ut = type->toRawUnionType(); + + if(createdTypes.find(ut->getTypeName()) != createdTypes.end()) + return createdTypes[ut->getTypeName()]; + + auto dl = llvm::DataLayout(mod); + + size_t maxSz = 0; + for(auto v : ut->getVariants()) + maxSz = std::max(maxSz, (size_t) dl.getTypeAllocSize(typeToLlvm(v.second, mod))); + + iceAssert(maxSz > 0); + createdTypes[ut->getTypeName()] = llvm::StructType::create(gc, { + // llvm::ArrayType::get(llvm::Type::getInt8Ty(gc), maxSz) + llvm::IntegerType::getIntNTy(gc, maxSz * CHAR_BIT) + }, ut->getTypeName().mangled()); + + return createdTypes[ut->getTypeName()]; + } else if(type->isPolyPlaceholderType()) { error("llvm: Unfulfilled polymorphic placeholder type '%s'", type); @@ -688,7 +709,6 @@ namespace backend for(auto type : firmod->_getNamedTypes()) { // should just automatically create it. - // if(!isGenericInAnyWay(type.second)) typeToLlvm(type.second, module); } @@ -696,30 +716,32 @@ namespace backend { llvm::Constant* fn = 0; + //* in LLVM 7, the intrinsics changed to no longer specify the alignment + //* so, the arugments are: [ ptr, ptr, size, is_volatile ] if(intr.first.str() == "memcpy") { llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getVoidTy(gc), { llvm::Type::getInt8PtrTy(gc), - llvm::Type::getInt8PtrTy(gc), llvm::Type::getInt64Ty(gc), llvm::Type::getInt32Ty(gc), llvm::Type::getInt1Ty(gc) }, false); + llvm::Type::getInt8PtrTy(gc), llvm::Type::getInt64Ty(gc), llvm::Type::getInt1Ty(gc) }, false); fn = module->getOrInsertFunction("llvm.memcpy.p0i8.p0i8.i64", ft); } else if(intr.first.str() == "memmove") { llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getVoidTy(gc), { llvm::Type::getInt8PtrTy(gc), - llvm::Type::getInt8PtrTy(gc), llvm::Type::getInt64Ty(gc), llvm::Type::getInt32Ty(gc), llvm::Type::getInt1Ty(gc) }, false); + llvm::Type::getInt8PtrTy(gc), llvm::Type::getInt64Ty(gc), llvm::Type::getInt1Ty(gc) }, false); fn = module->getOrInsertFunction("llvm.memmove.p0i8.p0i8.i64", ft); } else if(intr.first.str() == "memset") { llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getVoidTy(gc), { llvm::Type::getInt8PtrTy(gc), - llvm::Type::getInt8Ty(gc), llvm::Type::getInt64Ty(gc), llvm::Type::getInt32Ty(gc), llvm::Type::getInt1Ty(gc) }, false); + llvm::Type::getInt8Ty(gc), llvm::Type::getInt64Ty(gc), llvm::Type::getInt1Ty(gc) }, false); fn = module->getOrInsertFunction("llvm.memset.p0i8.i64", ft); } else if(intr.first.str() == "memcmp") { - // in line with the rest, take 5 arguments, the last 2 being alignment and isvolatile. + // in line with the rest, take 4 arguments. (this is our own "intrinsic") llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(gc), { llvm::Type::getInt8PtrTy(gc), - llvm::Type::getInt8PtrTy(gc), llvm::Type::getInt64Ty(gc), llvm::Type::getInt32Ty(gc), llvm::Type::getInt1Ty(gc) }, false); + llvm::Type::getInt8PtrTy(gc), llvm::Type::getInt64Ty(gc), llvm::Type::getInt1Ty(gc) }, false); fn = llvm::Function::Create(ft, llvm::GlobalValue::LinkageTypes::InternalLinkage, "fir.intrinsic.memcmp", module); llvm::Function* func = llvm::cast(fn); @@ -2307,6 +2329,19 @@ namespace backend break; } + case fir::OpKind::RawUnion_GEP: + { + iceAssert(inst->operands.size() == 2); + llvm::Value* a = getUndecayedOperand(inst, 0); + llvm::Type* target = typeToLlvm(inst->operands[1]->getType(), module); + + iceAssert(a->getType()->isPointerTy() && a->getType()->getPointerElementType()->isStructTy()); + auto ptr = builder.CreateConstGEP2_32(a->getType()->getPointerElementType(), a, 0, 0); + ptr = builder.CreatePointerCast(ptr, target->getPointerTo()); + + addValueToMap(ptr, inst->realOutput); + break; + } diff --git a/source/codegen/alloc.cpp b/source/codegen/alloc.cpp index b18afd58..ad9492cf 100644 --- a/source/codegen/alloc.cpp +++ b/source/codegen/alloc.cpp @@ -98,7 +98,7 @@ static fir::Value* performAllocation(cgn::CodegenState* cs, sst::AllocOp* alloc, auto value = cs->getDefaultValue(type); - if(cs->isRefCountedType(type)) + if(fir::isRefCountedType(type)) cs->addRefCountedValue(value); cs->autoAssignRefCountedValue(ptr, value, true, true); diff --git a/source/codegen/autocasting.cpp b/source/codegen/autocasting.cpp index 74c2fbe8..e0a05707 100644 --- a/source/codegen/autocasting.cpp +++ b/source/codegen/autocasting.cpp @@ -29,14 +29,14 @@ namespace cgn } else { - if(ty->getMinBits() <= fir::Type::getInt64()->getBitWidth() - 1) - return fir::ConstantInt::getInt64(cn->getInt64()); + if(ty->getMinBits() <= fir::Type::getInt64()->getBitWidth() - 1) + return fir::ConstantInt::getInt64(cn->getInt64()); - else if(ty->isSigned() && ty->getMinBits() <= fir::Type::getUint64()->getBitWidth()) - return fir::ConstantInt::getUint64(cn->getUint64()); + else if(ty->isSigned() && ty->getMinBits() <= fir::Type::getUint64()->getBitWidth()) + return fir::ConstantInt::getUint64(cn->getUint64()); - else - error("int overflow"); + else + error("int overflow"); } } @@ -65,7 +65,7 @@ namespace cgn if(target->toPrimitiveType()->getBitWidth() < ty->getMinBits()) { // TODO: actually do what we say. - warn(cs->loc(), "Casting literal to type '%s' will cause an overflow; resulting value will be the limit of the casted type", + warn(cs->loc(), "Casting literal to type '%s' will cause an overflow; value will be truncated bitwise to fit", target); } @@ -240,7 +240,7 @@ namespace cgn } else { - if(this->isRefCountedType(result->getType())) + if(fir::isRefCountedType(result->getType())) this->addRefCountedValue(result); return result; diff --git a/source/codegen/builtin.cpp b/source/codegen/builtin.cpp index 840ec7d3..cefae85c 100644 --- a/source/codegen/builtin.cpp +++ b/source/codegen/builtin.cpp @@ -52,7 +52,7 @@ CGResult sst::BuiltinDotOp::_codegen(cgn::CodegenState* cs, fir::Type* infer) auto ret = cs->irb.Call(clonef, cs->irb.CreateSliceFromSAA(res.value, false), fir::ConstantInt::getInt64(0)); - iceAssert(cs->isRefCountedType(ret->getType())); + iceAssert(fir::isRefCountedType(ret->getType())); cs->addRefCountedValue(ret); return CGResult(ret); diff --git a/source/codegen/call.cpp b/source/codegen/call.cpp index 7440bd1d..9b96b4bf 100644 --- a/source/codegen/call.cpp +++ b/source/codegen/call.cpp @@ -67,9 +67,10 @@ static std::vector _codegenAndArrangeFunctionCallArguments(cgn::Cod auto vr = arg.value->codegen(cs, inf); auto val = vr.value; - // ! ACHTUNG ! - // TODO: is this actually necessary, or will we end up leaking memory??? - if(cs->isRefCountedType(val->getType())) + //* arguments are added to the refcounting list in the function, + //* so we need to "pre-increment" the refcount here, so it does not + //* get freed when the function returns. + if(fir::isRefCountedType(val->getType())) cs->incrementRefCount(val); @@ -262,7 +263,7 @@ CGResult sst::FunctionCall::_codegen(cgn::CodegenState* cs, fir::Type* infer) } // do the refcounting if we need to - if(cs->isRefCountedType(ret->getType())) + if(fir::isRefCountedType(ret->getType())) cs->addRefCountedValue(ret); return CGResult(ret); @@ -342,10 +343,13 @@ CGResult sst::ExprCall::_codegen(cgn::CodegenState* cs, fir::Type* infer) auto ft = fn->getType()->toFunctionType(); - if(ft->getArgumentTypes().size() != this->arguments.size() && !ft->isVariadicFunc()) + if(ft->getArgumentTypes().size() != this->arguments.size()) { - error(this, "Mismatched number of arguments; expected %zu, but %zu were given", - ft->getArgumentTypes().size(), this->arguments.size()); + if((!ft->isVariadicFunc() && !ft->isCStyleVarArg()) || this->arguments.size() < ft->getArgumentTypes().size()) + { + error(this, "Mismatched number of arguments; expected %zu, but %zu were given", + ft->getArgumentTypes().size(), this->arguments.size()); + } } std::vector fcas = util::map(this->arguments, [](sst::Expr* arg) -> FnCallArgument { diff --git a/source/codegen/constructor.cpp b/source/codegen/constructor.cpp index 1b98de36..af424e8a 100644 --- a/source/codegen/constructor.cpp +++ b/source/codegen/constructor.cpp @@ -39,7 +39,7 @@ fir::Value* cgn::CodegenState::getConstructedStructValue(fir::StructType* str, c // if(names) iceAssert(i == str->getElementCount()); } - if(this->isRefCountedType(str)) + if(fir::isRefCountedType(str)) this->addRefCountedValue(value); return value; @@ -128,7 +128,7 @@ CGResult sst::ClassConstructorCall::_codegen(cgn::CodegenState* cs, fir::Type* i cs->constructClassWithArguments(cls->toClassType(), this->target, cs->irb.AddressOf(self, true), this->arguments, true); // auto value = cs->irb.Dereference(self); - if(cs->isRefCountedType(cls)) + if(fir::isRefCountedType(cls)) cs->addRefCountedValue(self); return CGResult(self); diff --git a/source/codegen/controlflow.cpp b/source/codegen/controlflow.cpp index 9e69a6bb..30941d00 100644 --- a/source/codegen/controlflow.cpp +++ b/source/codegen/controlflow.cpp @@ -238,7 +238,7 @@ CGResult sst::ReturnStmt::_codegen(cgn::CodegenState* cs, fir::Type* infer) if(this->value) { auto v = this->value->codegen(cs, this->expectedType).value; - if(cs->isRefCountedType(v->getType())) + if(fir::isRefCountedType(v->getType())) cs->incrementRefCount(v); doBlockEndThings(cs, cs->getCurrentCFPoint(), cs->getCurrentBlockPoint()); diff --git a/source/codegen/destructure.cpp b/source/codegen/destructure.cpp index a265927d..0b84d2a7 100644 --- a/source/codegen/destructure.cpp +++ b/source/codegen/destructure.cpp @@ -32,7 +32,7 @@ static void handleDefn(cgn::CodegenState* cs, sst::VarDefn* defn, CGResult res) //* also, since the vardefn adds itself to the counting stack, when it dies we will get decremented. //* however, this cannot be allowed to happen, because we want a copy and not a move. - if(cs->isRefCountedType(res->getType())) + if(fir::isRefCountedType(res->getType())) { cs->addRefCountedValue(res.value); cs->incrementRefCount(res.value); diff --git a/source/codegen/dotop.cpp b/source/codegen/dotop.cpp index 8686a28d..d4422e1b 100644 --- a/source/codegen/dotop.cpp +++ b/source/codegen/dotop.cpp @@ -7,6 +7,10 @@ #include "typecheck.h" +static bool isAutoDereferencable(fir::Type* t) +{ + return (t->isStructType() || t->isClassType() || t->isRawUnionType()); +} static CGResult getAppropriateValuePointer(cgn::CodegenState* cs, sst::Expr* user, sst::Expr* lhs, fir::Type** baseType) { @@ -15,10 +19,10 @@ static CGResult getAppropriateValuePointer(cgn::CodegenState* cs, sst::Expr* use fir::Value* retv = 0; - if(restype->isStructType() || restype->isClassType()) + if(isAutoDereferencable(restype)) { auto t = res.value->getType(); - iceAssert(t->isStructType() || t->isClassType()); + iceAssert(isAutoDereferencable(t)); retv = res.value; *baseType = restype; @@ -28,9 +32,9 @@ static CGResult getAppropriateValuePointer(cgn::CodegenState* cs, sst::Expr* use retv = res.value; *baseType = restype; } - else if(restype->isPointerType() && (restype->getPointerElementType()->isStructType() || restype->getPointerElementType()->isClassType())) + else if(restype->isPointerType() && isAutoDereferencable(restype->getPointerElementType())) { - iceAssert(res.value->getType()->getPointerElementType()->isStructType() || res.value->getType()->getPointerElementType()->isClassType()); + iceAssert(isAutoDereferencable(res.value->getType()->getPointerElementType())); retv = cs->irb.Dereference(res.value); *baseType = restype->getPointerElementType(); @@ -57,8 +61,6 @@ CGResult sst::MethodDotOp::_codegen(cgn::CodegenState* cs, fir::Type* infer) // basically what we need to do is just get the pointer fir::Type* sty = 0; auto res = getAppropriateValuePointer(cs, this, this->lhs, &sty); - // if(!res.pointer) - // res.pointer = cs->irb.ImmutStackAlloc(sty, res.value); // then we insert it as the first argument auto rv = new sst::RawValueExpr(this->loc, res.value->getType()->getMutablePointerTo()); @@ -90,15 +92,70 @@ CGResult sst::FieldDotOp::_codegen(cgn::CodegenState* cs, fir::Type* infer) fir::Type* sty = 0; auto res = getAppropriateValuePointer(cs, this, this->lhs, &sty); - if(!res->islorclvalue()) + + // TODO: clean up the code dupe here + if(this->isTransparentField) { - // use extractvalue. - return CGResult(cs->irb.ExtractValueByName(res.value, this->rhsIdent)); + iceAssert(this->lhs->type->isRawUnionType() || this->lhs->type->isStructType()); + if(this->lhs->type->isRawUnionType()) + { + fir::Value* field = 0; + if(res->islorclvalue()) + { + field = cs->irb.GetRawUnionFieldByType(res.value, this->type); + } + else + { + auto addr = cs->irb.ImmutStackAlloc(this->lhs->type, res.value); + field = cs->irb.GetRawUnionFieldByType(addr, this->type); + } + + return CGResult(field); + } + else + { + if(res->islorclvalue()) + { + // ok, at this point it's just a normal, instance field. + return CGResult(cs->irb.StructGEP(res.value, this->indexOfTransparentField)); + } + else + { + // use extractvalue. + return CGResult(cs->irb.ExtractValue(res.value, { this->indexOfTransparentField })); + } + } } else { - // ok, at this point it's just a normal, instance field. - return CGResult(cs->irb.GetStructMember(res.value, this->rhsIdent)); + if(this->lhs->type->isRawUnionType()) + { + fir::Value* field = 0; + if(res->islorclvalue()) + { + field = cs->irb.GetRawUnionField(res.value, this->rhsIdent); + } + else + { + auto addr = cs->irb.ImmutStackAlloc(this->lhs->type, res.value); + field = cs->irb.GetRawUnionField(addr, this->rhsIdent); + } + + return CGResult(field); + } + else + { + if(res->islorclvalue()) + { + // ok, at this point it's just a normal, instance field. + return CGResult(cs->irb.GetStructMember(res.value, this->rhsIdent)); + } + else + { + // use extractvalue. + return CGResult(cs->irb.ExtractValueByName(res.value, this->rhsIdent)); + } + } } } diff --git a/source/codegen/function.cpp b/source/codegen/function.cpp index 45b6b630..84da8765 100644 --- a/source/codegen/function.cpp +++ b/source/codegen/function.cpp @@ -142,7 +142,7 @@ CGResult sst::ArgumentDefn::_codegen(cgn::CodegenState* cs, fir::Type* infer) auto fn = cs->getCurrentFunction(); auto arg = cs->irb.CreateConstLValue(fn->getArgumentWithName(this->id.name), this->id.name); - if(cs->isRefCountedType(arg->getType())) + if(fir::isRefCountedType(arg->getType())) cs->addRefCountedValue(arg); // ok... diff --git a/source/codegen/glue/arrays.cpp b/source/codegen/glue/arrays.cpp index cfcb146f..7468fcee 100644 --- a/source/codegen/glue/arrays.cpp +++ b/source/codegen/glue/arrays.cpp @@ -563,7 +563,7 @@ namespace array iceAssert(freefn); // only when we free, do we loop through our array and decrement its refcount. - if(cs->isRefCountedType(elmtype)) + if(fir::isRefCountedType(elmtype)) { auto ctrp = cs->irb.StackAlloc(fir::Type::getInt64()); cs->irb.WritePtr(zv, ctrp); diff --git a/source/codegen/glue/saa_common.cpp b/source/codegen/glue/saa_common.cpp index b496025e..70a8e1fb 100644 --- a/source/codegen/glue/saa_common.cpp +++ b/source/codegen/glue/saa_common.cpp @@ -30,7 +30,7 @@ namespace saa_common static fir::Function* generateIncrementArrayRefCountInLoopFunction(CodegenState* cs, fir::Type* elm) { - iceAssert(cs->isRefCountedType(elm)); + iceAssert(fir::isRefCountedType(elm)); auto fname = "__loop_incr_rc_" + elm->str(); fir::Function* retfn = cs->module->getFunction(Identifier(fname, IdKind::Name)); @@ -178,7 +178,7 @@ namespace saa_common fir::Function* memcpyf = cs->module->getIntrinsicFunction("memmove"); cs->irb.Call(memcpyf, { newptr, cs->irb.PointerTypeCast(cs->irb.PointerAdd(oldptr, - startIndex), fir::Type::getMutInt8Ptr()), bytecount, fir::ConstantInt::getInt32(0), fir::ConstantBool::get(false) }); + startIndex), fir::Type::getMutInt8Ptr()), bytecount, fir::ConstantBool::get(false) }); #if DEBUG_ARRAY_ALLOCATION | DEBUG_STRING_ALLOCATION { @@ -224,7 +224,7 @@ namespace saa_common fir::Function* memcpyf = cs->module->getIntrinsicFunction("memmove"); cs->irb.Call(memcpyf, { newptr, cs->irb.PointerTypeCast(cs->irb.PointerAdd(oldptr, - startIndex), fir::Type::getMutInt8Ptr()), bytecount, fir::ConstantInt::getInt32(0), fir::ConstantBool::get(false) }); + startIndex), fir::Type::getMutInt8Ptr()), bytecount, fir::ConstantBool::get(false) }); } else { @@ -382,7 +382,7 @@ namespace saa_common cs->irb.Call(memcpyf, { cs->irb.PointerTypeCast(cs->irb.PointerAdd(lhsbuf, lhslen), fir::Type::getMutInt8Ptr()), cs->irb.PointerTypeCast(rhsbuf, fir::Type::getMutInt8Ptr()), rhsbytecount, - fir::ConstantInt::getInt32(0), fir::ConstantBool::get(false) + fir::ConstantBool::get(false) }); // null terminator @@ -396,7 +396,7 @@ namespace saa_common lhs = cs->irb.SetSAALength(lhs, cs->irb.Add(lhslen, rhslen)); // handle refcounting - if(cs->isRefCountedType(getSAAElm(saa))) + if(fir::isRefCountedType(getSAAElm(saa))) { auto incrfn = generateIncrementArrayRefCountInLoopFunction(cs, getSAAElm(saa)); iceAssert(incrfn); @@ -521,11 +521,11 @@ namespace saa_common auto rhsbytecount = cs->irb.Multiply(rhslen, cs->irb.Sizeof(getSAAElm(saa)), "rhsbytecount"); cs->irb.Call(memcpyf, { rawbuf, rawlhsbuf, - lhsbytecount, fir::ConstantInt::getInt32(0), fir::ConstantBool::get(false) + lhsbytecount, fir::ConstantBool::get(false) }); cs->irb.Call(memcpyf, { cs->irb.PointerAdd(rawbuf, lhsbytecount), rawrhsbuf, - rhsbytecount, fir::ConstantInt::getInt32(0), fir::ConstantBool::get(false) + rhsbytecount, fir::ConstantBool::get(false) }); // if it's a string, again, null terminator. @@ -533,7 +533,7 @@ namespace saa_common { cs->irb.WritePtr(fir::ConstantInt::getInt8(0), cs->irb.PointerAdd(rawbuf, cs->irb.Add(lhsbytecount, rhsbytecount))); } - else if(cs->isRefCountedType(getSAAElm(saa))) + else if(fir::isRefCountedType(getSAAElm(saa))) { auto incrfn = generateIncrementArrayRefCountInLoopFunction(cs, getSAAElm(saa)); iceAssert(incrfn); @@ -670,7 +670,6 @@ namespace saa_common cs->irb.setCurrentBlock(doExpansion); { - // TODO: is it faster to times 3 divide by 2, or do FP casts and times 1.5? auto newlen = cs->irb.Divide(cs->irb.Multiply(minsz, getCI(3)), getCI(2), "mul1.5"); // call realloc. handles the null case as well, which is nice. diff --git a/source/codegen/literals.cpp b/source/codegen/literals.cpp index d7920edc..5d54cdf6 100644 --- a/source/codegen/literals.cpp +++ b/source/codegen/literals.cpp @@ -36,7 +36,7 @@ CGResult sst::LiteralArray::_codegen(cgn::CodegenState* cs, fir::Type* infer) if(this->type->isArrayType()) { auto elmty = this->type->toArrayType()->getElementType(); - if(cs->isRefCountedType(elmty)) + if(fir::isRefCountedType(elmty)) error(this, "Cannot have refcounted type in array literal"); std::vector vals; diff --git a/source/codegen/refcounting.cpp b/source/codegen/refcounting.cpp index 6be6c57b..cb6bba89 100644 --- a/source/codegen/refcounting.cpp +++ b/source/codegen/refcounting.cpp @@ -128,7 +128,7 @@ namespace cgn // warn(this->loc(), "hi (%d)", rhs->islorclvalue()); - if(this->isRefCountedType(rhs->getType())) + if(fir::isRefCountedType(rhs->getType())) { if(rhs->islorclvalue()) this->performRefCountingAssignment(lhs, rhs, isinit); @@ -167,12 +167,10 @@ namespace cgn template void doRefCountOfAggregateType(CodegenState* cs, T* type, fir::Value* value, bool incr) { - // iceAssert(cgi->isRefCountedType(type)); - size_t i = 0; for(auto m : type->getElements()) { - if(cs->isRefCountedType(m)) + if(fir::isRefCountedType(m)) { fir::Value* mem = cs->irb.ExtractValue(value, { i }); @@ -223,16 +221,6 @@ namespace cgn } else if(type->isArrayType()) { - // fir::ArrayType* at = type->toArrayType(); - // for(size_t i = 0; i < at->getArraySize(); i++) - // { - // fir::Value* elm = cs->irb.ExtractValue(type, { i }); - // iceAssert(cs->isRefCountedType(elm->getType())); - - // if(incr) cs->incrementRefCount(elm); - // else cs->decrementRefCount(elm); - // } - error("no array"); } else if(type->isAnyType()) @@ -253,62 +241,15 @@ namespace cgn void CodegenState::incrementRefCount(fir::Value* val) { - iceAssert(this->isRefCountedType(val->getType())); + iceAssert(fir::isRefCountedType(val->getType())); _doRefCount(this, val, true); } void CodegenState::decrementRefCount(fir::Value* val) { - iceAssert(this->isRefCountedType(val->getType())); + iceAssert(fir::isRefCountedType(val->getType())); _doRefCount(this, val, false); } - - - - - - bool CodegenState::isRefCountedType(fir::Type* type) - { - // strings, and structs with rc inside - if(type->isStructType()) - { - for(auto m : type->toStructType()->getElements()) - { - if(this->isRefCountedType(m)) - return true; - } - - return false; - } - else if(type->isClassType()) - { - for(auto m : type->toClassType()->getElements()) - { - if(this->isRefCountedType(m)) - return true; - } - - return false; - } - else if(type->isTupleType()) - { - for(auto m : type->toTupleType()->getElements()) - { - if(this->isRefCountedType(m)) - return true; - } - - return false; - } - else if(type->isArrayType()) // note: no slices, because slices don't own memory - { - return this->isRefCountedType(type->getArrayElementType()); - } - else - { - return type->isStringType() || type->isAnyType() || type->isDynamicArrayType(); - } - } } diff --git a/source/codegen/unions.cpp b/source/codegen/unions.cpp index 0f29ffe4..c84d7fda 100644 --- a/source/codegen/unions.cpp +++ b/source/codegen/unions.cpp @@ -7,9 +7,6 @@ CGResult sst::UnionDefn::_codegen(cgn::CodegenState* cs, fir::Type* infer) { - cs->pushLoc(this); - defer(cs->popLoc()); - // there's actually nothing to do. // nothing at all. @@ -18,9 +15,12 @@ CGResult sst::UnionDefn::_codegen(cgn::CodegenState* cs, fir::Type* infer) CGResult sst::UnionVariantDefn::_codegen(cgn::CodegenState* cs, fir::Type* infer) { - cs->pushLoc(this); - defer(cs->popLoc()); + return CGResult(0); +} +CGResult sst::RawUnionDefn::_codegen(cgn::CodegenState* cs, fir::Type* infer) +{ + // again, does nothing. return CGResult(0); } diff --git a/source/codegen/variable.cpp b/source/codegen/variable.cpp index 9e95b70a..a0510c86 100644 --- a/source/codegen/variable.cpp +++ b/source/codegen/variable.cpp @@ -32,8 +32,6 @@ CGResult sst::VarDefn::_codegen(cgn::CodegenState* cs, fir::Type* infer) if(auto it = cs->typeDefnMap.find(this->type); it != cs->typeDefnMap.end()) it->second->codegen(cs); - // bool refcounted = cs->isRefCountedType(this->type); - if(this->global) { auto rest = cs->enterGlobalInitFunction(); @@ -111,7 +109,7 @@ void cgn::CodegenState::addVariableUsingStorage(sst::VarDefn* var, fir::Value* a if(val.value) this->autoAssignRefCountedValue(alloc, val.value, /* isInitial: */ true, /* performStore: */ !var->immutable); - if(this->isRefCountedType(var->type)) + if(fir::isRefCountedType(var->type)) this->addRefCountedValue(alloc); } diff --git a/source/fir/IRBuilder.cpp b/source/fir/IRBuilder.cpp index 5d698b46..cd076681 100644 --- a/source/fir/IRBuilder.cpp +++ b/source/fir/IRBuilder.cpp @@ -1265,6 +1265,37 @@ namespace fir } + Value* IRBuilder::GetRawUnionFieldByType(Value* lval, Type* type, const std::string& vname) + { + if(!lval->islorclvalue()) + error("cannot do raw union ops on non-lvalue"); + + if(!lval->getType()->isRawUnionType()) + error("'%s' is not a raw union type!", lval->getType()); + + Instruction* instr = make_instr(OpKind::RawUnion_GEP, false, this->currentBlock, type, { lval, ConstantValue::getZeroValue(type) }); + + auto ret = this->addInstruction(instr, ""); + ret->setKind(lval->kind); + + return ret; + } + + Value* IRBuilder::GetRawUnionField(Value* lval, const std::string& field, const std::string& vname) + { + if(!lval->islorclvalue()) + error("cannot do raw union ops on non-lvalue"); + + if(!lval->getType()->isRawUnionType()) + error("'%s' is not a raw union type!", lval->getType()); + + auto rut = lval->getType()->toRawUnionType(); + if(!rut->hasVariant(field)) + error("union '%s' does not have a field '%s'", rut->getTypeName(), field); + + auto ty = rut->getVariant(field); + return this->GetRawUnionFieldByType(lval, ty, vname); + } @@ -1290,17 +1321,16 @@ namespace fir if(!structPtr->islorclvalue()) error("cannot do GEP on non-lvalue"); - if(StructType* st = dcast(StructType, structPtr->getType())) - { - return this->addInstruction(doGEPOnCompoundType(this->currentBlock, st, structPtr, memberIndex), vname); - } - if(ClassType* st = dcast(ClassType, structPtr->getType())) + //* note: we do not allow raw gep (by index) into classes, because V T A B L E + if(structPtr->getType()->isStructType()) { - return this->addInstruction(doGEPOnCompoundType(this->currentBlock, st, structPtr, memberIndex), vname); + return this->addInstruction(doGEPOnCompoundType(this->currentBlock, structPtr->getType()->toStructType(), + structPtr, memberIndex), vname); } - else if(TupleType* tt = dcast(TupleType, structPtr->getType())) + else if(structPtr->getType()->isTupleType()) { - return this->addInstruction(doGEPOnCompoundType(this->currentBlock, tt, structPtr, memberIndex), vname); + return this->addInstruction(doGEPOnCompoundType(this->currentBlock, structPtr->getType()->toTupleType(), + structPtr, memberIndex), vname); } else { @@ -1308,7 +1338,7 @@ namespace fir } } - Value* IRBuilder::GetStructMember(Value* ptr, std::string memberName) + Value* IRBuilder::GetStructMember(Value* ptr, const std::string& memberName) { if(!ptr->islorclvalue()) error("cannot do GEP on non-lvalue"); @@ -1536,7 +1566,6 @@ namespace fir args.push_back(fir::ConstantInt::getInt64(id + ofs)); - // note: no sideeffects, since we return a new aggregate Instruction* instr = make_instr(OpKind::Value_ExtractValue, false, this->currentBlock, et, args); return this->addInstruction(instr, vname); } diff --git a/source/fir/Instruction.cpp b/source/fir/Instruction.cpp index 04eace87..f6ff2179 100644 --- a/source/fir/Instruction.cpp +++ b/source/fir/Instruction.cpp @@ -172,6 +172,8 @@ namespace fir case OpKind::Union_GetVariantID: instrname = "get_union.id"; break; case OpKind::Union_SetVariantID: instrname = "set_union.id"; break; + case OpKind::RawUnion_GEP: instrname = "raw_union_gep"; break; + case OpKind::Value_AddressOf: instrname = "addrof"; break; case OpKind::Value_Store: instrname = "store"; break; case OpKind::Value_Dereference: instrname = "dereferece"; break; diff --git a/source/fir/Module.cpp b/source/fir/Module.cpp index 48264e32..15211170 100644 --- a/source/fir/Module.cpp +++ b/source/fir/Module.cpp @@ -353,21 +353,21 @@ namespace fir { name = Identifier("memcpy", IdKind::Name); ft = FunctionType::get({ fir::Type::getMutInt8Ptr(), fir::Type::getInt8Ptr(), - fir::Type::getInt64(), fir::Type::getInt32(), fir::Type::getBool() }, + fir::Type::getInt64(), fir::Type::getBool() }, fir::Type::getVoid()); } else if(id == "memmove") { name = Identifier("memmove", IdKind::Name); ft = FunctionType::get({ fir::Type::getMutInt8Ptr(), fir::Type::getMutInt8Ptr(), - fir::Type::getInt64(), fir::Type::getInt32(), fir::Type::getBool() }, + fir::Type::getInt64(), fir::Type::getBool() }, fir::Type::getVoid()); } else if(id == "memset") { name = Identifier("memset", IdKind::Name); ft = FunctionType::get({ fir::Type::getMutInt8Ptr(), fir::Type::getInt8(), - fir::Type::getInt64(), fir::Type::getInt32(), fir::Type::getBool() }, + fir::Type::getInt64(), fir::Type::getBool() }, fir::Type::getVoid()); } else if(id == "memcmp") @@ -377,7 +377,7 @@ namespace fir name = Identifier("memcmp", IdKind::Name); ft = FunctionType::get({ fir::Type::getInt8Ptr(), fir::Type::getInt8Ptr(), - fir::Type::getInt64(), fir::Type::getInt32(), fir::Type::getBool() }, + fir::Type::getInt64(), fir::Type::getBool() }, fir::Type::getInt32()); } else if(id == "roundup_pow2") diff --git a/source/fir/Types/ClassType.cpp b/source/fir/Types/ClassType.cpp index 211e182e..4becc201 100644 --- a/source/fir/Types/ClassType.cpp +++ b/source/fir/Types/ClassType.cpp @@ -26,7 +26,7 @@ namespace fir const std::vector& methods, const std::vector& inits) { if(auto it = typeCache.find(name); it != typeCache.end()) - error("Class with name '%s' already exists", name.str()); + error("class with name '%s' already exists", name.str()); else return (typeCache[name] = new ClassType(name, members, methods, inits)); @@ -304,7 +304,7 @@ namespace fir } else { - error("No such method named '%s' matching signature '%s' in virtual method table of class '%s'", + error("no such method named '%s' matching signature '%s' in virtual method table of class '%s'", name, (Type*) ft, this->getTypeName().name); } } diff --git a/source/fir/Types/EnumType.cpp b/source/fir/Types/EnumType.cpp index ad9174a0..eea39f11 100644 --- a/source/fir/Types/EnumType.cpp +++ b/source/fir/Types/EnumType.cpp @@ -73,7 +73,7 @@ namespace fir EnumType* EnumType::get(const Identifier& name, Type* caseType) { if(auto it = typeCache.find(name); it != typeCache.end()) - error("Enum with name '%s' already exists", name.str()); + error("enum with name '%s' already exists", name.str()); else return (typeCache[name] = new EnumType(name, caseType)); diff --git a/source/fir/Types/RawUnionType.cpp b/source/fir/Types/RawUnionType.cpp new file mode 100644 index 00000000..59ea993f --- /dev/null +++ b/source/fir/Types/RawUnionType.cpp @@ -0,0 +1,107 @@ +// RawUnionType.cpp +// Copyright (c) 2019, zhiayang@gmail.com +// Licensed under the Apache License Version 2.0. + + +#include "errors.h" +#include "ir/type.h" + +#include "pts.h" + +namespace fir +{ + // structs + RawUnionType::RawUnionType(const Identifier& name, const util::hash_map& mems) + : Type(TypeKind::RawUnion) + { + this->unionName = name; + this->setBody(mems); + } + + static util::hash_map typeCache; + RawUnionType* RawUnionType::create(const Identifier& name, const util::hash_map& mems) + { + if(auto it = typeCache.find(name); it != typeCache.end()) + error("Union with name '%s' already exists", name.str()); + + else + return (typeCache[name] = new RawUnionType(name, mems)); + } + + RawUnionType* RawUnionType::createWithoutBody(const Identifier& name) + { + return RawUnionType::create(name, { }); + } + + + + + + + // various + std::string RawUnionType::str() + { + return "raw_union(" + this->unionName.name + ")"; + } + + std::string RawUnionType::encodedStr() + { + return this->unionName.str(); + } + + + bool RawUnionType::isTypeEqual(Type* other) + { + if(other->kind != TypeKind::Union) + return false; + + return (this->unionName == other->toRawUnionType()->unionName); + } + + + + // struct stuff + Identifier RawUnionType::getTypeName() + { + return this->unionName; + } + + size_t RawUnionType::getVariantCount() + { + return this->variants.size(); + } + + util::hash_map RawUnionType::getVariants() + { + return this->variants; + } + + bool RawUnionType::hasVariant(const std::string& name) + { + return this->variants.find(name) != this->variants.end(); + } + + Type* RawUnionType::getVariant(const std::string& name) + { + if(auto it = this->variants.find(name); it != this->variants.end()) + return it->second; + + else + error("no variant named '%s' in union '%s'", name, this->getTypeName().str()); + } + + + + void RawUnionType::setBody(const util::hash_map& members) + { + this->variants = members; + } + + fir::Type* RawUnionType::substitutePlaceholders(const util::hash_map& subst) + { + if(this->containsPlaceholders()) + error("not supported!"); + + return this; + } +} \ No newline at end of file diff --git a/source/fir/Types/StructType.cpp b/source/fir/Types/StructType.cpp index 4965ef9c..189444f7 100644 --- a/source/fir/Types/StructType.cpp +++ b/source/fir/Types/StructType.cpp @@ -23,7 +23,7 @@ namespace fir StructType* StructType::create(const Identifier& name, const std::vector>& members, bool packed) { if(auto it = typeCache.find(name); it != typeCache.end()) - error("Struct with name '%s' already exists", name.str()); + error("struct with name '%s' already exists", name.str()); else return (typeCache[name] = new StructType(name, members, packed)); diff --git a/source/fir/Types/Type.cpp b/source/fir/Types/Type.cpp index 2ff05514..f8c1cdf4 100644 --- a/source/fir/Types/Type.cpp +++ b/source/fir/Types/Type.cpp @@ -519,6 +519,12 @@ namespace fir return static_cast(this); } + RawUnionType* Type::toRawUnionType() + { + if(this->kind != TypeKind::RawUnion) error("not raw union type"); + return static_cast(this); + } + AnyType* Type::toAnyType() { if(this->kind != TypeKind::Any) error("not any type"); @@ -602,6 +608,11 @@ namespace fir return this->isIntegerType() && this->toPrimitiveType()->isSigned(); } + bool Type::isUnsignedIntType() + { + return this->isIntegerType() && !this->toPrimitiveType()->isSigned(); + } + bool Type::isFunctionType() { return this->kind == TypeKind::Function; @@ -662,6 +673,11 @@ namespace fir return this->kind == TypeKind::Union; } + bool Type::isRawUnionType() + { + return this->kind == TypeKind::RawUnion; + } + bool Type::isAnyType() { return this->kind == TypeKind::Any; @@ -1006,6 +1022,41 @@ namespace fir return getAggregateSize(tys); } + else if(type->isUnionType() ) + { + auto ut = type->toUnionType(); + + size_t maxSz = 0; + for(auto v : ut->getVariants()) + { + if(!v.second->getInteriorType()->isVoidType()) + maxSz = std::max(maxSz, getSizeOfType(v.second->getInteriorType())); + } + + if(maxSz > 0) + { + return getAggregateSize({ Type::getInt64(), ArrayType::get(Type::getInt8(), maxSz) }); + } + else + { + return getAggregateSize({ Type::getInt64() }); + } + } + else if(type->isRawUnionType()) + { + auto ut = type->toRawUnionType(); + + size_t maxSz = 0; + for(auto v : ut->getVariants()) + maxSz = std::max(maxSz, getSizeOfType(v.second)); + + iceAssert(maxSz > 0); + return getAggregateSize({ ArrayType::get(Type::getInt8(), maxSz) }); + } + else if(type->isUnionVariantType()) + { + return getSizeOfType(type->toUnionVariantType()->getInteriorType()); + } else { error("cannot get size of unsupported type '%s'", type); @@ -1017,6 +1068,50 @@ namespace fir if(type->isArrayType()) return getAlignmentOfType(type->getArrayElementType()); else return getSizeOfType(type); } + + + bool isRefCountedType(Type* type) + { + // strings, and structs with rc inside + if(type->isStructType()) + { + for(auto m : type->toStructType()->getElements()) + { + if(isRefCountedType(m)) + return true; + } + + return false; + } + else if(type->isClassType()) + { + for(auto m : type->toClassType()->getElements()) + { + if(isRefCountedType(m)) + return true; + } + + return false; + } + else if(type->isTupleType()) + { + for(auto m : type->toTupleType()->getElements()) + { + if(isRefCountedType(m)) + return true; + } + + return false; + } + else if(type->isArrayType()) // note: no slices, because slices don't own memory + { + return isRefCountedType(type->getArrayElementType()); + } + else + { + return type->isStringType() || type->isAnyType() || type->isDynamicArrayType(); + } + } } diff --git a/source/fir/Types/UnionType.cpp b/source/fir/Types/UnionType.cpp index 30c2ce6e..0e429539 100644 --- a/source/fir/Types/UnionType.cpp +++ b/source/fir/Types/UnionType.cpp @@ -21,7 +21,7 @@ namespace fir UnionType* UnionType::create(const Identifier& name, const util::hash_map>& mems) { if(auto it = typeCache.find(name); it != typeCache.end()) - error("Union with name '%s' already exists", name.str()); + error("union with name '%s' already exists", name.str()); else return (typeCache[name] = new UnionType(name, mems)); diff --git a/source/frontend/errors.cpp b/source/frontend/errors.cpp index 1a18fddf..d68a7995 100644 --- a/source/frontend/errors.cpp +++ b/source/frontend/errors.cpp @@ -163,8 +163,8 @@ std::string __error_gen_internal(const Location& loc, const std::string& msg, co // bool empty = strcmp(type, "") == 0; // bool dobold = strcmp(type, "note") != 0; - // todo: do we want to truncate the file path? - // we're doing it now, might want to change (or use a flag) + //? do we want to truncate the file path? + //? we're doing it now, might want to change (or use a flag) std::string filename = frontend::getFilenameFromPath(loc.fileID == 0 ? "(unknown)" : frontend::getFilenameFromID(loc.fileID)); diff --git a/source/frontend/parser/expr.cpp b/source/frontend/parser/expr.cpp index 4ebf559b..3a2e35bd 100644 --- a/source/frontend/parser/expr.cpp +++ b/source/frontend/parser/expr.cpp @@ -140,11 +140,18 @@ namespace parser case TT::Continue: return parseContinue(st); + case TT::Attr_Raw: + st.eat(); + if(st.front() != TT::Union) + expectedAfter(st.loc(), "'union'", "'@raw' while parsing statement", st.front().str()); + + return parseUnion(st, /* isRaw: */ true, /* nameless: */ false); + case TT::Union: - return parseUnion(st); + return parseUnion(st, /* isRaw: */ false, /* nameless: */ false); case TT::Struct: - return parseStruct(st); + return parseStruct(st, /* nameless: */ false); case TT::Class: return parseClass(st); @@ -231,16 +238,79 @@ namespace parser { switch(t) { - // () and [] have the same precedence. - // not sure if this should stay -- works for now. + /* + ! ACHTUNG ! + * DOCUMENT THIS SOMEWHERE PROPERLY!!! * + + due to how we handle identifiers and scope paths (foo::bar), function calls must have higher precedence + than scope resolution. + + this might seem counter-intuitive (i should be resolving the complete identifier first, then calling it with + ()! why would it be any other way???), this is a sad fact of how the typechecker works. + + as it stands, identifiers are units; paths consist of multiple identifiers in a DotOp with the :: operator, which + is left-associative. so for something like foo::bar::qux, it's ((foo)::(bar))::qux. + + to resolve a function call, eg. foo::bar::qux(), the DotOp is laid out as [foo::bar]::[qux()] (with [] for grouping), + instead of the 'intuitive' [[foo::bar]::qux](). the reason for this design was the original rewrite goal of not + relying on string manipulation in the compiler; having identifiers contain :: would have been counter to that + goal. + + (note: currently we are forced to manipulate ::s in the pts->fir type converter!!) + + the current typechecker will thus first find a namespace 'foo', and within that a namespace 'bar', and within that + a function 'qux'. this is opposed to finding a namespace 'foo', then 'bar', then an identifier 'qux', and leaving + that to be resolved later. + + also, another potential issue is how we deal with references to functions (ie. function pointers). our resolver + for ExprCall is strictly less advanced than that for a normal FunctionCall (for reasons i can't reCALL (lmao)), so + we would prefer to return a FuncCall rather than an ExprCall. + + + this model could be re-architected without a *major* rewrite, but it would be a non-trivial task and a considerable amount + of work and debugging. for reference: + + 1. make Idents be able to refer to entire paths; just a datastructure change + 2. make :: have higher precedence than (), to take advantage of (1) + 3. parse ExprCall and FuncCall identically -- they should both just be a NewCall, with an Expr as the callee + + 4. in typechecking a NewCall, just call ->typecheck() on the LHS; the current implementation of Ident::typecheck + returns an sst::VarRef, which has an sst::VarDefn field which we can use. + + 4a. if the Defn was a VarDefn, they cannot overload, and we can just do what we currently do for ExprCall. + if it was a function defn, then things get more complicated. + + 5. the Defn was a FuncDefn. currently Ident cannot return more than one Defn in the event of ambiguous results (eg. + when overloading!), which means we are unable to properly do overload resolution! (duh) we need to make a mechanism + for Ident to return a list of Defns. + + potential solution (A): make an sst::AmbiguousDefn struct that itself holds a list of Defns. the VarRef returned by + Ident would then return that in the ->def field. this would only happen when the target is function; i presume we + have existing mechanisms to detect invalid "overload" scenarios. + + back to (4), we should in theory be able to resolve functions from a list of defns. + + + the problem with this is that while it might seem like a simple 5.5-step plan, a lot of the supporting resolver functions + need to change, and effort is better spent elsewhere tbh + + for now we just stick to parsing () at higher precedence than ::. + */ + + case TT::LParen: - case TT::LSquare: - return 2000; + return 9001; // very funny - case TT::Period: case TT::DoubleColon: + return 5000; + + case TT::Period: + return 1500; + + case TT::LSquare: return 1000; + // unary ! // unary +/- // bitwise ~ @@ -1100,16 +1170,16 @@ namespace parser case TT::Attr_Raw: st.pop(); if(st.front() == TT::StringLiteral) - return parseString(st, true); + return parseString(st, /* isRaw: */ true); else if(st.front() == TT::LSquare) - return parseArray(st, true); + return parseArray(st, /* isRaw: */ true); else if(st.front() == TT::Alloc) - return parseAlloc(st, true); + return parseAlloc(st, /* isRaw: */ true); else - expectedAfter(st, "one of string-literal, array, or alloc", "@raw", st.front().str()); + expectedAfter(st, "one of string-literal, array, or alloc", "'@raw' while parsing expression", st.front().str()); case TT::StringLiteral: return parseString(st, false); diff --git a/source/frontend/parser/function.cpp b/source/frontend/parser/function.cpp index 7b7b1398..befea371 100644 --- a/source/frontend/parser/function.cpp +++ b/source/frontend/parser/function.cpp @@ -172,7 +172,7 @@ namespace parser if(st.front() != TT::StringLiteral) expectedAfter(st.loc(), "string literal", "'as' in foreign function declaration", st.front().str()); - ffn->realName = st.eat().str(); + ffn->realName = parseStringEscapes(st.loc(), st.eat().str()); } else { diff --git a/source/frontend/parser/literal.cpp b/source/frontend/parser/literal.cpp index 15d1acbe..f10ae11f 100644 --- a/source/frontend/parser/literal.cpp +++ b/source/frontend/parser/literal.cpp @@ -8,11 +8,19 @@ #include "mpool.h" +#include "utf8rewind/include/utf8rewind/utf8rewind.h" + #include using namespace ast; using namespace lexer; +// #ifdef _WIN32 +// #define PLATFORM_NEWLINE "\r\n" +// #else +// #define PLATFORM_NEWLINE "\n" +// #endif + using TT = lexer::TokenType; namespace parser { @@ -24,39 +32,115 @@ namespace parser return util::pool(st.ploc(), t.str()); } - LitString* parseString(State& st, bool israw) + static std::string parseHexEscapes(const Location& loc, std::string_view sv, size_t* ofs) { - iceAssert(st.front() == TT::StringLiteral); - auto t = st.eat(); + if(sv[0] == 'x') + { + if(sv.size() < 3) + error(loc, "malformed escape sequence: unexpected end of string"); + + if(!isxdigit(sv[1]) || !isxdigit(sv[2])) + error(loc, "malformed escape sequence: non-hex character in \\x escape"); + + // ok then. + char s[2] = { sv[1], sv[2] }; + char val = std::stol(s, /* pos: */ 0, /* base: */ 16); + + *ofs = 3; + return std::string(&val, 1); + } + else if(sv[0] == 'u') + { + if(sv.size() < 3) + error(loc, "malformed escape sequence: unexpected end of string"); + + sv.remove_prefix(1); + if(sv[0] != '{') + error(loc, "malformed escape sequence: expected '{' after \\u"); - // do replacement here, instead of in the lexer. - std::string tmp = t.str(); + sv.remove_prefix(1); + + std::string digits; + size_t i = 0; + for(i = 0; i < sv.size(); i++) + { + if(sv[i] == '}') break; + + if(!isxdigit(sv[i])) + error(loc, "malformed escape sequence: non-hex character '%c' inside \\u{...}", sv[i]); + + if(digits.size() == 8) + error(loc, "malformed escape sequence: too many digits inside \\u{...}; up to 8 are allowed"); + + digits += sv[i]; + } + + if(sv[i] != '}') + error(loc, "malformed escape sequence: expcected '}' to end codepoint escape"); + + uint32_t codepoint = std::stol(digits, /* pos: */ 0, /* base: */ 16); + + char output[8] = { 0 }; + int err = 0; + auto sz = utf32toutf8(&codepoint, 4, output, 8, &err); + if(err != UTF8_ERR_NONE) + error(loc, "invalid utf32 codepoint!"); + + *ofs = 3 + digits.size(); + return std::string(output, sz); + } + else + { + iceAssert("wtf yo" && 0); + } + } + + std::string parseStringEscapes(const Location& loc, const std::string& str) + { std::stringstream ss; - for(size_t i = 0; i < tmp.length(); i++) + for(size_t i = 0; i < str.length(); i++) { - if(tmp[i] == '\\') + if(str[i] == '\\') { i++; - switch(tmp[i]) + switch(str[i]) { - // todo: handle hex sequences and stuff - case 'n': ss << '\n'; break; - case 'b': ss << '\b'; break; - case 'r': ss << '\r'; break; - case 't': ss << '\t'; break; - case '"': ss << '\"'; break; - case '\\': ss << '\\'; break; - default: ss << std::string("\\") + tmp[i]; break; + case 'n': ss << "\n"; break; + case 'b': ss << "\b"; break; + case 'r': ss << "\r"; break; + case 't': ss << "\t"; break; + case '"': ss << "\""; break; + case '\\': ss << "\\"; break; + + case 'x': // fallthrough + case 'u': { + size_t ofs = 0; + ss << parseHexEscapes(loc, std::string_view(str.c_str() + i, str.size() - i), &ofs); + i += ofs - 1; + break; + } + + default: + ss << std::string("\\") + str[i]; + break; } } else { - ss << tmp[i]; + ss << str[i]; } } - return util::pool(st.ploc(), ss.str(), israw); + return ss.str(); + } + + LitString* parseString(State& st, bool israw) + { + iceAssert(st.front() == TT::StringLiteral); + auto t = st.eat(); + + return util::pool(st.ploc(), parseStringEscapes(st.ploc(), t.str()), israw); } LitArray* parseArray(State& st, bool israw) diff --git a/source/frontend/parser/toplevel.cpp b/source/frontend/parser/toplevel.cpp index ab4f962c..8e8a4d08 100644 --- a/source/frontend/parser/toplevel.cpp +++ b/source/frontend/parser/toplevel.cpp @@ -275,6 +275,9 @@ namespace parser st.eat(); } + // throw in all anonymous types to the top level + root->statements.insert(root->statements.begin(), st.anonymousTypeDefns.begin(), st.anonymousTypeDefns.end()); + return root; } diff --git a/source/frontend/parser/type.cpp b/source/frontend/parser/type.cpp index 79a16a24..780a8b76 100644 --- a/source/frontend/parser/type.cpp +++ b/source/frontend/parser/type.cpp @@ -13,72 +13,6 @@ using namespace lexer; namespace parser { using TT = lexer::TokenType; - StructDefn* parseStruct(State& st) - { - iceAssert(st.front() == TT::Struct); - st.eat(); - - if(st.front() != TT::Identifier) - expectedAfter(st, "identifier", "'struct'", st.front().str()); - - StructDefn* defn = util::pool(st.loc()); - defn->name = st.eat().str(); - - // check for generic function - if(st.front() == TT::LAngle) - { - st.eat(); - // parse generic - if(st.front() == TT::RAngle) - error(st, "empty type parameter lists are not allowed"); - - defn->generics = parseGenericTypeList(st); - } - - st.skipWS(); - if(st.front() != TT::LBrace) - expectedAfter(st, "'{'", "'struct'", st.front().str()); - - st.enterStructBody(); - - auto blk = parseBracedBlock(st); - for(auto s : blk->statements) - { - if(auto v = dcast(VarDefn, s)) - { - if(v->type == pts::InferredType::get()) - error(v, "struct fields must have types explicitly specified"); - - else if(v->initialiser) - error(v->initialiser, "struct fields cannot have inline initialisers"); - - defn->fields.push_back(v); - } - else if(auto f = dcast(FuncDefn, s)) - { - defn->methods.push_back(f); - } - else if(auto t = dcast(TypeDefn, s)) - { - defn->nestedTypes.push_back(t); - } - else if(dcast(InitFunctionDefn, s)) - { - error(s, "structs cannot have user-defined initialisers"); - } - else - { - error(s, "unsupported expression or statement in struct body"); - } - } - - for(auto s : blk->deferredStatements) - error(s, "unsupported expression or statement in struct body"); - - st.leaveStructBody(); - return defn; - } - ClassDefn* parseClass(State& st) { @@ -185,17 +119,144 @@ namespace parser + StructDefn* parseStruct(State& st, bool nameless) + { + static size_t anon_counter = 0; + + iceAssert(st.front() == TT::Struct); + st.eat(); + + StructDefn* defn = util::pool(st.loc()); + if(nameless) + { + defn->name = strprintf("__anon_struct_%zu", anon_counter++); + } + else + { + if(st.front() == TT::LBrace) + error(st, "declared structs (in non-type usage) must be named"); + + else if(st.front() != TT::Identifier) + expectedAfter(st, "identifier", "'struct'", st.front().str()); + + else + defn->name = st.eat().str(); + } + + + // check for generic function + if(st.front() == TT::LAngle) + { + st.eat(); + // parse generic + if(st.front() == TT::RAngle) + error(st, "empty type parameter lists are not allowed"); + + defn->generics = parseGenericTypeList(st); + } + + // unions don't inherit stuff (for now????) so we don't check for it. + + st.skipWS(); + if(st.eat() != TT::LBrace) + expectedAfter(st.ploc(), "opening brace", "'struct'", st.front().str()); + + st.enterStructBody(); + st.skipWS(); + + size_t index = 0; + while(st.front() != TT::RBrace) + { + st.skipWS(); + + if(st.front() == TT::Identifier) + { + auto loc = st.loc(); + std::string name = st.eat().str(); + + // we can't check for duplicates when it's transparent, duh + // we'll collapse and collect and check during typechecking. + if(name != "_") + { + if(auto it = std::find_if(defn->fields.begin(), defn->fields.end(), [&name](const auto& p) -> bool { + return std::get<0>(p) == name; + }); it != defn->fields.end()) + { + SimpleError::make(loc, "duplicate field '%s' in struct definition", name) + ->append(SimpleError::make(MsgType::Note, std::get<1>(*it), "field '%s' previously defined here:", name)) + ->postAndQuit(); + } + } + + if(st.eat() != TT::Colon) + error(st.ploc(), "expected type specifier after field name in struct"); + + pts::Type* type = parseType(st); + defn->fields.push_back(std::make_tuple(name, loc, type)); + + if(st.front() == TT::Equal) + error(st.loc(), "struct fields cannot have initialisers"); + } + else if(st.front() == TT::Func) + { + // ok parse a func as usual + defn->methods.push_back(parseFunction(st)); + } + else if(st.front() == TT::Var || st.front() == TT::Val) + { + error(st.loc(), "struct fields are declared as 'name: type'; val/let is omitted"); + } + else if(st.front() == TT::Static) + { + error(st.loc(), "structs cannot have static declarations"); + } + else if(st.front() == TT::NewLine || st.front() == TT::Semicolon) + { + st.pop(); + } + else if(st.front() == TT::RBrace) + { + break; + } + else + { + error(st.loc(), "unexpected token '%s' inside struct body", st.front().str()); + } + + index++; + } + + iceAssert(st.front() == TT::RBrace); + st.eat(); + + st.leaveStructBody(); + return defn; + } + - UnionDefn* parseUnion(State& st) + UnionDefn* parseUnion(State& st, bool israw, bool nameless) { + static size_t anon_counter = 0; iceAssert(st.front() == TT::Union); st.eat(); - if(st.front() != TT::Identifier) - expectedAfter(st, "identifier", "'union'", st.front().str()); - UnionDefn* defn = util::pool(st.loc()); - defn->name = st.eat().str(); + if(nameless) + { + defn->name = strprintf("__anon_union_%zu", anon_counter++); + } + else + { + if(st.front() == TT::LBrace) + error(st, "declared unions (in non-type usage) must be named"); + + else if(st.front() != TT::Identifier) + expectedAfter(st, "identifier", "'union'", st.front().str()); + + else + defn->name = st.eat().str(); + } + // check for generic function if(st.front() == TT::LAngle) @@ -214,37 +275,63 @@ namespace parser if(st.eat() != TT::LBrace) expectedAfter(st.ploc(), "opening brace", "'union'", st.front().str()); + st.skipWS(); + if(st.front() == TT::RBrace) + error(st, "union must contain at least one variant"); + size_t index = 0; while(st.front() != TT::RBrace) { st.skipWS(); if(st.front() != TT::Identifier) - expected(st.loc(), "identifier inside union body", st.front().str()); + { + if(st.front() == TT::Var || st.front() == TT::Val) + error(st.loc(), "union fields are declared as 'name: type'; val/let is omitted"); + + else + expected(st.loc(), "identifier inside union body", st.front().str()); + } auto loc = st.loc(); pts::Type* type = 0; std::string name = st.eat().str(); - if(auto it = defn->cases.find(name); it != defn->cases.end()) - { - SimpleError::make(loc, "duplicate variant '%s' in union definition", name) - ->append(SimpleError::make(MsgType::Note, std::get<1>(it->second), "variant '%s' previously defined here:", name)) - ->postAndQuit(); - } - + // to improve code flow, handle the type first. if(st.front() == TT::Colon) { st.eat(); type = parseType(st); } - else if(st.front() != TT::NewLine) + else { - error(st.loc(), "expected newline after union variant"); + if(israw) + error(st.loc(), "raw unions cannot have empty variants (must have a type)"); + + else if(st.front() != TT::NewLine) + error(st.loc(), "expected newline after union variant"); } - if(type == nullptr) type = pts::NamedType::create(loc, VOID_TYPE_STRING); - defn->cases[name] = { index, loc, type }; + if(name == "_") + { + if(!israw) + error(loc, "transparent fields can only be present in raw unions"); + + iceAssert(type); + defn->transparentFields.push_back({ loc, type }); + } + else + { + if(auto it = defn->cases.find(name); it != defn->cases.end()) + { + SimpleError::make(loc, "duplicate variant '%s' in union definition", name) + ->append(SimpleError::make(MsgType::Note, std::get<1>(it->second), "variant '%s' previously defined here:", name)) + ->postAndQuit(); + } + + if(type == nullptr) type = pts::NamedType::create(loc, VOID_TYPE_STRING); + defn->cases[name] = { index, loc, type }; + } // do some things if(st.front() == TT::NewLine || st.front() == TT::Semicolon) @@ -266,6 +353,7 @@ namespace parser iceAssert(st.front() == TT::RBrace); st.eat(); + defn->israw = israw; return defn; } @@ -323,7 +411,7 @@ namespace parser } else if(hadValue) { - // todo: remove this restriction maybe + //? this is mostly because we don't want to deal with auto-incrementing stuff error(st.loc(), "enumeration cases must either all have no values, or all have values; a mix is not allowed."); } @@ -589,6 +677,27 @@ namespace parser return util::pool(loc, types); } } + else if(st.front() == TT::Struct) + { + auto str = parseStruct(st, /* nameless: */ true); + st.anonymousTypeDefns.push_back(str); + + return pts::NamedType::create(str->loc, str->name); + } + else if(st.front() == TT::Union || (st.front() == TT::Attr_Raw && st.lookahead(1) == TT::Union)) + { + bool israw = st.front() == TT::Attr_Raw; + if(israw) st.eat(); + + auto unn = parseUnion(st, israw, /* nameless: */ true); + st.anonymousTypeDefns.push_back(unn); + + return pts::NamedType::create(unn->loc, unn->name); + } + else if(st.front() == TT::Class) + { + error(st, "classes cannot be defined anonymously"); + } else { error(st, "unexpected token '%s' while parsing type", st.front().str()); @@ -596,7 +705,7 @@ namespace parser } - + // PAM == PolyArgMapping PolyArgMapping_t parsePAMs(State& st, bool* failed) { iceAssert(st.front() == TT::Exclamation && st.lookahead(1) == TT::LAngle); diff --git a/source/include/ast.h b/source/include/ast.h index 59f2619a..4c897cbe 100644 --- a/source/include/ast.h +++ b/source/include/ast.h @@ -391,13 +391,8 @@ namespace ast std::vector bases; - std::vector fields; + std::vector> fields; std::vector methods; - - std::vector staticFields; - std::vector staticMethods; - - std::vector nestedTypes; }; struct ClassDefn : TypeDefn @@ -451,7 +446,10 @@ namespace ast virtual TCResult typecheck(sst::TypecheckState* fs, fir::Type* infer, const TypeParamMap_t& gmaps) override; virtual TCResult generateDeclaration(sst::TypecheckState* fs, fir::Type* infer, const TypeParamMap_t& gmaps) override; + bool israw = false; util::hash_map> cases; + + std::vector> transparentFields; }; struct TypeExpr : Expr diff --git a/source/include/backends/llvm.h b/source/include/backends/llvm.h index 15aae065..750751d6 100644 --- a/source/include/backends/llvm.h +++ b/source/include/backends/llvm.h @@ -59,20 +59,26 @@ namespace backend struct LLVMJit { - typedef llvm::orc::IRCompileLayer::ModuleHandleT ModuleHandle_t; + // typedef llvm::orc::IRCompileLayer::ModuleHandleT ModuleHandle_t; + + using ModuleHandle_t = llvm::orc::VModuleKey; LLVMJit(llvm::TargetMachine* tm); llvm::TargetMachine* getTargetMachine(); void removeModule(ModuleHandle_t mod); - ModuleHandle_t addModule(std::shared_ptr mod); + ModuleHandle_t addModule(std::unique_ptr mod); llvm::JITSymbol findSymbol(const std::string& name); llvm::JITTargetAddress getSymbolAddress(const std::string& name); private: - llvm::orc::RTDyldObjectLinkingLayer objectLayer; + llvm::orc::ExecutionSession execSession; std::unique_ptr targetMachine; + std::shared_ptr symbolResolver; + + llvm::DataLayout dataLayout; + llvm::orc::RTDyldObjectLinkingLayer objectLayer; llvm::orc::IRCompileLayer compileLayer; }; @@ -99,7 +105,7 @@ namespace backend llvm::Function* entryFunction = 0; llvm::TargetMachine* targetMachine = 0; - std::shared_ptr linkedModule; + std::unique_ptr linkedModule; LLVMJit* jitInstance = 0; }; diff --git a/source/include/ir/instruction.h b/source/include/ir/instruction.h index 80e1eb45..2f0c4dc2 100644 --- a/source/include/ir/instruction.h +++ b/source/include/ir/instruction.h @@ -162,6 +162,8 @@ namespace fir Union_GetVariantID, Union_SetVariantID, + RawUnion_GEP, + Branch_UnCond, Branch_Cond, diff --git a/source/include/ir/irbuilder.h b/source/include/ir/irbuilder.h index 826aa96e..9aa29b91 100644 --- a/source/include/ir/irbuilder.h +++ b/source/include/ir/irbuilder.h @@ -94,9 +94,15 @@ namespace fir Value* StackAlloc(Type* type, const std::string& vname = ""); Value* ImmutStackAlloc(Type* type, Value* initialValue, const std::string& vname = ""); + // given an l or cl value of raw_union type, return an l or cl value of the correct type of the field + Value* GetRawUnionField(Value* lval, const std::string& field, const std::string& vname = ""); + + // same as the above, but give it a type instead -- this is hacky cos it's not checked. + // the backend will just do some pointer magic regardless, so it doesn't really matter. + Value* GetRawUnionFieldByType(Value* lval, Type* type, const std::string& vname = ""); // equivalent to GEP(ptr*, 0, memberIndex) - Value* GetStructMember(Value* ptr, std::string memberName); + Value* GetStructMember(Value* ptr, const std::string& memberName); Value* StructGEP(Value* structPtr, size_t memberIndex, const std::string& vname = ""); // equivalent to GEP(ptr*, index) diff --git a/source/include/ir/type.h b/source/include/ir/type.h index dc1aa0e2..a6185a95 100644 --- a/source/include/ir/type.h +++ b/source/include/ir/type.h @@ -29,6 +29,7 @@ namespace fir struct StringType; struct PointerType; struct FunctionType; + struct RawUnionType; struct PrimitiveType; struct ArraySliceType; struct DynamicArrayType; @@ -44,6 +45,7 @@ namespace fir Type* getBestFitTypeForConstant(ConstantNumberType* cnt); int getCastDistance(Type* from, Type* to); + bool isRefCountedType(Type* ty); enum class TypeKind { @@ -63,6 +65,7 @@ namespace fir String, Pointer, Function, + RawUnion, Primitive, ArraySlice, DynamicArray, @@ -108,6 +111,7 @@ namespace fir UnionVariantType* toUnionVariantType(); ArraySliceType* toArraySliceType(); PrimitiveType* toPrimitiveType(); + RawUnionType* toRawUnionType(); FunctionType* toFunctionType(); PointerType* toPointerType(); StructType* toStructType(); @@ -130,6 +134,7 @@ namespace fir bool isClassType(); bool isStructType(); bool isPackedStruct(); + bool isRawUnionType(); bool isUnionVariantType(); bool isRangeType(); @@ -143,6 +148,7 @@ namespace fir bool isIntegerType(); bool isFunctionType(); bool isSignedIntType(); + bool isUnsignedIntType(); bool isFloatingPointType(); bool isArraySliceType(); @@ -583,6 +589,40 @@ namespace fir }; + + struct RawUnionType : Type + { + friend struct Type; + + Identifier getTypeName(); + size_t getVariantCount(); + + bool hasVariant(const std::string& name); + Type* getVariant(const std::string& name); + util::hash_map getVariants(); + + void setBody(const util::hash_map& variants); + + virtual std::string str() override; + virtual std::string encodedStr() override; + virtual bool isTypeEqual(Type* other) override; + virtual fir::Type* substitutePlaceholders(const util::hash_map& subst) override; + + + virtual ~RawUnionType() override { } + protected: + + RawUnionType(const Identifier& id, const util::hash_map& variants); + + Identifier unionName; + util::hash_map variants; + + public: + static RawUnionType* create(const Identifier& id, const util::hash_map& variants); + static RawUnionType* createWithoutBody(const Identifier& id); + }; + + struct StructType : Type { friend struct Type; diff --git a/source/include/parser_internal.h b/source/include/parser_internal.h index 46878927..3a9c7732 100644 --- a/source/include/parser_internal.h +++ b/source/include/parser_internal.h @@ -224,6 +224,8 @@ namespace parser return this->loc(); } + std::vector anonymousTypeDefns; + std::string currentFilePath; util::hash_map binaryOps; @@ -241,6 +243,8 @@ namespace parser const lexer::TokenList& tokens; }; + std::string parseStringEscapes(const Location& loc, const std::string& str); + std::string parseOperatorTokens(State& st); pts::Type* parseType(State& st); @@ -266,10 +270,11 @@ namespace parser ast::EnumDefn* parseEnum(State& st); ast::ClassDefn* parseClass(State& st); - ast::UnionDefn* parseUnion(State& st); - ast::StructDefn* parseStruct(State& st); ast::StaticDecl* parseStaticDecl(State& st); + ast::StructDefn* parseStruct(State& st, bool nameless); + ast::UnionDefn* parseUnion(State& st, bool israw, bool nameless); + ast::Expr* parseDollarExpr(State& st); ast::InitFunctionDefn* parseInitFunction(State& st); diff --git a/source/include/sst.h b/source/include/sst.h index ef1ffce0..fc0998b9 100644 --- a/source/include/sst.h +++ b/source/include/sst.h @@ -417,6 +417,9 @@ namespace sst Expr* lhs = 0; std::string rhsIdent; bool isMethodRef = false; + + bool isTransparentField = false; + size_t indexOfTransparentField = 0; }; struct MethodDotOp : Expr @@ -679,6 +682,7 @@ namespace sst virtual CGResult _codegen(cgn::CodegenState* cs, fir::Type* inferred = 0) override { return CGResult(0); } TypeDefn* parentType = 0; + bool isTransparentField = false; }; struct ClassInitialiserDefn : FunctionDefn @@ -752,6 +756,18 @@ namespace sst }; + struct RawUnionDefn : TypeDefn + { + RawUnionDefn(const Location& l) : TypeDefn(l) { this->readableName = "raw union definition"; } + ~RawUnionDefn() { } + + virtual std::string getKind() override { return "raw union"; } + virtual CGResult _codegen(cgn::CodegenState* cs, fir::Type* inferred = 0) override; + + util::hash_map fields; + std::vector transparentFields; + }; + struct UnionVariantDefn; diff --git a/source/main.cpp b/source/main.cpp index e7341e53..542b9772 100644 --- a/source/main.cpp +++ b/source/main.cpp @@ -12,6 +12,7 @@ #include "mpool.h" #include "allocator.h" + struct timer { timer(double* t) : out(t) { start = std::chrono::high_resolution_clock::now(); } diff --git a/source/misc/identifier.cpp b/source/misc/identifier.cpp index 1c80daea..d2dfd8eb 100644 --- a/source/misc/identifier.cpp +++ b/source/misc/identifier.cpp @@ -7,11 +7,6 @@ #include "sst.h" -//* so how this works is, instead of having to do manual checks and error-posting, since everyone uses this TCResult system, -//* when we try to unwrap something we got from a typecheck and it's actually an error, we just post the error and quit. -//? in theory this should work like it used to, probably. -// TODO: investigate?? - sst::Stmt* TCResult::stmt() const { if(this->_kind == RK::Error) diff --git a/source/typecheck/arithmetic.cpp b/source/typecheck/arithmetic.cpp index 5438fe04..1fa60215 100644 --- a/source/typecheck/arithmetic.cpp +++ b/source/typecheck/arithmetic.cpp @@ -216,9 +216,6 @@ TCResult ast::BinaryOp::typecheck(sst::TypecheckState* fs, fir::Type* inferred) iceAssert(!Operator::isAssignment(this->op)); - // TODO: infer the types properly for literal numbers - // this has always been a thorn, dammit - auto l = this->left->typecheck(fs, inferred).expr(); sst::Expr* r = 0; diff --git a/source/typecheck/call.cpp b/source/typecheck/call.cpp index 85a3266a..9b665378 100644 --- a/source/typecheck/call.cpp +++ b/source/typecheck/call.cpp @@ -119,14 +119,14 @@ sst::Expr* ast::ExprCall::typecheckWithArguments(sst::TypecheckState* fs, const iceAssert(target); if(!target->type->isFunctionType()) - error(this->callee, "expression with non-function-type '%s' cannot be called"); + error(this->callee, "expression with non-function-type '%s' cannot be called", target->type); auto ft = target->type->toFunctionType(); auto [ dist, errs ] = sst::resolver::computeOverloadDistance(this->loc, util::map(ft->getArgumentTypes(), [](fir::Type* t) -> auto { return fir::LocatedType(t, Location()); }), util::map(arguments, [](const FnCallArgument& fca) -> fir::LocatedType { return fir::LocatedType(fca.value->type, fca.loc); - }), false); + }), target->type->toFunctionType()->isCStyleVarArg()); if(errs != nullptr || dist == -1) { diff --git a/source/typecheck/dotop.cpp b/source/typecheck/dotop.cpp index 70bfb91e..5664e4fd 100644 --- a/source/typecheck/dotop.cpp +++ b/source/typecheck/dotop.cpp @@ -62,6 +62,116 @@ static ErrorMsg* wrongDotOpError(ErrorMsg* e, sst::StructDefn* str, const Locati +struct search_result_t +{ + search_result_t() { } + search_result_t(fir::Type* t, size_t i, bool tr) : type(t), fieldIdx(i), isTransparent(tr) { } + + fir::Type* type = 0; + size_t fieldIdx = 0; + bool isTransparent = false; +}; + +static std::vector searchTransparentFields(sst::TypecheckState* fs, std::vector stack, + const std::vector& fields, const Location& loc, const std::string& name) +{ + // search for them by name first, instead of doing a super-depth-first-search. + for(auto df : fields) + { + if(df->id.name == name) + { + stack.push_back(search_result_t(df->type, 0, false)); + return stack; + } + } + + + size_t idx = 0; + for(auto df : fields) + { + if(df->isTransparentField) + { + auto ty = df->type; + assert(ty->isRawUnionType() || ty->isStructType()); + + auto defn = fs->typeDefnMap[ty]; + iceAssert(defn); + + std::vector flds; + if(auto str = dcast(sst::StructDefn, defn); str) + flds = str->fields; + + else if(auto unn = dcast(sst::RawUnionDefn, defn); unn) + flds = util::map(util::pairs(unn->fields), [](const auto& x) -> auto { return x.second; }) + unn->transparentFields; + + else + error(loc, "what kind of type is this? '%s'", ty); + + stack.push_back(search_result_t(ty, idx, true)); + auto ret = searchTransparentFields(fs, stack, flds, loc, name); + + if(!ret.empty()) return ret; + else stack.pop_back(); + } + + idx += 1; + } + + // if we've reached the end of the line, return nothing. + return { }; +} + + +static sst::FieldDotOp* resolveFieldNameDotOp(sst::TypecheckState* fs, sst::Expr* lhs, const std::vector& fields, + const Location& loc, const std::string& name) +{ + for(auto df : fields) + { + if(df->id.name == name) + { + auto ret = util::pool(loc, df->type); + ret->lhs = lhs; + ret->rhsIdent = name; + + return ret; + } + } + + // sad. search for the field, recursively, in transparent members. + auto ops = searchTransparentFields(fs, { }, fields, loc, name); + if(ops.empty()) + return nullptr; + + // ok, now we just need to make a link of fielddotops... + sst::Expr* cur = lhs; + for(const auto& x : ops) + { + auto op = util::pool(loc, x.type); + + op->lhs = cur; + op->isTransparentField = x.isTransparent; + op->indexOfTransparentField = x.fieldIdx; + + // don't set a name if we're transparent. + op->rhsIdent = (x.isTransparent ? "" : name); + + cur = op; + } + + auto ret = dcast(sst::FieldDotOp, cur); + assert(ret); + + return ret; +} + + + + + + + + + @@ -345,13 +455,6 @@ static sst::Expr* doExpressionDotOp(sst::TypecheckState* fs, ast::DotOperator* d // else: fallthrough } - // TODO: plug in extensions here. - - - if(!type->isStructType() && !type->isClassType()) - { - error(dotop->right, "unsupported right-side expression for dot operator on type '%s'", type); - } // ok. @@ -360,7 +463,6 @@ static sst::Expr* doExpressionDotOp(sst::TypecheckState* fs, ast::DotOperator* d if(auto str = dcast(sst::StructDefn, defn)) { - // right. if(auto fc = dcast(ast::FunctionCall, dotop->right)) { @@ -473,23 +575,13 @@ static sst::Expr* doExpressionDotOp(sst::TypecheckState* fs, ast::DotOperator* d else if(auto fld = dcast(ast::Ident, dotop->right)) { auto name = fld->name; - { auto copy = str; while(copy) { - for(auto f : copy->fields) - { - if(f->id.name == name) - { - auto ret = util::pool(dotop->loc, f->type); - ret->lhs = lhs; - ret->rhsIdent = name; - - return ret; - } - } + auto hmm = resolveFieldNameDotOp(fs, lhs, copy->fields, dotop->loc, name); + if(hmm) return hmm; // ok, we didn't find it. if(auto cls = dcast(sst::ClassDefn, copy); cls) @@ -574,10 +666,33 @@ static sst::Expr* doExpressionDotOp(sst::TypecheckState* fs, ast::DotOperator* d error(dotop->right, "unsupported right-side expression for dot-operator on type '%s'", str->id.name); } } + else if(auto rnn = dcast(sst::RawUnionDefn, defn)) + { + if(auto fld = dcast(ast::Ident, dotop->right)) + { + auto flds = util::map(util::pairs(rnn->fields), [](const auto& x) -> auto { return x.second; }) + rnn->transparentFields; + auto hmm = resolveFieldNameDotOp(fs, lhs, flds, dotop->loc, fld->name); + if(hmm) + { + return hmm; + } + else + { + // ok we didn't return. this is a raw union so extensions R NOT ALLOWED!! (except methods maybe) + error(fld, "union '%s' has no member named '%s'", rnn->id.name, fld->name); + } + } + else + { + error(dotop->right, "unsupported right-side expression for dot-operator on type '%s'", defn->id.name); + } + } else { error(lhs, "unsupported left-side expression (with type '%s') for dot-operator", lhs->type); } + + // TODO: plug in extensions here!! } @@ -736,7 +851,7 @@ static sst::Expr* doStaticDotOp(sst::TypecheckState* fs, ast::DotOperator* dot, // we should be able to pass in the infer value such that it works properly // eg. let x: Foo = Foo.none SimpleError::make(dot->right->loc, - "unable to resolve type parameters for polymorphic union '%s' using variant '%s' (which has no values)", + "could not infer type parameters for polymorphic union '%s' using variant '%s' ", unn->id.name, name)->append(SimpleError::make(MsgType::Note, unn->variants[name]->loc, "variant was defined here:"))->postAndQuit(); } else if(wasfncall && unn->type->toUnionType()->getVariants()[name]->getInteriorType()->isVoidType()) diff --git a/source/typecheck/literals.cpp b/source/typecheck/literals.cpp index 802d9559..f0448597 100644 --- a/source/typecheck/literals.cpp +++ b/source/typecheck/literals.cpp @@ -15,28 +15,46 @@ TCResult ast::LitNumber::typecheck(sst::TypecheckState* fs, fir::Type* infer) fs->pushLoc(this); defer(fs->popLoc()); - // set base = 0 to autodetect. - auto number = mpfr::mpreal(this->num, mpfr_get_default_prec(), /* base: */ 0); + // i don't think mpfr auto-detects base, LMAO + int base = 10; + if(this->num.find("0x") == 0 || this->num.find("0X") == 0) + base = 16; + + auto number = mpfr::mpreal(this->num, mpfr_get_default_prec(), base); bool sgn = mpfr::signbit(number); bool flt = !mpfr::isint(number); - //* this is the stupidest thing. + size_t bits = 0; + if(flt) + { + // fuck it lah. + bits = sizeof(double) * CHAR_BIT; + } + else + { + auto m_ptr = number.mpfr_ptr(); + auto m_rnd = MPFR_RNDN; + if(mpfr_fits_sshort_p(m_ptr, m_rnd)) + bits = sizeof(short) * CHAR_BIT; + + else if(mpfr_fits_sint_p(m_ptr, m_rnd)) + bits = sizeof(int) * CHAR_BIT; - // mpfr's 'get_min_prec' returns the number of bits required to store the significand (eg. for 1.413x10^-2, it is 1.413). - // so you'd think that, for example, given '1024', it would return '10', given that 2^10 == 1024. - // no, it returns '1', because you only need one bit -- the 10th bit -- to get the value 1024. - // which is fucking stupid. + else if(mpfr_fits_slong_p(m_ptr, m_rnd)) + bits = sizeof(long) * CHAR_BIT; - // so what we do here is we change the last digit of the number to be '9'. this effectively forces the first - // bit of the entire number to be set (we don't use '1' because we don't want to make the number smaller - // -- eg. 1024 would become 1021, which would only need 9 bits to store -- versus 1029) + else if(mpfr_fits_intmax_p(m_ptr, m_rnd)) + bits = sizeof(intmax_t) * CHAR_BIT; - // in this way we force mpfr to return the real number of bits required to store the entire thing properly. + else if(!sgn && mpfr_fits_uintmax_p(m_ptr, m_rnd)) + bits = sizeof(uintmax_t) * CHAR_BIT; - size_t bits = mpfr_min_prec(mpfr::mpreal(this->num.substr(0, this->num.size() - 1) + "9").mpfr_ptr()); + else // lmao + bits = SIZE_MAX; + } auto ret = util::pool(this->loc, (infer && infer->isPrimitiveType()) ? infer : fir::ConstantNumberType::get(sgn, flt, bits)); - ret->num = mpfr::mpreal(this->num); + ret->num = number; return TCResult(ret); } diff --git a/source/typecheck/polymorph/driver.cpp b/source/typecheck/polymorph/driver.cpp index 60667b7e..7adeaa53 100644 --- a/source/typecheck/polymorph/driver.cpp +++ b/source/typecheck/polymorph/driver.cpp @@ -125,8 +125,10 @@ namespace poly size_t i = 0; for(auto f : str->fields) { - fieldset.insert(f->name); - fieldNames[f->name] = i++; + auto nm = std::get<0>(f); + + fieldset.insert(nm); + fieldNames[nm] = i++; } } @@ -143,7 +145,8 @@ namespace poly { auto idx = fieldNames[s.first]; - target[idx] = fir::LocatedType(convertPtsType(fs, str->generics, str->fields[idx]->type, session), str->fields[idx]->loc); + target[idx] = fir::LocatedType(convertPtsType(fs, str->generics, std::get<2>(str->fields[idx]), session), + std::get<1>(str->fields[idx])); given[idx] = fir::LocatedType(input[s.second].value->type, input[s.second].loc); } diff --git a/source/typecheck/polymorph/instantiator.cpp b/source/typecheck/polymorph/instantiator.cpp index 79122de8..9becb2df 100644 --- a/source/typecheck/polymorph/instantiator.cpp +++ b/source/typecheck/polymorph/instantiator.cpp @@ -161,10 +161,6 @@ namespace poly fs->pushGenericContext(); defer(fs->popGenericContext()); - //* allowFail is only allowed to forgive a failure when we're checking for type conformance to protocols or something like that. - //* we generally don't look into type or function bodies when checking stuff, and it'd be hard to check for something like this (eg. - //* T passes all the checks, but causes some kind of type-checking failure when substituted in) - // TODO: ??? do we want this to be the behaviour ??? for(auto map : mappings) { int ptrs = 0; diff --git a/source/typecheck/structs.cpp b/source/typecheck/structs.cpp index bf662113..df462b8f 100644 --- a/source/typecheck/structs.cpp +++ b/source/typecheck/structs.cpp @@ -37,16 +37,78 @@ static void _checkFieldRecursion(sst::TypecheckState* fs, fir::Type* strty, fir: for(auto f : field->toStructType()->getElements()) _checkFieldRecursion(fs, field, f, floc, seeing); } + else if(field->isRawUnionType()) + { + for(auto f : field->toRawUnionType()->getVariants()) + _checkFieldRecursion(fs, field, f.second, floc, seeing); + } // ok, we should be fine...? } +// used in typecheck/unions.cpp and typecheck/classes.cpp void checkFieldRecursion(sst::TypecheckState* fs, fir::Type* strty, fir::Type* field, const Location& floc) { std::set seeing; _checkFieldRecursion(fs, strty, field, floc, seeing); } +static void _checkTransparentFieldRedefinition(sst::TypecheckState* fs, sst::TypeDefn* defn, const std::vector& fields, + util::hash_map& seen) +{ + for(auto fld : fields) + { + if(fld->isTransparentField) + { + auto ty = fld->type; + if(!ty->isRawUnionType() && !ty->isStructType()) + { + // you can't have a transparentl field if it's not an aggregate type, lmao + error(fld, "transparent fields must have either a struct or raw-union type."); + } + + auto defn = fs->typeDefnMap[ty]; + iceAssert(defn); + + std::vector flds; + if(auto str = dcast(sst::StructDefn, defn); str) + flds = str->fields; + + else if(auto unn = dcast(sst::RawUnionDefn, defn); unn) + flds = util::map(util::pairs(unn->fields), [](const auto& x) -> auto { return x.second; }) + unn->transparentFields; + + else + error(fs->loc(), "what kind of type is this? '%s'", ty); + + _checkTransparentFieldRedefinition(fs, defn, flds, seen); + } + else + { + if(auto it = seen.find(fld->id.name); it != seen.end()) + { + SimpleError::make(fld->loc, "redefinition of transparently accessible field '%s'", fld->id.name) + ->append(SimpleError::make(MsgType::Note, it->second, "previous definition was here:")) + ->postAndQuit(); + } + else + { + seen[fld->id.name] = fld->loc; + } + } + } +} + +void checkTransparentFieldRedefinition(sst::TypecheckState* fs, sst::TypeDefn* defn, const std::vector& fields) +{ + util::hash_map seen; + _checkTransparentFieldRedefinition(fs, defn, fields, seen); +} + + + + + + @@ -92,21 +154,6 @@ TCResult ast::StructDefn::generateDeclaration(sst::TypecheckState* fs, fir::Type fs->typeDefnMap[str] = defn; } - - auto oldscope = fs->getCurrentScope(); - fs->teleportToScope(defn->id.scope); - fs->pushTree(defn->id.name); - { - for(auto t : this->nestedTypes) - { - t->realScope = this->realScope + defn->id.name; - t->generateDeclaration(fs, 0, { }); - } - } - - fs->popTree(); - fs->teleportToScope(oldscope); - this->genericVersions.push_back({ defn, fs->getGenericContextStack() }); return TCResult(defn); } @@ -134,34 +181,24 @@ TCResult ast::StructDefn::typecheck(sst::TypecheckState* fs, fir::Type* infer, c fs->teleportToScope(defn->id.scope); fs->pushTree(defn->id.name); - std::vector> tys; - for(auto t : this->nestedTypes) - { - auto tcr = t->typecheck(fs); - if(tcr.isParametric()) continue; - if(tcr.isError()) error(t, "failed to generate declaration for nested type '%s' in struct '%s'", t->name, this->name); - - auto st = dcast(sst::TypeDefn, tcr.defn()); - iceAssert(st); - - defn->nestedTypes.push_back(st); - } - - - //* this is a slight misnomer, since we only 'enter' the struct body when generating methods. - //* for all intents and purposes, static methods (aka functions) don't really need any special - //* treatment anyway, apart from living in a special namespace -- so this should really be fine. fs->enterStructBody(defn); { for(auto f : this->fields) { - auto v = dcast(sst::StructFieldDefn, f->typecheck(fs).defn()); + auto vdef = util::pool(std::get<1>(f)); + vdef->immut = false; + vdef->name = std::get<0>(f); + vdef->initialiser = nullptr; + vdef->type = std::get<2>(f); + + auto v = dcast(sst::StructFieldDefn, vdef->typecheck(fs).defn()); iceAssert(v); - if(v->init) error(v, "struct fields cannot have inline initialisers"); + if(v->id.name == "_") + v->isTransparentField = true; defn->fields.push_back(v); tys.push_back({ v->id.name, v->type }); @@ -185,38 +222,12 @@ TCResult ast::StructDefn::typecheck(sst::TypecheckState* fs, fir::Type* infer, c for(auto m : this->methods) m->typecheck(fs, str, { }); } - fs->leaveStructBody(); + checkTransparentFieldRedefinition(fs, defn, defn->fields); - // do static things. - { - for(auto f : this->staticFields) - { - auto v = dcast(sst::VarDefn, f->typecheck(fs).defn()); - iceAssert(v); - - defn->staticFields.push_back(v); - } - - // same deal so we can call them out of order. - for(auto m : this->staticMethods) - { - // infer is 0 because this is a static thing - auto res = m->generateDeclaration(fs, str, { }); - if(res.isParametric()) - error(m, "static methods of a type cannot be polymorphic (for now???)"); - auto decl = dcast(sst::FunctionDefn, res.defn()); - iceAssert(decl); - - defn->staticMethods.push_back(decl); - } + fs->leaveStructBody(); - for(auto m : this->staticMethods) - { - m->typecheck(fs, 0, { }); - } - } str->setBody(tys); diff --git a/source/typecheck/typecheckstate.cpp b/source/typecheck/typecheckstate.cpp index 875215b7..889e6b09 100644 --- a/source/typecheck/typecheckstate.cpp +++ b/source/typecheck/typecheckstate.cpp @@ -380,6 +380,8 @@ namespace sst + + std::vector TypecheckState::getDefinitionsWithName(const std::string& name, StateTree* tree) { if(tree == 0) diff --git a/source/typecheck/unions.cpp b/source/typecheck/unions.cpp index 97642973..f0485f5d 100644 --- a/source/typecheck/unions.cpp +++ b/source/typecheck/unions.cpp @@ -9,6 +9,13 @@ #include "ir/type.h" #include "mpool.h" +// defined in typecheck/structs.cpp +void checkFieldRecursion(sst::TypecheckState* fs, fir::Type* strty, fir::Type* field, const Location& floc); +void checkTransparentFieldRedefinition(sst::TypecheckState* fs, sst::TypeDefn* defn, const std::vector& fields); + + + + TCResult ast::UnionDefn::generateDeclaration(sst::TypecheckState* fs, fir::Type* infer, const TypeParamMap_t& gmaps) { fs->pushLoc(this); @@ -19,13 +26,18 @@ TCResult ast::UnionDefn::generateDeclaration(sst::TypecheckState* fs, fir::Type* else if(ret) return TCResult(ret); auto defnname = util::typeParamMapToString(this->name, gmaps); - auto defn = util::pool(this->loc); + + sst::TypeDefn* defn = 0; + if(this->israw) defn = util::pool(this->loc); + else defn = util::pool(this->loc); + defn->id = Identifier(defnname, IdKind::Type); defn->id.scope = this->realScope; defn->visibility = this->visibility; defn->original = this; - defn->type = fir::UnionType::createWithoutBody(defn->id); + if(this->israw) defn->type = fir::RawUnionType::createWithoutBody(defn->id); + else defn->type = fir::UnionType::createWithoutBody(defn->id); fs->checkForShadowingOrConflictingDefinition(defn, [](sst::TypecheckState* fs, sst::Defn* other) -> bool { return true; }); @@ -47,53 +59,144 @@ TCResult ast::UnionDefn::typecheck(sst::TypecheckState* fs, fir::Type* infer, co if(tcr.isParametric()) return tcr; else if(tcr.isError()) error(this, "failed to generate declaration for union '%s'", this->name); - auto defn = dcast(sst::UnionDefn, tcr.defn()); - iceAssert(defn); - if(this->finishedTypechecking.find(defn) != this->finishedTypechecking.end()) - return TCResult(defn); + if(this->finishedTypechecking.find(tcr.defn()) != this->finishedTypechecking.end()) + return TCResult(tcr.defn()); auto oldscope = fs->getCurrentScope(); - fs->teleportToScope(defn->id.scope); - fs->pushTree(defn->id.name); + fs->teleportToScope(tcr.defn()->id.scope); + fs->pushTree(tcr.defn()->id.name); - util::hash_map> vars; - std::vector> vdefs; - for(auto variant : this->cases) + sst::TypeDefn* ret = 0; + if(this->israw) { - vars[variant.first] = { std::get<0>(variant.second), (std::get<2>(variant.second) - ? fs->convertParserTypeToFIR(std::get<2>(variant.second)) : fir::Type::getVoid()) + auto defn = dcast(sst::RawUnionDefn, tcr.defn()); + iceAssert(defn); + + //* in many ways raw unions resemble structs rather than tagged unions + //* and since we are using sst::StructFieldDefn for the variants, we will need + //* to enter the struct body. + + fs->enterStructBody(defn); + auto unionTy = defn->type->toRawUnionType(); + + util::hash_map types; + util::hash_map fields; + + + + auto make_field = [fs, unionTy](const std::string& name, const Location& loc, pts::Type* ty) -> sst::StructFieldDefn* { + + auto vdef = util::pool(loc); + vdef->immut = false; + vdef->name = name; + vdef->initialiser = nullptr; + vdef->type = ty; + + auto sfd = dcast(sst::StructFieldDefn, vdef->typecheck(fs).defn()); + iceAssert(sfd); + + if(fir::isRefCountedType(sfd->type)) + error(sfd, "reference-counted type '%s' cannot be a member of a raw union", sfd->type); + + checkFieldRecursion(fs, unionTy, sfd->type, sfd->loc); + return sfd; }; + std::vector tfields; + std::vector allFields; + + for(auto variant : this->cases) + { + auto sfd = make_field(variant.first, std::get<1>(variant.second), std::get<2>(variant.second)); + iceAssert(sfd); + + fields[sfd->id.name] = sfd; + types[sfd->id.name] = sfd->type; + + allFields.push_back(sfd); + } - auto vdef = util::pool(std::get<1>(variant.second)); - vdef->parentUnion = defn; - vdef->variantName = variant.first; - vdef->id = Identifier(defn->id.name + "::" + variant.first, IdKind::Name); - vdef->id.scope = fs->getCurrentScope(); - vdefs.push_back({ vdef, std::get<0>(variant.second) }); + // do the transparent fields + { + size_t tfn = 0; + for(auto [ loc, pty ] : this->transparentFields) + { + auto sfd = make_field(strprintf("__transparent_field_%zu", tfn++), loc, pty); + iceAssert(sfd); - fs->stree->addDefinition(variant.first, vdef); + sfd->isTransparentField = true; - defn->variants[variant.first] = vdef; + // still add to the types, cos we need to compute sizes and stuff + types[sfd->id.name] = sfd->type; + tfields.push_back(sfd); + allFields.push_back(sfd); + } + } + + checkTransparentFieldRedefinition(fs, defn, allFields); + + + defn->fields = fields; + defn->transparentFields = tfields; + + + + + unionTy->setBody(types); + + fs->leaveStructBody(); + ret = defn; } + else + { + auto defn = dcast(sst::UnionDefn, tcr.defn()); + iceAssert(defn); + + util::hash_map> vars; + std::vector> vdefs; - auto unionTy = defn->type->toUnionType(); - unionTy->setBody(vars); + iceAssert(this->transparentFields.empty()); - // in a bit of stupidity, we need to set the type of each definition properly. - for(const auto& [ uvd, id ] : vdefs) - uvd->type = unionTy->getVariant(id); + for(auto variant : this->cases) + { + vars[variant.first] = { std::get<0>(variant.second), (std::get<2>(variant.second) + ? fs->convertParserTypeToFIR(std::get<2>(variant.second)) : fir::Type::getVoid()) + }; - this->finishedTypechecking.insert(defn); + auto vdef = util::pool(std::get<1>(variant.second)); + vdef->parentUnion = defn; + vdef->variantName = variant.first; + vdef->id = Identifier(defn->id.name + "::" + variant.first, IdKind::Name); + vdef->id.scope = fs->getCurrentScope(); + + vdefs.push_back({ vdef, std::get<0>(variant.second) }); + + fs->stree->addDefinition(variant.first, vdef); + + defn->variants[variant.first] = vdef; + } + + auto unionTy = defn->type->toUnionType(); + unionTy->setBody(vars); + + // in a bit of stupidity, we need to set the type of each definition properly. + for(const auto& [ uvd, id ] : vdefs) + uvd->type = unionTy->getVariant(id); + + ret = defn; + } + + iceAssert(ret); + this->finishedTypechecking.insert(ret); fs->popTree(); fs->teleportToScope(oldscope); - return TCResult(defn); + return TCResult(ret); } diff --git a/source/typecheck/variable.cpp b/source/typecheck/variable.cpp index bbb27fff..ab0b1a24 100644 --- a/source/typecheck/variable.cpp +++ b/source/typecheck/variable.cpp @@ -357,7 +357,8 @@ TCResult ast::VarDefn::typecheck(sst::TypecheckState* fs, fir::Type* infer) } } - fs->stree->addDefinition(this->name, defn); + if(this->name != "_") + fs->stree->addDefinition(this->name, defn); // store the place where we were defined. if(fs->isInFunctionBody())