diff --git a/build/ultratiny.flx b/build/ultratiny.flx index 67de758b..0e733dfa 100644 --- a/build/ultratiny.flx +++ b/build/ultratiny.flx @@ -44,21 +44,53 @@ class X var data: int } -class Y: X +class Y : X { init(c: int) : super(c: c + 30) { } } -class Z: Y +class Z : Y { - init(c: int) : super(c: c - 60) { } + init(c: int) : super(c: c - 90) { } } +// issue #2: class inheritance is source-order-dependent (ie. the definition of the base class must +// appear before the definition of any children) + +// classes are a Bad Idea (tm) ): + +class P +{ + var z: int + + init(x: int = 3, y: int = 7) + { + this.z = x * y + } + + virtual fn foo(x: &Y) => printf("P::foo(%d)\n", x.data) +} + +class Q : P +{ + init() : super(x: 5, y: 8) { } + + // override fn foo(x: int) => printf("Q::foo(%d)\n", x) +} + +class R : Q +{ + init() : super() { } + + override fn foo(x: &X) => printf("R::foo(%d)\n", x.data) +} + + class A { init() { } - virtual fn lol(a: Y) -> &Y + virtual fn foo(a: &Y) -> &Y { std::io::println("A::foo()") return alloc Y(c: 471) @@ -69,36 +101,14 @@ class B : A { init() : super() { } - override fn foo(a: Y) -> &Y + override fn foo(a: &X) -> &Z { std::io::println("B::foo()") return alloc Z(c: 748) - // return 3 } } -// import std::io as _ - -// @entry fn main() -// { -// println("hello, world!") -// } -// issue #0: we basically don't really even check for overriding methods properly. -// issue #1: co/contra-variance of return and parameter types for virtual methods -// issue #2: crashes when functions have default arguments?? - -// classes are a Bad Idea (tm) ): - -class Tmp -{ - var z: int - - init(x: int = 3, y: int = 7) - { - this.z = x * y - } -} @entry fn main() @@ -113,10 +123,8 @@ class Tmp // let q = Tmp(k: 3, m: 7).meth() - // let q = B().foo(Y(c: 1)).data - - let q = foo(x: 7) - std::io::println("thing: %", q) + let q = B().foo(alloc Y(c: 1)).data + printf("q = %d\n", q) } diff --git a/source/codegen/autocasting.cpp b/source/codegen/autocasting.cpp index 56b40f42..85dedef6 100644 --- a/source/codegen/autocasting.cpp +++ b/source/codegen/autocasting.cpp @@ -176,7 +176,7 @@ namespace cgn result = ret; } else if(fromType->isPointerType() && target->isPointerType() && fromType->getPointerElementType()->isClassType() - && fromType->getPointerElementType()->toClassType()->isInParentHierarchy(target->getPointerElementType())) + && fromType->getPointerElementType()->toClassType()->hasParent(target->getPointerElementType())) { auto ret = this->irb.PointerTypeCast(from, target); result = ret; diff --git a/source/codegen/classes.cpp b/source/codegen/classes.cpp index a2982523..f6e7737b 100644 --- a/source/codegen/classes.cpp +++ b/source/codegen/classes.cpp @@ -87,7 +87,8 @@ CGResult sst::ClassDefn::_codegen(cgn::CodegenState* cs, fir::Type* infer) // set our vtable if(clsty->getVirtualMethodCount() > 0) { - auto vtable = cs->irb.PointerTypeCast(cs->irb.AddressOf(cs->module->getOrCreateVirtualTableForClass(clsty), false), fir::Type::getInt8Ptr()); + auto vtable = cs->irb.PointerTypeCast(cs->irb.AddressOf(cs->module->getOrCreateVirtualTableForClass(clsty), false), + fir::Type::getInt8Ptr()); cs->irb.SetVtable(self, vtable); } diff --git a/source/fir/Types/ClassType.cpp b/source/fir/Types/ClassType.cpp index ad933f19..5c6ada21 100644 --- a/source/fir/Types/ClassType.cpp +++ b/source/fir/Types/ClassType.cpp @@ -240,7 +240,7 @@ namespace fir } - bool ClassType::isInParentHierarchy(Type* base) + bool ClassType::hasParent(Type* base) { auto target = dcast(ClassType, base); if(!target) return false; @@ -294,24 +294,74 @@ namespace fir this->reverseVirtualMethodMap = this->baseClass->reverseVirtualMethodMap; } - void ClassType::addVirtualMethod(Function* method) + + // expects the self param to be removed already!!! + // note: this one doesn't check if the return types are compatible; we expect typechecking to have already + // verified that, and we don't store the return type in the class virtual method map anyway. + static bool _areTypeListsVirtuallyCompatible(const std::vector& base, const std::vector& fn) { - //* what this does is compare the arguments without the first parameter, - //* since that's going to be the self parameter, and that's going to be different - auto withoutself = [](std::vector p) -> std::vector { - p.erase(p.begin()); + // parameters must be contravariant, ie. fn must take more general types than base + // return type must be covariant, ie. fn must return a more specific type than base. - return p; - }; + // duh + if(base.size() != fn.size()) + return false; - auto matching = [&withoutself](const std::vector& a, FunctionType* ft) -> bool { - auto bp = withoutself(ft->getArgumentTypes()); + // drop the first argument. + for(auto [ base, derv ] : util::zip(base, fn)) + { + if(base == derv) + continue; - //* note: we don't call withoutself on 'a' because we expect that to already have been done - //* before it was added. - return Type::areTypeListsEqual(a, bp); - }; + if(!derv->isPointerType() || !derv->getPointerElementType()->isClassType() + || !base->isPointerType() || !base->getPointerElementType()->isClassType()) + { + return false; + } + + auto bc = base->getPointerElementType()->toClassType(); + auto dc = derv->getPointerElementType()->toClassType(); + + if(!bc->hasParent(dc)) + { + debuglogln("%s is not a parent of %s", dc->str(), bc->str()); + return false; + } + } + + return true; + } + bool ClassType::areMethodsVirtuallyCompatible(FunctionType* base, FunctionType* fn) + { + bool ret = _areTypeListsVirtuallyCompatible(util::drop(base->getArgumentTypes(), 1), util::drop(fn->getArgumentTypes(), 1)); + + if(!ret) + return false; + + auto baseRet = base->getReturnType(); + auto fnRet = fn->getReturnType(); + + // ok now check the return type. + if(baseRet == fnRet) + return true; + + if(baseRet->isPointerType() && baseRet->getPointerElementType()->isClassType() + && fnRet->isPointerType() && fnRet->getPointerElementType()->isClassType()) + { + auto br = baseRet->getPointerElementType()->toClassType(); + auto dr = fnRet->getPointerElementType()->toClassType(); + + return dr->hasParent(br); + } + else + { + return false; + } + } + + void ClassType::addVirtualMethod(Function* method) + { //* note: the 'reverse' virtual method map is to allow us, at translation time, to easily create the vtable without //* unnecessary searching. When we set a base class, we copy its 'reverse' map; thus, if we don't override anything, //* our vtable will just refer to the methods in the base class. @@ -319,16 +369,17 @@ namespace fir //* but if we do override something, we just set the method in our 'reverse' map, which is what we'll use to build //* the vtable. simple? - auto list = method->getType()->toFunctionType()->getArgumentTypes(); + auto list = util::drop(method->getType()->toFunctionType()->getArgumentTypes(), 1); // check every member of the current mapping -- not the fastest method i admit. bool found = false; for(auto vm : this->virtualMethodMap) { - if(vm.first.first == method->getName().name && matching(vm.first.second, method->getType()->toFunctionType())) + if(vm.first.first == method->getName().name + && _areTypeListsVirtuallyCompatible(vm.first.second, list)) { found = true; - this->virtualMethodMap[{ method->getName().name, withoutself(list) }] = vm.second; + this->virtualMethodMap[{ method->getName().name, list }] = vm.second; this->reverseVirtualMethodMap[vm.second] = method; break; } @@ -337,7 +388,7 @@ namespace fir if(!found) { // just make a new one. - this->virtualMethodMap[{ method->getName().name, withoutself(list) }] = this->virtualMethodCount; + this->virtualMethodMap[{ method->getName().name, list }] = this->virtualMethodCount; this->reverseVirtualMethodMap[this->virtualMethodCount] = method; this->virtualMethodCount++; } diff --git a/source/fir/Types/Type.cpp b/source/fir/Types/Type.cpp index a63170b9..d4b69f45 100644 --- a/source/fir/Types/Type.cpp +++ b/source/fir/Types/Type.cpp @@ -127,7 +127,7 @@ namespace fir } //* note: we don't need to check that 'to' is a class type, because if it's not then the parent check will fail anyway. else if(from->isPointerType() && to->isPointerType() && from->getPointerElementType()->isClassType() - && from->getPointerElementType()->toClassType()->isInParentHierarchy(to->getPointerElementType())) + && from->getPointerElementType()->toClassType()->hasParent(to->getPointerElementType())) { // cast from a derived class pointer to a base class pointer return 2; diff --git a/source/frontend/errors.cpp b/source/frontend/errors.cpp index 050c7a67..0e367532 100644 --- a/source/frontend/errors.cpp +++ b/source/frontend/errors.cpp @@ -155,7 +155,7 @@ static std::string getSingleContext(const Location& loc, const std::string& unde -std::string __error_gen_internal(const Location& loc, const std::string& msg, const char* type, bool context) +std::string __error_gen_internal(const Location& loc, const std::string& msg, const char* type, bool context, bool multipart) { std::string ret; @@ -194,6 +194,7 @@ std::string __error_gen_internal(const Location& loc, const std::string& msg, co ret += getSingleContext(loc, underlineColour) + "\n"; } + if(!multipart) ret += "\n"; return ret; } @@ -223,10 +224,10 @@ static size_t strprinterrf(const char* fmt, Ts... ts) return (size_t) fprintf(stderr, "%s", strprintf(fmt, ts...).c_str()); } -template -static void outputWithoutContext(const char* type, const Location& loc, const char* fmt, Ts... ts) +// template +static void outputWithoutContext(const char* type, const Location& loc, const char* s, bool multi) { - strprinterrf("%s", __error_gen_internal(loc, strprintf(fmt, ts...), type, false)); + strprinterrf("%s", __error_gen_internal(loc, s, type, false, multi)); } @@ -237,7 +238,8 @@ static void outputWithoutContext(const char* type, const Location& loc, const ch void BareError::post() { - if(!this->msg.empty()) outputWithoutContext(typestr(this->type).c_str(), Location(), this->msg.c_str()); + if(!this->msg.empty()) + outputWithoutContext(typestr(this->type).c_str(), Location(), this->msg.c_str(), !this->subs.empty()); for(auto other : this->subs) other->post(); @@ -248,9 +250,9 @@ void SimpleError::post() { if(!this->msg.empty()) { - outputWithoutContext(typestr(this->type).c_str(), this->loc, this->msg.c_str()); + outputWithoutContext(typestr(this->type).c_str(), this->loc, this->msg.c_str(), !this->subs.empty()); strprinterrf("%s%s%s", this->wordsBeforeContext, this->wordsBeforeContext.size() > 0 ? "\n" : "", - this->printContext ? getSingleContext(this->loc, this->type == MsgType::Note ? COLOUR_BLUE_BOLD : COLOUR_RED_BOLD) + "\n\n" : ""); + this->printContext ? getSingleContext(this->loc, this->type == MsgType::Note ? COLOUR_BLUE_BOLD : COLOUR_RED_BOLD) + "\n" : ""); } for(auto other : this->subs) @@ -260,7 +262,7 @@ void SimpleError::post() void ExampleMsg::post() { - outputWithoutContext(typestr(this->type).c_str(), Location(), "for example:"); + outputWithoutContext(typestr(this->type).c_str(), Location(), "for example:", !this->subs.empty()); strprinterrf("%s\n\n", getSingleContext(Location(), COLOUR_BLUE_BOLD, this->example)); for(auto other : this->subs) @@ -470,7 +472,7 @@ void OverloadError::post() [[noreturn]] void doTheExit(bool trace) { - fprintf(stderr, "there were errors, compilation cannot continue\n"); + fprintf(stderr, "\nthere were errors, compilation cannot continue\n"); if(frontend::getAbortOnError()) abort(); else exit(-1); diff --git a/source/frontend/parser/expr.cpp b/source/frontend/parser/expr.cpp index 042afc6b..1a1c739c 100644 --- a/source/frontend/parser/expr.cpp +++ b/source/frontend/parser/expr.cpp @@ -1019,11 +1019,17 @@ namespace parser if(st.front() == TT::LParen) { + auto leftloc = st.loc(); + st.pop(); ret->args = parseCallArgumentList(st); if(ret->args.empty()) - info(st.loc(), "empty argument list in alloc expression () can be omitted"); + { + // parseCallArgumentList consumes the closing ) + auto tmp = Location::unionOf(leftloc, st.ploc()); + info(tmp, "empty argument list in alloc expression () can be omitted"); + } } diff --git a/source/include/errors.h b/source/include/errors.h index 9f15ae04..00659151 100644 --- a/source/include/errors.h +++ b/source/include/errors.h @@ -86,12 +86,12 @@ namespace frontend } -std::string __error_gen_internal(const Location& loc, const std::string& msg, const char* type, bool context); +std::string __error_gen_internal(const Location& loc, const std::string& msg, const char* type, bool context, bool multiPart); template std::string __error_gen(const Location& loc, const char* msg, const char* type, bool, Ts&&... ts) { - return __error_gen_internal(loc, tinyformat::format(msg, ts...), type, true); + return __error_gen_internal(loc, tinyformat::format(msg, ts...), type, true, false); } diff --git a/source/include/ir/type.h b/source/include/ir/type.h index 804a346f..7dca8229 100644 --- a/source/include/ir/type.h +++ b/source/include/ir/type.h @@ -762,7 +762,7 @@ namespace fir Function* getCopyConstructor(); Function* getMoveConstructor(); - bool isInParentHierarchy(Type* base); + bool hasParent(Type* base); void addVirtualMethod(Function* method); size_t getVirtualMethodIndex(const std::string& name, FunctionType* ft); @@ -818,6 +818,9 @@ namespace fir static ClassType* createWithoutBody(const Identifier& name); static ClassType* create(const Identifier& name, const std::vector>& members, const std::vector& methods, const std::vector& inits); + + // returns true if 'fn' is a valid virtual override of 'base'. deals with co/contra-variance + static bool areMethodsVirtuallyCompatible(FunctionType* base, FunctionType* fn); }; diff --git a/source/include/utils.h b/source/include/utils.h index d265a6a8..c811b0cc 100644 --- a/source/include/utils.h +++ b/source/include/utils.h @@ -180,6 +180,30 @@ namespace util return std::vector(v.begin() + std::min(num, v.size()), v.end()); } + template + std::vector> cartesian(const std::vector& a, const std::vector& b) + { + std::vector> ret; + + for(size_t i = 0; i < a.size(); i++) + for(size_t k = 0; k < b.size(); k++) + ret.push_back({ a[i], b[k] }); + + return ret; + } + + template + std::vector> zip(const std::vector& a, const std::vector& b) + { + assert(a.size() == b.size()); + + std::vector> ret; + for(size_t i = 0; i < a.size(); i++) + ret.push_back({ a[i], b[i] }); + + return ret; + } + inline std::string join(const std::vector& list, const std::string& sep) { if(list.empty()) return ""; diff --git a/source/typecheck/classes.cpp b/source/typecheck/classes.cpp index 48752a45..34b55a60 100644 --- a/source/typecheck/classes.cpp +++ b/source/typecheck/classes.cpp @@ -184,81 +184,113 @@ TCResult ast::ClassDefn::typecheck(sst::TypecheckState* fs, fir::Type* infer, co currently i think we error, and we probably don't check for the return type at all? */ - for(auto m : this->methods) { - if(m->name == "init") - error(m, "cannot have methods named 'init' in a class; to create an initialiser, omit the 'fn' keyword."); + //* check for what would be called 'method hiding' in c++, and also valid overrides. + // TODO: make an error note about co/contra-variance for param/return types. right now it just complains and it's vague af. + auto checkAgainstBaseClasses = [](sst::ClassDefn* cls, sst::FunctionDefn* meth) -> auto { - auto res = m->generateDeclaration(fs, cls, { }); - if(res.isParametric()) - continue; - - auto decl = dcast(sst::FunctionDefn, res.defn()); - iceAssert(decl); + auto checkSingleMethod = [](sst::ClassDefn* cls, sst::FunctionDefn* self, sst::FunctionDefn* bf, bool* matchedName) -> bool { - defn->methods.push_back(decl); + // ok -- issue is that we cannot compare the method signatures directly -- because the method will take the 'self' of its + // respective class, meaning they won't be duplicates. so, we must compare without the first parameter. + auto compareMethodSignatures = [](fir::FunctionType* a, fir::FunctionType* b) -> bool { - //* check for what would be called 'method hiding' in c++ -- ie. methods in the derived class with exactly the same type signature as - //* the base class method. + // well the order is important!! + return fir::ClassType::areMethodsVirtuallyCompatible(a, b); + }; - // TODO: code dupe with the field hiding thing we have above. simplify?? - std::function checkDupe = [&fs](sst::ClassDefn* cls, sst::FunctionDefn* meth) -> auto { - while(cls) - { - for(auto bf : cls->methods) + if(bf->id.name == self->id.name) { - // ok -- issue is that we cannot compare the method signatures directly -- because the method will take the 'self' of its - // respective class, meaning they won't be duplicates. so, we must compare without the first parameter. - auto compareMethodSignatures = [&fs](const std::vector& a, const std::vector& b) -> bool { - return fs->isDuplicateOverload(util::drop(a, 1), util::drop(b, 1)); - }; + *matchedName |= true; + + if(!compareMethodSignatures(bf->type->toFunctionType(), self->type->toFunctionType())) + return false; + + // check for virtual functions. + //* note: we don't need to care if 'bf' is the base method, because if we are 'isOverride', then we are also + //* 'isVirtual'. - if(bf->id.name == meth->id.name && compareMethodSignatures(bf->params, meth->params)) + // nice comprehensive error messages, I hope. + if(!self->isOverride) { - // check for virtual functions. - //* note: we don't need to care if 'bf' is the base method, because if we are 'isOverride', then we are also - //* 'isVirtual'. + auto err = SimpleError::make(self->loc, "redefinition of method '%s' (with type '%s'), that exists in" + " the base class '%s'", self->id.name, self->type, cls->id); - // nice comprehensive error messages, I hope. - if(!meth->isOverride) + if(bf->isVirtual) { - auto err = SimpleError::make(meth->loc, "redefinition of method '%s' (with type '%s'), that exists in the base class '%s'", - meth->id.name, meth->type, cls->id); - - if(bf->isVirtual) - { - err->append(SimpleError::make(MsgType::Note, bf->loc, "'%s' was defined as a virtual method; to override it, use the 'override' keyword", bf->id.name)); - } - else - { - err->append( - SimpleError::make(MsgType::Note, bf->loc, - "'%s' was previously defined in the base class as a non-virtual method here:", bf->id.name)->append( - BareError::make(MsgType::Note, "to override it, define '%s' as a virtual method", bf->id.name) - ) - ); - } - - err->postAndQuit(); + err->append(SimpleError::make(MsgType::Note, bf->loc, "'%s' was defined as a virtual method; to override it, use the 'override' keyword", bf->id.name)); } - else if(!bf->isVirtual) + else { - SimpleError::make(meth->loc, "cannot override non-virtual method '%s'", bf->id.name) - ->append(SimpleError::make(MsgType::Note, bf->loc, - "'%s' was previously defined in the base class as a non-virtual method here:", bf->id.name) - )->append(BareError::make(MsgType::Note, "to override it, define '%s' as a virtual method", bf->id.name)) - ->postAndQuit(); + err->append( + SimpleError::make(MsgType::Note, bf->loc, "'%s' was previously defined in the base class '%s'" + " as a non-virtual method here:", bf->id.name, cls->id.name + )->append(BareError::make(MsgType::Note, "to override it, define '%s' as a virtual method", + bf->id.name) + ) + ); } + + err->postAndQuit(); } + else if(!bf->isVirtual) + { + SimpleError::make(self->loc, "cannot override non-virtual method '%s'", bf->id.name) + ->append(SimpleError::make(MsgType::Note, bf->loc, + "'%s' was previously defined in the base class '%s' as a non-virtual method here:", bf->id.name, cls->id.name) + )->append(BareError::make(MsgType::Note, "to override it, define '%s' as a virtual method", bf->id.name)) + ->postAndQuit(); + } + + return true; } + return false; + }; + + bool matchedSig = false; + bool matchedName = false; + while(cls) + { + for(auto bf : cls->methods) + matchedSig |= checkSingleMethod(cls, meth, bf, &matchedName); + cls = cls->baseClass; } + + if(meth->isOverride && !matchedSig) + { + if(matchedName && !matchedSig) + { + error(meth, "invalid override: no method named '%s' in any base class with a signature matching" + " (or compatible with) '%s'", meth->id.name, meth->type->str()); + } + else if(!matchedName) + { + error(meth, "invalid override: no method in any base class named '%s'", meth->id.name); + } + } }; - checkDupe(defn->baseClass, decl); + for(auto m : this->methods) + { + if(m->name == "init") + error(m, "cannot have methods named 'init' in a class; to create an initialiser, omit the 'fn' keyword."); + + auto res = m->generateDeclaration(fs, cls, { }); + if(res.isParametric()) + continue; + + auto decl = dcast(sst::FunctionDefn, res.defn()); + iceAssert(decl); + + defn->methods.push_back(decl); + + checkAgainstBaseClasses(defn->baseClass, decl); + } } + { // make the constructors for(auto it : this->initialisers) diff --git a/source/typecheck/resolver/resolver.cpp b/source/typecheck/resolver/resolver.cpp index 948c8eb9..61d32c97 100644 --- a/source/typecheck/resolver/resolver.cpp +++ b/source/typecheck/resolver/resolver.cpp @@ -336,7 +336,7 @@ namespace resolver //* here we're just checking that 'ty' and 'self' are part of the same class hierarchy -- we don't really care about the method //* that we resolve being at the lowest or highest level of that hierarchy. - if(!ty->isInParentHierarchy(self) && !self->isInParentHierarchy(ty)) + if(!ty->hasParent(self) && !self->hasParent(ty)) { virt = false; break; diff --git a/source/typecheck/typecheckstate.cpp b/source/typecheck/typecheckstate.cpp index 5de57804..5de0920b 100644 --- a/source/typecheck/typecheckstate.cpp +++ b/source/typecheck/typecheckstate.cpp @@ -468,7 +468,8 @@ namespace sst if(fir::Type::areTypeListsEqual(util::map(a->params, [](const auto& p) -> fir::Type* { return p.type; }), util::map(b->params, [](const auto& p) -> fir::Type* { return p.type; }))) { - errs->append(BareError::make(MsgType::Note, "functions cannot be overloaded based on argument names alone")); + errs->append(BareError::make(MsgType::Note, "functions cannot be overloaded over argument names or" + " return types alone")); } }