From e3fec071ce420361839cfc559db69c032b39a57f Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sat, 9 Nov 2024 14:56:46 -0500 Subject: [PATCH] generic structs --- luisa_lang/_builtin_decor.py | 103 ++++++++++++++++++++++--------- luisa_lang/codegen/cpp.py | 40 +++++++++--- luisa_lang/hir.py | 110 +++++++++++++++++++++++++++------ luisa_lang/lang_builtins.py | 49 ++++++++------- luisa_lang/parse.py | 115 +++++++++++++++++++++++------------ 5 files changed, 303 insertions(+), 114 deletions(-) diff --git a/luisa_lang/_builtin_decor.py b/luisa_lang/_builtin_decor.py index b8a75fa..622bcf8 100644 --- a/luisa_lang/_builtin_decor.py +++ b/luisa_lang/_builtin_decor.py @@ -135,7 +135,7 @@ def make_builtin(): def builtin(s: str) -> Callable[[_F], _F]: - def wrapper(func: _F) -> _F: + def wrapper(func: _F) -> _F: setattr(func, "__luisa_builtin__", s) return func return wrapper @@ -148,7 +148,6 @@ def _intrinsic_impl(*args, **kwargs) -> Any: ) - class _ObjKind(Enum): BUILTIN_TYPE = auto() STRUCT = auto() @@ -156,7 +155,7 @@ class _ObjKind(Enum): KERNEL = auto() -def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dict[str, Any], self_type: Optional[hir.Type] = None): +def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue], self_type: Optional[hir.Type] = None): # parsing_ctx = _parse.ParsingContext(func_name, func_globals) # func_sig_parser = _parse.FuncParser(func_name, f, parsing_ctx, self_type) # func_sig = func_sig_parser.parsed_func @@ -165,7 +164,7 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dic func_sig = classinfo.parse_func_signature(f, func_globals, []) func_sig_converted, sig_parser = parse.convert_func_signature( - func_sig, func_name, func_globals, {}, {}, self_type) + func_sig, func_name, func_globals, foreign_type_var_ns, {}, self_type) implicit_type_params = sig_parser.implicit_type_params implicit_generic_params: Set[hir.GenericParameter] = set() for p in implicit_type_params.values(): @@ -173,7 +172,8 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dic implicit_generic_params.add(p.param) def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike: - type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue] = {} + type_var_ns: Dict[TypeVar, hir.Type | + hir.ComptimeValue] = foreign_type_var_ns.copy() mapped_implicit_type_params: Dict[str, hir.Type] = dict() if is_generic: @@ -206,10 +206,14 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike: return func_parser.parse_body() params = [v[0] for v in func_sig.args] is_generic = len(func_sig_converted.generic_params) > 0 - # print(f"func {func_name} is_generic: {is_generic}") + # print( + # f"func {func_name} is_generic: {is_generic} {func_sig_converted.generic_params}") return hir.FunctionTemplate(func_name, params, parsing_func, is_generic) + + _TT = TypeVar('_TT') + def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT: import sourceinspect assert inspect.isfunction(f), f"{f} is not a function" @@ -220,7 +224,7 @@ def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT: func_globals: Any = getattr(f, "__globals__", {}) if kind == _ObjKind.FUNC: - template = _make_func_template(f, func_name, func_globals) + template = _make_func_template(f, func_name, func_globals, {}) ctx.functions[f] = template setattr(f, "__luisa_func__", template) return typing.cast(_TT, f) @@ -236,32 +240,73 @@ def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any]) -> type[_TT]: cls_info = class_typeinfo(cls) globalns = _get_cls_globalns(cls) globalns[cls.__name__] = cls - - def get_ir_type(var_ty: VarType) -> hir.Type: - if isinstance(var_ty, (UnionType, classinfo.AnyType, classinfo.SelfType)): - raise RuntimeError("Struct fields cannot be UnionType") - if isinstance(var_ty, TypeVar): - raise NotImplementedError() - if isinstance(var_ty, GenericInstance): + type_var_to_generic_param: Dict[TypeVar, hir.GenericParameter] = {} + for type_var in cls_info.type_vars: + type_var_to_generic_param[type_var] = hir.GenericParameter( + type_var.__name__, cls.__qualname__) + + def parse_fields(tp: parse.TypeParser, self_ty: hir.Type): + fields: List[Tuple[str, hir.Type]] = [] + for name, field in cls_info.fields.items(): + field_ty = tp.parse_type(field) + if field_ty is None: + raise hir.TypeInferenceError( + None, f"Cannot infer type for field {name} of {cls.__name__}") + fields.append((name, field_ty)) + if isinstance(self_ty, hir.StructType): + self_ty.fields = fields + elif isinstance(self_ty, hir.BoundType): + assert isinstance(self_ty.instantiated, hir.StructType) + self_ty.instantiated.fields = fields + else: raise NotImplementedError() - return ctx.types[var_ty] - fields: List[Tuple[str, hir.Type]] = [] - for name, field in cls_info.fields.items(): - fields.append((name, get_ir_type(field))) - ir_ty = hir.StructType( - f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, fields) + def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type): + for name in cls_info.methods: + method_object = getattr(cls, name) + template = _make_func_template( + method_object, get_full_name(method_object), globalns, type_var_ns, self_type=self_ty) + if isinstance(self_ty, hir.BoundType): + assert isinstance(self_ty.instantiated, hir.StructType) + self_ty.instantiated.methods[name] = template + else: + self_ty.methods[name] = template + + ir_ty: hir.Type = hir.StructType( + f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, []) + type_parser = parse.TypeParser( + cls.__qualname__, globalns, {}, ir_ty, 'parse') + + parse_fields(type_parser, ir_ty) + is_generic = len(cls_info.type_vars) > 0 + if is_generic: + def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type: + assert isinstance(ir_ty, hir.ParametricType) + type_var_ns = {} + if len(args) != len(cls_info.type_vars): + raise hir.TypeInferenceError( + None, f"Expected {len(cls_info.type_vars)} type arguments but got {len(args)}") + for i, arg in enumerate(args): + type_var_ns[cls_info.type_vars[i]] = arg + hash_s = unique_hash(f'{cls.__qualname__}_{args}') + inner_ty = hir.StructType( + f'{cls.__name__}_{hash_s}M', f'{cls.__qualname__}[{",".join([str(a) for a in args])}]', []) + mono_self_ty = hir.BoundType(ir_ty, args, inner_ty) + mono_type_parser = parse.TypeParser( + cls.__qualname__, globalns, type_var_ns, mono_self_ty, 'instantiate') + parse_fields(mono_type_parser, mono_self_ty) + parse_methods(type_var_ns, mono_self_ty) + return inner_ty + ir_ty = hir.ParametricType( + list(type_var_to_generic_param.values()), ir_ty, monomorphization_func) + else: + pass ctx.types[cls] = ir_ty - - for name, method in cls_info.methods.items(): - method_object = getattr(cls, name) - template = _make_func_template( - method_object, get_full_name(method_object), globalns, self_type=ir_ty) - ir_ty.methods[name] = template + if not is_generic: + parse_methods({},ir_ty) return cls - def _dsl_decorator_impl(obj: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT: if kind == _ObjKind.STRUCT: assert isinstance(obj, type), f"{obj} is not a type" @@ -288,8 +333,10 @@ def volume(self) -> float: """ return _dsl_decorator_impl(cls, _ObjKind.STRUCT, {}) + _KernelType = TypeVar("_KernelType", bound=Callable[..., None]) + @overload def kernel(f: _KernelType) -> _KernelType: ... @@ -353,4 +400,4 @@ def impl(f: _F) -> _F: def decorator(f): return impl(f) - return decorator \ No newline at end of file + return decorator diff --git a/luisa_lang/codegen/cpp.py b/luisa_lang/codegen/cpp.py index 0c3b0b7..e8ed251 100644 --- a/luisa_lang/codegen/cpp.py +++ b/luisa_lang/codegen/cpp.py @@ -2,7 +2,7 @@ from luisa_lang import hir from luisa_lang.utils import unique_hash, unwrap from luisa_lang.codegen import CodeGen, ScratchBuffer -from typing import Any, Callable, Dict, Set, Tuple, Union +from typing import Any, Callable, Dict, Optional, Set, Tuple, Union from luisa_lang.hir import get_dsl_func @@ -54,6 +54,13 @@ def do(): self.impl.writeln('};') return name return do() + case hir.BoundType(): + assert ty.instantiated + return self.gen(ty.instantiated) + case hir.FunctionType(): + return '' + case hir.TypeConstructorType(): + return '' case _: raise NotImplementedError(f"unsupported type: {ty}") @@ -144,6 +151,9 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str: case hir.TupleType(): elements = [self.mangle(e) for e in obj.elements] return f"T{unique_hash(''.join(elements))}" + case hir.BoundType(): + assert obj.instantiated + return self.mangle(obj.instantiated) case _: raise NotImplementedError(f"unsupported object: {obj}") @@ -300,6 +310,8 @@ def impl() -> None: ty = self.base.type_cache.gen(expr.type) self.body.writeln( f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};") + case hir.TypeValue(): + pass case _: raise NotImplementedError( f"unsupported expression: {expr}") @@ -308,7 +320,7 @@ def impl() -> None: self.node_map[expr] = f'v{vid}' return f'v{vid}' - def gen_node(self, node: hir.Node): + def gen_node(self, node: hir.Node) -> Optional[hir.BasicBlock]: match node: case hir.Return() as ret: @@ -331,7 +343,7 @@ def gen_node(self, node: hir.Node): self.body.indent += 1 self.gen_bb(if_stmt.else_body) self.body.indent -= 1 - self.gen_bb(if_stmt.merge) + return if_stmt.merge case hir.Break(): self.body.writeln("__loop_break = true; break;") case hir.Continue(): @@ -349,7 +361,7 @@ def gen_node(self, node: hir.Node): if (loop_break) break; update(); } - + """ self.body.writeln("while(true) {") self.body.indent += 1 @@ -368,6 +380,7 @@ def gen_node(self, node: hir.Node): self.gen_bb(loop.update) self.body.indent -= 1 self.body.writeln("}") + return loop.merge case hir.Alloca() as alloca: vid = self.new_vid() assert alloca.type @@ -378,12 +391,23 @@ def gen_node(self, node: hir.Node): self.gen_expr(node) case hir.Member() | hir.Index(): pass + return None def gen_bb(self, bb: hir.BasicBlock): - self.body.writeln(f"{{ // BasicBlock Begin {bb.span}") - for node in bb.nodes: - self.gen_node(node) - self.body.writeln(f"}} // BasicBlock End {bb.span}") + + while True: + loop_again = False + old_bb = bb + self.body.writeln(f"// BasicBlock Begin {bb.span}") + for i, node in enumerate(bb.nodes): + if (next := self.gen_node(node)) and next is not None: + assert i == len(bb.nodes) - 1 + loop_again = True + bb = next + break + self.body.writeln(f"// BasicBlock End {old_bb.span}") + if not loop_again: + break def gen_locals(self): for local in self.func.locals: diff --git a/luisa_lang/hir.py b/luisa_lang/hir.py index ee4e834..96c69e8 100644 --- a/luisa_lang/hir.py +++ b/luisa_lang/hir.py @@ -31,7 +31,7 @@ FunctionTemplateResolvingArgs = List[Tuple[str, - Union['Type', 'ComptimeValue']]] + Union['Type', Any]]] """ [Function parameter name, Type or Value]. The reason for using parameter name instead of GenericParameter is that python supports passing type[T] as a parameter, @@ -52,7 +52,7 @@ class FunctionTemplate: """ parsing_func: FunctionTemplateResolvingFunc __resolved: Dict[Tuple[Tuple[str, - Union['Type', 'ComptimeValue']], ...], FunctionLike] + Union['Type', Any]], ...], FunctionLike] is_generic: bool name: str params: List[str] @@ -451,10 +451,16 @@ def __init__(self, name: str, display_name: str, fields: List[Tuple[str, Type] self._fields = fields self.display_name = display_name self._field_dict = {name: ty for name, ty in fields} + @property def fields(self) -> List[Tuple[str, Type]]: return self._fields + + @fields.setter + def fields(self, value: List[Tuple[str, Type]]) -> None: + self._fields = value + self._field_dict = {name: ty for name, ty in value} def size(self) -> int: return sum(field.size() for _, field in self.fields) @@ -604,6 +610,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"SymbolicType({self.param})" +MonomorphizationFunc = Callable[[List[Type | Any]], Type] class ParametricType(Type): """ @@ -611,11 +618,27 @@ class ParametricType(Type): """ params: List[GenericParameter] body: Type + monomorphification_cache: Dict[Tuple[Union['Type', Any], ...], 'Type'] + monomorphification_func: Optional[MonomorphizationFunc] - def __init__(self, params: List[GenericParameter], body: Type) -> None: + def __init__(self, params: List[GenericParameter], + body: Type, + monomorphification_func: MonomorphizationFunc | None = None) -> None: super().__init__() self.params = params self.body = body + self.monomorphification_func = monomorphification_func + self.monomorphification_cache = {} + + def instantiate(self, args: List[Union[Type, Any]]) -> 'Type': + keys = tuple(args) + if keys in self.monomorphification_cache: + return self.monomorphification_cache[keys] + if self.monomorphification_func is not None: + ty = self.monomorphification_func(args) + self.monomorphification_cache[keys] = ty + return ty + raise RuntimeError("monomorphification_func is not set") def size(self) -> int: raise RuntimeError("ParametricType has no size") @@ -627,10 +650,11 @@ def __eq__(self, value: object) -> bool: return ( isinstance(value, ParametricType) and value.params == self.params + and value.body == self.body ) def __hash__(self) -> int: - return hash((ParametricType, tuple(self.params))) + return hash((ParametricType, tuple(self.params), self.body)) class BoundType(Type): @@ -638,11 +662,14 @@ class BoundType(Type): An instance of a parametric type, e.g. Foo[int] """ generic: ParametricType - args: List[Union[Type, 'SymbolicConstant']] + args: List[Union[Type, Any]] + instantiated: Optional[Type] - def __init__(self, generic: ParametricType, args: List[Union[Type, 'SymbolicConstant']]) -> None: + def __init__(self, generic: ParametricType, args: List[Union[Type, Any]], instantiated: Optional[Type]=None) -> None: + super().__init__() self.generic = generic self.args = args + self.instantiated = instantiated def size(self) -> int: raise RuntimeError("don't call size on BoundedType") @@ -657,6 +684,41 @@ def __eq__(self, value: object) -> bool: and value.args == self.args ) + def __hash__(self): + return hash((BoundType, self.generic, tuple(self.args))) + + @override + def member(self, field) -> Optional['Type']: + if self.instantiated is not None: + return self.instantiated.member(field) + else: + raise RuntimeError("member access on uninstantiated BoundType") + + @override + def method(self, name) -> Optional[FunctionLike | FunctionTemplate]: + if self.instantiated is not None: + return self.instantiated.method(name) + else: + raise RuntimeError("method access on uninstantiated BoundType") + +class TypeConstructorType(Type): + inner: Type + + def __init__(self, inner: Type) -> None: + super().__init__() + self.inner = inner + + def size(self) -> int: + raise RuntimeError("TypeConstructorType has no size") + + def align(self) -> int: + raise RuntimeError("TypeConstructorType has no align") + + def __eq__(self, value: object) -> bool: + return isinstance(value, TypeConstructorType) and value.inner == self.inner + + def __hash__(self) -> int: + return hash((TypeConstructorType, self.inner)) class FunctionType(Type): func_like: FunctionLike | FunctionTemplate @@ -841,9 +903,9 @@ def __hash__(self) -> int: return hash(self.value) -class Ctor(Value): +class TypeValue(Value): def __init__(self, ty: Type, span: Optional[Span] = None) -> None: - super().__init__(ty, span) + super().__init__(TypeConstructorType(ty), span) class Alloca(Ref): @@ -1129,7 +1191,7 @@ def unify(a: Type | ComptimeValue, b: Type | ComptimeValue): return unify(mapping[a.param], b) if isinstance(b, GenericFloatType) or isinstance(b, GenericIntType): raise TypeInferenceError(None, - "float/int literal cannot be used to infer generic type directly, wrap it with a concrete type") + f"float/int literal cannot be used to infer generic type for `{a.param.name}` directly, wrap it with a concrete type") mapping[a.param] = b return case VectorType(): @@ -1155,14 +1217,16 @@ def unify(a: Type | ComptimeValue, b: Type | ComptimeValue): None, f"expected {a}, got {b}") unify(a.element, b.element) case TupleType(): - if not isinstance(b, TupleType): - raise TypeInferenceError( - None, f"expected {a}, got {b}") - if len(a.elements) != len(b.elements): - raise TypeInferenceError( - None, f"expected {a}, got {b}") - for ea, eb in zip(a.elements, b.elements): - unify(ea, eb) + def do() -> None: + if not isinstance(b, TupleType): + raise TypeInferenceError( + None, f"expected {a}, got {b}") + if len(a.elements) != len(b.elements): + raise TypeInferenceError( + None, f"expected {a}, got {b}") + for ea, eb in zip(a.elements, b.elements): + unify(ea, eb) + do() case StructType(): raise RuntimeError( "StructType should not appear in match_template_args") @@ -1178,7 +1242,17 @@ def unify(a: Type | ComptimeValue, b: Type | ComptimeValue): # None, f"field name mismatch,expected {a}, got {b}") # unify(ta, tb) case BoundType(): - raise NotImplementedError() + def do() -> None: + if not isinstance(b, BoundType): + raise TypeInferenceError( + None, f"{b} is not a BoundType") + if len(a.args) != len(b.args): + raise TypeInferenceError( + None, f"expected {len(a.args)} arguments, got {len(b.args)}") + for ea, eb in zip(a.args, b.args): + unify(ea, eb) + unify(a.generic.body, b.generic.body) + do() case ParametricType(): raise RuntimeError( "ParametricType should not appear in match_template_args") diff --git a/luisa_lang/lang_builtins.py b/luisa_lang/lang_builtins.py index e3b4093..33dced4 100644 --- a/luisa_lang/lang_builtins.py +++ b/luisa_lang/lang_builtins.py @@ -22,8 +22,8 @@ from luisa_lang._builtin_decor import builtin, builtin_type, _intrinsic_impl from luisa_lang import parse -_T = TypeVar("_T") -_N = TypeVar("_N", int, u32, u64) +T = TypeVar("T") +N = TypeVar("N", int, u32, u64) @builtin("dispatch_id") @@ -42,7 +42,7 @@ def block_id() -> uint3: @builtin("cast") -def cast(target: type[_T], value: Any) -> _T: +def cast(target: type[T], value: Any) -> T: """ Attempt to convert the value to the target type. """ @@ -50,7 +50,7 @@ def cast(target: type[_T], value: Any) -> _T: @builtin("bitcast") -def bitcast(target: type[_T], value: Any) -> _T: +def bitcast(target: type[T], value: Any) -> T: return _intrinsic_impl() @@ -88,7 +88,7 @@ def comptime(src: str) -> Any: ... @overload -def comptime(a: _T) -> _T: ... +def comptime(a: T) -> T: ... def comptime(a): @@ -115,7 +115,7 @@ def unroll(range_: Sequence[int]) -> Sequence[int]: @builtin("address_of") -def address_of(a: _T) -> 'Pointer[_T]': +def address_of(a: T) -> 'Pointer[T]': return _intrinsic_impl() # class StaticEval: @@ -147,54 +147,59 @@ def typeof(value: Any) -> hir.Type: # "Array", [hir.TypeParameter(_t, bound=[])], hir.ArrayType(_t, _n) # ) # ) -class Array(Generic[_T, _N]): +class Array(Generic[T, N]): def __init__(self) -> None: return _intrinsic_impl() - def __getitem__(self, index: int | u32 | u64) -> _T: + def __getitem__(self, index: int | u32 | u64) -> T: return _intrinsic_impl() - def __setitem__(self, index: int | u32 | u64, value: _T) -> None: + def __setitem__(self, index: int | u32 | u64, value: T) -> None: return _intrinsic_impl() def __len__(self) -> u32 | u64: return _intrinsic_impl() +def __buffer_ty(): + t = hir.GenericParameter("T", "luisa_lang.lang") + return hir.ParametricType( + [t], hir.OpaqueType("Buffer"), None + ) -# @_builtin_type( -# hir.ParametricType( -# "Buffer", [hir.TypeParameter(_t, bound=[])], hir.OpaqueType("Buffer") -# ) +# @builtin_type( +# # hir.ParametricType( +# # "Buffer", [hir.TypeParameter(_t, bound=[])], hir.OpaqueType("Buffer") +# # ) # ) -class Buffer(Generic[_T]): - def __getitem__(self, index: int | u32 | u64) -> _T: +class Buffer(Generic[T]): + def __getitem__(self, index: int | u32 | u64) -> T: return _intrinsic_impl() - def __setitem__(self, index: int | u32 | u64, value: _T) -> None: + def __setitem__(self, index: int | u32 | u64, value: T) -> None: return _intrinsic_impl() def __len__(self) -> u32 | u64: return _intrinsic_impl() -# @_builtin_type( +# @builtin_type( # hir.ParametricType( # "Pointer", [hir.TypeParameter(_t, bound=[])], hir.PointerType(_t) # ) # ) -class Pointer(Generic[_T]): - def __getitem__(self, index: int | i32 | i64 | u32 | u64) -> _T: +class Pointer(Generic[T]): + def __getitem__(self, index: int | i32 | i64 | u32 | u64) -> T: return _intrinsic_impl() - def __setitem__(self, index: int | i32 | i64 | u32 | u64, value: _T) -> None: + def __setitem__(self, index: int | i32 | i64 | u32 | u64, value: T) -> None: return _intrinsic_impl() @property - def value(self) -> _T: + def value(self) -> T: return _intrinsic_impl() @value.setter - def value(self, value: _T) -> None: + def value(self, value: T) -> None: return _intrinsic_impl() diff --git a/luisa_lang/parse.py b/luisa_lang/parse.py index 2a6fcdb..f276cc6 100644 --- a/luisa_lang/parse.py +++ b/luisa_lang/parse.py @@ -41,6 +41,9 @@ def is_valid_comptime_value_in_dsl_code(value: Any) -> bool: return False +ParsingMode = Literal['parse', 'instantiate'] + + class TypeParser: ctx_name: str globalns: Dict[str, Any] @@ -49,19 +52,40 @@ class TypeParser: generic_params: List[hir.GenericParameter] generic_param_to_type_var: Dict[hir.GenericParameter, typing.TypeVar] implicit_type_params: Dict[str, hir.Type] + mode: ParsingMode - def __init__(self, ctx_name: str, globalns: Dict[str, Any], type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue], self_type: Optional[Type] = None) -> None: + def __init__(self, ctx_name: str, globalns: Dict[str, Any], type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue], self_type: Optional[Type], mode: ParsingMode) -> None: self.globalns = globalns self.self_type = self_type self.type_var_ns = type_var_ns self.ctx_name = ctx_name self.generic_params = [] self.generic_param_to_type_var = {} + self.mode = mode def parse_type(self, ty: classinfo.VarType) -> Optional[hir.Type]: match ty: case classinfo.GenericInstance(): - raise NotImplementedError() + origin = ty.origin + ir_ty = self.parse_type(origin) + if not ir_ty: + raise RuntimeError( + f"Type {origin} is not a valid DSL type") + if not isinstance(ir_ty, hir.ParametricType): + raise RuntimeError( + f"Type {origin} is not a parametric type but is supplied with type arguments") + if len(ir_ty.params) != len(ty.args): + raise RuntimeError( + f"Type {origin} expects {len(ir_ty.params)} type arguments, got {len(ty.args)}") + type_args = [self.parse_type(arg) for arg in ty.args] + if any(arg is None for arg in type_args): + raise RuntimeError( + "failed to parse type arguments") + if self.mode == 'instantiate': + instantiated = ir_ty.instantiate(type_args) + else: + instantiated = None + return hir.BoundType(ir_ty, cast(List[hir.Type | hir.SymbolicConstant], type_args), instantiated) case classinfo.TypeVar(): # print(f'{ty} @ {id(ty)} {ty.__name__} in {self.type_var_ns}? : {ty in self.type_var_ns}') if ty in self.type_var_ns: @@ -96,13 +120,13 @@ def convert_func_signature(signature: classinfo.MethodType, type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue], implicit_type_params: Dict[str, hir.Type], self_type: Optional[Type], - mode: Literal['parse', 'instantiate'] = 'parse' + mode: ParsingMode = 'parse' ) -> Tuple[hir.FunctionSignature, TypeParser]: """ implicit_type_params: Tuple[List[Tuple[str, classinfo.VarType]], classinfo.VarType] """ - type_parser = TypeParser(ctx_name, globalns, type_var_ns, self_type) + type_parser = TypeParser(ctx_name, globalns, type_var_ns, self_type, mode) type_parser.implicit_type_params = implicit_type_params params: List[Var] = [] for arg in signature.args: @@ -160,7 +184,8 @@ def __init__(self, name: str, globalns: Dict[str, Any], type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue], self_type: Optional[Type]) -> None: - self.type_parser = TypeParser(name, globalns, type_var_ns, self_type) + self.type_parser = TypeParser( + name, globalns, type_var_ns, self_type, 'instantiate') self.name = name self.func = func self.signature = signature @@ -191,8 +216,8 @@ def parse_type(self, ty: classinfo.VarType) -> Optional[hir.Type]: raise RuntimeError(f"Type {t} is not resolved") return t - def convert_constexpr(self, value: ComptimeValue, span: Optional[hir.Span] = None) -> Optional[hir.Value]: - value = value.value + def convert_constexpr(self, comptime_val: ComptimeValue, span: Optional[hir.Span] = None) -> Optional[hir.Value]: + value = comptime_val.value if isinstance(value, int): return hir.Constant(value, type=hir.GenericIntType()) elif isinstance(value, float): @@ -216,7 +241,8 @@ def convert_constexpr(self, value: ComptimeValue, span: Optional[hir.Span] = Non if dsl_type is None: raise hir.ParsingError( span, f"expected DSL type but got {value}") - return hir.Ctor(dsl_type) + + return hir.TypeValue(dsl_type) return None def parse_const(self, const: ast.Constant) -> hir.Value: @@ -236,6 +262,12 @@ def parse_const(self, const: ast.Constant) -> hir.Value: const, f"unsupported constant type {type(value)}, wrap it in lc.comptime(...) if you intead to use it as a compile-time expression") def convert_any_to_value(self, a: Any, span: hir.Span | None) -> hir.Value | ComptimeValue: + if isinstance(a, typing.TypeVar): + if a in self.type_var_ns: + v = self.type_var_ns[a] + if isinstance(v, hir.Type): + return hir.TypeValue(v) + return self.convert_any_to_value(v, span) if not isinstance(a, ComptimeValue): a = ComptimeValue(a, None) if a.value in SPECIAL_FUNCTIONS: @@ -341,6 +373,27 @@ def parse_access(self, expr: ast.Subscript | ast.Attribute) -> hir.Value | Compt if isinstance(value, ComptimeValue): raise hir.ParsingError( expr, "attempt to access comptime value in DSL code; wrap it in lc.comptime(...) if you intead to use it as a compile-time expression") + if isinstance(value, hir.TypeValue): + type_args: List[hir.Type] = [] + + def parse_type_arg(expr: ast.expr) -> hir.Type: + type_annotation = self.eval_expr(expr) + type_hint = classinfo.parse_type_hint(type_annotation) + ty = self.parse_type(type_hint) + assert ty + return ty + + match expr.slice: + case ast.Tuple(): + for e in expr.slice.elts: + type_args.append(parse_type_arg(e)) + case _: + type_args.append(parse_type_arg(expr.slice)) + # print(f"Type args: {type_args}") + assert isinstance(value.type, hir.TypeConstructorType) and isinstance(value.type.inner, hir.ParametricType) + return hir.TypeValue( + hir.BoundType(value.type.inner, type_args, value.type.inner.instantiate(type_args))) + assert value.type index = self.parse_expr(expr.slice) index = self.convert_to_value(index, span) @@ -493,7 +546,7 @@ def make_int(i: int) -> hir.Value: def parse_call(self, expr: ast.Call) -> hir.Value | ComptimeValue: func = self.parse_expr(expr.func) - + span = hir.Span.from_ast(expr) if isinstance(func, hir.Ref): raise hir.ParsingError(expr, f"function expected") elif isinstance(func, ComptimeValue): @@ -510,14 +563,23 @@ def collect_args() -> List[hir.Value | hir.Ref]: arg, hir.Span.from_ast(expr.args[i])) return cast(List[hir.Value | hir.Ref], args) - if isinstance(func, hir.Ctor): - cls = func.type + if isinstance(func.type, hir.TypeConstructorType): + # TypeConstructorType is unique for each type + # so if any value has this type, it must be referring to the same underlying type + # even if it comes from a very complex expression, it's still fine + cls = func.type.inner assert cls + if isinstance(cls, hir.ParametricType): + raise hir.ParsingError( + span, f"please provide type arguments for {cls.body}") + init = cls.method("__init__") - tmp = self.cur_bb().append(hir.Alloca(cls, span=hir.Span.from_ast(expr))) - assert init is not None + tmp = self.cur_bb().append(hir.Alloca(cls, span)) + if init is None: + raise hir.ParsingError( + span, f"__init__ method not found for type {cls}") call = self.parse_call_impl( - hir.Span.from_ast(expr), init, [tmp]+collect_args()) + span, init, [tmp]+collect_args()) if isinstance(call, hir.TemplateMatchingError): raise hir.ParsingError(expr, call.message) assert isinstance(call, hir.Call) @@ -529,7 +591,7 @@ def collect_args() -> List[hir.Value | hir.Ref]: else: func_like = func.value ret = self.parse_call_impl( - hir.Span.from_ast(expr), func_like, collect_args()) + span, func_like, collect_args()) if isinstance(ret, hir.TemplateMatchingError): raise hir.ParsingError(expr, ret.message) return ret @@ -956,29 +1018,6 @@ def check_return_type(ty: hir.Type) -> None: check_return_type(hir.UnitType()) self.cur_bb().append(hir.Return(None)) case ast.Assign(): - # if len(stmt.targets) != 1: - # raise hir.ParsingError(stmt, f"expected single target") - # target = stmt.targets[0] - # var = self.parse_ref(target, new_var_hint='dsl') - # value = self.parse_expr(stmt.value) - # if isinstance(var, ComptimeValue): - # if not isinstance(value, ComptimeValue): - # raise hir.ParsingError( - # stmt, f"comptime value cannot be assigned with DSL value") - # var.update(value.value) - # return None - # value = self.convert_to_value(value, span) - # assert value.type - # if var.type: - # if not hir.is_type_compatible_to(value.type, var.type): - # raise hir.ParsingError( - # stmt, f"expected {var.type}, got {value.type}") - # else: - # if not value.type.is_concrete(): - # raise hir.ParsingError( - # stmt, "only concrete type can be assigned, please annotate the variable with type hint") - # var.type = value.type - # self.cur_bb().append(hir.Assign(var, value, span)) assert len(stmt.targets) == 1 target = stmt.targets[0] if isinstance(target, ast.Tuple):